1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package chaosdaemon
17
18 import (
19 "context"
20 "encoding/json"
21 "fmt"
22 "strings"
23
24 "github.com/go-logr/logr"
25
26 "google.golang.org/grpc/codes"
27 "google.golang.org/grpc/status"
28
29 "github.com/chaos-mesh/chaos-mesh/pkg/bpm"
30 pb "github.com/chaos-mesh/chaos-mesh/pkg/chaosdaemon/pb"
31
32 "github.com/golang/protobuf/ptypes/empty"
33 )
34
35 const (
36 ruleNotExist = "Cannot delete qdisc with handle of zero."
37 ruleNotExistLowerVersion = "RTNETLINK answers: No such file or directory"
38
39 defaultDevice = "eth0"
40 )
41
42 func generateQdiscArgs(action string, qdisc *pb.Qdisc) ([]string, error) {
43 if qdisc == nil {
44 return nil, fmt.Errorf("qdisc is required")
45 }
46
47 if qdisc.Type == "" {
48 return nil, fmt.Errorf("qdisc.Type is required")
49 }
50
51 args := []string{"qdisc", action, "dev", "eth0"}
52
53 if qdisc.Parent == nil {
54 args = append(args, "root")
55 } else if qdisc.Parent.Major == 1 && qdisc.Parent.Minor == 0 {
56 args = append(args, "root")
57 } else {
58 args = append(args, "parent", fmt.Sprintf("%d:%d", qdisc.Parent.Major, qdisc.Parent.Minor))
59 }
60
61 if qdisc.Handle == nil {
62 args = append(args, "handle", fmt.Sprintf("%d:%d", 1, 0))
63 } else {
64 args = append(args, "handle", fmt.Sprintf("%d:%d", qdisc.Handle.Major, qdisc.Handle.Minor))
65 }
66
67 args = append(args, qdisc.Type)
68
69 if qdisc.Args != nil {
70 args = append(args, qdisc.Args...)
71 }
72
73 return args, nil
74 }
75
76 func getAllInterfaces(ctx context.Context, log logr.Logger, pid uint32) ([]string, error) {
77 ipOutput, err := bpm.DefaultProcessBuilder("ip", "-j", "addr", "show").SetNS(pid, bpm.NetNS).Build(ctx).CombinedOutput()
78 if err != nil {
79 return []string{}, err
80 }
81 var data []map[string]interface{}
82
83 err = json.Unmarshal(ipOutput, &data)
84 if err != nil {
85 return []string{}, err
86 }
87
88 var ifaces []string
89 for _, iface := range data {
90 name, ok := iface["ifname"]
91 if !ok {
92 return []string{}, fmt.Errorf("fail to read ifname from ip -j addr show")
93 }
94
95 ifaces = append(ifaces, name.(string))
96 }
97
98 log.Info("get interfaces from ip command", "ifaces", ifaces)
99 return ifaces, nil
100 }
101
102 func (s *DaemonServer) SetTcs(ctx context.Context, in *pb.TcsRequest) (*empty.Empty, error) {
103 log := s.getLoggerFromContext(ctx)
104 log.Info("handling tc request", "tcs", in)
105
106 pid, err := s.crClient.GetPidFromContainerID(ctx, in.ContainerId)
107 if err != nil {
108 return nil, status.Errorf(codes.Internal, "get pid from containerID error: %v", err)
109 }
110
111 tcCli := buildTcClient(ctx, log, in.EnterNS, pid)
112
113 ifaces, err := getAllInterfaces(ctx, log, pid)
114 if err != nil {
115 log.Error(err, "error while getting interfaces")
116 return nil, err
117 }
118 for _, iface := range ifaces {
119 err = tcCli.flush(iface)
120 if err != nil {
121 log.Error(err, "fail to flush tc rules on device", "device", iface)
122 }
123 }
124 if err != nil {
125 return &empty.Empty{}, err
126 }
127
128 for device, rules := range s.groupRulesAccordingToDevices(in.Tcs) {
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153 globalTc := []*pb.Tc{}
154 filterTc := make(map[string][]*pb.Tc)
155
156 for _, tc := range rules {
157 filter := abstractTcFilter(tc)
158 if len(filter) > 0 {
159 filterTc[filter] = append(filterTc[filter], tc)
160 continue
161 }
162 globalTc = append(globalTc, tc)
163 }
164
165 if len(globalTc) > 0 {
166 if err := s.setGlobalTcs(tcCli, globalTc, device); err != nil {
167 log.Error(err, "error while setting global tc")
168 return &empty.Empty{}, err
169 }
170 }
171
172 if len(filterTc) > 0 {
173 iptablesCli := buildIptablesClient(ctx, in.EnterNS, pid)
174 if err := s.setFilterTcs(tcCli, iptablesCli, filterTc, device, len(globalTc)); err != nil {
175 log.Error(err, "error while setting filter tc")
176 return &empty.Empty{}, err
177 }
178 }
179 }
180
181 return &empty.Empty{}, nil
182 }
183
184 func (s *DaemonServer) groupRulesAccordingToDevices(tcs []*pb.Tc) map[string][]*pb.Tc {
185 rules := make(map[string][]*pb.Tc)
186 for _, tc := range tcs {
187 if tc.Device == "" {
188 tc.Device = defaultDevice
189 }
190 rules[tc.Device] = append(rules[tc.Device], tc)
191 }
192 return rules
193 }
194
195 func (s *DaemonServer) setGlobalTcs(cli tcClient, tcs []*pb.Tc, device string) error {
196 for index, tc := range tcs {
197 parentArg := "root"
198 if index > 0 {
199 parentArg = fmt.Sprintf("parent %d:", index)
200 }
201
202 handleArg := fmt.Sprintf("handle %d:", index+1)
203
204 err := cli.addTc(device, parentArg, handleArg, tc)
205 if err != nil {
206 s.rootLogger.Error(err, "error while adding tc")
207 return err
208 }
209 }
210
211 return nil
212 }
213
214 func (s *DaemonServer) setFilterTcs(
215 tcCli tcClient,
216 iptablesCli iptablesClient,
217 filterTc map[string][]*pb.Tc,
218 device string,
219 baseIndex int,
220 ) error {
221 parent := baseIndex
222 band := 3 + len(filterTc)
223 if err := tcCli.addPrio(device, parent, band); err != nil {
224 s.rootLogger.Error(err, "error while adding prio")
225 return err
226 }
227
228 parent++
229 index := 0
230 currentHandler := parent + 3
231
232
233
234
235 chains := []*pb.Chain{}
236 for _, tcs := range filterTc {
237 for i, tc := range tcs {
238 parentArg := fmt.Sprintf("parent %d:%d", parent, index+4)
239 if i > 0 {
240 parentArg = fmt.Sprintf("parent %d:", currentHandler)
241 }
242
243 currentHandler++
244 handleArg := fmt.Sprintf("handle %d:", currentHandler)
245
246 err := tcCli.addTc(device, parentArg, handleArg, tc)
247 if err != nil {
248 s.rootLogger.Error(err, "error while adding tc")
249 return err
250 }
251 }
252
253 ch := &pb.Chain{
254 Name: fmt.Sprintf("TC-TABLES-%d", index),
255 Direction: pb.Chain_OUTPUT,
256 Target: fmt.Sprintf("CLASSIFY --set-class %d:%d", parent, index+4),
257 Device: device,
258 }
259
260 tc := tcs[0]
261 if len(tc.Ipset) > 0 {
262 ch.Ipsets = []string{tc.Ipset}
263 }
264
265 ch.Protocol = tc.Protocol
266 ch.SourcePorts = tc.SourcePort
267 ch.DestinationPorts = tc.EgressPort
268
269 chains = append(chains, ch)
270
271 index++
272 }
273 if err := iptablesCli.setIptablesChains(chains); err != nil {
274 s.rootLogger.Error(err, "error while setting iptables")
275 return err
276 }
277
278 return nil
279 }
280
281 type tcClient struct {
282 ctx context.Context
283 log logr.Logger
284 enterNS bool
285 pid uint32
286 }
287
288 func buildTcClient(ctx context.Context, log logr.Logger, enterNS bool, pid uint32) tcClient {
289 return tcClient{
290 ctx,
291 log,
292 enterNS,
293 pid,
294 }
295 }
296
297 func (c *tcClient) flush(device string) error {
298 processBuilder := bpm.DefaultProcessBuilder("tc", "qdisc", "del", "dev", device, "root").SetContext(c.ctx)
299 if c.enterNS {
300 processBuilder = processBuilder.SetNS(c.pid, bpm.NetNS)
301 }
302 cmd := processBuilder.Build(c.ctx)
303 output, err := cmd.CombinedOutput()
304 if err != nil {
305 if (!strings.Contains(string(output), ruleNotExistLowerVersion)) && (!strings.Contains(string(output), ruleNotExist)) {
306 return encodeOutputToError(output, err)
307 }
308 }
309 return nil
310 }
311
312 func (c *tcClient) addTc(device string, parentArg string, handleArg string, tc *pb.Tc) error {
313 c.log.Info("add tc", "tc", tc)
314
315 if tc.Type == pb.Tc_BANDWIDTH {
316
317 if tc.Tbf == nil {
318 return fmt.Errorf("tbf is nil while type is BANDWIDTH")
319 }
320 err := c.addTbf(device, parentArg, handleArg, tc.Tbf)
321 if err != nil {
322 return err
323 }
324
325 } else if tc.Type == pb.Tc_NETEM {
326
327 if tc.Netem == nil {
328 return fmt.Errorf("netem is nil while type is NETEM")
329 }
330 err := c.addNetem(device, parentArg, handleArg, tc.Netem)
331 if err != nil {
332 return err
333 }
334
335 } else {
336 return fmt.Errorf("unknown tc qdisc type")
337 }
338
339 return nil
340 }
341
342 func (c *tcClient) addPrio(device string, parent int, band int) error {
343 c.log.Info("adding prio", "parent", parent)
344
345 parentArg := "root"
346 if parent > 0 {
347 parentArg = fmt.Sprintf("parent %d:", parent)
348 }
349 args := fmt.Sprintf("qdisc add dev %s %s handle %d: prio bands %d priomap 1 2 2 2 1 2 0 0 1 1 1 1 1 1 1 1", device, parentArg, parent+1, band)
350
351 processBuilder := bpm.DefaultProcessBuilder("tc", strings.Split(args, " ")...).SetContext(c.ctx)
352 if c.enterNS {
353 processBuilder = processBuilder.SetNS(c.pid, bpm.NetNS)
354 }
355 cmd := processBuilder.Build(c.ctx)
356 output, err := cmd.CombinedOutput()
357 if err != nil {
358 return encodeOutputToError(output, err)
359 }
360
361 for index := 1; index <= 3; index++ {
362 args := fmt.Sprintf("qdisc add dev %s parent %d:%d handle %d: sfq", device, parent+1, index, parent+1+index)
363
364 processBuilder := bpm.DefaultProcessBuilder("tc", strings.Split(args, " ")...).SetContext(c.ctx)
365 if c.enterNS {
366 processBuilder = processBuilder.SetNS(c.pid, bpm.NetNS)
367 }
368 cmd := processBuilder.Build(c.ctx)
369 output, err := cmd.CombinedOutput()
370 if err != nil {
371 return encodeOutputToError(output, err)
372 }
373 }
374
375 return nil
376 }
377
378 func (c *tcClient) addNetem(device string, parent string, handle string, netem *pb.Netem) error {
379 c.log.Info("adding netem", "device", device, "parent", parent, "handle", handle)
380
381 args := fmt.Sprintf("qdisc add dev %s %s %s netem %s", device, parent, handle, convertNetemToArgs(netem))
382 processBuilder := bpm.DefaultProcessBuilder("tc", strings.Split(args, " ")...).SetContext(c.ctx)
383 if c.enterNS {
384 processBuilder = processBuilder.SetNS(c.pid, bpm.NetNS)
385 }
386 cmd := processBuilder.Build(c.ctx)
387 output, err := cmd.CombinedOutput()
388 if err != nil {
389 return encodeOutputToError(output, err)
390 }
391 return nil
392 }
393
394 func (c *tcClient) addTbf(device string, parent string, handle string, tbf *pb.Tbf) error {
395 c.log.Info("adding tbf", "device", device, "parent", parent, "handle", handle)
396
397 args := fmt.Sprintf("qdisc add dev %s %s %s tbf %s", device, parent, handle, convertTbfToArgs(tbf))
398 processBuilder := bpm.DefaultProcessBuilder("tc", strings.Split(args, " ")...).SetContext(c.ctx)
399 if c.enterNS {
400 processBuilder = processBuilder.SetNS(c.pid, bpm.NetNS)
401 }
402 cmd := processBuilder.Build(c.ctx)
403 output, err := cmd.CombinedOutput()
404 if err != nil {
405 return encodeOutputToError(output, err)
406 }
407 return nil
408 }
409
410 func convertNetemToArgs(netem *pb.Netem) string {
411 args := ""
412 if netem.Time > 0 {
413 args = fmt.Sprintf("delay %d", netem.Time)
414 if netem.Jitter > 0 {
415 args = fmt.Sprintf("%s %d", args, netem.Jitter)
416
417 if netem.DelayCorr > 0 {
418 args = fmt.Sprintf("%s %f", args, netem.DelayCorr)
419 }
420 }
421
422
423 if netem.Reorder > 0 {
424 args = fmt.Sprintf("%s reorder %f", args, netem.Reorder)
425 if netem.ReorderCorr > 0 {
426 args = fmt.Sprintf("%s %f", args, netem.ReorderCorr)
427 }
428
429 if netem.Gap > 0 {
430 args = fmt.Sprintf("%s gap %d", args, netem.Gap)
431 }
432 }
433 }
434
435 if netem.Limit > 0 {
436 args = fmt.Sprintf("%s limit %d", args, netem.Limit)
437 }
438
439 if netem.Loss > 0 {
440 args = fmt.Sprintf("%s loss %f", args, netem.Loss)
441 if netem.LossCorr > 0 {
442 args = fmt.Sprintf("%s %f", args, netem.LossCorr)
443 }
444 }
445
446 if netem.Duplicate > 0 {
447 args = fmt.Sprintf("%s duplicate %f", args, netem.Duplicate)
448 if netem.DuplicateCorr > 0 {
449 args = fmt.Sprintf("%s %f", args, netem.DuplicateCorr)
450 }
451 }
452
453 if netem.Corrupt > 0 {
454 args = fmt.Sprintf("%s corrupt %f", args, netem.Corrupt)
455 if netem.CorruptCorr > 0 {
456 args = fmt.Sprintf("%s %f", args, netem.CorruptCorr)
457 }
458 }
459
460 trimedArgs := []string{}
461
462 for _, part := range strings.Split(args, " ") {
463 if len(part) > 0 {
464 trimedArgs = append(trimedArgs, part)
465 }
466 }
467
468 return strings.Join(trimedArgs, " ")
469 }
470
471 func convertTbfToArgs(tbf *pb.Tbf) string {
472 args := fmt.Sprintf("rate %d burst %d", tbf.Rate, tbf.Buffer)
473 if tbf.Limit > 0 {
474 args = fmt.Sprintf("%s limit %d", args, tbf.Limit)
475 }
476 if tbf.PeakRate > 0 {
477 args = fmt.Sprintf("%s peakrate %d mtu %d", args, tbf.PeakRate, tbf.MinBurst)
478 }
479
480 return args
481 }
482
483 func abstractTcFilter(tc *pb.Tc) string {
484 filter := tc.Ipset
485
486 if len(tc.Protocol) > 0 {
487 filter += "-" + tc.Protocol
488 }
489
490 if len(tc.EgressPort) > 0 {
491 filter += "-" + tc.EgressPort
492 }
493
494 if len(tc.SourcePort) > 0 {
495 filter += "-" + tc.EgressPort
496 }
497
498 return filter
499 }
500