1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package chaosdaemon
17
18 import (
19 "context"
20 "encoding/json"
21 "fmt"
22 "net"
23 "strings"
24
25 "github.com/go-logr/logr"
26 "github.com/golang/protobuf/ptypes/empty"
27 "github.com/pkg/errors"
28 "google.golang.org/grpc/codes"
29 "google.golang.org/grpc/status"
30
31 "github.com/chaos-mesh/chaos-mesh/pkg/bpm"
32 pb "github.com/chaos-mesh/chaos-mesh/pkg/chaosdaemon/pb"
33 "github.com/chaos-mesh/chaos-mesh/pkg/chaosdaemon/util"
34 )
35
36 const (
37 ruleNotExist = "Cannot delete qdisc with handle of zero."
38 ruleNotExistLowerVersion = "RTNETLINK answers: No such file or directory"
39
40 defaultDevice = "eth0"
41 )
42
43 func generateQdiscArgs(action string, qdisc *pb.Qdisc) ([]string, error) {
44 if qdisc == nil {
45 return nil, errors.New("qdisc is required")
46 }
47
48 if qdisc.Type == "" {
49 return nil, errors.New("qdisc.Type is required")
50 }
51
52 args := []string{"qdisc", action, "dev", "eth0"}
53
54 if qdisc.Parent == nil {
55 args = append(args, "root")
56 } else if qdisc.Parent.Major == 1 && qdisc.Parent.Minor == 0 {
57 args = append(args, "root")
58 } else {
59 args = append(args, "parent", fmt.Sprintf("%d:%d", qdisc.Parent.Major, qdisc.Parent.Minor))
60 }
61
62 if qdisc.Handle == nil {
63 args = append(args, "handle", fmt.Sprintf("%d:%d", 1, 0))
64 } else {
65 args = append(args, "handle", fmt.Sprintf("%d:%d", qdisc.Handle.Major, qdisc.Handle.Minor))
66 }
67
68 args = append(args, qdisc.Type)
69
70 if qdisc.Args != nil {
71 args = append(args, qdisc.Args...)
72 }
73
74 return args, nil
75 }
76
77 func getAllInterfaces(ctx context.Context, log logr.Logger, pid uint32, enterNS bool) ([]string, error) {
78 var ifaces []string
79 if enterNS {
80 ipOutput, err := bpm.DefaultProcessBuilder("ip", "-j", "addr", "show").SetNS(pid, bpm.NetNS).SetContext(ctx).Build(ctx).CombinedOutput()
81 if err != nil {
82 return []string{}, err
83 }
84 var data []map[string]interface{}
85
86 err = json.Unmarshal(ipOutput, &data)
87 if err != nil {
88 return []string{}, err
89 }
90 for _, iface := range data {
91 name, ok := iface["ifname"]
92 if !ok {
93 return []string{}, errors.New("fail to read ifname from ip -j addr show")
94 }
95 ifaces = append(ifaces, name.(string))
96 }
97 log.Info("get interfaces from ip command", "ifaces", ifaces)
98 } else {
99 interfaces, err := net.Interfaces()
100 if err != nil {
101 return []string{}, errors.New("fail to read ifname from net.Interfaces()")
102 }
103 for _, iface := range interfaces {
104 ifaces = append(ifaces, iface.Name)
105 }
106 log.Info("get interfaces from net.Interfaces()", "ifaces", ifaces)
107 }
108
109 return ifaces, nil
110 }
111
112 func (s *DaemonServer) SetTcs(ctx context.Context, in *pb.TcsRequest) (*empty.Empty, error) {
113 log := s.getLoggerFromContext(ctx)
114 log.Info("handling tc request", "tcs", in)
115
116 pid, err := s.crClient.GetPidFromContainerID(ctx, in.ContainerId)
117 if err != nil {
118 return nil, status.Errorf(codes.Internal, "get pid from containerID error: %v", err)
119 }
120
121 tcCli := buildTcClient(ctx, log, in.EnterNS, pid)
122
123 ifaces, err := getAllInterfaces(ctx, log, pid, in.EnterNS)
124 if err != nil {
125 log.Error(err, "error while getting interfaces")
126 return nil, err
127 }
128 for _, iface := range ifaces {
129 err = tcCli.flush(iface)
130 if err != nil {
131 log.Error(err, "fail to flush tc rules on device", "device", iface)
132 }
133 }
134 if err != nil {
135 return &empty.Empty{}, err
136 }
137
138 for device, rules := range s.groupRulesAccordingToDevices(in.Tcs) {
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163 globalTc := []*pb.Tc{}
164 filterTc := make(map[string][]*pb.Tc)
165
166 for _, tc := range rules {
167 filter := abstractTcFilter(tc)
168 if len(filter) > 0 {
169 filterTc[filter] = append(filterTc[filter], tc)
170 continue
171 }
172 globalTc = append(globalTc, tc)
173 }
174
175 if len(globalTc) > 0 {
176 if err := s.setGlobalTcs(log, tcCli, globalTc, device); err != nil {
177 log.Error(err, "error while setting global tc")
178 return &empty.Empty{}, err
179 }
180 }
181
182 if len(filterTc) > 0 {
183 iptablesCli := buildIptablesClient(ctx, in.EnterNS, pid)
184 if err := s.setFilterTcs(log, tcCli, iptablesCli, filterTc, device, len(globalTc)); err != nil {
185 log.Error(err, "error while setting filter tc")
186 return &empty.Empty{}, err
187 }
188 }
189 }
190
191 return &empty.Empty{}, nil
192 }
193
194 func (s *DaemonServer) groupRulesAccordingToDevices(tcs []*pb.Tc) map[string][]*pb.Tc {
195 rules := make(map[string][]*pb.Tc)
196 for _, tc := range tcs {
197 if tc.Device == "" {
198 tc.Device = defaultDevice
199 }
200 rules[tc.Device] = append(rules[tc.Device], tc)
201 }
202 return rules
203 }
204
205 func (s *DaemonServer) setGlobalTcs(log logr.Logger, cli tcClient, tcs []*pb.Tc, device string) error {
206 for index, tc := range tcs {
207 parentArg := "root"
208 if index > 0 {
209 parentArg = fmt.Sprintf("parent %d:", index)
210 }
211
212 handleArg := fmt.Sprintf("handle %d:", index+1)
213
214 err := cli.addTc(device, parentArg, handleArg, tc)
215 if err != nil {
216 log.Error(err, "error while adding tc")
217 return err
218 }
219 }
220
221 return nil
222 }
223
224 func (s *DaemonServer) setFilterTcs(
225 log logr.Logger,
226 tcCli tcClient,
227 iptablesCli iptablesClient,
228 filterTc map[string][]*pb.Tc,
229 device string,
230 baseIndex int,
231 ) error {
232 parent := baseIndex
233 band := 3 + len(filterTc)
234 if err := tcCli.addPrio(device, parent, band); err != nil {
235 log.Error(err, "error while adding prio")
236 return err
237 }
238
239 parent++
240 index := 0
241 currentHandler := parent + 3
242
243
244
245
246 chains := []*pb.Chain{}
247 for _, tcs := range filterTc {
248 for i, tc := range tcs {
249 parentArg := fmt.Sprintf("parent %d:%d", parent, index+4)
250 if i > 0 {
251 parentArg = fmt.Sprintf("parent %d:", currentHandler)
252 }
253
254 currentHandler++
255 handleArg := fmt.Sprintf("handle %d:", currentHandler)
256
257 err := tcCli.addTc(device, parentArg, handleArg, tc)
258 if err != nil {
259 log.Error(err, "error while adding tc")
260 return err
261 }
262 }
263
264 ch := &pb.Chain{
265 Name: fmt.Sprintf("TC-TABLES-%d", index),
266 Direction: pb.Chain_OUTPUT,
267 Target: fmt.Sprintf("CLASSIFY --set-class %d:%d", parent, index+4),
268 Device: device,
269 }
270
271 tc := tcs[0]
272 if len(tc.Ipset) > 0 {
273 ch.Ipsets = []string{tc.Ipset}
274 }
275
276 ch.Protocol = tc.Protocol
277 ch.SourcePorts = tc.SourcePort
278 ch.DestinationPorts = tc.EgressPort
279
280 chains = append(chains, ch)
281
282 index++
283 }
284 if err := iptablesCli.setIptablesChains(chains); err != nil {
285 log.Error(err, "error while setting iptables")
286 return err
287 }
288
289 return nil
290 }
291
292 type tcClient struct {
293 ctx context.Context
294 log logr.Logger
295 enterNS bool
296 pid uint32
297 }
298
299 func buildTcClient(ctx context.Context, log logr.Logger, enterNS bool, pid uint32) tcClient {
300 return tcClient{
301 ctx,
302 log,
303 enterNS,
304 pid,
305 }
306 }
307
308 func (c *tcClient) flush(device string) error {
309 processBuilder := bpm.DefaultProcessBuilder("tc", "qdisc", "del", "dev", device, "root").SetContext(c.ctx)
310 if c.enterNS {
311 processBuilder = processBuilder.SetNS(c.pid, bpm.NetNS)
312 }
313 cmd := processBuilder.Build(c.ctx)
314 output, err := cmd.CombinedOutput()
315 if err != nil {
316 if (!strings.Contains(string(output), ruleNotExistLowerVersion)) && (!strings.Contains(string(output), ruleNotExist)) {
317 return util.EncodeOutputToError(output, err)
318 }
319 }
320 return nil
321 }
322
323 func (c *tcClient) addTc(device string, parentArg string, handleArg string, tc *pb.Tc) error {
324 c.log.Info("add tc", "tc", tc)
325
326 if tc.Type == pb.Tc_BANDWIDTH {
327
328 if tc.Tbf == nil {
329 return errors.New("tbf is nil while type is BANDWIDTH")
330 }
331 err := c.addTbf(device, parentArg, handleArg, tc.Tbf)
332 if err != nil {
333 return err
334 }
335
336 } else if tc.Type == pb.Tc_NETEM {
337
338 if tc.Netem == nil {
339 return errors.New("netem is nil while type is NETEM")
340 }
341 err := c.addNetem(device, parentArg, handleArg, tc.Netem)
342 if err != nil {
343 return err
344 }
345
346 } else {
347 return errors.New("unknown tc qdisc type")
348 }
349
350 return nil
351 }
352
353 func (c *tcClient) addPrio(device string, parent int, band int) error {
354 c.log.Info("adding prio", "parent", parent)
355
356 parentArg := "root"
357 if parent > 0 {
358 parentArg = fmt.Sprintf("parent %d:", parent)
359 }
360 args := fmt.Sprintf("qdisc add dev %s %s handle %d: prio bands %d priomap 1 2 2 2 1 2 0 0 1 1 1 1 1 1 1 1", device, parentArg, parent+1, band)
361
362 processBuilder := bpm.DefaultProcessBuilder("tc", strings.Split(args, " ")...).SetContext(c.ctx)
363 if c.enterNS {
364 processBuilder = processBuilder.SetNS(c.pid, bpm.NetNS)
365 }
366 cmd := processBuilder.Build(c.ctx)
367 output, err := cmd.CombinedOutput()
368 if err != nil {
369 return util.EncodeOutputToError(output, err)
370 }
371
372 for index := 1; index <= 3; index++ {
373 args := fmt.Sprintf("qdisc add dev %s parent %d:%d handle %d: sfq", device, parent+1, index, parent+1+index)
374
375 processBuilder := bpm.DefaultProcessBuilder("tc", strings.Split(args, " ")...).SetContext(c.ctx)
376 if c.enterNS {
377 processBuilder = processBuilder.SetNS(c.pid, bpm.NetNS)
378 }
379 cmd := processBuilder.Build(c.ctx)
380 output, err := cmd.CombinedOutput()
381 if err != nil {
382 return util.EncodeOutputToError(output, err)
383 }
384 }
385
386 return nil
387 }
388
389 func (c *tcClient) addNetem(device string, parent string, handle string, netem *pb.Netem) error {
390 c.log.Info("adding netem", "device", device, "parent", parent, "handle", handle)
391
392 args := fmt.Sprintf("qdisc add dev %s %s %s netem %s", device, parent, handle, convertNetemToArgs(netem))
393 processBuilder := bpm.DefaultProcessBuilder("tc", strings.Split(args, " ")...).SetContext(c.ctx)
394 if c.enterNS {
395 processBuilder = processBuilder.SetNS(c.pid, bpm.NetNS)
396 }
397 cmd := processBuilder.Build(c.ctx)
398 output, err := cmd.CombinedOutput()
399 if err != nil {
400 return util.EncodeOutputToError(output, err)
401 }
402 return nil
403 }
404
405 func (c *tcClient) addTbf(device string, parent string, handle string, tbf *pb.Tbf) error {
406 c.log.Info("adding tbf", "device", device, "parent", parent, "handle", handle)
407
408 args := fmt.Sprintf("qdisc add dev %s %s %s tbf %s", device, parent, handle, convertTbfToArgs(tbf))
409 processBuilder := bpm.DefaultProcessBuilder("tc", strings.Split(args, " ")...).SetContext(c.ctx)
410 if c.enterNS {
411 processBuilder = processBuilder.SetNS(c.pid, bpm.NetNS)
412 }
413 cmd := processBuilder.Build(c.ctx)
414 output, err := cmd.CombinedOutput()
415 if err != nil {
416 return util.EncodeOutputToError(output, err)
417 }
418 return nil
419 }
420
421 func convertNetemToArgs(netem *pb.Netem) string {
422 args := ""
423 if netem.Time > 0 {
424 args = fmt.Sprintf("delay %d", netem.Time)
425 if netem.Jitter > 0 {
426 args = fmt.Sprintf("%s %d", args, netem.Jitter)
427
428 if netem.DelayCorr > 0 {
429 args = fmt.Sprintf("%s %f", args, netem.DelayCorr)
430 }
431 }
432
433
434 if netem.Reorder > 0 {
435 args = fmt.Sprintf("%s reorder %f", args, netem.Reorder)
436 if netem.ReorderCorr > 0 {
437 args = fmt.Sprintf("%s %f", args, netem.ReorderCorr)
438 }
439
440 if netem.Gap > 0 {
441 args = fmt.Sprintf("%s gap %d", args, netem.Gap)
442 }
443 }
444 }
445
446 if netem.Limit > 0 {
447 args = fmt.Sprintf("%s limit %d", args, netem.Limit)
448 }
449
450 if netem.Loss > 0 {
451 args = fmt.Sprintf("%s loss %f", args, netem.Loss)
452 if netem.LossCorr > 0 {
453 args = fmt.Sprintf("%s %f", args, netem.LossCorr)
454 }
455 }
456
457 if netem.Duplicate > 0 {
458 args = fmt.Sprintf("%s duplicate %f", args, netem.Duplicate)
459 if netem.DuplicateCorr > 0 {
460 args = fmt.Sprintf("%s %f", args, netem.DuplicateCorr)
461 }
462 }
463
464 if netem.Corrupt > 0 {
465 args = fmt.Sprintf("%s corrupt %f", args, netem.Corrupt)
466 if netem.CorruptCorr > 0 {
467 args = fmt.Sprintf("%s %f", args, netem.CorruptCorr)
468 }
469 }
470
471 if len(netem.Rate) > 0 {
472 args = fmt.Sprintf("%s rate %s", args, netem.Rate)
473 }
474
475 trimedArgs := []string{}
476
477 for _, part := range strings.Split(args, " ") {
478 if len(part) > 0 {
479 trimedArgs = append(trimedArgs, part)
480 }
481 }
482
483 return strings.Join(trimedArgs, " ")
484 }
485
486 func convertTbfToArgs(tbf *pb.Tbf) string {
487 args := fmt.Sprintf("rate %s burst %d", tbf.Rate, tbf.Buffer)
488 if tbf.Limit > 0 {
489 args = fmt.Sprintf("%s limit %d", args, tbf.Limit)
490 }
491 if tbf.PeakRate > 0 {
492 args = fmt.Sprintf("%s peakrate %d mtu %d", args, tbf.PeakRate, tbf.MinBurst)
493 }
494
495 return args
496 }
497
498 func abstractTcFilter(tc *pb.Tc) string {
499 filter := tc.Ipset
500
501 if len(tc.Protocol) > 0 {
502 filter += "-" + tc.Protocol
503 }
504
505 if len(tc.EgressPort) > 0 {
506 filter += "-" + tc.EgressPort
507 }
508
509 if len(tc.SourcePort) > 0 {
510 filter += "-" + tc.EgressPort
511 }
512
513 return filter
514 }
515