1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package bpm
17
18 import (
19 "context"
20 "fmt"
21 "io"
22 "os/exec"
23 "sync"
24 "syscall"
25
26 "github.com/go-logr/logr"
27 "github.com/google/uuid"
28 "github.com/pkg/errors"
29 "github.com/prometheus/client_golang/prometheus"
30 "github.com/shirou/gopsutil/process"
31
32 "github.com/chaos-mesh/chaos-mesh/pkg/log"
33 )
34
35 type NsType string
36
37 const (
38 MountNS NsType = "mnt"
39
40
41 IpcNS NsType = "ipc"
42 NetNS NsType = "net"
43 PidNS NsType = "pid"
44
45
46 )
47
48 var nsArgMap = map[NsType]string{
49 MountNS: "m",
50
51
52 IpcNS: "i",
53 NetNS: "n",
54 PidNS: "p",
55
56
57 }
58
59 const (
60 pausePath = "/usr/local/bin/pause"
61 nsexecPath = "/usr/local/bin/nsexec"
62
63 DefaultProcPrefix = "/proc"
64 )
65
66
67
68
69
70
71
72
73
74 type ProcessPair struct {
75 Pid int
76 CreateTime int64
77 }
78
79 type Process struct {
80 Uid string
81
82
83
84 Pair ProcessPair
85
86 Cmd *ManagedCommand
87 Pipes Pipes
88
89 ctx context.Context
90 stopped context.CancelFunc
91 }
92
93
94 type Pipes struct {
95 Stdin io.WriteCloser
96 Stdout io.ReadCloser
97 }
98
99
100 type BackgroundProcessManager struct {
101
102 deathChannel chan string
103
104
105 wg *sync.WaitGroup
106
107
108 identifiers *sync.Map
109
110
111 processes *sync.Map
112
113
114
115 pidPairToUid *sync.Map
116
117 rootLogger logr.Logger
118
119 metricsCollector *metricsCollector
120 }
121
122 func startProcess(cmd *ManagedCommand) (*Process, error) {
123 stdin, err := cmd.StdinPipe()
124 if err != nil {
125 return nil, errors.Wrap(err, "create stdin pipe")
126 }
127
128 stdout, err := cmd.StdoutPipe()
129 if err != nil {
130 return nil, errors.Wrap(err, "create stdout pipe")
131 }
132
133 err = cmd.Start()
134 if err != nil {
135 return nil, errors.Wrapf(err, "start command `%s`", cmd.String())
136 }
137
138 newProcess := &Process{
139 Uid: uuid.NewString(),
140 Cmd: cmd,
141 Pipes: Pipes{Stdin: stdin, Stdout: stdout},
142 }
143
144 newProcess.ctx, newProcess.stopped = context.WithCancel(context.Background())
145
146
147
148 pid := cmd.Process.Pid
149 proc, err := process.NewProcess(int32(cmd.Process.Pid))
150 if err != nil {
151 return nil, errors.Wrapf(err, "get process state for pid %d", pid)
152 }
153
154 ct, err := proc.CreateTime()
155 if err != nil {
156 return nil, errors.Wrapf(err, "get process create time for pid %d", pid)
157 }
158
159 newProcess.Pair = ProcessPair{
160 Pid: int(proc.Pid),
161 CreateTime: ct,
162 }
163 return newProcess, nil
164 }
165
166 func (p *Process) Stopped() <-chan struct{} {
167 return p.ctx.Done()
168 }
169
170
171 func StartBackgroundProcessManager(registry prometheus.Registerer, rootLogger logr.Logger) *BackgroundProcessManager {
172 backgroundProcessManager := &BackgroundProcessManager{
173 deathChannel: make(chan string, 1),
174 wg: &sync.WaitGroup{},
175 identifiers: &sync.Map{},
176 processes: &sync.Map{},
177 pidPairToUid: &sync.Map{},
178 rootLogger: rootLogger.WithName("background-process-manager"),
179 metricsCollector: nil,
180 }
181
182 go func() {
183
184 for uid := range backgroundProcessManager.deathChannel {
185 process, loaded := backgroundProcessManager.processes.LoadAndDelete(uid)
186 if loaded {
187 proc := process.(*Process)
188 backgroundProcessManager.pidPairToUid.Delete(proc.Pair)
189 if proc.Cmd.Identifier != nil {
190 backgroundProcessManager.identifiers.Delete(*proc.Cmd.Identifier)
191 }
192 proc.stopped()
193 }
194 backgroundProcessManager.wg.Done()
195 }
196 }()
197
198 if registry != nil {
199 backgroundProcessManager.metricsCollector = newMetricsCollector(backgroundProcessManager, registry)
200 }
201
202 return backgroundProcessManager
203 }
204
205 func (m *BackgroundProcessManager) recycle(uid string) {
206 m.deathChannel <- uid
207 }
208
209
210 func (m *BackgroundProcessManager) StartProcess(ctx context.Context, cmd *ManagedCommand) (*Process, error) {
211 log := m.getLoggerFromContext(ctx)
212 if cmd.Identifier != nil {
213 _, loaded := m.identifiers.LoadOrStore(*cmd.Identifier, true)
214 if loaded {
215 return nil, errors.Errorf("process with identifier %s is running", *cmd.Identifier)
216 }
217 }
218
219 process, err := startProcess(cmd)
220 if err != nil {
221 return nil, err
222 }
223
224 m.processes.Store(process.Uid, process)
225 m.pidPairToUid.Store(process.Pair, process.Uid)
226
227
228 if m.metricsCollector != nil {
229 m.metricsCollector.bpmControlledProcessTotal.Inc()
230 }
231
232 m.wg.Add(1)
233 log = log.WithValues("uid", process.Uid, "pid", process.Pair.Pid)
234
235 go func() {
236 err := cmd.Wait()
237 if err != nil {
238 if exitErr, ok := err.(*exec.ExitError); ok {
239 status := exitErr.Sys().(syscall.WaitStatus)
240 if status.Signaled() && status.Signal() == syscall.SIGTERM {
241 log.Info("process stopped with SIGTERM signal")
242 }
243 } else {
244 log.Error(err, "process exited accidentally")
245 }
246 }
247 log.Info("process stopped")
248 m.recycle(process.Uid)
249 }()
250
251 return process, nil
252 }
253
254 func (m *BackgroundProcessManager) Shutdown(ctx context.Context) {
255 log := m.getLoggerFromContext(ctx)
256
257 m.processes.Range(func(_, value interface{}) bool {
258 process := value.(*Process)
259 log := log.WithValues("uid", process.Uid, "pid", process.Pair.Pid)
260 if err := process.Cmd.Process.Signal(syscall.SIGTERM); err != nil {
261 log.Error(err, "send SIGTERM to process")
262 return true
263 }
264 return true
265 })
266 m.wg.Wait()
267 close(m.deathChannel)
268 }
269
270 func (m *BackgroundProcessManager) GetUID(pair ProcessPair) (string, bool) {
271 if uid, loaded := m.pidPairToUid.Load(pair); loaded {
272 return uid.(string), true
273 }
274 return "", false
275 }
276
277 func (m *BackgroundProcessManager) getProc(uid string) (*Process, bool) {
278 if proc, loaded := m.processes.Load(uid); loaded {
279 return proc.(*Process), true
280 }
281 return nil, false
282 }
283
284 func (m *BackgroundProcessManager) GetPipes(uid string) (Pipes, bool) {
285 proc, ok := m.getProc(uid)
286 if !ok {
287 return Pipes{}, false
288 }
289 return proc.Pipes, true
290 }
291
292
293 func (m *BackgroundProcessManager) KillBackgroundProcess(ctx context.Context, uid string) error {
294 log := m.getLoggerFromContext(ctx)
295
296 log = log.WithValues("uid", uid)
297
298 proc, loaded := m.getProc(uid)
299 if !loaded {
300 return errors.Errorf("failed to find process with uid %s", uid)
301 }
302
303 if err := proc.Cmd.Process.Signal(syscall.SIGTERM); err != nil {
304 return errors.Wrap(err, "send SIGTERM to process")
305 }
306
307 select {
308 case <-proc.Stopped():
309 log.Info("Successfully killed process")
310 case <-ctx.Done():
311 if err := ctx.Err(); err != nil {
312 return errors.Wrap(err, "context closed")
313 }
314 }
315 return nil
316 }
317
318
319 func (m *BackgroundProcessManager) GetIdentifiers() []string {
320 var identifiers []string
321 m.identifiers.Range(func(key, value interface{}) bool {
322 identifiers = append(identifiers, key.(string))
323 return true
324 })
325
326 return identifiers
327 }
328
329 func (m *BackgroundProcessManager) getLoggerFromContext(ctx context.Context) logr.Logger {
330 return log.EnrichLoggerWithContext(ctx, m.rootLogger)
331 }
332
333
334 func DefaultProcessBuilder(cmd string, args ...string) *CommandBuilder {
335 return &CommandBuilder{
336 cmd: cmd,
337 args: args,
338 nsOptions: []nsOption{},
339 pause: false,
340 identifier: nil,
341 ctx: context.Background(),
342 }
343 }
344
345
346 type CommandBuilder struct {
347 cmd string
348 args []string
349 env []string
350
351 nsOptions []nsOption
352
353 pause bool
354 localMnt bool
355
356 identifier *string
357 stdin io.ReadWriteCloser
358 stdout io.ReadWriteCloser
359 stderr io.ReadWriteCloser
360
361 oomScoreAdj int
362
363
364
365 ctx context.Context
366 }
367
368
369 func GetNsPath(pid uint32, typ NsType) string {
370 return fmt.Sprintf("%s/%d/ns/%s", DefaultProcPrefix, pid, string(typ))
371 }
372
373
374 func (b *CommandBuilder) SetEnv(key, value string) *CommandBuilder {
375 b.env = append(b.env, fmt.Sprintf("%s=%s", key, value))
376 return b
377 }
378
379
380 func (b *CommandBuilder) SetNS(pid uint32, typ NsType) *CommandBuilder {
381 return b.SetNSOpt([]nsOption{{
382 Typ: typ,
383 Path: GetNsPath(pid, typ),
384 }})
385 }
386
387
388 func (b *CommandBuilder) SetNSOpt(options []nsOption) *CommandBuilder {
389 b.nsOptions = append(b.nsOptions, options...)
390
391 return b
392 }
393
394
395
396
397
398 func (b *CommandBuilder) SetIdentifier(id string) *CommandBuilder {
399 b.identifier = &id
400
401 return b
402 }
403
404
405 func (b *CommandBuilder) EnablePause() *CommandBuilder {
406 b.pause = true
407
408 return b
409 }
410
411 func (b *CommandBuilder) EnableLocalMnt() *CommandBuilder {
412 b.localMnt = true
413
414 return b
415 }
416
417
418 func (b *CommandBuilder) SetContext(ctx context.Context) *CommandBuilder {
419 b.ctx = ctx
420
421 return b
422 }
423
424
425 func (b *CommandBuilder) SetStdin(stdin io.ReadWriteCloser) *CommandBuilder {
426 b.stdin = stdin
427
428 return b
429 }
430
431
432 func (b *CommandBuilder) SetStdout(stdout io.ReadWriteCloser) *CommandBuilder {
433 b.stdout = stdout
434
435 return b
436 }
437
438
439 func (b *CommandBuilder) SetStderr(stderr io.ReadWriteCloser) *CommandBuilder {
440 b.stderr = stderr
441
442 return b
443 }
444
445
446
447 func (b *CommandBuilder) SetOOMScoreAdj(scoreAdj int) *CommandBuilder {
448 b.oomScoreAdj = scoreAdj
449 return b
450 }
451
452 func (b *CommandBuilder) getLoggerFromContext(ctx context.Context) logr.Logger {
453
454
455 logger := log.L().WithName("background-process-manager.process-builder")
456 return log.EnrichLoggerWithContext(ctx, logger)
457 }
458
459 type nsOption struct {
460 Typ NsType
461 Path string
462 }
463
464
465 type ManagedCommand struct {
466 *exec.Cmd
467
468
469
470 Identifier *string
471 }
472