1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package physicalmachine
17
18 import (
19 "fmt"
20 "net"
21 "os"
22 "path/filepath"
23 "syscall"
24 "time"
25
26 "github.com/pkg/errors"
27 "github.com/pkg/sftp"
28 "golang.org/x/crypto/ssh"
29 "golang.org/x/crypto/ssh/knownhosts"
30 "golang.org/x/term"
31 )
32
33 type SshTunnel struct {
34 config *ssh.ClientConfig
35 host string
36 port string
37 client *ssh.Client
38 }
39
40 func NewSshTunnel(ip, port string, user, privateKeyFile string) (*SshTunnel, error) {
41 hostKeyCallback, err := knownhosts.New(filepath.Join(os.Getenv("HOME"), ".ssh", "known_hosts"))
42 if err != nil {
43 return nil, err
44 }
45 config := ssh.ClientConfig{
46 Timeout: 5 * time.Minute,
47 User: user,
48 Auth: []ssh.AuthMethod{
49 ssh.PublicKeysCallback(func() ([]ssh.Signer, error) {
50 key, err := os.ReadFile(privateKeyFile)
51 if err != nil {
52 return nil, errors.Wrap(err, "ssh key file read failed")
53 }
54
55 signer, err := ssh.ParsePrivateKey(key)
56 if err != nil {
57 return nil, errors.Wrap(err, "ssh key signer failed")
58 }
59 return []ssh.Signer{signer}, nil
60 }),
61 ssh.PasswordCallback(func() (secret string, err error) {
62 fmt.Printf("please input the password: ")
63 password, err := term.ReadPassword(int(syscall.Stdin))
64 if err != nil {
65 return "", errors.Wrap(err, "read ssh password failed")
66 }
67 return string(password), nil
68 }),
69 },
70 HostKeyCallback: hostKeyCallback,
71 }
72 return &SshTunnel{
73 config: &config,
74 host: ip,
75 port: port,
76 }, nil
77 }
78
79 func (s *SshTunnel) Open() error {
80 conn, err := ssh.Dial("tcp", net.JoinHostPort(s.host, s.port), s.config)
81 if err != nil {
82 return errors.Wrap(err, "open ssh tunnel failed")
83 }
84 s.client = conn
85 return nil
86 }
87
88 func (s *SshTunnel) Close() error {
89 if s.client == nil {
90 return nil
91 }
92 return s.client.Close()
93 }
94
95 func (s *SshTunnel) SFTP(filename string, data []byte) error {
96 if s.client == nil {
97 return errors.New("tunnel is not opened")
98 }
99
100
101 client, err := sftp.NewClient(s.client)
102 if err != nil {
103 return errors.Wrap(err, "create sftp client failed")
104 }
105 defer client.Close()
106
107 if err := client.MkdirAll(filepath.Dir(filename)); err != nil {
108 return errors.Wrapf(err, "make directory %s failed", filepath.Dir(filename))
109 }
110
111 f, err := client.Create(filename)
112 if err != nil {
113 return errors.Wrapf(err, "create file %s failed", filename)
114 }
115 defer f.Close()
116
117 if _, err := f.Write(data); err != nil {
118 return errors.Wrapf(err, "write file %s failed", filename)
119 }
120 return nil
121 }
122