...

Source file src/github.com/chaos-mesh/chaos-mesh/pkg/chaosdaemon/iptables_server.go

Documentation: github.com/chaos-mesh/chaos-mesh/pkg/chaosdaemon

     1  // Copyright 2021 Chaos Mesh Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  // http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  //
    15  
    16  package chaosdaemon
    17  
    18  import (
    19  	"context"
    20  	"fmt"
    21  	"strings"
    22  
    23  	"github.com/golang/protobuf/ptypes/empty"
    24  	"github.com/pkg/errors"
    25  
    26  	"github.com/chaos-mesh/chaos-mesh/pkg/bpm"
    27  	pb "github.com/chaos-mesh/chaos-mesh/pkg/chaosdaemon/pb"
    28  	"github.com/chaos-mesh/chaos-mesh/pkg/chaosdaemon/util"
    29  )
    30  
    31  const (
    32  	iptablesCmd = "iptables"
    33  
    34  	iptablesChainAlreadyExistErr = "iptables: Chain already exists."
    35  )
    36  
    37  func (s *DaemonServer) SetIptablesChains(ctx context.Context, req *pb.IptablesChainsRequest) (*empty.Empty, error) {
    38  	log := s.getLoggerFromContext(ctx)
    39  	log.Info("Set iptables chains", "request", req)
    40  
    41  	pid, err := s.crClient.GetPidFromContainerID(ctx, req.ContainerId)
    42  	if err != nil {
    43  		log.Error(err, "error while getting PID")
    44  		return nil, err
    45  	}
    46  
    47  	iptables := buildIptablesClient(ctx, req.EnterNS, pid)
    48  	err = iptables.initializeEnv()
    49  	if err != nil {
    50  		log.Error(err, "error while initializing iptables")
    51  		return nil, err
    52  	}
    53  
    54  	err = iptables.setIptablesChains(req.Chains)
    55  	if err != nil {
    56  		log.Error(err, "error while setting iptables chains")
    57  		return nil, err
    58  	}
    59  
    60  	return &empty.Empty{}, nil
    61  }
    62  
    63  type iptablesClient struct {
    64  	ctx     context.Context
    65  	enterNS bool
    66  	pid     uint32
    67  }
    68  
    69  type iptablesChain struct {
    70  	Name  string
    71  	Rules []string
    72  }
    73  
    74  func buildIptablesClient(ctx context.Context, enterNS bool, pid uint32) iptablesClient {
    75  	return iptablesClient{
    76  		ctx,
    77  		enterNS,
    78  		pid,
    79  	}
    80  }
    81  
    82  func (iptables *iptablesClient) setIptablesChains(chains []*pb.Chain) error {
    83  	for _, chain := range chains {
    84  		err := iptables.setIptablesChain(chain)
    85  		if err != nil {
    86  			return err
    87  		}
    88  	}
    89  
    90  	return nil
    91  }
    92  
    93  func (iptables *iptablesClient) setIptablesChain(chain *pb.Chain) error {
    94  	var matchPart string
    95  	var interfaceMatcher string
    96  	if chain.Direction == pb.Chain_INPUT {
    97  		matchPart = "src,dst"
    98  		interfaceMatcher = "-i"
    99  	} else if chain.Direction == pb.Chain_OUTPUT {
   100  		matchPart = "dst,dst"
   101  		interfaceMatcher = "-o"
   102  	} else {
   103  		return errors.Errorf("unknown chain direction %d", chain.Direction)
   104  	}
   105  
   106  	if chain.Device == "" {
   107  		chain.Device = defaultDevice
   108  	}
   109  
   110  	protocolAndPort := ""
   111  	if len(chain.Protocol) > 0 {
   112  		protocolAndPort += fmt.Sprintf("--protocol %s", chain.Protocol)
   113  
   114  		if len(chain.SourcePorts) > 0 {
   115  			if strings.Contains(chain.SourcePorts, ",") {
   116  				protocolAndPort += fmt.Sprintf(" -m multiport --source-ports %s", chain.SourcePorts)
   117  			} else {
   118  				protocolAndPort += fmt.Sprintf(" --source-port %s", chain.SourcePorts)
   119  			}
   120  		}
   121  
   122  		if len(chain.DestinationPorts) > 0 {
   123  			if strings.Contains(chain.DestinationPorts, ",") {
   124  				protocolAndPort += fmt.Sprintf(" -m multiport --destination-ports %s", chain.DestinationPorts)
   125  			} else {
   126  				protocolAndPort += fmt.Sprintf(" --destination-port %s", chain.DestinationPorts)
   127  			}
   128  		}
   129  
   130  		if len(chain.TcpFlags) > 0 {
   131  			protocolAndPort += fmt.Sprintf(" --tcp-flags %s", chain.TcpFlags)
   132  		}
   133  	}
   134  
   135  	rules := []string{}
   136  
   137  	if len(chain.Ipsets) == 0 {
   138  		rules = append(rules, strings.TrimSpace(fmt.Sprintf("-A %s %s %s -j %s -w 5 %s", chain.Name, interfaceMatcher, chain.Device, chain.Target, protocolAndPort)))
   139  	}
   140  
   141  	for _, ipset := range chain.Ipsets {
   142  		rules = append(rules, strings.TrimSpace(fmt.Sprintf("-A %s %s %s -m set --match-set %s %s -j %s -w 5 %s",
   143  			chain.Name, interfaceMatcher, chain.Device, ipset, matchPart, chain.Target, protocolAndPort)))
   144  	}
   145  	err := iptables.createNewChain(&iptablesChain{
   146  		Name:  chain.Name,
   147  		Rules: rules,
   148  	})
   149  	if err != nil {
   150  		return err
   151  	}
   152  
   153  	if chain.Direction == pb.Chain_INPUT {
   154  		err := iptables.ensureRule(&iptablesChain{
   155  			Name: "CHAOS-INPUT",
   156  		}, "-A CHAOS-INPUT -j "+chain.Name)
   157  		if err != nil {
   158  			return err
   159  		}
   160  	} else if chain.Direction == pb.Chain_OUTPUT {
   161  		iptables.ensureRule(&iptablesChain{
   162  			Name: "CHAOS-OUTPUT",
   163  		}, "-A CHAOS-OUTPUT -j "+chain.Name)
   164  		if err != nil {
   165  			return err
   166  		}
   167  	} else {
   168  		return errors.Errorf("unknown direction %d", chain.Direction)
   169  	}
   170  	return nil
   171  }
   172  
   173  func (iptables *iptablesClient) initializeEnv() error {
   174  	for _, direction := range []string{"INPUT", "OUTPUT"} {
   175  		chainName := "CHAOS-" + direction
   176  
   177  		err := iptables.createNewChain(&iptablesChain{
   178  			Name:  chainName,
   179  			Rules: []string{},
   180  		})
   181  		if err != nil {
   182  			return err
   183  		}
   184  
   185  		iptables.ensureRule(&iptablesChain{
   186  			Name:  direction,
   187  			Rules: []string{},
   188  		}, "-A "+direction+" -j "+chainName)
   189  	}
   190  
   191  	return nil
   192  }
   193  
   194  // createNewChain will cover existing chain
   195  func (iptables *iptablesClient) createNewChain(chain *iptablesChain) error {
   196  	processBuilder := bpm.DefaultProcessBuilder(iptablesCmd, "-w", "-N", chain.Name).SetContext(iptables.ctx)
   197  	if iptables.enterNS {
   198  		processBuilder = processBuilder.SetNS(iptables.pid, bpm.NetNS)
   199  	}
   200  	cmd := processBuilder.Build(iptables.ctx)
   201  	out, err := cmd.CombinedOutput()
   202  
   203  	if (err == nil && len(out) == 0) ||
   204  		(err != nil && strings.Contains(string(out), iptablesChainAlreadyExistErr)) {
   205  		// Successfully create a new chain
   206  		return iptables.deleteAndWriteRules(chain)
   207  	}
   208  
   209  	return util.EncodeOutputToError(out, err)
   210  }
   211  
   212  // deleteAndWriteRules will remove all existing function in the chain
   213  // and replace with the new settings
   214  func (iptables *iptablesClient) deleteAndWriteRules(chain *iptablesChain) error {
   215  
   216  	// This chain should already exist
   217  	err := iptables.flushIptablesChain(chain)
   218  	if err != nil {
   219  		return err
   220  	}
   221  
   222  	for _, rule := range chain.Rules {
   223  		err := iptables.ensureRule(chain, rule)
   224  		if err != nil {
   225  			return err
   226  		}
   227  	}
   228  
   229  	return nil
   230  }
   231  
   232  func (iptables *iptablesClient) ensureRule(chain *iptablesChain, rule string) error {
   233  	processBuilder := bpm.DefaultProcessBuilder(iptablesCmd, "-w", "-S", chain.Name).SetContext(iptables.ctx)
   234  	if iptables.enterNS {
   235  		processBuilder = processBuilder.SetNS(iptables.pid, bpm.NetNS)
   236  	}
   237  	cmd := processBuilder.Build(iptables.ctx)
   238  	out, err := cmd.CombinedOutput()
   239  	if err != nil {
   240  		return util.EncodeOutputToError(out, err)
   241  	}
   242  
   243  	if strings.Contains(string(out), rule) {
   244  		// The required rule already exist in chain
   245  		return nil
   246  	}
   247  
   248  	// TODO: lock on every container but not on chaos-daemon's `/run/xtables.lock`
   249  	processBuilder = bpm.DefaultProcessBuilder(iptablesCmd, strings.Split("-w "+rule, " ")...).SetContext(iptables.ctx)
   250  	if iptables.enterNS {
   251  		processBuilder = processBuilder.SetNS(iptables.pid, bpm.NetNS)
   252  	}
   253  	cmd = processBuilder.Build(iptables.ctx)
   254  	out, err = cmd.CombinedOutput()
   255  	if err != nil {
   256  		return util.EncodeOutputToError(out, err)
   257  	}
   258  
   259  	return nil
   260  }
   261  
   262  func (iptables *iptablesClient) flushIptablesChain(chain *iptablesChain) error {
   263  	processBuilder := bpm.DefaultProcessBuilder(iptablesCmd, "-w", "-F", chain.Name).SetContext(iptables.ctx)
   264  	if iptables.enterNS {
   265  		processBuilder = processBuilder.SetNS(iptables.pid, bpm.NetNS)
   266  	}
   267  	cmd := processBuilder.Build(iptables.ctx)
   268  	out, err := cmd.CombinedOutput()
   269  	if err != nil {
   270  		return util.EncodeOutputToError(out, err)
   271  	}
   272  
   273  	return nil
   274  }
   275