1
2
3
4
5
6
7
8
9
10
11
12
13
14 package utils
15
16 import (
17 "context"
18 "fmt"
19 "os"
20 "time"
21
22 "github.com/pkg/errors"
23 "google.golang.org/grpc"
24 v1 "k8s.io/api/core/v1"
25 "k8s.io/apimachinery/pkg/types"
26
27 "sigs.k8s.io/controller-runtime/pkg/client"
28 )
29
30
31 const DefaultRPCTimeout = 60 * time.Second
32
33
34 var RPCTimeout = DefaultRPCTimeout
35
36
37 func CreateGrpcConnection(ctx context.Context, c client.Client, pod *v1.Pod, port int) (*grpc.ClientConn, error) {
38 nodeName := pod.Spec.NodeName
39 log.Info("Creating client to chaos-daemon", "node", nodeName)
40
41 ns := os.Getenv("NAMESPACE")
42 if len(ns) == 0 {
43 return nil, errors.Errorf("fail to find NAMESPACE")
44 }
45 var endpoints v1.Endpoints
46 err := c.Get(ctx, types.NamespacedName{
47 Namespace: ns,
48 Name: "chaos-daemon",
49 }, &endpoints)
50 if err != nil {
51 return nil, err
52 }
53
54 daemonIP := findIPOnEndpoints(&endpoints, nodeName)
55 if len(daemonIP) == 0 {
56 return nil, errors.Errorf("cannot find daemonIP on node %s in related Endpoints %v", nodeName, endpoints)
57 }
58 return CreateGrpcConnectionWithAddress(daemonIP, port)
59 }
60
61
62 func CreateGrpcConnectionWithAddress(address string, port int) (*grpc.ClientConn, error) {
63 conn, err := grpc.Dial(fmt.Sprintf("%s:%d", address, port),
64 grpc.WithInsecure(),
65 grpc.WithUnaryInterceptor(TimeoutClientInterceptor))
66 if err != nil {
67 return nil, err
68 }
69 return conn, nil
70 }
71
72
73 func TimeoutClientInterceptor(ctx context.Context, method string, req, reply interface{},
74 cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
75 ctx, cancel := context.WithTimeout(ctx, RPCTimeout)
76 defer cancel()
77 return invoker(ctx, method, req, reply, cc, opts...)
78 }
79
80
81
82 func TimeoutServerInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo,
83 handler grpc.UnaryHandler) (interface{}, error) {
84 if ctx.Err() != nil {
85 return nil, ctx.Err()
86 }
87 return handler(ctx, req)
88 }
89
90 func findIPOnEndpoints(e *v1.Endpoints, nodeName string) string {
91 for _, subset := range e.Subsets {
92 for _, addr := range subset.Addresses {
93 if addr.NodeName != nil && *addr.NodeName == nodeName {
94 return addr.IP
95 }
96 }
97 }
98
99 return ""
100 }
101