1
2
3
4
5 package astutil
6
7
8
9 import (
10 "fmt"
11 "go/ast"
12 "go/token"
13 "sort"
14 )
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60 func PathEnclosingInterval(root *ast.File, start, end token.Pos) (path []ast.Node, exact bool) {
61
62
63
64 var visit func(node ast.Node) bool
65 visit = func(node ast.Node) bool {
66 path = append(path, node)
67
68 nodePos := node.Pos()
69 nodeEnd := node.End()
70
71
72
73
74 if start < nodePos {
75 start = nodePos
76 }
77 if end > nodeEnd {
78 end = nodeEnd
79 }
80
81
82 children := childrenOf(node)
83 l := len(children)
84 for i, child := range children {
85
86 childPos := child.Pos()
87 childEnd := child.End()
88
89
90 augPos := childPos
91 augEnd := childEnd
92 if i > 0 {
93 augPos = children[i-1].End()
94 }
95 if i < l-1 {
96 nextChildPos := children[i+1].Pos()
97
98 if start >= augEnd && end <= nextChildPos {
99 return false
100 }
101 augEnd = nextChildPos
102 }
103
104
105
106
107
108 if augPos <= start && end <= augEnd {
109 _, isToken := child.(tokenNode)
110 return isToken || visit(child)
111 }
112
113
114
115
116 if start < childEnd && end > augEnd {
117 break
118 }
119 }
120
121
122
123
124
125
126
127
128 if start == nodePos && end == nodeEnd {
129 return true
130 }
131
132 return false
133 }
134
135
136 if start > end {
137 start, end = end, start
138 }
139
140 if start < root.End() && end > root.Pos() {
141 if start == end {
142 end = start + 1
143 }
144 exact = visit(root)
145
146
147 for i, l := 0, len(path); i < l/2; i++ {
148 path[i], path[l-1-i] = path[l-1-i], path[i]
149 }
150 } else {
151
152
153
154 path = append(path, root)
155 }
156
157 return
158 }
159
160
161
162
163 type tokenNode struct {
164 pos token.Pos
165 end token.Pos
166 }
167
168 func (n tokenNode) Pos() token.Pos {
169 return n.pos
170 }
171
172 func (n tokenNode) End() token.Pos {
173 return n.end
174 }
175
176 func tok(pos token.Pos, len int) ast.Node {
177 return tokenNode{pos, pos + token.Pos(len)}
178 }
179
180
181
182
183 func childrenOf(n ast.Node) []ast.Node {
184 var children []ast.Node
185
186
187 ast.Inspect(n, func(node ast.Node) bool {
188 if node == n {
189 return true
190 }
191 if node != nil {
192 children = append(children, node)
193 }
194 return false
195 })
196
197
198 switch n := n.(type) {
199 case *ast.ArrayType:
200 children = append(children,
201 tok(n.Lbrack, len("[")),
202 tok(n.Elt.End(), len("]")))
203
204 case *ast.AssignStmt:
205 children = append(children,
206 tok(n.TokPos, len(n.Tok.String())))
207
208 case *ast.BasicLit:
209 children = append(children,
210 tok(n.ValuePos, len(n.Value)))
211
212 case *ast.BinaryExpr:
213 children = append(children, tok(n.OpPos, len(n.Op.String())))
214
215 case *ast.BlockStmt:
216 children = append(children,
217 tok(n.Lbrace, len("{")),
218 tok(n.Rbrace, len("}")))
219
220 case *ast.BranchStmt:
221 children = append(children,
222 tok(n.TokPos, len(n.Tok.String())))
223
224 case *ast.CallExpr:
225 children = append(children,
226 tok(n.Lparen, len("(")),
227 tok(n.Rparen, len(")")))
228 if n.Ellipsis != 0 {
229 children = append(children, tok(n.Ellipsis, len("...")))
230 }
231
232 case *ast.CaseClause:
233 if n.List == nil {
234 children = append(children,
235 tok(n.Case, len("default")))
236 } else {
237 children = append(children,
238 tok(n.Case, len("case")))
239 }
240 children = append(children, tok(n.Colon, len(":")))
241
242 case *ast.ChanType:
243 switch n.Dir {
244 case ast.RECV:
245 children = append(children, tok(n.Begin, len("<-chan")))
246 case ast.SEND:
247 children = append(children, tok(n.Begin, len("chan<-")))
248 case ast.RECV | ast.SEND:
249 children = append(children, tok(n.Begin, len("chan")))
250 }
251
252 case *ast.CommClause:
253 if n.Comm == nil {
254 children = append(children,
255 tok(n.Case, len("default")))
256 } else {
257 children = append(children,
258 tok(n.Case, len("case")))
259 }
260 children = append(children, tok(n.Colon, len(":")))
261
262 case *ast.Comment:
263
264
265 case *ast.CommentGroup:
266
267
268 case *ast.CompositeLit:
269 children = append(children,
270 tok(n.Lbrace, len("{")),
271 tok(n.Rbrace, len("{")))
272
273 case *ast.DeclStmt:
274
275
276 case *ast.DeferStmt:
277 children = append(children,
278 tok(n.Defer, len("defer")))
279
280 case *ast.Ellipsis:
281 children = append(children,
282 tok(n.Ellipsis, len("...")))
283
284 case *ast.EmptyStmt:
285
286
287 case *ast.ExprStmt:
288
289
290 case *ast.Field:
291
292
293 case *ast.FieldList:
294 children = append(children,
295 tok(n.Opening, len("(")),
296 tok(n.Closing, len(")")))
297
298 case *ast.File:
299
300 children = append(children,
301 tok(n.Package, len("package")))
302
303 case *ast.ForStmt:
304 children = append(children,
305 tok(n.For, len("for")))
306
307 case *ast.FuncDecl:
308
309
310
311
312
313
314
315
316
317 children = nil
318 children = append(children, tok(n.Type.Func, len("func")))
319 if n.Recv != nil {
320 children = append(children, n.Recv)
321 }
322 children = append(children, n.Name)
323 if tparams := n.Type.TypeParams; tparams != nil {
324 children = append(children, tparams)
325 }
326 if n.Type.Params != nil {
327 children = append(children, n.Type.Params)
328 }
329 if n.Type.Results != nil {
330 children = append(children, n.Type.Results)
331 }
332 if n.Body != nil {
333 children = append(children, n.Body)
334 }
335
336 case *ast.FuncLit:
337
338
339 case *ast.FuncType:
340 if n.Func != 0 {
341 children = append(children,
342 tok(n.Func, len("func")))
343 }
344
345 case *ast.GenDecl:
346 children = append(children,
347 tok(n.TokPos, len(n.Tok.String())))
348 if n.Lparen != 0 {
349 children = append(children,
350 tok(n.Lparen, len("(")),
351 tok(n.Rparen, len(")")))
352 }
353
354 case *ast.GoStmt:
355 children = append(children,
356 tok(n.Go, len("go")))
357
358 case *ast.Ident:
359 children = append(children,
360 tok(n.NamePos, len(n.Name)))
361
362 case *ast.IfStmt:
363 children = append(children,
364 tok(n.If, len("if")))
365
366 case *ast.ImportSpec:
367
368
369 case *ast.IncDecStmt:
370 children = append(children,
371 tok(n.TokPos, len(n.Tok.String())))
372
373 case *ast.IndexExpr:
374 children = append(children,
375 tok(n.Lbrack, len("[")),
376 tok(n.Rbrack, len("]")))
377
378 case *ast.IndexListExpr:
379 children = append(children,
380 tok(n.Lbrack, len("[")),
381 tok(n.Rbrack, len("]")))
382
383 case *ast.InterfaceType:
384 children = append(children,
385 tok(n.Interface, len("interface")))
386
387 case *ast.KeyValueExpr:
388 children = append(children,
389 tok(n.Colon, len(":")))
390
391 case *ast.LabeledStmt:
392 children = append(children,
393 tok(n.Colon, len(":")))
394
395 case *ast.MapType:
396 children = append(children,
397 tok(n.Map, len("map")))
398
399 case *ast.ParenExpr:
400 children = append(children,
401 tok(n.Lparen, len("(")),
402 tok(n.Rparen, len(")")))
403
404 case *ast.RangeStmt:
405 children = append(children,
406 tok(n.For, len("for")),
407 tok(n.TokPos, len(n.Tok.String())))
408
409 case *ast.ReturnStmt:
410 children = append(children,
411 tok(n.Return, len("return")))
412
413 case *ast.SelectStmt:
414 children = append(children,
415 tok(n.Select, len("select")))
416
417 case *ast.SelectorExpr:
418
419
420 case *ast.SendStmt:
421 children = append(children,
422 tok(n.Arrow, len("<-")))
423
424 case *ast.SliceExpr:
425 children = append(children,
426 tok(n.Lbrack, len("[")),
427 tok(n.Rbrack, len("]")))
428
429 case *ast.StarExpr:
430 children = append(children, tok(n.Star, len("*")))
431
432 case *ast.StructType:
433 children = append(children, tok(n.Struct, len("struct")))
434
435 case *ast.SwitchStmt:
436 children = append(children, tok(n.Switch, len("switch")))
437
438 case *ast.TypeAssertExpr:
439 children = append(children,
440 tok(n.Lparen-1, len(".")),
441 tok(n.Lparen, len("(")),
442 tok(n.Rparen, len(")")))
443
444 case *ast.TypeSpec:
445
446
447 case *ast.TypeSwitchStmt:
448 children = append(children, tok(n.Switch, len("switch")))
449
450 case *ast.UnaryExpr:
451 children = append(children, tok(n.OpPos, len(n.Op.String())))
452
453 case *ast.ValueSpec:
454
455
456 case *ast.BadDecl, *ast.BadExpr, *ast.BadStmt:
457
458 }
459
460
461
462
463
464 sort.Sort(byPos(children))
465
466 return children
467 }
468
469 type byPos []ast.Node
470
471 func (sl byPos) Len() int {
472 return len(sl)
473 }
474 func (sl byPos) Less(i, j int) bool {
475 return sl[i].Pos() < sl[j].Pos()
476 }
477 func (sl byPos) Swap(i, j int) {
478 sl[i], sl[j] = sl[j], sl[i]
479 }
480
481
482
483
484
485
486
487 func NodeDescription(n ast.Node) string {
488 switch n := n.(type) {
489 case *ast.ArrayType:
490 return "array type"
491 case *ast.AssignStmt:
492 return "assignment"
493 case *ast.BadDecl:
494 return "bad declaration"
495 case *ast.BadExpr:
496 return "bad expression"
497 case *ast.BadStmt:
498 return "bad statement"
499 case *ast.BasicLit:
500 return "basic literal"
501 case *ast.BinaryExpr:
502 return fmt.Sprintf("binary %s operation", n.Op)
503 case *ast.BlockStmt:
504 return "block"
505 case *ast.BranchStmt:
506 switch n.Tok {
507 case token.BREAK:
508 return "break statement"
509 case token.CONTINUE:
510 return "continue statement"
511 case token.GOTO:
512 return "goto statement"
513 case token.FALLTHROUGH:
514 return "fall-through statement"
515 }
516 case *ast.CallExpr:
517 if len(n.Args) == 1 && !n.Ellipsis.IsValid() {
518 return "function call (or conversion)"
519 }
520 return "function call"
521 case *ast.CaseClause:
522 return "case clause"
523 case *ast.ChanType:
524 return "channel type"
525 case *ast.CommClause:
526 return "communication clause"
527 case *ast.Comment:
528 return "comment"
529 case *ast.CommentGroup:
530 return "comment group"
531 case *ast.CompositeLit:
532 return "composite literal"
533 case *ast.DeclStmt:
534 return NodeDescription(n.Decl) + " statement"
535 case *ast.DeferStmt:
536 return "defer statement"
537 case *ast.Ellipsis:
538 return "ellipsis"
539 case *ast.EmptyStmt:
540 return "empty statement"
541 case *ast.ExprStmt:
542 return "expression statement"
543 case *ast.Field:
544
545
546
547
548
549
550 return "field/method/parameter"
551 case *ast.FieldList:
552 return "field/method/parameter list"
553 case *ast.File:
554 return "source file"
555 case *ast.ForStmt:
556 return "for loop"
557 case *ast.FuncDecl:
558 return "function declaration"
559 case *ast.FuncLit:
560 return "function literal"
561 case *ast.FuncType:
562 return "function type"
563 case *ast.GenDecl:
564 switch n.Tok {
565 case token.IMPORT:
566 return "import declaration"
567 case token.CONST:
568 return "constant declaration"
569 case token.TYPE:
570 return "type declaration"
571 case token.VAR:
572 return "variable declaration"
573 }
574 case *ast.GoStmt:
575 return "go statement"
576 case *ast.Ident:
577 return "identifier"
578 case *ast.IfStmt:
579 return "if statement"
580 case *ast.ImportSpec:
581 return "import specification"
582 case *ast.IncDecStmt:
583 if n.Tok == token.INC {
584 return "increment statement"
585 }
586 return "decrement statement"
587 case *ast.IndexExpr:
588 return "index expression"
589 case *ast.IndexListExpr:
590 return "index list expression"
591 case *ast.InterfaceType:
592 return "interface type"
593 case *ast.KeyValueExpr:
594 return "key/value association"
595 case *ast.LabeledStmt:
596 return "statement label"
597 case *ast.MapType:
598 return "map type"
599 case *ast.Package:
600 return "package"
601 case *ast.ParenExpr:
602 return "parenthesized " + NodeDescription(n.X)
603 case *ast.RangeStmt:
604 return "range loop"
605 case *ast.ReturnStmt:
606 return "return statement"
607 case *ast.SelectStmt:
608 return "select statement"
609 case *ast.SelectorExpr:
610 return "selector"
611 case *ast.SendStmt:
612 return "channel send"
613 case *ast.SliceExpr:
614 return "slice expression"
615 case *ast.StarExpr:
616 return "*-operation"
617 case *ast.StructType:
618 return "struct type"
619 case *ast.SwitchStmt:
620 return "switch statement"
621 case *ast.TypeAssertExpr:
622 return "type assertion"
623 case *ast.TypeSpec:
624 return "type specification"
625 case *ast.TypeSwitchStmt:
626 return "type switch"
627 case *ast.UnaryExpr:
628 return fmt.Sprintf("unary %s operation", n.Op)
629 case *ast.ValueSpec:
630 return "value specification"
631
632 }
633 panic(fmt.Sprintf("unexpected node type: %T", n))
634 }
635
View as plain text