...

Source file src/github.com/chaos-mesh/chaos-mesh/pkg/ptrace/ptrace_linux.go

Documentation: github.com/chaos-mesh/chaos-mesh/pkg/ptrace

     1  // Copyright 2021 Chaos Mesh Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  // http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  //
    15  //go:build cgo
    16  
    17  package ptrace
    18  
    19  /*
    20  #include <stdint.h>
    21  struct iovec {
    22  	intptr_t iov_base;
    23  	size_t iov_len;
    24  };
    25  */
    26  import "C"
    27  
    28  import (
    29  	"bytes"
    30  	"debug/elf"
    31  	"fmt"
    32  	"os"
    33  	"strconv"
    34  	"strings"
    35  	"syscall"
    36  	"unsafe"
    37  
    38  	"github.com/go-logr/logr"
    39  	"github.com/pkg/errors"
    40  
    41  	"github.com/chaos-mesh/chaos-mesh/pkg/mapreader"
    42  )
    43  
    44  const waitPidErrorMessage = "waitpid ret value: %d"
    45  
    46  // If it's on 64-bit platform, `^uintptr(0)` will get a 64-bit number full of one.
    47  // After shifting right for 63-bit, only 1 will be left. Than we got 8 here.
    48  // If it's on 32-bit platform, After shifting nothing will be left. Than we got 4 here.
    49  const ptrSize = 4 << uintptr(^uintptr(0)>>63)
    50  
    51  var threadRetryLimit = 10
    52  
    53  // TracedProgram is a program traced by ptrace
    54  type TracedProgram struct {
    55  	pid     int
    56  	tids    []int
    57  	Entries []mapreader.Entry
    58  
    59  	backupRegs *syscall.PtraceRegs
    60  	backupCode []byte
    61  
    62  	logger logr.Logger
    63  }
    64  
    65  // Pid return the pid of traced program
    66  func (p *TracedProgram) Pid() int {
    67  	return p.pid
    68  }
    69  
    70  func waitPid(pid int) error {
    71  	ret := waitpid(pid)
    72  	if ret == pid {
    73  		return nil
    74  	}
    75  
    76  	return errors.Errorf(waitPidErrorMessage, ret)
    77  }
    78  
    79  // Trace ptrace all threads of a process
    80  func Trace(pid int, logger logr.Logger) (*TracedProgram, error) {
    81  	traceSuccess := false
    82  
    83  	tidMap := make(map[int]bool)
    84  	retryCount := make(map[int]int)
    85  
    86  	// iterate over the thread group, until it doens't change
    87  	//
    88  	// we have tried several ways to ensure that we have stopped all the tasks:
    89  	// 1. iterating over and over again to make sure all of them are tracee
    90  	// 2. send `SIGSTOP` signal
    91  	// ...
    92  	// only the first way finally worked for every situations
    93  	for {
    94  		threads, err := os.ReadDir(fmt.Sprintf("/proc/%d/task", pid))
    95  		if err != nil {
    96  			return nil, errors.WithStack(err)
    97  		}
    98  
    99  		// judge whether `threads` is a subset of `tidMap`
   100  		subset := true
   101  
   102  		tids := make(map[int]bool)
   103  		for _, thread := range threads {
   104  			tid64, err := strconv.ParseInt(thread.Name(), 10, 32)
   105  			if err != nil {
   106  				return nil, errors.WithStack(err)
   107  			}
   108  			tid := int(tid64)
   109  
   110  			_, ok := tidMap[tid]
   111  			if ok {
   112  				tids[tid] = true
   113  				continue
   114  			}
   115  			subset = false
   116  
   117  			err = syscall.PtraceAttach(tid)
   118  			if err != nil {
   119  				_, ok := retryCount[tid]
   120  				if !ok {
   121  					retryCount[tid] = 1
   122  				} else {
   123  					retryCount[tid]++
   124  				}
   125  				if retryCount[tid] < threadRetryLimit {
   126  					logger.Info("retry attaching thread", "tid", tid, "retryCount", retryCount[tid], "limit", threadRetryLimit)
   127  					continue
   128  				}
   129  
   130  				if !strings.Contains(err.Error(), "no such process") {
   131  					return nil, errors.WithStack(err)
   132  				}
   133  				continue
   134  			}
   135  			defer func() {
   136  				if !traceSuccess {
   137  					err = syscall.PtraceDetach(tid)
   138  					if err != nil {
   139  						if !strings.Contains(err.Error(), "no such process") {
   140  							logger.Error(err, "detach failed", "tid", tid)
   141  						}
   142  					}
   143  				}
   144  			}()
   145  
   146  			err = waitPid(tid)
   147  			if err != nil {
   148  				return nil, errors.WithStack(err)
   149  			}
   150  
   151  			logger.Info("attach successfully", "tid", tid)
   152  			tids[tid] = true
   153  			tidMap[tid] = true
   154  		}
   155  
   156  		if subset {
   157  			tidMap = tids
   158  			break
   159  		}
   160  	}
   161  
   162  	var tids []int
   163  	for key := range tidMap {
   164  		tids = append(tids, key)
   165  	}
   166  
   167  	entries, err := mapreader.Read(pid)
   168  	if err != nil {
   169  		return nil, err
   170  	}
   171  
   172  	program := &TracedProgram{
   173  		pid:        pid,
   174  		tids:       tids,
   175  		Entries:    entries,
   176  		backupRegs: &syscall.PtraceRegs{},
   177  		backupCode: make([]byte, syscallInstrSize),
   178  		logger:     logger,
   179  	}
   180  
   181  	traceSuccess = true
   182  
   183  	return program, nil
   184  }
   185  
   186  // Detach detaches from all threads of the processes
   187  func (p *TracedProgram) Detach() error {
   188  	for _, tid := range p.tids {
   189  		p.logger.Info("detaching", "tid", tid)
   190  
   191  		err := syscall.PtraceDetach(tid)
   192  
   193  		if err != nil {
   194  			if !strings.Contains(err.Error(), "no such process") {
   195  				return errors.WithStack(err)
   196  			}
   197  		}
   198  	}
   199  
   200  	p.logger.Info("Successfully detach and rerun process", "pid", p.pid)
   201  	return nil
   202  }
   203  
   204  // Protect will backup regs and rip into fields
   205  func (p *TracedProgram) Protect() error {
   206  	err := getRegs(p.pid, p.backupRegs)
   207  	if err != nil {
   208  		return errors.WithStack(err)
   209  	}
   210  
   211  	_, err = syscall.PtracePeekData(p.pid, getIp(p.backupRegs), p.backupCode)
   212  	if err != nil {
   213  		return errors.WithStack(err)
   214  	}
   215  
   216  	return nil
   217  }
   218  
   219  // Restore will restore regs and rip from fields
   220  func (p *TracedProgram) Restore() error {
   221  	err := setRegs(p.pid, p.backupRegs)
   222  	if err != nil {
   223  		return errors.WithStack(err)
   224  	}
   225  
   226  	_, err = syscall.PtracePokeData(p.pid, getIp(p.backupRegs), p.backupCode)
   227  	if err != nil {
   228  		return errors.WithStack(err)
   229  	}
   230  
   231  	return nil
   232  }
   233  
   234  // Wait waits until the process stops
   235  func (p *TracedProgram) Wait() error {
   236  	return waitPid(p.pid)
   237  }
   238  
   239  // Step moves one step forward
   240  func (p *TracedProgram) Step() error {
   241  	err := syscall.PtraceSingleStep(p.pid)
   242  	if err != nil {
   243  		return errors.WithStack(err)
   244  	}
   245  
   246  	return p.Wait()
   247  }
   248  
   249  // Mmap runs mmap syscall
   250  func (p *TracedProgram) Mmap(length uint64, fd uint64) (uint64, error) {
   251  	return p.Syscall(syscall.SYS_MMAP, 0, length, syscall.PROT_READ|syscall.PROT_WRITE|syscall.PROT_EXEC, syscall.MAP_ANON|syscall.MAP_PRIVATE, fd, 0)
   252  }
   253  
   254  // ReadSlice reads from addr and return a slice
   255  func (p *TracedProgram) ReadSlice(addr uint64, size uint64) (*[]byte, error) {
   256  	buffer := make([]byte, size)
   257  
   258  	localIov := C.struct_iovec{
   259  		iov_base: C.long(uintptr(unsafe.Pointer(&buffer[0]))),
   260  		iov_len:  C.ulong(size),
   261  	}
   262  
   263  	remoteIov := C.struct_iovec{
   264  		iov_base: C.long(addr),
   265  		iov_len:  C.ulong(size),
   266  	}
   267  
   268  	_, _, errno := syscall.Syscall6(nrProcessVMReadv, uintptr(p.pid), uintptr(unsafe.Pointer(&localIov)), uintptr(1), uintptr(unsafe.Pointer(&remoteIov)), uintptr(1), uintptr(0))
   269  	if errno != 0 {
   270  		return nil, errors.WithStack(errno)
   271  	}
   272  	// TODO: check size and warn
   273  
   274  	return &buffer, nil
   275  }
   276  
   277  // WriteSlice writes a buffer into addr
   278  func (p *TracedProgram) WriteSlice(addr uint64, buffer []byte) error {
   279  	size := len(buffer)
   280  
   281  	localIov := C.struct_iovec{
   282  		iov_base: C.long(uintptr(unsafe.Pointer(&buffer[0]))),
   283  		iov_len:  C.ulong(size),
   284  	}
   285  
   286  	remoteIov := C.struct_iovec{
   287  		iov_base: C.long(addr),
   288  		iov_len:  C.ulong(size),
   289  	}
   290  
   291  	_, _, errno := syscall.Syscall6(nrProcessVMWritev, uintptr(p.pid), uintptr(unsafe.Pointer(&localIov)), uintptr(1), uintptr(unsafe.Pointer(&remoteIov)), uintptr(1), uintptr(0))
   292  	if errno != 0 {
   293  		return errors.WithStack(errno)
   294  	}
   295  	// TODO: check size and warn
   296  
   297  	return nil
   298  }
   299  
   300  func alignBuffer(buffer []byte) []byte {
   301  	if buffer == nil {
   302  		return nil
   303  	}
   304  
   305  	alignedSize := (len(buffer) / ptrSize) * ptrSize
   306  	if alignedSize < len(buffer) {
   307  		alignedSize += ptrSize
   308  	}
   309  	clonedBuffer := make([]byte, alignedSize)
   310  	copy(clonedBuffer, buffer)
   311  
   312  	return clonedBuffer
   313  }
   314  
   315  // PtraceWriteSlice uses ptrace rather than process_vm_write to write a buffer into addr
   316  func (p *TracedProgram) PtraceWriteSlice(addr uint64, buffer []byte) error {
   317  	wroteSize := 0
   318  
   319  	buffer = alignBuffer(buffer)
   320  
   321  	for wroteSize+ptrSize <= len(buffer) {
   322  		addr := uintptr(addr + uint64(wroteSize))
   323  		data := buffer[wroteSize : wroteSize+ptrSize]
   324  
   325  		_, err := syscall.PtracePokeData(p.pid, addr, data)
   326  		if err != nil {
   327  			err = errors.WithStack(err)
   328  			return errors.WithMessagef(err, "write to addr %x with %+v failed", addr, data)
   329  		}
   330  
   331  		wroteSize += ptrSize
   332  	}
   333  
   334  	return nil
   335  }
   336  
   337  // GetLibBuffer reads an entry
   338  func (p *TracedProgram) GetLibBuffer(entry *mapreader.Entry) (*[]byte, error) {
   339  	if entry.PaddingSize > 0 {
   340  		return nil, errors.New("entry with padding size is not supported")
   341  	}
   342  
   343  	size := entry.EndAddress - entry.StartAddress
   344  
   345  	return p.ReadSlice(entry.StartAddress, size)
   346  }
   347  
   348  // MmapSlice mmaps a slice and return it's addr
   349  func (p *TracedProgram) MmapSlice(slice []byte) (*mapreader.Entry, error) {
   350  	size := uint64(len(slice))
   351  
   352  	addr, err := p.Mmap(size, 0)
   353  	if err != nil {
   354  		return nil, errors.WithStack(err)
   355  	}
   356  
   357  	err = p.WriteSlice(addr, slice)
   358  	if err != nil {
   359  		return nil, errors.WithStack(err)
   360  	}
   361  
   362  	return &mapreader.Entry{
   363  		StartAddress: addr,
   364  		EndAddress:   addr + size,
   365  		Privilege:    "rwxp",
   366  		PaddingSize:  0,
   367  		Path:         "",
   368  	}, nil
   369  }
   370  
   371  // FindSymbolInEntry finds symbol in entry through parsing elf
   372  func (p *TracedProgram) FindSymbolInEntry(symbolName string, entry *mapreader.Entry) (uint64, uint64, error) {
   373  	libBuffer, err := p.GetLibBuffer(entry)
   374  	if err != nil {
   375  		return 0, 0, err
   376  	}
   377  
   378  	reader := bytes.NewReader(*libBuffer)
   379  	vdsoElf, err := elf.NewFile(reader)
   380  	if err != nil {
   381  		return 0, 0, errors.WithStack(err)
   382  	}
   383  
   384  	loadOffset := uint64(0)
   385  
   386  	for _, prog := range vdsoElf.Progs {
   387  		if prog.Type == elf.PT_LOAD {
   388  			loadOffset = prog.Vaddr - prog.Off
   389  
   390  			// break here is enough for vdso
   391  			break
   392  		}
   393  	}
   394  
   395  	symbols, err := vdsoElf.DynamicSymbols()
   396  	if err != nil {
   397  		return 0, 0, errors.WithStack(err)
   398  	}
   399  	for _, symbol := range symbols {
   400  		if strings.Contains(symbol.Name, symbolName) {
   401  			offset := symbol.Value
   402  
   403  			return entry.StartAddress + (offset - loadOffset), symbol.Size, nil
   404  		}
   405  	}
   406  	return 0, 0, errors.New("cannot find symbol")
   407  }
   408  
   409  // WriteUint64ToAddr writes uint64 to addr
   410  func (p *TracedProgram) WriteUint64ToAddr(addr uint64, value uint64) error {
   411  	valueSlice := make([]byte, 8)
   412  	endian.PutUint64(valueSlice, value)
   413  	err := p.WriteSlice(addr, valueSlice)
   414  	return err
   415  }
   416