...

Source file src/github.com/chaos-mesh/chaos-mesh/pkg/chaosdaemon/tc_server.go

Documentation: github.com/chaos-mesh/chaos-mesh/pkg/chaosdaemon

     1  // Copyright 2021 Chaos Mesh Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  // http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  //
    15  
    16  package chaosdaemon
    17  
    18  import (
    19  	"context"
    20  	"encoding/json"
    21  	"fmt"
    22  	"net"
    23  	"strings"
    24  
    25  	"github.com/go-logr/logr"
    26  	"github.com/golang/protobuf/ptypes/empty"
    27  	"github.com/pkg/errors"
    28  	"google.golang.org/grpc/codes"
    29  	"google.golang.org/grpc/status"
    30  
    31  	"github.com/chaos-mesh/chaos-mesh/pkg/bpm"
    32  	pb "github.com/chaos-mesh/chaos-mesh/pkg/chaosdaemon/pb"
    33  	"github.com/chaos-mesh/chaos-mesh/pkg/chaosdaemon/util"
    34  )
    35  
    36  const (
    37  	ruleNotExist             = "Cannot delete qdisc with handle of zero."
    38  	ruleNotExistLowerVersion = "RTNETLINK answers: No such file or directory"
    39  
    40  	defaultDevice = "eth0"
    41  )
    42  
    43  func generateQdiscArgs(action string, qdisc *pb.Qdisc) ([]string, error) {
    44  	if qdisc == nil {
    45  		return nil, errors.New("qdisc is required")
    46  	}
    47  
    48  	if qdisc.Type == "" {
    49  		return nil, errors.New("qdisc.Type is required")
    50  	}
    51  
    52  	args := []string{"qdisc", action, "dev", "eth0"}
    53  
    54  	if qdisc.Parent == nil {
    55  		args = append(args, "root")
    56  	} else if qdisc.Parent.Major == 1 && qdisc.Parent.Minor == 0 {
    57  		args = append(args, "root")
    58  	} else {
    59  		args = append(args, "parent", fmt.Sprintf("%d:%d", qdisc.Parent.Major, qdisc.Parent.Minor))
    60  	}
    61  
    62  	if qdisc.Handle == nil {
    63  		args = append(args, "handle", fmt.Sprintf("%d:%d", 1, 0))
    64  	} else {
    65  		args = append(args, "handle", fmt.Sprintf("%d:%d", qdisc.Handle.Major, qdisc.Handle.Minor))
    66  	}
    67  
    68  	args = append(args, qdisc.Type)
    69  
    70  	if qdisc.Args != nil {
    71  		args = append(args, qdisc.Args...)
    72  	}
    73  
    74  	return args, nil
    75  }
    76  
    77  func getAllInterfaces(ctx context.Context, log logr.Logger, pid uint32, enterNS bool) ([]string, error) {
    78  	var ifaces []string
    79  	if enterNS {
    80  		ipOutput, err := bpm.DefaultProcessBuilder("ip", "-j", "addr", "show").SetNS(pid, bpm.NetNS).SetContext(ctx).Build(ctx).CombinedOutput()
    81  		if err != nil {
    82  			return []string{}, err
    83  		}
    84  		var data []map[string]interface{}
    85  
    86  		err = json.Unmarshal(ipOutput, &data)
    87  		if err != nil {
    88  			return []string{}, err
    89  		}
    90  		for _, iface := range data {
    91  			name, ok := iface["ifname"]
    92  			if !ok {
    93  				return []string{}, errors.New("fail to read ifname from ip -j addr show")
    94  			}
    95  			ifaces = append(ifaces, name.(string))
    96  		}
    97  		log.Info("get interfaces from ip command", "ifaces", ifaces)
    98  	} else {
    99  		interfaces, err := net.Interfaces()
   100  		if err != nil {
   101  			return []string{}, errors.New("fail to read ifname from net.Interfaces()")
   102  		}
   103  		for _, iface := range interfaces {
   104  			ifaces = append(ifaces, iface.Name)
   105  		}
   106  		log.Info("get interfaces from net.Interfaces()", "ifaces", ifaces)
   107  	}
   108  
   109  	return ifaces, nil
   110  }
   111  
   112  func (s *DaemonServer) SetTcs(ctx context.Context, in *pb.TcsRequest) (*empty.Empty, error) {
   113  	log := s.getLoggerFromContext(ctx)
   114  	log.Info("handling tc request", "tcs", in)
   115  
   116  	pid, err := s.crClient.GetPidFromContainerID(ctx, in.ContainerId)
   117  	if err != nil {
   118  		return nil, status.Errorf(codes.Internal, "get pid from containerID error: %v", err)
   119  	}
   120  
   121  	tcCli := buildTcClient(ctx, log, in.EnterNS, pid)
   122  
   123  	ifaces, err := getAllInterfaces(ctx, log, pid, in.EnterNS)
   124  	if err != nil {
   125  		log.Error(err, "error while getting interfaces")
   126  		return nil, err
   127  	}
   128  	for _, iface := range ifaces {
   129  		err = tcCli.flush(iface)
   130  		if err != nil {
   131  			log.Error(err, "fail to flush tc rules on device", "device", iface)
   132  		}
   133  	}
   134  	if err != nil {
   135  		return &empty.Empty{}, err
   136  	}
   137  
   138  	for device, rules := range s.groupRulesAccordingToDevices(in.Tcs) {
   139  		// tc rules are split into two different kinds according to whether it has filter.
   140  		// all tc rules without filter are called `globalTc` and the tc rules with filter will be called `filterTc`.
   141  		// the `globalTc` rules will be piped one by one from root, and the last `globalTc` will be connected with a PRIO
   142  		// qdisc, which has `3 + len(filterTc)` bands. Then the 4.. bands will be connected to `filterTc` and a filter will
   143  		// be setuped to flow packet from PRIO qdisc to it.
   144  
   145  		// for example, four tc rules:
   146  		// - NETEM: 50ms latency without filter
   147  		// - NETEM: 100ms latency without filter
   148  		// - NETEM: 50ms latency with filter ipset A
   149  		// - NETEM: 100ms latency with filter ipset B
   150  		// will generate tc rules:
   151  		//	tc qdisc del dev eth0 root
   152  		//  tc qdisc add dev eth0 root handle 1: netem delay 50000
   153  		//  tc qdisc add dev eth0 parent 1: handle 2: netem delay 100000
   154  		//  tc qdisc add dev eth0 parent 2: handle 3: prio bands 5 priomap 1 2 2 2 1 2 0 0 1 1 1 1 1 1 1 1
   155  		//  tc qdisc add dev eth0 parent 3:1 handle 4: sfq
   156  		//  tc qdisc add dev eth0 parent 3:2 handle 5: sfq
   157  		//  tc qdisc add dev eth0 parent 3:3 handle 6: sfq
   158  		//  tc qdisc add dev eth0 parent 3:4 handle 7: netem delay 50000
   159  		//  iptables -A TC-TABLES-0 -o eth0 -m set --match-set A dst -j CLASSIFY --set-class 3:4 -w 5
   160  		//  tc qdisc add dev eth0 parent 3:5 handle 8: netem delay 100000
   161  		//  iptables -A TC-TABLES-1 -o eth0 -m set --match-set B dst -j CLASSIFY --set-class 3:5 -w 5
   162  
   163  		globalTc := []*pb.Tc{}
   164  		filterTc := make(map[string][]*pb.Tc)
   165  
   166  		for _, tc := range rules {
   167  			filter := abstractTcFilter(tc)
   168  			if len(filter) > 0 {
   169  				filterTc[filter] = append(filterTc[filter], tc)
   170  				continue
   171  			}
   172  			globalTc = append(globalTc, tc)
   173  		}
   174  
   175  		if len(globalTc) > 0 {
   176  			if err := s.setGlobalTcs(log, tcCli, globalTc, device); err != nil {
   177  				log.Error(err, "error while setting global tc")
   178  				return &empty.Empty{}, err
   179  			}
   180  		}
   181  
   182  		if len(filterTc) > 0 {
   183  			iptablesCli := buildIptablesClient(ctx, in.EnterNS, pid)
   184  			if err := s.setFilterTcs(log, tcCli, iptablesCli, filterTc, device, len(globalTc)); err != nil {
   185  				log.Error(err, "error while setting filter tc")
   186  				return &empty.Empty{}, err
   187  			}
   188  		}
   189  	}
   190  
   191  	return &empty.Empty{}, nil
   192  }
   193  
   194  func (s *DaemonServer) groupRulesAccordingToDevices(tcs []*pb.Tc) map[string][]*pb.Tc {
   195  	rules := make(map[string][]*pb.Tc)
   196  	for _, tc := range tcs {
   197  		if tc.Device == "" {
   198  			tc.Device = defaultDevice
   199  		}
   200  		rules[tc.Device] = append(rules[tc.Device], tc)
   201  	}
   202  	return rules
   203  }
   204  
   205  func (s *DaemonServer) setGlobalTcs(log logr.Logger, cli tcClient, tcs []*pb.Tc, device string) error {
   206  	for index, tc := range tcs {
   207  		parentArg := "root"
   208  		if index > 0 {
   209  			parentArg = fmt.Sprintf("parent %d:", index)
   210  		}
   211  
   212  		handleArg := fmt.Sprintf("handle %d:", index+1)
   213  
   214  		err := cli.addTc(device, parentArg, handleArg, tc)
   215  		if err != nil {
   216  			log.Error(err, "error while adding tc")
   217  			return err
   218  		}
   219  	}
   220  
   221  	return nil
   222  }
   223  
   224  func (s *DaemonServer) setFilterTcs(
   225  	log logr.Logger,
   226  	tcCli tcClient,
   227  	iptablesCli iptablesClient,
   228  	filterTc map[string][]*pb.Tc,
   229  	device string,
   230  	baseIndex int,
   231  ) error {
   232  	parent := baseIndex
   233  	band := 3 + len(filterTc) // 3 handlers for normal sfq on prio qdisc
   234  	if err := tcCli.addPrio(device, parent, band); err != nil {
   235  		log.Error(err, "error while adding prio")
   236  		return err
   237  	}
   238  
   239  	parent++
   240  	index := 0
   241  	currentHandler := parent + 3 // 3 handlers for sfq on prio qdisc
   242  
   243  	// iptables chain has been initialized by previous grpc request to set iptables
   244  	// and iptables rules are recovered by previous call too, so there is no need
   245  	// to remove these rules here
   246  	chains := []*pb.Chain{}
   247  	for _, tcs := range filterTc {
   248  		for i, tc := range tcs {
   249  			parentArg := fmt.Sprintf("parent %d:%d", parent, index+4)
   250  			if i > 0 {
   251  				parentArg = fmt.Sprintf("parent %d:", currentHandler)
   252  			}
   253  
   254  			currentHandler++
   255  			handleArg := fmt.Sprintf("handle %d:", currentHandler)
   256  
   257  			err := tcCli.addTc(device, parentArg, handleArg, tc)
   258  			if err != nil {
   259  				log.Error(err, "error while adding tc")
   260  				return err
   261  			}
   262  		}
   263  
   264  		ch := &pb.Chain{
   265  			Name:      fmt.Sprintf("TC-TABLES-%d", index),
   266  			Direction: pb.Chain_OUTPUT,
   267  			Target:    fmt.Sprintf("CLASSIFY --set-class %d:%d", parent, index+4),
   268  			Device:    device,
   269  		}
   270  
   271  		tc := tcs[0]
   272  		if len(tc.Ipset) > 0 {
   273  			ch.Ipsets = []string{tc.Ipset}
   274  		}
   275  
   276  		ch.Protocol = tc.Protocol
   277  		ch.SourcePorts = tc.SourcePort
   278  		ch.DestinationPorts = tc.EgressPort
   279  
   280  		chains = append(chains, ch)
   281  
   282  		index++
   283  	}
   284  	if err := iptablesCli.setIptablesChains(chains); err != nil {
   285  		log.Error(err, "error while setting iptables")
   286  		return err
   287  	}
   288  
   289  	return nil
   290  }
   291  
   292  type tcClient struct {
   293  	ctx     context.Context
   294  	log     logr.Logger
   295  	enterNS bool
   296  	pid     uint32
   297  }
   298  
   299  func buildTcClient(ctx context.Context, log logr.Logger, enterNS bool, pid uint32) tcClient {
   300  	return tcClient{
   301  		ctx,
   302  		log,
   303  		enterNS,
   304  		pid,
   305  	}
   306  }
   307  
   308  func (c *tcClient) flush(device string) error {
   309  	processBuilder := bpm.DefaultProcessBuilder("tc", "qdisc", "del", "dev", device, "root").SetContext(c.ctx)
   310  	if c.enterNS {
   311  		processBuilder = processBuilder.SetNS(c.pid, bpm.NetNS)
   312  	}
   313  	cmd := processBuilder.Build(c.ctx)
   314  	output, err := cmd.CombinedOutput()
   315  	if err != nil {
   316  		if (!strings.Contains(string(output), ruleNotExistLowerVersion)) && (!strings.Contains(string(output), ruleNotExist)) {
   317  			return util.EncodeOutputToError(output, err)
   318  		}
   319  	}
   320  	return nil
   321  }
   322  
   323  func (c *tcClient) addTc(device string, parentArg string, handleArg string, tc *pb.Tc) error {
   324  	c.log.Info("add tc", "tc", tc)
   325  
   326  	if tc.Type == pb.Tc_BANDWIDTH {
   327  
   328  		if tc.Tbf == nil {
   329  			return errors.New("tbf is nil while type is BANDWIDTH")
   330  		}
   331  		err := c.addTbf(device, parentArg, handleArg, tc.Tbf)
   332  		if err != nil {
   333  			return err
   334  		}
   335  
   336  	} else if tc.Type == pb.Tc_NETEM {
   337  
   338  		if tc.Netem == nil {
   339  			return errors.New("netem is nil while type is NETEM")
   340  		}
   341  		err := c.addNetem(device, parentArg, handleArg, tc.Netem)
   342  		if err != nil {
   343  			return err
   344  		}
   345  
   346  	} else {
   347  		return errors.New("unknown tc qdisc type")
   348  	}
   349  
   350  	return nil
   351  }
   352  
   353  func (c *tcClient) addPrio(device string, parent int, band int) error {
   354  	c.log.Info("adding prio", "parent", parent)
   355  
   356  	parentArg := "root"
   357  	if parent > 0 {
   358  		parentArg = fmt.Sprintf("parent %d:", parent)
   359  	}
   360  	args := fmt.Sprintf("qdisc add dev %s %s handle %d: prio bands %d priomap 1 2 2 2 1 2 0 0 1 1 1 1 1 1 1 1", device, parentArg, parent+1, band)
   361  
   362  	processBuilder := bpm.DefaultProcessBuilder("tc", strings.Split(args, " ")...).SetContext(c.ctx)
   363  	if c.enterNS {
   364  		processBuilder = processBuilder.SetNS(c.pid, bpm.NetNS)
   365  	}
   366  	cmd := processBuilder.Build(c.ctx)
   367  	output, err := cmd.CombinedOutput()
   368  	if err != nil {
   369  		return util.EncodeOutputToError(output, err)
   370  	}
   371  
   372  	for index := 1; index <= 3; index++ {
   373  		args := fmt.Sprintf("qdisc add dev %s parent %d:%d handle %d: sfq", device, parent+1, index, parent+1+index)
   374  
   375  		processBuilder := bpm.DefaultProcessBuilder("tc", strings.Split(args, " ")...).SetContext(c.ctx)
   376  		if c.enterNS {
   377  			processBuilder = processBuilder.SetNS(c.pid, bpm.NetNS)
   378  		}
   379  		cmd := processBuilder.Build(c.ctx)
   380  		output, err := cmd.CombinedOutput()
   381  		if err != nil {
   382  			return util.EncodeOutputToError(output, err)
   383  		}
   384  	}
   385  
   386  	return nil
   387  }
   388  
   389  func (c *tcClient) addNetem(device string, parent string, handle string, netem *pb.Netem) error {
   390  	c.log.Info("adding netem", "device", device, "parent", parent, "handle", handle)
   391  
   392  	args := fmt.Sprintf("qdisc add dev %s %s %s netem %s", device, parent, handle, convertNetemToArgs(netem))
   393  	processBuilder := bpm.DefaultProcessBuilder("tc", strings.Split(args, " ")...).SetContext(c.ctx)
   394  	if c.enterNS {
   395  		processBuilder = processBuilder.SetNS(c.pid, bpm.NetNS)
   396  	}
   397  	cmd := processBuilder.Build(c.ctx)
   398  	output, err := cmd.CombinedOutput()
   399  	if err != nil {
   400  		return util.EncodeOutputToError(output, err)
   401  	}
   402  	return nil
   403  }
   404  
   405  func (c *tcClient) addTbf(device string, parent string, handle string, tbf *pb.Tbf) error {
   406  	c.log.Info("adding tbf", "device", device, "parent", parent, "handle", handle)
   407  
   408  	args := fmt.Sprintf("qdisc add dev %s %s %s tbf %s", device, parent, handle, convertTbfToArgs(tbf))
   409  	processBuilder := bpm.DefaultProcessBuilder("tc", strings.Split(args, " ")...).SetContext(c.ctx)
   410  	if c.enterNS {
   411  		processBuilder = processBuilder.SetNS(c.pid, bpm.NetNS)
   412  	}
   413  	cmd := processBuilder.Build(c.ctx)
   414  	output, err := cmd.CombinedOutput()
   415  	if err != nil {
   416  		return util.EncodeOutputToError(output, err)
   417  	}
   418  	return nil
   419  }
   420  
   421  func convertNetemToArgs(netem *pb.Netem) string {
   422  	args := ""
   423  	if netem.Time > 0 {
   424  		args = fmt.Sprintf("delay %d", netem.Time)
   425  		if netem.Jitter > 0 {
   426  			args = fmt.Sprintf("%s %d", args, netem.Jitter)
   427  
   428  			if netem.DelayCorr > 0 {
   429  				args = fmt.Sprintf("%s %f", args, netem.DelayCorr)
   430  			}
   431  		}
   432  
   433  		// reordering not possible without specifying some delay
   434  		if netem.Reorder > 0 {
   435  			args = fmt.Sprintf("%s reorder %f", args, netem.Reorder)
   436  			if netem.ReorderCorr > 0 {
   437  				args = fmt.Sprintf("%s %f", args, netem.ReorderCorr)
   438  			}
   439  
   440  			if netem.Gap > 0 {
   441  				args = fmt.Sprintf("%s gap %d", args, netem.Gap)
   442  			}
   443  		}
   444  	}
   445  
   446  	if netem.Limit > 0 {
   447  		args = fmt.Sprintf("%s limit %d", args, netem.Limit)
   448  	}
   449  
   450  	if netem.Loss > 0 {
   451  		args = fmt.Sprintf("%s loss %f", args, netem.Loss)
   452  		if netem.LossCorr > 0 {
   453  			args = fmt.Sprintf("%s %f", args, netem.LossCorr)
   454  		}
   455  	}
   456  
   457  	if netem.Duplicate > 0 {
   458  		args = fmt.Sprintf("%s duplicate %f", args, netem.Duplicate)
   459  		if netem.DuplicateCorr > 0 {
   460  			args = fmt.Sprintf("%s %f", args, netem.DuplicateCorr)
   461  		}
   462  	}
   463  
   464  	if netem.Corrupt > 0 {
   465  		args = fmt.Sprintf("%s corrupt %f", args, netem.Corrupt)
   466  		if netem.CorruptCorr > 0 {
   467  			args = fmt.Sprintf("%s %f", args, netem.CorruptCorr)
   468  		}
   469  	}
   470  
   471  	if len(netem.Rate) > 0 {
   472  		args = fmt.Sprintf("%s rate %s", args, netem.Rate)
   473  	}
   474  
   475  	trimedArgs := []string{}
   476  
   477  	for _, part := range strings.Split(args, " ") {
   478  		if len(part) > 0 {
   479  			trimedArgs = append(trimedArgs, part)
   480  		}
   481  	}
   482  
   483  	return strings.Join(trimedArgs, " ")
   484  }
   485  
   486  func convertTbfToArgs(tbf *pb.Tbf) string {
   487  	args := fmt.Sprintf("rate %s burst %d", tbf.Rate, tbf.Buffer)
   488  	if tbf.Limit > 0 {
   489  		args = fmt.Sprintf("%s limit %d", args, tbf.Limit)
   490  	}
   491  	if tbf.PeakRate > 0 {
   492  		args = fmt.Sprintf("%s peakrate %d mtu %d", args, tbf.PeakRate, tbf.MinBurst)
   493  	}
   494  
   495  	return args
   496  }
   497  
   498  func abstractTcFilter(tc *pb.Tc) string {
   499  	filter := tc.Ipset
   500  
   501  	if len(tc.Protocol) > 0 {
   502  		filter += "-" + tc.Protocol
   503  	}
   504  
   505  	if len(tc.EgressPort) > 0 {
   506  		filter += "-" + tc.EgressPort
   507  	}
   508  
   509  	if len(tc.SourcePort) > 0 {
   510  		filter += "-" + tc.EgressPort
   511  	}
   512  
   513  	return filter
   514  }
   515