...

Source file src/github.com/chaos-mesh/chaos-mesh/pkg/chaosdaemon/iptables_server.go

Documentation: github.com/chaos-mesh/chaos-mesh/pkg/chaosdaemon

     1  // Copyright 2021 Chaos Mesh Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  // http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  //
    15  
    16  package chaosdaemon
    17  
    18  import (
    19  	"context"
    20  	"fmt"
    21  	"strings"
    22  
    23  	"github.com/golang/protobuf/ptypes/empty"
    24  
    25  	"github.com/chaos-mesh/chaos-mesh/pkg/bpm"
    26  	pb "github.com/chaos-mesh/chaos-mesh/pkg/chaosdaemon/pb"
    27  )
    28  
    29  const (
    30  	iptablesCmd = "iptables"
    31  
    32  	iptablesChainAlreadyExistErr = "iptables: Chain already exists."
    33  )
    34  
    35  func (s *DaemonServer) SetIptablesChains(ctx context.Context, req *pb.IptablesChainsRequest) (*empty.Empty, error) {
    36  	log.Info("Set iptables chains", "request", req)
    37  
    38  	pid, err := s.crClient.GetPidFromContainerID(ctx, req.ContainerId)
    39  	if err != nil {
    40  		log.Error(err, "error while getting PID")
    41  		return nil, err
    42  	}
    43  
    44  	iptables := buildIptablesClient(ctx, req.EnterNS, pid)
    45  	err = iptables.initializeEnv()
    46  	if err != nil {
    47  		log.Error(err, "error while initializing iptables")
    48  		return nil, err
    49  	}
    50  
    51  	err = iptables.setIptablesChains(req.Chains)
    52  	if err != nil {
    53  		log.Error(err, "error while setting iptables chains")
    54  		return nil, err
    55  	}
    56  
    57  	return &empty.Empty{}, nil
    58  }
    59  
    60  type iptablesClient struct {
    61  	ctx     context.Context
    62  	enterNS bool
    63  	pid     uint32
    64  }
    65  
    66  type iptablesChain struct {
    67  	Name  string
    68  	Rules []string
    69  }
    70  
    71  func buildIptablesClient(ctx context.Context, enterNS bool, pid uint32) iptablesClient {
    72  	return iptablesClient{
    73  		ctx,
    74  		enterNS,
    75  		pid,
    76  	}
    77  }
    78  
    79  func (iptables *iptablesClient) setIptablesChains(chains []*pb.Chain) error {
    80  	for _, chain := range chains {
    81  		err := iptables.setIptablesChain(chain)
    82  		if err != nil {
    83  			return err
    84  		}
    85  	}
    86  
    87  	return nil
    88  }
    89  
    90  func (iptables *iptablesClient) setIptablesChain(chain *pb.Chain) error {
    91  	var matchPart string
    92  	var interfaceMatcher string
    93  	if chain.Direction == pb.Chain_INPUT {
    94  		matchPart = "src"
    95  		interfaceMatcher = "-i"
    96  	} else if chain.Direction == pb.Chain_OUTPUT {
    97  		matchPart = "dst"
    98  		interfaceMatcher = "-o"
    99  	} else {
   100  		return fmt.Errorf("unknown chain direction %d", chain.Direction)
   101  	}
   102  
   103  	if chain.Device == "" {
   104  		chain.Device = defaultDevice
   105  	}
   106  
   107  	protocolAndPort := ""
   108  	if len(chain.Protocol) > 0 {
   109  		protocolAndPort += fmt.Sprintf("--protocol %s", chain.Protocol)
   110  
   111  		if len(chain.SourcePorts) > 0 {
   112  			if strings.Contains(chain.SourcePorts, ",") {
   113  				protocolAndPort += fmt.Sprintf(" -m multiport --source-ports %s", chain.SourcePorts)
   114  			} else {
   115  				protocolAndPort += fmt.Sprintf(" --source-port %s", chain.SourcePorts)
   116  			}
   117  		}
   118  
   119  		if len(chain.DestinationPorts) > 0 {
   120  			if strings.Contains(chain.DestinationPorts, ",") {
   121  				protocolAndPort += fmt.Sprintf(" -m multiport --destination-ports %s", chain.DestinationPorts)
   122  			} else {
   123  				protocolAndPort += fmt.Sprintf(" --destination-port %s", chain.DestinationPorts)
   124  			}
   125  		}
   126  
   127  		if len(chain.TcpFlags) > 0 {
   128  			protocolAndPort += fmt.Sprintf(" --tcp-flags %s", chain.TcpFlags)
   129  		}
   130  	}
   131  
   132  	rules := []string{}
   133  
   134  	if len(chain.Ipsets) == 0 {
   135  		rules = append(rules, strings.TrimSpace(fmt.Sprintf("-A %s %s %s -j %s -w 5 %s", chain.Name, interfaceMatcher, chain.Device, chain.Target, protocolAndPort)))
   136  	}
   137  
   138  	for _, ipset := range chain.Ipsets {
   139  		rules = append(rules, strings.TrimSpace(fmt.Sprintf("-A %s %s %s -m set --match-set %s %s -j %s -w 5 %s",
   140  			chain.Name, interfaceMatcher, chain.Device, ipset, matchPart, chain.Target, protocolAndPort)))
   141  	}
   142  	err := iptables.createNewChain(&iptablesChain{
   143  		Name:  chain.Name,
   144  		Rules: rules,
   145  	})
   146  	if err != nil {
   147  		return err
   148  	}
   149  
   150  	if chain.Direction == pb.Chain_INPUT {
   151  		err := iptables.ensureRule(&iptablesChain{
   152  			Name: "CHAOS-INPUT",
   153  		}, "-A CHAOS-INPUT -j "+chain.Name)
   154  		if err != nil {
   155  			return err
   156  		}
   157  	} else if chain.Direction == pb.Chain_OUTPUT {
   158  		iptables.ensureRule(&iptablesChain{
   159  			Name: "CHAOS-OUTPUT",
   160  		}, "-A CHAOS-OUTPUT -j "+chain.Name)
   161  		if err != nil {
   162  			return err
   163  		}
   164  	} else {
   165  		return fmt.Errorf("unknown direction %d", chain.Direction)
   166  	}
   167  	return nil
   168  }
   169  
   170  func (iptables *iptablesClient) initializeEnv() error {
   171  	for _, direction := range []string{"INPUT", "OUTPUT"} {
   172  		chainName := "CHAOS-" + direction
   173  
   174  		err := iptables.createNewChain(&iptablesChain{
   175  			Name:  chainName,
   176  			Rules: []string{},
   177  		})
   178  		if err != nil {
   179  			return err
   180  		}
   181  
   182  		iptables.ensureRule(&iptablesChain{
   183  			Name:  direction,
   184  			Rules: []string{},
   185  		}, "-A "+direction+" -j "+chainName)
   186  	}
   187  
   188  	return nil
   189  }
   190  
   191  // createNewChain will cover existing chain
   192  func (iptables *iptablesClient) createNewChain(chain *iptablesChain) error {
   193  	processBuilder := bpm.DefaultProcessBuilder(iptablesCmd, "-w", "-N", chain.Name).SetContext(iptables.ctx)
   194  	if iptables.enterNS {
   195  		processBuilder = processBuilder.SetNS(iptables.pid, bpm.NetNS)
   196  	}
   197  	cmd := processBuilder.Build()
   198  	out, err := cmd.CombinedOutput()
   199  
   200  	if (err == nil && len(out) == 0) ||
   201  		(err != nil && strings.Contains(string(out), iptablesChainAlreadyExistErr)) {
   202  		// Successfully create a new chain
   203  		return iptables.deleteAndWriteRules(chain)
   204  	}
   205  
   206  	return encodeOutputToError(out, err)
   207  }
   208  
   209  // deleteAndWriteRules will remove all existing function in the chain
   210  // and replace with the new settings
   211  func (iptables *iptablesClient) deleteAndWriteRules(chain *iptablesChain) error {
   212  
   213  	// This chain should already exist
   214  	err := iptables.flushIptablesChain(chain)
   215  	if err != nil {
   216  		return err
   217  	}
   218  
   219  	for _, rule := range chain.Rules {
   220  		err := iptables.ensureRule(chain, rule)
   221  		if err != nil {
   222  			return err
   223  		}
   224  	}
   225  
   226  	return nil
   227  }
   228  
   229  func (iptables *iptablesClient) ensureRule(chain *iptablesChain, rule string) error {
   230  	processBuilder := bpm.DefaultProcessBuilder(iptablesCmd, "-w", "-S", chain.Name).SetContext(iptables.ctx)
   231  	if iptables.enterNS {
   232  		processBuilder = processBuilder.SetNS(iptables.pid, bpm.NetNS)
   233  	}
   234  	cmd := processBuilder.Build()
   235  	out, err := cmd.CombinedOutput()
   236  	if err != nil {
   237  		return encodeOutputToError(out, err)
   238  	}
   239  
   240  	if strings.Contains(string(out), rule) {
   241  		// The required rule already exist in chain
   242  		return nil
   243  	}
   244  
   245  	// TODO: lock on every container but not on chaos-daemon's `/run/xtables.lock`
   246  	processBuilder = bpm.DefaultProcessBuilder(iptablesCmd, strings.Split("-w "+rule, " ")...).SetContext(iptables.ctx)
   247  	if iptables.enterNS {
   248  		processBuilder = processBuilder.SetNS(iptables.pid, bpm.NetNS)
   249  	}
   250  	cmd = processBuilder.Build()
   251  	out, err = cmd.CombinedOutput()
   252  	if err != nil {
   253  		return encodeOutputToError(out, err)
   254  	}
   255  
   256  	return nil
   257  }
   258  
   259  func (iptables *iptablesClient) flushIptablesChain(chain *iptablesChain) error {
   260  	processBuilder := bpm.DefaultProcessBuilder(iptablesCmd, "-w", "-F", chain.Name).SetContext(iptables.ctx)
   261  	if iptables.enterNS {
   262  		processBuilder = processBuilder.SetNS(iptables.pid, bpm.NetNS)
   263  	}
   264  	cmd := processBuilder.Build()
   265  	out, err := cmd.CombinedOutput()
   266  	if err != nil {
   267  		return encodeOutputToError(out, err)
   268  	}
   269  
   270  	return nil
   271  }
   272