1
2
3
4
5
6
7
8
9
10
11
12
13
14 package grpc
15
16 import (
17 "context"
18 "crypto/tls"
19 "crypto/x509"
20 "fmt"
21 "io/ioutil"
22 "net"
23 "strconv"
24 "time"
25
26 "google.golang.org/grpc"
27 "google.golang.org/grpc/credentials"
28 )
29
30
31 const DefaultRPCTimeout = 60 * time.Second
32
33
34 var RPCTimeout = DefaultRPCTimeout
35
36 const ChaosDaemonServerName = "chaos-daemon.chaos-mesh.org"
37
38 type TLSRaw struct {
39 CaCert []byte
40 Cert []byte
41 Key []byte
42 }
43
44 type TLSFile struct {
45 CaCert string
46 Cert string
47 Key string
48 }
49
50 type FileProvider struct {
51 file TLSFile
52 }
53
54 type RawProvider struct {
55 raw TLSRaw
56 }
57
58 type InsecureProvider struct {
59 }
60
61 type CredentialProvider interface {
62 getCredentialOption() (grpc.DialOption, error)
63 }
64
65 func (it *FileProvider) getCredentialOption() (grpc.DialOption, error) {
66 caCert, err := ioutil.ReadFile(it.file.CaCert)
67 if err != nil {
68 return nil, err
69 }
70 caCertPool := x509.NewCertPool()
71 caCertPool.AppendCertsFromPEM(caCert)
72
73 clientCert, err := tls.LoadX509KeyPair(it.file.Cert, it.file.Key)
74 if err != nil {
75 return nil, err
76 }
77
78 creds := credentials.NewTLS(&tls.Config{
79 Certificates: []tls.Certificate{clientCert},
80 RootCAs: caCertPool,
81 ServerName: ChaosDaemonServerName,
82 })
83 return grpc.WithTransportCredentials(creds), nil
84 }
85
86 func (it *RawProvider) getCredentialOption() (grpc.DialOption, error) {
87 caCertPool := x509.NewCertPool()
88 caCertPool.AppendCertsFromPEM(it.raw.CaCert)
89
90 clientCert, err := tls.X509KeyPair(it.raw.Cert, it.raw.Key)
91 if err != nil {
92 return nil, err
93 }
94
95 creds := credentials.NewTLS(&tls.Config{
96 Certificates: []tls.Certificate{clientCert},
97 RootCAs: caCertPool,
98 ServerName: ChaosDaemonServerName,
99 })
100 return grpc.WithTransportCredentials(creds), nil
101 }
102
103 func (it *InsecureProvider) getCredentialOption() (grpc.DialOption, error) {
104 return grpc.WithInsecure(), nil
105 }
106
107 type GrpcBuilder struct {
108 options []grpc.DialOption
109 credentialProvider CredentialProvider
110 address string
111 port int
112 }
113
114 func Builder(address string, port int) *GrpcBuilder {
115 return &GrpcBuilder{options: []grpc.DialOption{}, address: address, port: port}
116 }
117
118 func (it *GrpcBuilder) WithDefaultTimeout() *GrpcBuilder {
119 it.options = append(it.options, grpc.WithUnaryInterceptor(TimeoutClientInterceptor(DefaultRPCTimeout)))
120 return it
121 }
122
123 func (it *GrpcBuilder) WithTimeout(timeout time.Duration) *GrpcBuilder {
124 it.options = append(it.options, grpc.WithUnaryInterceptor(TimeoutClientInterceptor(timeout)))
125 return it
126 }
127
128 func (it *GrpcBuilder) Insecure() *GrpcBuilder {
129 it.credentialProvider = &InsecureProvider{}
130 return it
131 }
132
133 func (it *GrpcBuilder) TLSFromRaw(caCert []byte, cert []byte, key []byte) *GrpcBuilder {
134 it.credentialProvider = &RawProvider{
135 raw: TLSRaw{
136 CaCert: caCert,
137 Cert: cert,
138 Key: key,
139 },
140 }
141
142 return it
143 }
144
145 func (it *GrpcBuilder) TLSFromFile(caCertPath string, certPath string, keyPath string) *GrpcBuilder {
146 it.credentialProvider = &FileProvider{
147 file: TLSFile{
148 CaCert: caCertPath,
149 Cert: certPath,
150 Key: keyPath,
151 },
152 }
153 return it
154 }
155
156 func (it *GrpcBuilder) Build() (*grpc.ClientConn, error) {
157 if it.credentialProvider == nil {
158 return nil, fmt.Errorf("an authorization method must be specified")
159 }
160 option, err := it.credentialProvider.getCredentialOption()
161 if err != nil {
162 return nil, err
163 }
164 it.options = append(it.options, option)
165 return grpc.Dial(net.JoinHostPort(it.address, strconv.Itoa(it.port)), it.options...)
166 }
167
168
169 func TimeoutClientInterceptor(timeout time.Duration) func(context.Context, string, interface{}, interface{},
170 *grpc.ClientConn, grpc.UnaryInvoker, ...grpc.CallOption) error {
171 return func(ctx context.Context, method string, req, reply interface{},
172 cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
173 ctx, cancel := context.WithTimeout(ctx, timeout)
174 defer cancel()
175 return invoker(ctx, method, req, reply, cc, opts...)
176 }
177 }
178
179
180
181 func TimeoutServerInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo,
182 handler grpc.UnaryHandler) (interface{}, error) {
183 if ctx.Err() != nil {
184 return nil, ctx.Err()
185 }
186 return handler(ctx, req)
187 }
188