1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package bpm
17
18 import (
19 "context"
20 "fmt"
21 "io"
22 "os"
23 "os/exec"
24 "sync"
25 "syscall"
26
27 "github.com/shirou/gopsutil/process"
28 ctrl "sigs.k8s.io/controller-runtime"
29 )
30
31 var log = ctrl.Log.WithName("background-process-manager")
32
33 type NsType string
34
35 const (
36 MountNS NsType = "mnt"
37
38
39 IpcNS NsType = "ipc"
40 NetNS NsType = "net"
41 PidNS NsType = "pid"
42
43
44 )
45
46 var nsArgMap = map[NsType]string{
47 MountNS: "m",
48
49
50 IpcNS: "i",
51 NetNS: "n",
52 PidNS: "p",
53
54
55 }
56
57 const (
58 pausePath = "/usr/local/bin/pause"
59 nsexecPath = "/usr/local/bin/nsexec"
60
61 DefaultProcPrefix = "/proc"
62 )
63
64
65 type ProcessPair struct {
66 Pid int
67 CreateTime int64
68 }
69
70
71 type Stdio struct {
72 sync.Locker
73 Stdin, Stdout, Stderr io.ReadWriteCloser
74 }
75
76
77 type BackgroundProcessManager struct {
78 deathSig *sync.Map
79 identifiers *sync.Map
80 stdio *sync.Map
81 }
82
83
84 func NewBackgroundProcessManager() BackgroundProcessManager {
85 return BackgroundProcessManager{
86 deathSig: &sync.Map{},
87 identifiers: &sync.Map{},
88 stdio: &sync.Map{},
89 }
90 }
91
92
93 func (m *BackgroundProcessManager) StartProcess(cmd *ManagedProcess) (*process.Process, error) {
94 var identifierLock *sync.Mutex
95 if cmd.Identifier != nil {
96 lock, _ := m.identifiers.LoadOrStore(*cmd.Identifier, &sync.Mutex{})
97
98 identifierLock = lock.(*sync.Mutex)
99
100 identifierLock.Lock()
101 }
102
103 err := cmd.Start()
104 if err != nil {
105 log.Error(err, "fail to start process")
106 return nil, err
107 }
108
109 pid := cmd.Process.Pid
110 procState, err := process.NewProcess(int32(cmd.Process.Pid))
111 if err != nil {
112 return nil, err
113 }
114 ct, err := procState.CreateTime()
115 if err != nil {
116 return nil, err
117 }
118
119 pair := ProcessPair{
120 Pid: pid,
121 CreateTime: ct,
122 }
123
124 channel, _ := m.deathSig.LoadOrStore(pair, make(chan bool, 1))
125 deathChannel := channel.(chan bool)
126
127 stdio := &Stdio{Locker: &sync.Mutex{}}
128 if cmd.Stdin != nil {
129 if stdin, ok := cmd.Stdin.(io.ReadWriteCloser); ok {
130 stdio.Stdin = stdin
131 }
132 }
133
134 if cmd.Stdout != nil {
135 if stdout, ok := cmd.Stdout.(io.ReadWriteCloser); ok {
136 stdio.Stdout = stdout
137 }
138 }
139
140 if cmd.Stderr != nil {
141 if stderr, ok := cmd.Stderr.(io.ReadWriteCloser); ok {
142 stdio.Stderr = stderr
143 }
144 }
145
146 m.stdio.Store(pair, stdio)
147
148 log := log.WithValues("pid", pid)
149
150 go func() {
151 err := cmd.Wait()
152 if err != nil {
153 if exitErr, ok := err.(*exec.ExitError); ok {
154 status := exitErr.Sys().(syscall.WaitStatus)
155 if status.Signaled() && status.Signal() == syscall.SIGTERM {
156 log.Info("process stopped with SIGTERM signal")
157 }
158 } else {
159 log.Error(err, "process exited accidentally")
160 }
161 }
162
163 log.Info("process stopped")
164
165 deathChannel <- true
166 m.deathSig.Delete(pair)
167 if io, loaded := m.stdio.LoadAndDelete(pair); loaded {
168 if stdio, ok := io.(*Stdio); ok {
169 stdio.Lock()
170 if stdio.Stdin != nil {
171 if err = stdio.Stdin.Close(); err != nil {
172 log.Error(err, "stdin fails to be closed")
173 }
174 }
175 if stdio.Stdout != nil {
176 if err = stdio.Stdout.Close(); err != nil {
177 log.Error(err, "stdout fails to be closed")
178 }
179 }
180 if stdio.Stderr != nil {
181 if err = stdio.Stderr.Close(); err != nil {
182 log.Error(err, "stderr fails to be closed")
183 }
184 }
185 stdio.Unlock()
186 }
187 }
188
189 if identifierLock != nil {
190 identifierLock.Unlock()
191 m.identifiers.Delete(*cmd.Identifier)
192 }
193 }()
194
195 return procState, nil
196 }
197
198
199 func (m *BackgroundProcessManager) KillBackgroundProcess(ctx context.Context, pid int, startTime int64) error {
200 log := log.WithValues("pid", pid)
201
202 p, err := os.FindProcess(int(pid))
203 if err != nil {
204 log.Error(err, "unreachable path. `os.FindProcess` will never return an error on unix")
205 return err
206 }
207
208 procState, err := process.NewProcess(int32(pid))
209 if err != nil {
210
211 return nil
212 }
213 ct, err := procState.CreateTime()
214 if err != nil {
215 log.Error(err, "fail to read create time")
216
217 return nil
218 }
219
220
221
222 if startTime-ct > 1000 || ct-startTime > 1000 {
223 log.Info("process has already been killed", "startTime", ct, "expectedStartTime", startTime)
224
225 return nil
226 }
227
228 ppid, err := procState.Ppid()
229 if err != nil {
230 log.Error(err, "fail to read parent id")
231
232 return nil
233 }
234 if ppid != int32(os.Getpid()) {
235 log.Info("process has already been killed", "ppid", ppid)
236
237 return nil
238 }
239
240 err = p.Signal(syscall.SIGTERM)
241
242 if err != nil && err.Error() != "os: process already finished" {
243 log.Error(err, "error while killing process")
244 return err
245 }
246
247 pair := ProcessPair{
248 Pid: pid,
249 CreateTime: startTime,
250 }
251 channel, ok := m.deathSig.Load(pair)
252 if ok {
253 deathChannel := channel.(chan bool)
254 select {
255 case <-deathChannel:
256 case <-ctx.Done():
257 return ctx.Err()
258 }
259 }
260
261 log.Info("Successfully killed process")
262 return nil
263 }
264
265 func (m *BackgroundProcessManager) Stdio(pid int, startTime int64) *Stdio {
266 log := log.WithValues("pid", pid)
267
268 procState, err := process.NewProcess(int32(pid))
269 if err != nil {
270 log.Info("fail to get process information", "pid", pid)
271
272 return nil
273 }
274 ct, err := procState.CreateTime()
275 if err != nil {
276 log.Error(err, "fail to read create time")
277
278 return nil
279 }
280
281
282
283 if startTime-ct > 1000 || ct-startTime > 1000 {
284 log.Info("process has exited", "startTime", ct, "expectedStartTime", startTime)
285
286 return nil
287 }
288
289 pair := ProcessPair{
290 Pid: pid,
291 CreateTime: startTime,
292 }
293
294 io, ok := m.stdio.Load(pair)
295 if !ok {
296 log.Info("fail to load with pair", "pair", pair)
297
298 return nil
299 }
300
301 return io.(*Stdio)
302 }
303
304
305 func DefaultProcessBuilder(cmd string, args ...string) *ProcessBuilder {
306 return &ProcessBuilder{
307 cmd: cmd,
308 args: args,
309 nsOptions: []nsOption{},
310 pause: false,
311 identifier: nil,
312 ctx: context.Background(),
313 }
314 }
315
316
317 type ProcessBuilder struct {
318 cmd string
319 args []string
320 env []string
321
322 nsOptions []nsOption
323
324 pause bool
325 localMnt bool
326
327 identifier *string
328 stdin io.ReadWriteCloser
329 stdout io.ReadWriteCloser
330 stderr io.ReadWriteCloser
331
332 ctx context.Context
333 }
334
335
336 func GetNsPath(pid uint32, typ NsType) string {
337 return fmt.Sprintf("%s/%d/ns/%s", DefaultProcPrefix, pid, string(typ))
338 }
339
340
341 func (b *ProcessBuilder) SetEnv(key, value string) *ProcessBuilder {
342 b.env = append(b.env, fmt.Sprintf("%s=%s", key, value))
343 return b
344 }
345
346
347 func (b *ProcessBuilder) SetNS(pid uint32, typ NsType) *ProcessBuilder {
348 return b.SetNSOpt([]nsOption{{
349 Typ: typ,
350 Path: GetNsPath(pid, typ),
351 }})
352 }
353
354
355 func (b *ProcessBuilder) SetNSOpt(options []nsOption) *ProcessBuilder {
356 b.nsOptions = append(b.nsOptions, options...)
357
358 return b
359 }
360
361
362 func (b *ProcessBuilder) SetIdentifier(id string) *ProcessBuilder {
363 b.identifier = &id
364
365 return b
366 }
367
368
369 func (b *ProcessBuilder) EnablePause() *ProcessBuilder {
370 b.pause = true
371
372 return b
373 }
374
375 func (b *ProcessBuilder) EnableLocalMnt() *ProcessBuilder {
376 b.localMnt = true
377
378 return b
379 }
380
381
382 func (b *ProcessBuilder) SetContext(ctx context.Context) *ProcessBuilder {
383 b.ctx = ctx
384
385 return b
386 }
387
388
389 func (b *ProcessBuilder) SetStdin(stdin io.ReadWriteCloser) *ProcessBuilder {
390 b.stdin = stdin
391
392 return b
393 }
394
395
396 func (b *ProcessBuilder) SetStdout(stdout io.ReadWriteCloser) *ProcessBuilder {
397 b.stdout = stdout
398
399 return b
400 }
401
402
403 func (b *ProcessBuilder) SetStderr(stderr io.ReadWriteCloser) *ProcessBuilder {
404 b.stderr = stderr
405
406 return b
407 }
408
409 type nsOption struct {
410 Typ NsType
411 Path string
412 }
413
414
415 type ManagedProcess struct {
416 *exec.Cmd
417
418
419
420 Identifier *string
421 }
422