1
2
3
4
5
6
7
8
9
10
11
12
13
14 package chaosdaemon
15
16 import (
17 "context"
18 "fmt"
19 "strings"
20
21 "github.com/golang/protobuf/ptypes/empty"
22
23 "github.com/chaos-mesh/chaos-mesh/pkg/bpm"
24 pb "github.com/chaos-mesh/chaos-mesh/pkg/chaosdaemon/pb"
25 )
26
27 const (
28 iptablesCmd = "iptables"
29
30 iptablesChainAlreadyExistErr = "iptables: Chain already exists."
31 )
32
33 func (s *DaemonServer) SetIptablesChains(ctx context.Context, req *pb.IptablesChainsRequest) (*empty.Empty, error) {
34 log.Info("Set iptables chains", "request", req)
35
36 pid, err := s.crClient.GetPidFromContainerID(ctx, req.ContainerId)
37 if err != nil {
38 log.Error(err, "error while getting PID")
39 return nil, err
40 }
41
42 iptables := buildIptablesClient(ctx, req.EnterNS, pid)
43 err = iptables.initializeEnv()
44 if err != nil {
45 log.Error(err, "error while initializing iptables")
46 return nil, err
47 }
48
49 err = iptables.setIptablesChains(req.Chains)
50 if err != nil {
51 log.Error(err, "error while setting iptables chains")
52 return nil, err
53 }
54
55 return &empty.Empty{}, nil
56 }
57
58 type iptablesClient struct {
59 ctx context.Context
60 enterNS bool
61 pid uint32
62 }
63
64 type iptablesChain struct {
65 Name string
66 Rules []string
67 }
68
69 func buildIptablesClient(ctx context.Context, enterNS bool, pid uint32) iptablesClient {
70 return iptablesClient{
71 ctx,
72 enterNS,
73 pid,
74 }
75 }
76
77 func (iptables *iptablesClient) setIptablesChains(chains []*pb.Chain) error {
78 for _, chain := range chains {
79 err := iptables.setIptablesChain(chain)
80 if err != nil {
81 return err
82 }
83 }
84
85 return nil
86 }
87
88 func (iptables *iptablesClient) setIptablesChain(chain *pb.Chain) error {
89 var matchPart string
90 if chain.Direction == pb.Chain_INPUT {
91 matchPart = "src"
92 } else if chain.Direction == pb.Chain_OUTPUT {
93 matchPart = "dst"
94 } else {
95 return fmt.Errorf("unknown chain direction %d", chain.Direction)
96 }
97
98 protocolAndPort := ""
99 if len(chain.Protocol) > 0 {
100 protocolAndPort += fmt.Sprintf("--protocol %s", chain.Protocol)
101
102 if len(chain.SourcePorts) > 0 {
103 if strings.Contains(chain.SourcePorts, ",") {
104 protocolAndPort += fmt.Sprintf(" -m multiport --source-ports %s", chain.SourcePorts)
105 } else {
106 protocolAndPort += fmt.Sprintf(" --source-port %s", chain.SourcePorts)
107 }
108 }
109
110 if len(chain.DestinationPorts) > 0 {
111 if strings.Contains(chain.DestinationPorts, ",") {
112 protocolAndPort += fmt.Sprintf(" -m multiport --destination-ports %s", chain.DestinationPorts)
113 } else {
114 protocolAndPort += fmt.Sprintf(" --destination-port %s", chain.DestinationPorts)
115 }
116 }
117
118 if len(chain.TcpFlags) > 0 {
119 protocolAndPort += fmt.Sprintf(" --tcp-flags %s", chain.TcpFlags)
120 }
121 }
122
123 rules := []string{}
124
125 if len(chain.Ipsets) == 0 {
126 rules = append(rules, strings.TrimSpace(fmt.Sprintf("-A %s -j %s -w 5 %s", chain.Name, chain.Target, protocolAndPort)))
127 }
128
129 for _, ipset := range chain.Ipsets {
130 rules = append(rules, strings.TrimSpace(fmt.Sprintf("-A %s -m set --match-set %s %s -j %s -w 5 %s",
131 chain.Name, ipset, matchPart, chain.Target, protocolAndPort)))
132 }
133 err := iptables.createNewChain(&iptablesChain{
134 Name: chain.Name,
135 Rules: rules,
136 })
137 if err != nil {
138 return err
139 }
140
141 if chain.Direction == pb.Chain_INPUT {
142 err := iptables.ensureRule(&iptablesChain{
143 Name: "CHAOS-INPUT",
144 }, "-A CHAOS-INPUT -j "+chain.Name)
145 if err != nil {
146 return err
147 }
148 } else if chain.Direction == pb.Chain_OUTPUT {
149 iptables.ensureRule(&iptablesChain{
150 Name: "CHAOS-OUTPUT",
151 }, "-A CHAOS-OUTPUT -j "+chain.Name)
152 if err != nil {
153 return err
154 }
155 } else {
156 return fmt.Errorf("unknown direction %d", chain.Direction)
157 }
158 return nil
159 }
160
161 func (iptables *iptablesClient) initializeEnv() error {
162 for _, direction := range []string{"INPUT", "OUTPUT"} {
163 chainName := "CHAOS-" + direction
164
165 err := iptables.createNewChain(&iptablesChain{
166 Name: chainName,
167 Rules: []string{},
168 })
169 if err != nil {
170 return err
171 }
172
173 iptables.ensureRule(&iptablesChain{
174 Name: direction,
175 Rules: []string{},
176 }, "-A "+direction+" -j "+chainName)
177 }
178
179 return nil
180 }
181
182
183 func (iptables *iptablesClient) createNewChain(chain *iptablesChain) error {
184 processBuilder := bpm.DefaultProcessBuilder(iptablesCmd, "-w", "-N", chain.Name).SetContext(iptables.ctx)
185 if iptables.enterNS {
186 processBuilder = processBuilder.SetNS(iptables.pid, bpm.NetNS)
187 }
188 cmd := processBuilder.Build()
189 out, err := cmd.CombinedOutput()
190
191 if (err == nil && len(out) == 0) ||
192 (err != nil && strings.Contains(string(out), iptablesChainAlreadyExistErr)) {
193
194 return iptables.deleteAndWriteRules(chain)
195 }
196
197 return encodeOutputToError(out, err)
198 }
199
200
201
202 func (iptables *iptablesClient) deleteAndWriteRules(chain *iptablesChain) error {
203
204
205 err := iptables.flushIptablesChain(chain)
206 if err != nil {
207 return err
208 }
209
210 for _, rule := range chain.Rules {
211 err := iptables.ensureRule(chain, rule)
212 if err != nil {
213 return err
214 }
215 }
216
217 return nil
218 }
219
220 func (iptables *iptablesClient) ensureRule(chain *iptablesChain, rule string) error {
221 processBuilder := bpm.DefaultProcessBuilder(iptablesCmd, "-w", "-S", chain.Name).SetContext(iptables.ctx)
222 if iptables.enterNS {
223 processBuilder = processBuilder.SetNS(iptables.pid, bpm.NetNS)
224 }
225 cmd := processBuilder.Build()
226 out, err := cmd.CombinedOutput()
227 if err != nil {
228 return encodeOutputToError(out, err)
229 }
230
231 if strings.Contains(string(out), rule) {
232
233 return nil
234 }
235
236
237 processBuilder = bpm.DefaultProcessBuilder(iptablesCmd, strings.Split("-w "+rule, " ")...).SetContext(iptables.ctx)
238 if iptables.enterNS {
239 processBuilder = processBuilder.SetNS(iptables.pid, bpm.NetNS)
240 }
241 cmd = processBuilder.Build()
242 out, err = cmd.CombinedOutput()
243 if err != nil {
244 return encodeOutputToError(out, err)
245 }
246
247 return nil
248 }
249
250 func (iptables *iptablesClient) flushIptablesChain(chain *iptablesChain) error {
251 processBuilder := bpm.DefaultProcessBuilder(iptablesCmd, "-w", "-F", chain.Name).SetContext(iptables.ctx)
252 if iptables.enterNS {
253 processBuilder = processBuilder.SetNS(iptables.pid, bpm.NetNS)
254 }
255 cmd := processBuilder.Build()
256 out, err := cmd.CombinedOutput()
257 if err != nil {
258 return encodeOutputToError(out, err)
259 }
260
261 return nil
262 }
263