...

Source file src/github.com/chaos-mesh/chaos-mesh/pkg/chaosdaemon/iptables_server.go

Documentation: github.com/chaos-mesh/chaos-mesh/pkg/chaosdaemon

     1  // Copyright 2020 Chaos Mesh Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  package chaosdaemon
    15  
    16  import (
    17  	"context"
    18  	"fmt"
    19  	"strings"
    20  
    21  	"github.com/golang/protobuf/ptypes/empty"
    22  
    23  	"github.com/chaos-mesh/chaos-mesh/pkg/bpm"
    24  	pb "github.com/chaos-mesh/chaos-mesh/pkg/chaosdaemon/pb"
    25  )
    26  
    27  const (
    28  	iptablesCmd = "iptables"
    29  
    30  	iptablesChainAlreadyExistErr = "iptables: Chain already exists."
    31  )
    32  
    33  func (s *DaemonServer) SetIptablesChains(ctx context.Context, req *pb.IptablesChainsRequest) (*empty.Empty, error) {
    34  	log.Info("Set iptables chains", "request", req)
    35  
    36  	pid, err := s.crClient.GetPidFromContainerID(ctx, req.ContainerId)
    37  	if err != nil {
    38  		log.Error(err, "error while getting PID")
    39  		return nil, err
    40  	}
    41  
    42  	iptables := buildIptablesClient(ctx, req.EnterNS, pid)
    43  	err = iptables.initializeEnv()
    44  	if err != nil {
    45  		log.Error(err, "error while initializing iptables")
    46  		return nil, err
    47  	}
    48  
    49  	err = iptables.setIptablesChains(req.Chains)
    50  	if err != nil {
    51  		log.Error(err, "error while setting iptables chains")
    52  		return nil, err
    53  	}
    54  
    55  	return &empty.Empty{}, nil
    56  }
    57  
    58  type iptablesClient struct {
    59  	ctx     context.Context
    60  	enterNS bool
    61  	pid     uint32
    62  }
    63  
    64  type iptablesChain struct {
    65  	Name  string
    66  	Rules []string
    67  }
    68  
    69  func buildIptablesClient(ctx context.Context, enterNS bool, pid uint32) iptablesClient {
    70  	return iptablesClient{
    71  		ctx,
    72  		enterNS,
    73  		pid,
    74  	}
    75  }
    76  
    77  func (iptables *iptablesClient) setIptablesChains(chains []*pb.Chain) error {
    78  	for _, chain := range chains {
    79  		err := iptables.setIptablesChain(chain)
    80  		if err != nil {
    81  			return err
    82  		}
    83  	}
    84  
    85  	return nil
    86  }
    87  
    88  func (iptables *iptablesClient) setIptablesChain(chain *pb.Chain) error {
    89  	var matchPart string
    90  	if chain.Direction == pb.Chain_INPUT {
    91  		matchPart = "src"
    92  	} else if chain.Direction == pb.Chain_OUTPUT {
    93  		matchPart = "dst"
    94  	} else {
    95  		return fmt.Errorf("unknown chain direction %d", chain.Direction)
    96  	}
    97  
    98  	protocolAndPort := ""
    99  	if len(chain.Protocol) > 0 {
   100  		protocolAndPort += fmt.Sprintf("--protocol %s", chain.Protocol)
   101  
   102  		if len(chain.SourcePorts) > 0 {
   103  			if strings.Contains(chain.SourcePorts, ",") {
   104  				protocolAndPort += fmt.Sprintf(" -m multiport --source-ports %s", chain.SourcePorts)
   105  			} else {
   106  				protocolAndPort += fmt.Sprintf(" --source-port %s", chain.SourcePorts)
   107  			}
   108  		}
   109  
   110  		if len(chain.DestinationPorts) > 0 {
   111  			if strings.Contains(chain.DestinationPorts, ",") {
   112  				protocolAndPort += fmt.Sprintf(" -m multiport --destination-ports %s", chain.DestinationPorts)
   113  			} else {
   114  				protocolAndPort += fmt.Sprintf(" --destination-port %s", chain.DestinationPorts)
   115  			}
   116  		}
   117  
   118  		if len(chain.TcpFlags) > 0 {
   119  			protocolAndPort += fmt.Sprintf(" --tcp-flags %s", chain.TcpFlags)
   120  		}
   121  	}
   122  
   123  	rules := []string{}
   124  
   125  	if len(chain.Ipsets) == 0 {
   126  		rules = append(rules, strings.TrimSpace(fmt.Sprintf("-A %s -j %s -w 5 %s", chain.Name, chain.Target, protocolAndPort)))
   127  	}
   128  
   129  	for _, ipset := range chain.Ipsets {
   130  		rules = append(rules, strings.TrimSpace(fmt.Sprintf("-A %s -m set --match-set %s %s -j %s -w 5 %s",
   131  			chain.Name, ipset, matchPart, chain.Target, protocolAndPort)))
   132  	}
   133  	err := iptables.createNewChain(&iptablesChain{
   134  		Name:  chain.Name,
   135  		Rules: rules,
   136  	})
   137  	if err != nil {
   138  		return err
   139  	}
   140  
   141  	if chain.Direction == pb.Chain_INPUT {
   142  		err := iptables.ensureRule(&iptablesChain{
   143  			Name: "CHAOS-INPUT",
   144  		}, "-A CHAOS-INPUT -j "+chain.Name)
   145  		if err != nil {
   146  			return err
   147  		}
   148  	} else if chain.Direction == pb.Chain_OUTPUT {
   149  		iptables.ensureRule(&iptablesChain{
   150  			Name: "CHAOS-OUTPUT",
   151  		}, "-A CHAOS-OUTPUT -j "+chain.Name)
   152  		if err != nil {
   153  			return err
   154  		}
   155  	} else {
   156  		return fmt.Errorf("unknown direction %d", chain.Direction)
   157  	}
   158  	return nil
   159  }
   160  
   161  func (iptables *iptablesClient) initializeEnv() error {
   162  	for _, direction := range []string{"INPUT", "OUTPUT"} {
   163  		chainName := "CHAOS-" + direction
   164  
   165  		err := iptables.createNewChain(&iptablesChain{
   166  			Name:  chainName,
   167  			Rules: []string{},
   168  		})
   169  		if err != nil {
   170  			return err
   171  		}
   172  
   173  		iptables.ensureRule(&iptablesChain{
   174  			Name:  direction,
   175  			Rules: []string{},
   176  		}, "-A "+direction+" -j "+chainName)
   177  	}
   178  
   179  	return nil
   180  }
   181  
   182  // createNewChain will cover existing chain
   183  func (iptables *iptablesClient) createNewChain(chain *iptablesChain) error {
   184  	processBuilder := bpm.DefaultProcessBuilder(iptablesCmd, "-w", "-N", chain.Name).SetContext(iptables.ctx)
   185  	if iptables.enterNS {
   186  		processBuilder = processBuilder.SetNS(iptables.pid, bpm.NetNS)
   187  	}
   188  	cmd := processBuilder.Build()
   189  	out, err := cmd.CombinedOutput()
   190  
   191  	if (err == nil && len(out) == 0) ||
   192  		(err != nil && strings.Contains(string(out), iptablesChainAlreadyExistErr)) {
   193  		// Successfully create a new chain
   194  		return iptables.deleteAndWriteRules(chain)
   195  	}
   196  
   197  	return encodeOutputToError(out, err)
   198  }
   199  
   200  // deleteAndWriteRules will remove all existing function in the chain
   201  // and replace with the new settings
   202  func (iptables *iptablesClient) deleteAndWriteRules(chain *iptablesChain) error {
   203  
   204  	// This chain should already exist
   205  	err := iptables.flushIptablesChain(chain)
   206  	if err != nil {
   207  		return err
   208  	}
   209  
   210  	for _, rule := range chain.Rules {
   211  		err := iptables.ensureRule(chain, rule)
   212  		if err != nil {
   213  			return err
   214  		}
   215  	}
   216  
   217  	return nil
   218  }
   219  
   220  func (iptables *iptablesClient) ensureRule(chain *iptablesChain, rule string) error {
   221  	processBuilder := bpm.DefaultProcessBuilder(iptablesCmd, "-w", "-S", chain.Name).SetContext(iptables.ctx)
   222  	if iptables.enterNS {
   223  		processBuilder = processBuilder.SetNS(iptables.pid, bpm.NetNS)
   224  	}
   225  	cmd := processBuilder.Build()
   226  	out, err := cmd.CombinedOutput()
   227  	if err != nil {
   228  		return encodeOutputToError(out, err)
   229  	}
   230  
   231  	if strings.Contains(string(out), rule) {
   232  		// The required rule already exist in chain
   233  		return nil
   234  	}
   235  
   236  	// TODO: lock on every container but not on chaos-daemon's `/run/xtables.lock`
   237  	processBuilder = bpm.DefaultProcessBuilder(iptablesCmd, strings.Split("-w "+rule, " ")...).SetContext(iptables.ctx)
   238  	if iptables.enterNS {
   239  		processBuilder = processBuilder.SetNS(iptables.pid, bpm.NetNS)
   240  	}
   241  	cmd = processBuilder.Build()
   242  	out, err = cmd.CombinedOutput()
   243  	if err != nil {
   244  		return encodeOutputToError(out, err)
   245  	}
   246  
   247  	return nil
   248  }
   249  
   250  func (iptables *iptablesClient) flushIptablesChain(chain *iptablesChain) error {
   251  	processBuilder := bpm.DefaultProcessBuilder(iptablesCmd, "-w", "-F", chain.Name).SetContext(iptables.ctx)
   252  	if iptables.enterNS {
   253  		processBuilder = processBuilder.SetNS(iptables.pid, bpm.NetNS)
   254  	}
   255  	cmd := processBuilder.Build()
   256  	out, err := cmd.CombinedOutput()
   257  	if err != nil {
   258  		return encodeOutputToError(out, err)
   259  	}
   260  
   261  	return nil
   262  }
   263