Source file
src/cmd/gofmt/rewrite.go
1
2
3
4
5 package main
6
7 import (
8 "fmt"
9 "go/ast"
10 "go/parser"
11 "go/token"
12 "os"
13 "reflect"
14 "strings"
15 "unicode"
16 "unicode/utf8"
17 )
18
19 func initRewrite() {
20 if *rewriteRule == "" {
21 rewrite = nil
22 return
23 }
24 f := strings.Split(*rewriteRule, "->")
25 if len(f) != 2 {
26 fmt.Fprintf(os.Stderr, "rewrite rule must be of the form 'pattern -> replacement'\n")
27 os.Exit(2)
28 }
29 pattern := parseExpr(f[0], "pattern")
30 replace := parseExpr(f[1], "replacement")
31 rewrite = func(fset *token.FileSet, p *ast.File) *ast.File {
32 return rewriteFile(fset, pattern, replace, p)
33 }
34 }
35
36
37
38
39
40 func parseExpr(s, what string) ast.Expr {
41 x, err := parser.ParseExpr(s)
42 if err != nil {
43 fmt.Fprintf(os.Stderr, "parsing %s %s at %s\n", what, s, err)
44 os.Exit(2)
45 }
46 return x
47 }
48
49
50
57
58
59 func rewriteFile(fileSet *token.FileSet, pattern, replace ast.Expr, p *ast.File) *ast.File {
60 cmap := ast.NewCommentMap(fileSet, p, p.Comments)
61 m := make(map[string]reflect.Value)
62 pat := reflect.ValueOf(pattern)
63 repl := reflect.ValueOf(replace)
64
65 var rewriteVal func(val reflect.Value) reflect.Value
66 rewriteVal = func(val reflect.Value) reflect.Value {
67
68 if !val.IsValid() {
69 return reflect.Value{}
70 }
71 val = apply(rewriteVal, val)
72 clear(m)
73 if match(m, pat, val) {
74 val = subst(m, repl, reflect.ValueOf(val.Interface().(ast.Node).Pos()))
75 }
76 return val
77 }
78
79 r := apply(rewriteVal, reflect.ValueOf(p)).Interface().(*ast.File)
80 r.Comments = cmap.Filter(r).Comments()
81 return r
82 }
83
84
85 func set(x, y reflect.Value) {
86
87 if !x.CanSet() || !y.IsValid() {
88 return
89 }
90 defer func() {
91 if x := recover(); x != nil {
92 if s, ok := x.(string); ok &&
93 (strings.Contains(s, "type mismatch") || strings.Contains(s, "not assignable")) {
94
95 return
96 }
97 panic(x)
98 }
99 }()
100 x.Set(y)
101 }
102
103
104 var (
105 objectPtrNil = reflect.ValueOf((*ast.Object)(nil))
106 scopePtrNil = reflect.ValueOf((*ast.Scope)(nil))
107
108 identType = reflect.TypeOf((*ast.Ident)(nil))
109 objectPtrType = reflect.TypeOf((*ast.Object)(nil))
110 positionType = reflect.TypeOf(token.NoPos)
111 callExprType = reflect.TypeOf((*ast.CallExpr)(nil))
112 scopePtrType = reflect.TypeOf((*ast.Scope)(nil))
113 )
114
115
116
117 func apply(f func(reflect.Value) reflect.Value, val reflect.Value) reflect.Value {
118 if !val.IsValid() {
119 return reflect.Value{}
120 }
121
122
123
124 if val.Type() == objectPtrType {
125 return objectPtrNil
126 }
127
128
129
130 if val.Type() == scopePtrType {
131 return scopePtrNil
132 }
133
134 switch v := reflect.Indirect(val); v.Kind() {
135 case reflect.Slice:
136 for i := 0; i < v.Len(); i++ {
137 e := v.Index(i)
138 set(e, f(e))
139 }
140 case reflect.Struct:
141 for i := 0; i < v.NumField(); i++ {
142 e := v.Field(i)
143 set(e, f(e))
144 }
145 case reflect.Interface:
146 e := v.Elem()
147 set(v, f(e))
148 }
149 return val
150 }
151
152 func isWildcard(s string) bool {
153 rune, size := utf8.DecodeRuneInString(s)
154 return size == len(s) && unicode.IsLower(rune)
155 }
156
157
158
159
160 func match(m map[string]reflect.Value, pattern, val reflect.Value) bool {
161
162
163
164 if m != nil && pattern.IsValid() && pattern.Type() == identType {
165 name := pattern.Interface().(*ast.Ident).Name
166 if isWildcard(name) && val.IsValid() {
167
168 if _, ok := val.Interface().(ast.Expr); ok && !val.IsNil() {
169 if old, ok := m[name]; ok {
170 return match(nil, old, val)
171 }
172 m[name] = val
173 return true
174 }
175 }
176 }
177
178
179 if !pattern.IsValid() || !val.IsValid() {
180 return !pattern.IsValid() && !val.IsValid()
181 }
182 if pattern.Type() != val.Type() {
183 return false
184 }
185
186
187 switch pattern.Type() {
188 case identType:
189
190
191
192
193 p := pattern.Interface().(*ast.Ident)
194 v := val.Interface().(*ast.Ident)
195 return p == nil && v == nil || p != nil && v != nil && p.Name == v.Name
196 case objectPtrType, positionType:
197
198 return true
199 case callExprType:
200
201
202
203 p := pattern.Interface().(*ast.CallExpr)
204 v := val.Interface().(*ast.CallExpr)
205 if p.Ellipsis.IsValid() != v.Ellipsis.IsValid() {
206 return false
207 }
208 }
209
210 p := reflect.Indirect(pattern)
211 v := reflect.Indirect(val)
212 if !p.IsValid() || !v.IsValid() {
213 return !p.IsValid() && !v.IsValid()
214 }
215
216 switch p.Kind() {
217 case reflect.Slice:
218 if p.Len() != v.Len() {
219 return false
220 }
221 for i := 0; i < p.Len(); i++ {
222 if !match(m, p.Index(i), v.Index(i)) {
223 return false
224 }
225 }
226 return true
227
228 case reflect.Struct:
229 for i := 0; i < p.NumField(); i++ {
230 if !match(m, p.Field(i), v.Field(i)) {
231 return false
232 }
233 }
234 return true
235
236 case reflect.Interface:
237 return match(m, p.Elem(), v.Elem())
238 }
239
240
241 return p.Interface() == v.Interface()
242 }
243
244
245
246
247
248 func subst(m map[string]reflect.Value, pattern reflect.Value, pos reflect.Value) reflect.Value {
249 if !pattern.IsValid() {
250 return reflect.Value{}
251 }
252
253
254 if m != nil && pattern.Type() == identType {
255 name := pattern.Interface().(*ast.Ident).Name
256 if isWildcard(name) {
257 if old, ok := m[name]; ok {
258 return subst(nil, old, reflect.Value{})
259 }
260 }
261 }
262
263 if pos.IsValid() && pattern.Type() == positionType {
264
265 if old := pattern.Interface().(token.Pos); !old.IsValid() {
266 return pattern
267 }
268 return pos
269 }
270
271
272 switch p := pattern; p.Kind() {
273 case reflect.Slice:
274 if p.IsNil() {
275
276
277
278 return reflect.Zero(p.Type())
279 }
280 v := reflect.MakeSlice(p.Type(), p.Len(), p.Len())
281 for i := 0; i < p.Len(); i++ {
282 v.Index(i).Set(subst(m, p.Index(i), pos))
283 }
284 return v
285
286 case reflect.Struct:
287 v := reflect.New(p.Type()).Elem()
288 for i := 0; i < p.NumField(); i++ {
289 v.Field(i).Set(subst(m, p.Field(i), pos))
290 }
291 return v
292
293 case reflect.Pointer:
294 v := reflect.New(p.Type()).Elem()
295 if elem := p.Elem(); elem.IsValid() {
296 v.Set(subst(m, elem, pos).Addr())
297 }
298 return v
299
300 case reflect.Interface:
301 v := reflect.New(p.Type()).Elem()
302 if elem := p.Elem(); elem.IsValid() {
303 v.Set(subst(m, elem, pos))
304 }
305 return v
306 }
307
308 return pattern
309 }
310
View as plain text