1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package chaosdaemon
17
18 import (
19 "context"
20 "fmt"
21 "strings"
22
23 "github.com/golang/protobuf/ptypes/empty"
24
25 "github.com/chaos-mesh/chaos-mesh/pkg/bpm"
26 pb "github.com/chaos-mesh/chaos-mesh/pkg/chaosdaemon/pb"
27 )
28
29 const (
30 iptablesCmd = "iptables"
31
32 iptablesChainAlreadyExistErr = "iptables: Chain already exists."
33 )
34
35 func (s *DaemonServer) SetIptablesChains(ctx context.Context, req *pb.IptablesChainsRequest) (*empty.Empty, error) {
36 log.Info("Set iptables chains", "request", req)
37
38 pid, err := s.crClient.GetPidFromContainerID(ctx, req.ContainerId)
39 if err != nil {
40 log.Error(err, "error while getting PID")
41 return nil, err
42 }
43
44 iptables := buildIptablesClient(ctx, req.EnterNS, pid)
45 err = iptables.initializeEnv()
46 if err != nil {
47 log.Error(err, "error while initializing iptables")
48 return nil, err
49 }
50
51 err = iptables.setIptablesChains(req.Chains)
52 if err != nil {
53 log.Error(err, "error while setting iptables chains")
54 return nil, err
55 }
56
57 return &empty.Empty{}, nil
58 }
59
60 type iptablesClient struct {
61 ctx context.Context
62 enterNS bool
63 pid uint32
64 }
65
66 type iptablesChain struct {
67 Name string
68 Rules []string
69 }
70
71 func buildIptablesClient(ctx context.Context, enterNS bool, pid uint32) iptablesClient {
72 return iptablesClient{
73 ctx,
74 enterNS,
75 pid,
76 }
77 }
78
79 func (iptables *iptablesClient) setIptablesChains(chains []*pb.Chain) error {
80 for _, chain := range chains {
81 err := iptables.setIptablesChain(chain)
82 if err != nil {
83 return err
84 }
85 }
86
87 return nil
88 }
89
90 func (iptables *iptablesClient) setIptablesChain(chain *pb.Chain) error {
91 var matchPart string
92 var interfaceMatcher string
93 if chain.Direction == pb.Chain_INPUT {
94 matchPart = "src"
95 interfaceMatcher = "-i"
96 } else if chain.Direction == pb.Chain_OUTPUT {
97 matchPart = "dst"
98 interfaceMatcher = "-o"
99 } else {
100 return fmt.Errorf("unknown chain direction %d", chain.Direction)
101 }
102
103 if chain.Device == "" {
104 chain.Device = defaultDevice
105 }
106
107 protocolAndPort := ""
108 if len(chain.Protocol) > 0 {
109 protocolAndPort += fmt.Sprintf("--protocol %s", chain.Protocol)
110
111 if len(chain.SourcePorts) > 0 {
112 if strings.Contains(chain.SourcePorts, ",") {
113 protocolAndPort += fmt.Sprintf(" -m multiport --source-ports %s", chain.SourcePorts)
114 } else {
115 protocolAndPort += fmt.Sprintf(" --source-port %s", chain.SourcePorts)
116 }
117 }
118
119 if len(chain.DestinationPorts) > 0 {
120 if strings.Contains(chain.DestinationPorts, ",") {
121 protocolAndPort += fmt.Sprintf(" -m multiport --destination-ports %s", chain.DestinationPorts)
122 } else {
123 protocolAndPort += fmt.Sprintf(" --destination-port %s", chain.DestinationPorts)
124 }
125 }
126
127 if len(chain.TcpFlags) > 0 {
128 protocolAndPort += fmt.Sprintf(" --tcp-flags %s", chain.TcpFlags)
129 }
130 }
131
132 rules := []string{}
133
134 if len(chain.Ipsets) == 0 {
135 rules = append(rules, strings.TrimSpace(fmt.Sprintf("-A %s %s %s -j %s -w 5 %s", chain.Name, interfaceMatcher, chain.Device, chain.Target, protocolAndPort)))
136 }
137
138 for _, ipset := range chain.Ipsets {
139 rules = append(rules, strings.TrimSpace(fmt.Sprintf("-A %s %s %s -m set --match-set %s %s -j %s -w 5 %s",
140 chain.Name, interfaceMatcher, chain.Device, ipset, matchPart, chain.Target, protocolAndPort)))
141 }
142 err := iptables.createNewChain(&iptablesChain{
143 Name: chain.Name,
144 Rules: rules,
145 })
146 if err != nil {
147 return err
148 }
149
150 if chain.Direction == pb.Chain_INPUT {
151 err := iptables.ensureRule(&iptablesChain{
152 Name: "CHAOS-INPUT",
153 }, "-A CHAOS-INPUT -j "+chain.Name)
154 if err != nil {
155 return err
156 }
157 } else if chain.Direction == pb.Chain_OUTPUT {
158 iptables.ensureRule(&iptablesChain{
159 Name: "CHAOS-OUTPUT",
160 }, "-A CHAOS-OUTPUT -j "+chain.Name)
161 if err != nil {
162 return err
163 }
164 } else {
165 return fmt.Errorf("unknown direction %d", chain.Direction)
166 }
167 return nil
168 }
169
170 func (iptables *iptablesClient) initializeEnv() error {
171 for _, direction := range []string{"INPUT", "OUTPUT"} {
172 chainName := "CHAOS-" + direction
173
174 err := iptables.createNewChain(&iptablesChain{
175 Name: chainName,
176 Rules: []string{},
177 })
178 if err != nil {
179 return err
180 }
181
182 iptables.ensureRule(&iptablesChain{
183 Name: direction,
184 Rules: []string{},
185 }, "-A "+direction+" -j "+chainName)
186 }
187
188 return nil
189 }
190
191
192 func (iptables *iptablesClient) createNewChain(chain *iptablesChain) error {
193 processBuilder := bpm.DefaultProcessBuilder(iptablesCmd, "-w", "-N", chain.Name).SetContext(iptables.ctx)
194 if iptables.enterNS {
195 processBuilder = processBuilder.SetNS(iptables.pid, bpm.NetNS)
196 }
197 cmd := processBuilder.Build()
198 out, err := cmd.CombinedOutput()
199
200 if (err == nil && len(out) == 0) ||
201 (err != nil && strings.Contains(string(out), iptablesChainAlreadyExistErr)) {
202
203 return iptables.deleteAndWriteRules(chain)
204 }
205
206 return encodeOutputToError(out, err)
207 }
208
209
210
211 func (iptables *iptablesClient) deleteAndWriteRules(chain *iptablesChain) error {
212
213
214 err := iptables.flushIptablesChain(chain)
215 if err != nil {
216 return err
217 }
218
219 for _, rule := range chain.Rules {
220 err := iptables.ensureRule(chain, rule)
221 if err != nil {
222 return err
223 }
224 }
225
226 return nil
227 }
228
229 func (iptables *iptablesClient) ensureRule(chain *iptablesChain, rule string) error {
230 processBuilder := bpm.DefaultProcessBuilder(iptablesCmd, "-w", "-S", chain.Name).SetContext(iptables.ctx)
231 if iptables.enterNS {
232 processBuilder = processBuilder.SetNS(iptables.pid, bpm.NetNS)
233 }
234 cmd := processBuilder.Build()
235 out, err := cmd.CombinedOutput()
236 if err != nil {
237 return encodeOutputToError(out, err)
238 }
239
240 if strings.Contains(string(out), rule) {
241
242 return nil
243 }
244
245
246 processBuilder = bpm.DefaultProcessBuilder(iptablesCmd, strings.Split("-w "+rule, " ")...).SetContext(iptables.ctx)
247 if iptables.enterNS {
248 processBuilder = processBuilder.SetNS(iptables.pid, bpm.NetNS)
249 }
250 cmd = processBuilder.Build()
251 out, err = cmd.CombinedOutput()
252 if err != nil {
253 return encodeOutputToError(out, err)
254 }
255
256 return nil
257 }
258
259 func (iptables *iptablesClient) flushIptablesChain(chain *iptablesChain) error {
260 processBuilder := bpm.DefaultProcessBuilder(iptablesCmd, "-w", "-F", chain.Name).SetContext(iptables.ctx)
261 if iptables.enterNS {
262 processBuilder = processBuilder.SetNS(iptables.pid, bpm.NetNS)
263 }
264 cmd := processBuilder.Build()
265 out, err := cmd.CombinedOutput()
266 if err != nil {
267 return encodeOutputToError(out, err)
268 }
269
270 return nil
271 }
272