...
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package apivalidator
17
18 import (
19 "fmt"
20 "reflect"
21 "strconv"
22 "strings"
23
24 "github.com/go-playground/validator/v10"
25 )
26
27
28 func RequiredFieldEqualValid(fl validator.FieldLevel) bool {
29 param := strings.Split(fl.Param(), `:`)
30 paramField := param[0]
31 paramValue := param[1]
32
33 if paramField == `` {
34 return true
35 }
36
37 var paramFieldValue reflect.Value
38 if fl.Parent().Kind() == reflect.Ptr {
39 paramFieldValue = fl.Parent().Elem().FieldByName(paramField)
40 } else {
41 paramFieldValue = fl.Parent().FieldByName(paramField)
42 }
43
44 if !isEq(paramFieldValue, paramValue) {
45 return true
46 }
47
48 return hasValue(fl)
49 }
50
51
52 func hasValue(fl validator.FieldLevel) bool {
53 return requireCheckFieldKind(fl, "")
54 }
55
56 func requireCheckFieldKind(fl validator.FieldLevel, param string) bool {
57 field := fl.Field()
58 if len(param) > 0 {
59 if fl.Parent().Kind() == reflect.Ptr {
60 field = fl.Parent().Elem().FieldByName(param)
61 } else {
62 field = fl.Parent().FieldByName(param)
63 }
64 }
65 switch field.Kind() {
66 case reflect.Slice, reflect.Map, reflect.Ptr, reflect.Interface, reflect.Chan, reflect.Func:
67 return !field.IsNil()
68 default:
69 _, _, nullable := fl.ExtractType(field)
70 if nullable && field.Interface() != nil {
71 return true
72 }
73 return field.IsValid() && field.Interface() != reflect.Zero(field.Type()).Interface()
74 }
75 }
76
77 func isEq(field reflect.Value, value string) bool {
78 switch field.Kind() {
79 case reflect.String:
80 return field.String() == value
81 case reflect.Slice, reflect.Map, reflect.Array:
82 p := asInt(value)
83 return int64(field.Len()) == p
84 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
85 p := asInt(value)
86 return field.Int() == p
87 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
88 p := asUint(value)
89 return field.Uint() == p
90 case reflect.Float32, reflect.Float64:
91 p := asFloat(value)
92 return field.Float() == p
93 }
94
95 panic(fmt.Sprintf("Bad field type %T", field.Interface()))
96 }
97
98 func asInt(param string) int64 {
99 i, err := strconv.ParseInt(param, 0, 64)
100 panicIf(err)
101
102 return i
103 }
104
105 func asUint(param string) uint64 {
106 i, err := strconv.ParseUint(param, 0, 64)
107 panicIf(err)
108
109 return i
110 }
111
112 func asFloat(param string) float64 {
113 i, err := strconv.ParseFloat(param, 64)
114 panicIf(err)
115
116 return i
117 }
118
119 func panicIf(err error) {
120 if err != nil {
121 panic(err.Error())
122 }
123 }
124