1
2
3
4
5
6
7
8
9
10
11
12
13
14 package bpm
15
16 import (
17 "context"
18 "fmt"
19 "io"
20 "os"
21 "os/exec"
22 "sync"
23 "syscall"
24
25 "github.com/shirou/gopsutil/process"
26 ctrl "sigs.k8s.io/controller-runtime"
27 )
28
29 var log = ctrl.Log.WithName("background-process-manager")
30
31 type NsType string
32
33 const (
34 MountNS NsType = "mnt"
35
36
37 IpcNS NsType = "ipc"
38 NetNS NsType = "net"
39 PidNS NsType = "pid"
40
41
42 )
43
44 var nsArgMap = map[NsType]string{
45 MountNS: "m",
46
47
48 IpcNS: "i",
49 NetNS: "n",
50 PidNS: "p",
51
52
53 }
54
55 const (
56 pausePath = "/usr/local/bin/pause"
57 nsexecPath = "/usr/local/bin/nsexec"
58
59 DefaultProcPrefix = "/proc"
60 )
61
62
63 type ProcessPair struct {
64 Pid int
65 CreateTime int64
66 }
67
68
69 type Stdio struct {
70 sync.Locker
71 Stdin, Stdout, Stderr io.ReadWriteCloser
72 }
73
74
75 type BackgroundProcessManager struct {
76 deathSig *sync.Map
77 identifiers *sync.Map
78 stdio *sync.Map
79 }
80
81
82 func NewBackgroundProcessManager() BackgroundProcessManager {
83 return BackgroundProcessManager{
84 deathSig: &sync.Map{},
85 identifiers: &sync.Map{},
86 stdio: &sync.Map{},
87 }
88 }
89
90
91 func (m *BackgroundProcessManager) StartProcess(cmd *ManagedProcess) (*process.Process, error) {
92 var identifierLock *sync.Mutex
93 if cmd.Identifier != nil {
94 lock, _ := m.identifiers.LoadOrStore(*cmd.Identifier, &sync.Mutex{})
95
96 identifierLock = lock.(*sync.Mutex)
97
98 identifierLock.Lock()
99 }
100
101 err := cmd.Start()
102 if err != nil {
103 log.Error(err, "fail to start process")
104 return nil, err
105 }
106
107 pid := cmd.Process.Pid
108 procState, err := process.NewProcess(int32(cmd.Process.Pid))
109 if err != nil {
110 return nil, err
111 }
112 ct, err := procState.CreateTime()
113 if err != nil {
114 return nil, err
115 }
116
117 pair := ProcessPair{
118 Pid: pid,
119 CreateTime: ct,
120 }
121
122 channel, _ := m.deathSig.LoadOrStore(pair, make(chan bool, 1))
123 deathChannel := channel.(chan bool)
124
125 stdio := &Stdio{Locker: &sync.Mutex{}}
126 if cmd.Stdin != nil {
127 if stdin, ok := cmd.Stdin.(io.ReadWriteCloser); ok {
128 stdio.Stdin = stdin
129 }
130 }
131
132 if cmd.Stdout != nil {
133 if stdout, ok := cmd.Stdout.(io.ReadWriteCloser); ok {
134 stdio.Stdout = stdout
135 }
136 }
137
138 if cmd.Stderr != nil {
139 if stderr, ok := cmd.Stderr.(io.ReadWriteCloser); ok {
140 stdio.Stderr = stderr
141 }
142 }
143
144 m.stdio.Store(pair, stdio)
145
146 log := log.WithValues("pid", pid)
147
148 go func() {
149 err := cmd.Wait()
150 if err != nil {
151 if exitErr, ok := err.(*exec.ExitError); ok {
152 status := exitErr.Sys().(syscall.WaitStatus)
153 if status.Signaled() && status.Signal() == syscall.SIGTERM {
154 log.Info("process stopped with SIGTERM signal")
155 }
156 } else {
157 log.Error(err, "process exited accidentally")
158 }
159 }
160
161 log.Info("process stopped")
162
163 deathChannel <- true
164 m.deathSig.Delete(pair)
165 if io, loaded := m.stdio.LoadAndDelete(pair); loaded {
166 if stdio, ok := io.(*Stdio); ok {
167 stdio.Lock()
168 if stdio.Stdin != nil {
169 if err = stdio.Stdin.Close(); err != nil {
170 log.Error(err, "stdin fails to be closed")
171 }
172 }
173 if stdio.Stdout != nil {
174 if err = stdio.Stdout.Close(); err != nil {
175 log.Error(err, "stdout fails to be closed")
176 }
177 }
178 if stdio.Stderr != nil {
179 if err = stdio.Stderr.Close(); err != nil {
180 log.Error(err, "stderr fails to be closed")
181 }
182 }
183 stdio.Unlock()
184 }
185 }
186
187 if identifierLock != nil {
188 identifierLock.Unlock()
189 m.identifiers.Delete(*cmd.Identifier)
190 }
191 }()
192
193 return procState, nil
194 }
195
196
197 func (m *BackgroundProcessManager) KillBackgroundProcess(ctx context.Context, pid int, startTime int64) error {
198 log := log.WithValues("pid", pid)
199
200 p, err := os.FindProcess(int(pid))
201 if err != nil {
202 log.Error(err, "unreachable path. `os.FindProcess` will never return an error on unix")
203 return err
204 }
205
206 procState, err := process.NewProcess(int32(pid))
207 if err != nil {
208
209 return nil
210 }
211 ct, err := procState.CreateTime()
212 if err != nil {
213 log.Error(err, "fail to read create time")
214
215 return nil
216 }
217
218
219
220 if startTime-ct > 1000 || ct-startTime > 1000 {
221 log.Info("process has already been killed", "startTime", ct, "expectedStartTime", startTime)
222
223 return nil
224 }
225
226 ppid, err := procState.Ppid()
227 if err != nil {
228 log.Error(err, "fail to read parent id")
229
230 return nil
231 }
232 if ppid != int32(os.Getpid()) {
233 log.Info("process has already been killed", "ppid", ppid)
234
235 return nil
236 }
237
238 err = p.Signal(syscall.SIGTERM)
239
240 if err != nil && err.Error() != "os: process already finished" {
241 log.Error(err, "error while killing process")
242 return err
243 }
244
245 pair := ProcessPair{
246 Pid: pid,
247 CreateTime: startTime,
248 }
249 channel, ok := m.deathSig.Load(pair)
250 if ok {
251 deathChannel := channel.(chan bool)
252 select {
253 case <-deathChannel:
254 case <-ctx.Done():
255 return ctx.Err()
256 }
257 }
258
259 log.Info("Successfully killed process")
260 return nil
261 }
262
263 func (m *BackgroundProcessManager) Stdio(pid int, startTime int64) *Stdio {
264 log := log.WithValues("pid", pid)
265
266 procState, err := process.NewProcess(int32(pid))
267 if err != nil {
268 log.Info("fail to get process information", "pid", pid)
269
270 return nil
271 }
272 ct, err := procState.CreateTime()
273 if err != nil {
274 log.Error(err, "fail to read create time")
275
276 return nil
277 }
278
279
280
281 if startTime-ct > 1000 || ct-startTime > 1000 {
282 log.Info("process has exited", "startTime", ct, "expectedStartTime", startTime)
283
284 return nil
285 }
286
287 pair := ProcessPair{
288 Pid: pid,
289 CreateTime: startTime,
290 }
291
292 io, ok := m.stdio.Load(pair)
293 if !ok {
294 log.Info("fail to load with pair", "pair", pair)
295
296 return nil
297 }
298
299 return io.(*Stdio)
300 }
301
302
303 func DefaultProcessBuilder(cmd string, args ...string) *ProcessBuilder {
304 return &ProcessBuilder{
305 cmd: cmd,
306 args: args,
307 nsOptions: []nsOption{},
308 pause: false,
309 identifier: nil,
310 ctx: context.Background(),
311 }
312 }
313
314
315 type ProcessBuilder struct {
316 cmd string
317 args []string
318 env []string
319
320 nsOptions []nsOption
321
322 pause bool
323 localMnt bool
324
325 identifier *string
326 stdin io.ReadWriteCloser
327 stdout io.ReadWriteCloser
328 stderr io.ReadWriteCloser
329
330 ctx context.Context
331 }
332
333
334 func GetNsPath(pid uint32, typ NsType) string {
335 return fmt.Sprintf("%s/%d/ns/%s", DefaultProcPrefix, pid, string(typ))
336 }
337
338
339 func (b *ProcessBuilder) SetEnv(key, value string) *ProcessBuilder {
340 b.env = append(b.env, fmt.Sprintf("%s=%s", key, value))
341 return b
342 }
343
344
345 func (b *ProcessBuilder) SetNS(pid uint32, typ NsType) *ProcessBuilder {
346 return b.SetNSOpt([]nsOption{{
347 Typ: typ,
348 Path: GetNsPath(pid, typ),
349 }})
350 }
351
352
353 func (b *ProcessBuilder) SetNSOpt(options []nsOption) *ProcessBuilder {
354 b.nsOptions = append(b.nsOptions, options...)
355
356 return b
357 }
358
359
360
361
362
363 func (b *ProcessBuilder) SetIdentifier(id string) *ProcessBuilder {
364 b.identifier = &id
365
366 return b
367 }
368
369
370 func (b *ProcessBuilder) EnablePause() *ProcessBuilder {
371 b.pause = true
372
373 return b
374 }
375
376 func (b *ProcessBuilder) EnableLocalMnt() *ProcessBuilder {
377 b.localMnt = true
378
379 return b
380 }
381
382
383 func (b *ProcessBuilder) SetContext(ctx context.Context) *ProcessBuilder {
384 b.ctx = ctx
385
386 return b
387 }
388
389
390 func (b *ProcessBuilder) SetStdin(stdin io.ReadWriteCloser) *ProcessBuilder {
391 b.stdin = stdin
392
393 return b
394 }
395
396
397 func (b *ProcessBuilder) SetStdout(stdout io.ReadWriteCloser) *ProcessBuilder {
398 b.stdout = stdout
399
400 return b
401 }
402
403
404 func (b *ProcessBuilder) SetStderr(stderr io.ReadWriteCloser) *ProcessBuilder {
405 b.stderr = stderr
406
407 return b
408 }
409
410 type nsOption struct {
411 Typ NsType
412 Path string
413 }
414
415
416 type ManagedProcess struct {
417 *exec.Cmd
418
419
420
421 Identifier *string
422 }
423