...

Source file src/github.com/chaos-mesh/chaos-mesh/pkg/chaosdaemon/iptables_server.go

Documentation: github.com/chaos-mesh/chaos-mesh/pkg/chaosdaemon

     1  // Copyright 2021 Chaos Mesh Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  // http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  //
    15  
    16  package chaosdaemon
    17  
    18  import (
    19  	"context"
    20  	"fmt"
    21  	"strings"
    22  
    23  	"github.com/golang/protobuf/ptypes/empty"
    24  
    25  	"github.com/chaos-mesh/chaos-mesh/pkg/bpm"
    26  	pb "github.com/chaos-mesh/chaos-mesh/pkg/chaosdaemon/pb"
    27  )
    28  
    29  const (
    30  	iptablesCmd = "iptables"
    31  
    32  	iptablesChainAlreadyExistErr = "iptables: Chain already exists."
    33  )
    34  
    35  func (s *DaemonServer) SetIptablesChains(ctx context.Context, req *pb.IptablesChainsRequest) (*empty.Empty, error) {
    36  	log := s.getLoggerFromContext(ctx)
    37  	log.Info("Set iptables chains", "request", req)
    38  
    39  	pid, err := s.crClient.GetPidFromContainerID(ctx, req.ContainerId)
    40  	if err != nil {
    41  		log.Error(err, "error while getting PID")
    42  		return nil, err
    43  	}
    44  
    45  	iptables := buildIptablesClient(ctx, req.EnterNS, pid)
    46  	err = iptables.initializeEnv()
    47  	if err != nil {
    48  		log.Error(err, "error while initializing iptables")
    49  		return nil, err
    50  	}
    51  
    52  	err = iptables.setIptablesChains(req.Chains)
    53  	if err != nil {
    54  		log.Error(err, "error while setting iptables chains")
    55  		return nil, err
    56  	}
    57  
    58  	return &empty.Empty{}, nil
    59  }
    60  
    61  type iptablesClient struct {
    62  	ctx     context.Context
    63  	enterNS bool
    64  	pid     uint32
    65  }
    66  
    67  type iptablesChain struct {
    68  	Name  string
    69  	Rules []string
    70  }
    71  
    72  func buildIptablesClient(ctx context.Context, enterNS bool, pid uint32) iptablesClient {
    73  	return iptablesClient{
    74  		ctx,
    75  		enterNS,
    76  		pid,
    77  	}
    78  }
    79  
    80  func (iptables *iptablesClient) setIptablesChains(chains []*pb.Chain) error {
    81  	for _, chain := range chains {
    82  		err := iptables.setIptablesChain(chain)
    83  		if err != nil {
    84  			return err
    85  		}
    86  	}
    87  
    88  	return nil
    89  }
    90  
    91  func (iptables *iptablesClient) setIptablesChain(chain *pb.Chain) error {
    92  	var matchPart string
    93  	var interfaceMatcher string
    94  	if chain.Direction == pb.Chain_INPUT {
    95  		matchPart = "src"
    96  		interfaceMatcher = "-i"
    97  	} else if chain.Direction == pb.Chain_OUTPUT {
    98  		matchPart = "dst"
    99  		interfaceMatcher = "-o"
   100  	} else {
   101  		return fmt.Errorf("unknown chain direction %d", chain.Direction)
   102  	}
   103  
   104  	if chain.Device == "" {
   105  		chain.Device = defaultDevice
   106  	}
   107  
   108  	protocolAndPort := ""
   109  	if len(chain.Protocol) > 0 {
   110  		protocolAndPort += fmt.Sprintf("--protocol %s", chain.Protocol)
   111  
   112  		if len(chain.SourcePorts) > 0 {
   113  			if strings.Contains(chain.SourcePorts, ",") {
   114  				protocolAndPort += fmt.Sprintf(" -m multiport --source-ports %s", chain.SourcePorts)
   115  			} else {
   116  				protocolAndPort += fmt.Sprintf(" --source-port %s", chain.SourcePorts)
   117  			}
   118  		}
   119  
   120  		if len(chain.DestinationPorts) > 0 {
   121  			if strings.Contains(chain.DestinationPorts, ",") {
   122  				protocolAndPort += fmt.Sprintf(" -m multiport --destination-ports %s", chain.DestinationPorts)
   123  			} else {
   124  				protocolAndPort += fmt.Sprintf(" --destination-port %s", chain.DestinationPorts)
   125  			}
   126  		}
   127  
   128  		if len(chain.TcpFlags) > 0 {
   129  			protocolAndPort += fmt.Sprintf(" --tcp-flags %s", chain.TcpFlags)
   130  		}
   131  	}
   132  
   133  	rules := []string{}
   134  
   135  	if len(chain.Ipsets) == 0 {
   136  		rules = append(rules, strings.TrimSpace(fmt.Sprintf("-A %s %s %s -j %s -w 5 %s", chain.Name, interfaceMatcher, chain.Device, chain.Target, protocolAndPort)))
   137  	}
   138  
   139  	for _, ipset := range chain.Ipsets {
   140  		rules = append(rules, strings.TrimSpace(fmt.Sprintf("-A %s %s %s -m set --match-set %s %s -j %s -w 5 %s",
   141  			chain.Name, interfaceMatcher, chain.Device, ipset, matchPart, chain.Target, protocolAndPort)))
   142  	}
   143  	err := iptables.createNewChain(&iptablesChain{
   144  		Name:  chain.Name,
   145  		Rules: rules,
   146  	})
   147  	if err != nil {
   148  		return err
   149  	}
   150  
   151  	if chain.Direction == pb.Chain_INPUT {
   152  		err := iptables.ensureRule(&iptablesChain{
   153  			Name: "CHAOS-INPUT",
   154  		}, "-A CHAOS-INPUT -j "+chain.Name)
   155  		if err != nil {
   156  			return err
   157  		}
   158  	} else if chain.Direction == pb.Chain_OUTPUT {
   159  		iptables.ensureRule(&iptablesChain{
   160  			Name: "CHAOS-OUTPUT",
   161  		}, "-A CHAOS-OUTPUT -j "+chain.Name)
   162  		if err != nil {
   163  			return err
   164  		}
   165  	} else {
   166  		return fmt.Errorf("unknown direction %d", chain.Direction)
   167  	}
   168  	return nil
   169  }
   170  
   171  func (iptables *iptablesClient) initializeEnv() error {
   172  	for _, direction := range []string{"INPUT", "OUTPUT"} {
   173  		chainName := "CHAOS-" + direction
   174  
   175  		err := iptables.createNewChain(&iptablesChain{
   176  			Name:  chainName,
   177  			Rules: []string{},
   178  		})
   179  		if err != nil {
   180  			return err
   181  		}
   182  
   183  		iptables.ensureRule(&iptablesChain{
   184  			Name:  direction,
   185  			Rules: []string{},
   186  		}, "-A "+direction+" -j "+chainName)
   187  	}
   188  
   189  	return nil
   190  }
   191  
   192  // createNewChain will cover existing chain
   193  func (iptables *iptablesClient) createNewChain(chain *iptablesChain) error {
   194  	processBuilder := bpm.DefaultProcessBuilder(iptablesCmd, "-w", "-N", chain.Name).SetContext(iptables.ctx)
   195  	if iptables.enterNS {
   196  		processBuilder = processBuilder.SetNS(iptables.pid, bpm.NetNS)
   197  	}
   198  	cmd := processBuilder.Build(iptables.ctx)
   199  	out, err := cmd.CombinedOutput()
   200  
   201  	if (err == nil && len(out) == 0) ||
   202  		(err != nil && strings.Contains(string(out), iptablesChainAlreadyExistErr)) {
   203  		// Successfully create a new chain
   204  		return iptables.deleteAndWriteRules(chain)
   205  	}
   206  
   207  	return encodeOutputToError(out, err)
   208  }
   209  
   210  // deleteAndWriteRules will remove all existing function in the chain
   211  // and replace with the new settings
   212  func (iptables *iptablesClient) deleteAndWriteRules(chain *iptablesChain) error {
   213  
   214  	// This chain should already exist
   215  	err := iptables.flushIptablesChain(chain)
   216  	if err != nil {
   217  		return err
   218  	}
   219  
   220  	for _, rule := range chain.Rules {
   221  		err := iptables.ensureRule(chain, rule)
   222  		if err != nil {
   223  			return err
   224  		}
   225  	}
   226  
   227  	return nil
   228  }
   229  
   230  func (iptables *iptablesClient) ensureRule(chain *iptablesChain, rule string) error {
   231  	processBuilder := bpm.DefaultProcessBuilder(iptablesCmd, "-w", "-S", chain.Name).SetContext(iptables.ctx)
   232  	if iptables.enterNS {
   233  		processBuilder = processBuilder.SetNS(iptables.pid, bpm.NetNS)
   234  	}
   235  	cmd := processBuilder.Build(iptables.ctx)
   236  	out, err := cmd.CombinedOutput()
   237  	if err != nil {
   238  		return encodeOutputToError(out, err)
   239  	}
   240  
   241  	if strings.Contains(string(out), rule) {
   242  		// The required rule already exist in chain
   243  		return nil
   244  	}
   245  
   246  	// TODO: lock on every container but not on chaos-daemon's `/run/xtables.lock`
   247  	processBuilder = bpm.DefaultProcessBuilder(iptablesCmd, strings.Split("-w "+rule, " ")...).SetContext(iptables.ctx)
   248  	if iptables.enterNS {
   249  		processBuilder = processBuilder.SetNS(iptables.pid, bpm.NetNS)
   250  	}
   251  	cmd = processBuilder.Build(iptables.ctx)
   252  	out, err = cmd.CombinedOutput()
   253  	if err != nil {
   254  		return encodeOutputToError(out, err)
   255  	}
   256  
   257  	return nil
   258  }
   259  
   260  func (iptables *iptablesClient) flushIptablesChain(chain *iptablesChain) error {
   261  	processBuilder := bpm.DefaultProcessBuilder(iptablesCmd, "-w", "-F", chain.Name).SetContext(iptables.ctx)
   262  	if iptables.enterNS {
   263  		processBuilder = processBuilder.SetNS(iptables.pid, bpm.NetNS)
   264  	}
   265  	cmd := processBuilder.Build(iptables.ctx)
   266  	out, err := cmd.CombinedOutput()
   267  	if err != nil {
   268  		return encodeOutputToError(out, err)
   269  	}
   270  
   271  	return nil
   272  }
   273