1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package chaosdaemon
17
18 import (
19 "context"
20 "fmt"
21 "strings"
22
23 "github.com/golang/protobuf/ptypes/empty"
24
25 "github.com/chaos-mesh/chaos-mesh/pkg/bpm"
26 pb "github.com/chaos-mesh/chaos-mesh/pkg/chaosdaemon/pb"
27 )
28
29 const (
30 iptablesCmd = "iptables"
31
32 iptablesChainAlreadyExistErr = "iptables: Chain already exists."
33 )
34
35 func (s *DaemonServer) SetIptablesChains(ctx context.Context, req *pb.IptablesChainsRequest) (*empty.Empty, error) {
36 log := s.getLoggerFromContext(ctx)
37 log.Info("Set iptables chains", "request", req)
38
39 pid, err := s.crClient.GetPidFromContainerID(ctx, req.ContainerId)
40 if err != nil {
41 log.Error(err, "error while getting PID")
42 return nil, err
43 }
44
45 iptables := buildIptablesClient(ctx, req.EnterNS, pid)
46 err = iptables.initializeEnv()
47 if err != nil {
48 log.Error(err, "error while initializing iptables")
49 return nil, err
50 }
51
52 err = iptables.setIptablesChains(req.Chains)
53 if err != nil {
54 log.Error(err, "error while setting iptables chains")
55 return nil, err
56 }
57
58 return &empty.Empty{}, nil
59 }
60
61 type iptablesClient struct {
62 ctx context.Context
63 enterNS bool
64 pid uint32
65 }
66
67 type iptablesChain struct {
68 Name string
69 Rules []string
70 }
71
72 func buildIptablesClient(ctx context.Context, enterNS bool, pid uint32) iptablesClient {
73 return iptablesClient{
74 ctx,
75 enterNS,
76 pid,
77 }
78 }
79
80 func (iptables *iptablesClient) setIptablesChains(chains []*pb.Chain) error {
81 for _, chain := range chains {
82 err := iptables.setIptablesChain(chain)
83 if err != nil {
84 return err
85 }
86 }
87
88 return nil
89 }
90
91 func (iptables *iptablesClient) setIptablesChain(chain *pb.Chain) error {
92 var matchPart string
93 var interfaceMatcher string
94 if chain.Direction == pb.Chain_INPUT {
95 matchPart = "src"
96 interfaceMatcher = "-i"
97 } else if chain.Direction == pb.Chain_OUTPUT {
98 matchPart = "dst"
99 interfaceMatcher = "-o"
100 } else {
101 return fmt.Errorf("unknown chain direction %d", chain.Direction)
102 }
103
104 if chain.Device == "" {
105 chain.Device = defaultDevice
106 }
107
108 protocolAndPort := ""
109 if len(chain.Protocol) > 0 {
110 protocolAndPort += fmt.Sprintf("--protocol %s", chain.Protocol)
111
112 if len(chain.SourcePorts) > 0 {
113 if strings.Contains(chain.SourcePorts, ",") {
114 protocolAndPort += fmt.Sprintf(" -m multiport --source-ports %s", chain.SourcePorts)
115 } else {
116 protocolAndPort += fmt.Sprintf(" --source-port %s", chain.SourcePorts)
117 }
118 }
119
120 if len(chain.DestinationPorts) > 0 {
121 if strings.Contains(chain.DestinationPorts, ",") {
122 protocolAndPort += fmt.Sprintf(" -m multiport --destination-ports %s", chain.DestinationPorts)
123 } else {
124 protocolAndPort += fmt.Sprintf(" --destination-port %s", chain.DestinationPorts)
125 }
126 }
127
128 if len(chain.TcpFlags) > 0 {
129 protocolAndPort += fmt.Sprintf(" --tcp-flags %s", chain.TcpFlags)
130 }
131 }
132
133 rules := []string{}
134
135 if len(chain.Ipsets) == 0 {
136 rules = append(rules, strings.TrimSpace(fmt.Sprintf("-A %s %s %s -j %s -w 5 %s", chain.Name, interfaceMatcher, chain.Device, chain.Target, protocolAndPort)))
137 }
138
139 for _, ipset := range chain.Ipsets {
140 rules = append(rules, strings.TrimSpace(fmt.Sprintf("-A %s %s %s -m set --match-set %s %s -j %s -w 5 %s",
141 chain.Name, interfaceMatcher, chain.Device, ipset, matchPart, chain.Target, protocolAndPort)))
142 }
143 err := iptables.createNewChain(&iptablesChain{
144 Name: chain.Name,
145 Rules: rules,
146 })
147 if err != nil {
148 return err
149 }
150
151 if chain.Direction == pb.Chain_INPUT {
152 err := iptables.ensureRule(&iptablesChain{
153 Name: "CHAOS-INPUT",
154 }, "-A CHAOS-INPUT -j "+chain.Name)
155 if err != nil {
156 return err
157 }
158 } else if chain.Direction == pb.Chain_OUTPUT {
159 iptables.ensureRule(&iptablesChain{
160 Name: "CHAOS-OUTPUT",
161 }, "-A CHAOS-OUTPUT -j "+chain.Name)
162 if err != nil {
163 return err
164 }
165 } else {
166 return fmt.Errorf("unknown direction %d", chain.Direction)
167 }
168 return nil
169 }
170
171 func (iptables *iptablesClient) initializeEnv() error {
172 for _, direction := range []string{"INPUT", "OUTPUT"} {
173 chainName := "CHAOS-" + direction
174
175 err := iptables.createNewChain(&iptablesChain{
176 Name: chainName,
177 Rules: []string{},
178 })
179 if err != nil {
180 return err
181 }
182
183 iptables.ensureRule(&iptablesChain{
184 Name: direction,
185 Rules: []string{},
186 }, "-A "+direction+" -j "+chainName)
187 }
188
189 return nil
190 }
191
192
193 func (iptables *iptablesClient) createNewChain(chain *iptablesChain) error {
194 processBuilder := bpm.DefaultProcessBuilder(iptablesCmd, "-w", "-N", chain.Name).SetContext(iptables.ctx)
195 if iptables.enterNS {
196 processBuilder = processBuilder.SetNS(iptables.pid, bpm.NetNS)
197 }
198 cmd := processBuilder.Build(iptables.ctx)
199 out, err := cmd.CombinedOutput()
200
201 if (err == nil && len(out) == 0) ||
202 (err != nil && strings.Contains(string(out), iptablesChainAlreadyExistErr)) {
203
204 return iptables.deleteAndWriteRules(chain)
205 }
206
207 return encodeOutputToError(out, err)
208 }
209
210
211
212 func (iptables *iptablesClient) deleteAndWriteRules(chain *iptablesChain) error {
213
214
215 err := iptables.flushIptablesChain(chain)
216 if err != nil {
217 return err
218 }
219
220 for _, rule := range chain.Rules {
221 err := iptables.ensureRule(chain, rule)
222 if err != nil {
223 return err
224 }
225 }
226
227 return nil
228 }
229
230 func (iptables *iptablesClient) ensureRule(chain *iptablesChain, rule string) error {
231 processBuilder := bpm.DefaultProcessBuilder(iptablesCmd, "-w", "-S", chain.Name).SetContext(iptables.ctx)
232 if iptables.enterNS {
233 processBuilder = processBuilder.SetNS(iptables.pid, bpm.NetNS)
234 }
235 cmd := processBuilder.Build(iptables.ctx)
236 out, err := cmd.CombinedOutput()
237 if err != nil {
238 return encodeOutputToError(out, err)
239 }
240
241 if strings.Contains(string(out), rule) {
242
243 return nil
244 }
245
246
247 processBuilder = bpm.DefaultProcessBuilder(iptablesCmd, strings.Split("-w "+rule, " ")...).SetContext(iptables.ctx)
248 if iptables.enterNS {
249 processBuilder = processBuilder.SetNS(iptables.pid, bpm.NetNS)
250 }
251 cmd = processBuilder.Build(iptables.ctx)
252 out, err = cmd.CombinedOutput()
253 if err != nil {
254 return encodeOutputToError(out, err)
255 }
256
257 return nil
258 }
259
260 func (iptables *iptablesClient) flushIptablesChain(chain *iptablesChain) error {
261 processBuilder := bpm.DefaultProcessBuilder(iptablesCmd, "-w", "-F", chain.Name).SetContext(iptables.ctx)
262 if iptables.enterNS {
263 processBuilder = processBuilder.SetNS(iptables.pid, bpm.NetNS)
264 }
265 cmd := processBuilder.Build(iptables.ctx)
266 out, err := cmd.CombinedOutput()
267 if err != nil {
268 return encodeOutputToError(out, err)
269 }
270
271 return nil
272 }
273