1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package chaosdaemon
17
18 import (
19 "context"
20 "fmt"
21 "strings"
22
23 "github.com/golang/protobuf/ptypes/empty"
24 "github.com/pkg/errors"
25
26 "github.com/chaos-mesh/chaos-mesh/pkg/bpm"
27 pb "github.com/chaos-mesh/chaos-mesh/pkg/chaosdaemon/pb"
28 "github.com/chaos-mesh/chaos-mesh/pkg/chaosdaemon/util"
29 )
30
31 const (
32 iptablesCmd = "iptables"
33
34 iptablesChainAlreadyExistErr = "iptables: Chain already exists."
35 )
36
37 func (s *DaemonServer) SetIptablesChains(ctx context.Context, req *pb.IptablesChainsRequest) (*empty.Empty, error) {
38 log := s.getLoggerFromContext(ctx)
39 log.Info("Set iptables chains", "request", req)
40
41 pid, err := s.crClient.GetPidFromContainerID(ctx, req.ContainerId)
42 if err != nil {
43 log.Error(err, "error while getting PID")
44 return nil, err
45 }
46
47 iptables := buildIptablesClient(ctx, req.EnterNS, pid)
48 err = iptables.initializeEnv()
49 if err != nil {
50 log.Error(err, "error while initializing iptables")
51 return nil, err
52 }
53
54 err = iptables.setIptablesChains(req.Chains)
55 if err != nil {
56 log.Error(err, "error while setting iptables chains")
57 return nil, err
58 }
59
60 return &empty.Empty{}, nil
61 }
62
63 type iptablesClient struct {
64 ctx context.Context
65 enterNS bool
66 pid uint32
67 }
68
69 type iptablesChain struct {
70 Name string
71 Rules []string
72 }
73
74 func buildIptablesClient(ctx context.Context, enterNS bool, pid uint32) iptablesClient {
75 return iptablesClient{
76 ctx,
77 enterNS,
78 pid,
79 }
80 }
81
82 func (iptables *iptablesClient) setIptablesChains(chains []*pb.Chain) error {
83 for _, chain := range chains {
84 err := iptables.setIptablesChain(chain)
85 if err != nil {
86 return err
87 }
88 }
89
90 return nil
91 }
92
93 func (iptables *iptablesClient) setIptablesChain(chain *pb.Chain) error {
94 var matchPart string
95 var interfaceMatcher string
96 if chain.Direction == pb.Chain_INPUT {
97 matchPart = "src,dst"
98 interfaceMatcher = "-i"
99 } else if chain.Direction == pb.Chain_OUTPUT {
100 matchPart = "dst,dst"
101 interfaceMatcher = "-o"
102 } else {
103 return errors.Errorf("unknown chain direction %d", chain.Direction)
104 }
105
106 if chain.Device == "" {
107 chain.Device = defaultDevice
108 }
109
110 protocolAndPort := ""
111 if len(chain.Protocol) > 0 {
112 protocolAndPort += fmt.Sprintf("--protocol %s", chain.Protocol)
113
114 if len(chain.SourcePorts) > 0 {
115 if strings.Contains(chain.SourcePorts, ",") {
116 protocolAndPort += fmt.Sprintf(" -m multiport --source-ports %s", chain.SourcePorts)
117 } else {
118 protocolAndPort += fmt.Sprintf(" --source-port %s", chain.SourcePorts)
119 }
120 }
121
122 if len(chain.DestinationPorts) > 0 {
123 if strings.Contains(chain.DestinationPorts, ",") {
124 protocolAndPort += fmt.Sprintf(" -m multiport --destination-ports %s", chain.DestinationPorts)
125 } else {
126 protocolAndPort += fmt.Sprintf(" --destination-port %s", chain.DestinationPorts)
127 }
128 }
129
130 if len(chain.TcpFlags) > 0 {
131 protocolAndPort += fmt.Sprintf(" --tcp-flags %s", chain.TcpFlags)
132 }
133 }
134
135 rules := []string{}
136
137 if len(chain.Ipsets) == 0 {
138 rules = append(rules, strings.TrimSpace(fmt.Sprintf("-A %s %s %s -j %s -w 5 %s", chain.Name, interfaceMatcher, chain.Device, chain.Target, protocolAndPort)))
139 }
140
141 for _, ipset := range chain.Ipsets {
142 rules = append(rules, strings.TrimSpace(fmt.Sprintf("-A %s %s %s -m set --match-set %s %s -j %s -w 5 %s",
143 chain.Name, interfaceMatcher, chain.Device, ipset, matchPart, chain.Target, protocolAndPort)))
144 }
145 err := iptables.createNewChain(&iptablesChain{
146 Name: chain.Name,
147 Rules: rules,
148 })
149 if err != nil {
150 return err
151 }
152
153 if chain.Direction == pb.Chain_INPUT {
154 err := iptables.ensureRule(&iptablesChain{
155 Name: "CHAOS-INPUT",
156 }, "-A CHAOS-INPUT -j "+chain.Name)
157 if err != nil {
158 return err
159 }
160 } else if chain.Direction == pb.Chain_OUTPUT {
161 iptables.ensureRule(&iptablesChain{
162 Name: "CHAOS-OUTPUT",
163 }, "-A CHAOS-OUTPUT -j "+chain.Name)
164 if err != nil {
165 return err
166 }
167 } else {
168 return errors.Errorf("unknown direction %d", chain.Direction)
169 }
170 return nil
171 }
172
173 func (iptables *iptablesClient) initializeEnv() error {
174 for _, direction := range []string{"INPUT", "OUTPUT"} {
175 chainName := "CHAOS-" + direction
176
177 err := iptables.createNewChain(&iptablesChain{
178 Name: chainName,
179 Rules: []string{},
180 })
181 if err != nil {
182 return err
183 }
184
185 iptables.ensureRule(&iptablesChain{
186 Name: direction,
187 Rules: []string{},
188 }, "-A "+direction+" -j "+chainName)
189 }
190
191 return nil
192 }
193
194
195 func (iptables *iptablesClient) createNewChain(chain *iptablesChain) error {
196 processBuilder := bpm.DefaultProcessBuilder(iptablesCmd, "-w", "-N", chain.Name).SetContext(iptables.ctx)
197 if iptables.enterNS {
198 processBuilder = processBuilder.SetNS(iptables.pid, bpm.NetNS)
199 }
200 cmd := processBuilder.Build(iptables.ctx)
201 out, err := cmd.CombinedOutput()
202
203 if (err == nil && len(out) == 0) ||
204 (err != nil && strings.Contains(string(out), iptablesChainAlreadyExistErr)) {
205
206 return iptables.deleteAndWriteRules(chain)
207 }
208
209 return util.EncodeOutputToError(out, err)
210 }
211
212
213
214 func (iptables *iptablesClient) deleteAndWriteRules(chain *iptablesChain) error {
215
216
217 err := iptables.flushIptablesChain(chain)
218 if err != nil {
219 return err
220 }
221
222 for _, rule := range chain.Rules {
223 err := iptables.ensureRule(chain, rule)
224 if err != nil {
225 return err
226 }
227 }
228
229 return nil
230 }
231
232 func (iptables *iptablesClient) ensureRule(chain *iptablesChain, rule string) error {
233 processBuilder := bpm.DefaultProcessBuilder(iptablesCmd, "-w", "-S", chain.Name).SetContext(iptables.ctx)
234 if iptables.enterNS {
235 processBuilder = processBuilder.SetNS(iptables.pid, bpm.NetNS)
236 }
237 cmd := processBuilder.Build(iptables.ctx)
238 out, err := cmd.CombinedOutput()
239 if err != nil {
240 return util.EncodeOutputToError(out, err)
241 }
242
243 if strings.Contains(string(out), rule) {
244
245 return nil
246 }
247
248
249 processBuilder = bpm.DefaultProcessBuilder(iptablesCmd, strings.Split("-w "+rule, " ")...).SetContext(iptables.ctx)
250 if iptables.enterNS {
251 processBuilder = processBuilder.SetNS(iptables.pid, bpm.NetNS)
252 }
253 cmd = processBuilder.Build(iptables.ctx)
254 out, err = cmd.CombinedOutput()
255 if err != nil {
256 return util.EncodeOutputToError(out, err)
257 }
258
259 return nil
260 }
261
262 func (iptables *iptablesClient) flushIptablesChain(chain *iptablesChain) error {
263 processBuilder := bpm.DefaultProcessBuilder(iptablesCmd, "-w", "-F", chain.Name).SetContext(iptables.ctx)
264 if iptables.enterNS {
265 processBuilder = processBuilder.SetNS(iptables.pid, bpm.NetNS)
266 }
267 cmd := processBuilder.Build(iptables.ctx)
268 out, err := cmd.CombinedOutput()
269 if err != nil {
270 return util.EncodeOutputToError(out, err)
271 }
272
273 return nil
274 }
275