Source file
src/cmd/fix/typecheck.go
1
2
3
4
5 package main
6
7 import (
8 "fmt"
9 "go/ast"
10 "go/parser"
11 "go/token"
12 "os"
13 "os/exec"
14 "path/filepath"
15 "reflect"
16 "runtime"
17 "strings"
18 )
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57 func mkType(t string) string {
58 return "type " + t
59 }
60
61 func getType(t string) string {
62 if !isType(t) {
63 return ""
64 }
65 return t[len("type "):]
66 }
67
68 func isType(t string) bool {
69 return strings.HasPrefix(t, "type ")
70 }
71
72
73
74
75
76
77 type TypeConfig struct {
78 Type map[string]*Type
79 Var map[string]string
80 Func map[string]string
81
82
83
84
85 External map[string]string
86 }
87
88
89
90 func (cfg *TypeConfig) typeof(name string) string {
91 if cfg.Var != nil {
92 if t := cfg.Var[name]; t != "" {
93 return t
94 }
95 }
96 if cfg.Func != nil {
97 if t := cfg.Func[name]; t != "" {
98 return "func()" + t
99 }
100 }
101 return ""
102 }
103
104
105
106
107 type Type struct {
108 Field map[string]string
109 Method map[string]string
110 Embed []string
111 Def string
112 }
113
114
115
116 func (typ *Type) dot(cfg *TypeConfig, name string) string {
117 if typ.Field != nil {
118 if t := typ.Field[name]; t != "" {
119 return t
120 }
121 }
122 if typ.Method != nil {
123 if t := typ.Method[name]; t != "" {
124 return t
125 }
126 }
127
128 for _, e := range typ.Embed {
129 etyp := cfg.Type[e]
130 if etyp != nil {
131 if t := etyp.dot(cfg, name); t != "" {
132 return t
133 }
134 }
135 }
136
137 return ""
138 }
139
140
141
142
143
144
145 func typecheck(cfg *TypeConfig, f *ast.File) (typeof map[any]string, assign map[string][]any) {
146 typeof = make(map[any]string)
147 assign = make(map[string][]any)
148 cfg1 := &TypeConfig{}
149 *cfg1 = *cfg
150 copied := false
151
152
153 cfg.External = map[string]string{}
154 cfg1.External = cfg.External
155 if imports(f, "C") {
156
157
158
159 err := func() error {
160 txt, err := gofmtFile(f)
161 if err != nil {
162 return err
163 }
164 dir, err := os.MkdirTemp(os.TempDir(), "fix_cgo_typecheck")
165 if err != nil {
166 return err
167 }
168 defer os.RemoveAll(dir)
169 err = os.WriteFile(filepath.Join(dir, "in.go"), txt, 0600)
170 if err != nil {
171 return err
172 }
173 goCmd := "go"
174 if goroot := runtime.GOROOT(); goroot != "" {
175 goCmd = filepath.Join(goroot, "bin", "go")
176 }
177 cmd := exec.Command(goCmd, "tool", "cgo", "-objdir", dir, "-srcdir", dir, "in.go")
178 if reportCgoError != nil {
179
180
181 cmd.Stderr = os.Stderr
182 }
183 err = cmd.Run()
184 if err != nil {
185 return err
186 }
187 out, err := os.ReadFile(filepath.Join(dir, "_cgo_gotypes.go"))
188 if err != nil {
189 return err
190 }
191 cgo, err := parser.ParseFile(token.NewFileSet(), "cgo.go", out, 0)
192 if err != nil {
193 return err
194 }
195 for _, decl := range cgo.Decls {
196 fn, ok := decl.(*ast.FuncDecl)
197 if !ok {
198 continue
199 }
200 if strings.HasPrefix(fn.Name.Name, "_Cfunc_") {
201 var params, results []string
202 for _, p := range fn.Type.Params.List {
203 t := gofmt(p.Type)
204 t = strings.ReplaceAll(t, "_Ctype_", "C.")
205 params = append(params, t)
206 }
207 for _, r := range fn.Type.Results.List {
208 t := gofmt(r.Type)
209 t = strings.ReplaceAll(t, "_Ctype_", "C.")
210 results = append(results, t)
211 }
212 cfg.External["C."+fn.Name.Name[7:]] = joinFunc(params, results)
213 }
214 }
215 return nil
216 }()
217 if err != nil {
218 if reportCgoError == nil {
219 fmt.Fprintf(os.Stderr, "go fix: warning: no cgo types: %s\n", err)
220 } else {
221 reportCgoError(err)
222 }
223 }
224 }
225
226
227 for _, decl := range f.Decls {
228 fn, ok := decl.(*ast.FuncDecl)
229 if !ok {
230 continue
231 }
232 typecheck1(cfg, fn.Type, typeof, assign)
233 t := typeof[fn.Type]
234 if fn.Recv != nil {
235
236 rcvr := typeof[fn.Recv]
237 if !isType(rcvr) {
238 if len(fn.Recv.List) != 1 {
239 continue
240 }
241 rcvr = mkType(gofmt(fn.Recv.List[0].Type))
242 typeof[fn.Recv.List[0].Type] = rcvr
243 }
244 rcvr = getType(rcvr)
245 if rcvr != "" && rcvr[0] == '*' {
246 rcvr = rcvr[1:]
247 }
248 typeof[rcvr+"."+fn.Name.Name] = t
249 } else {
250 if isType(t) {
251 t = getType(t)
252 } else {
253 t = gofmt(fn.Type)
254 }
255 typeof[fn.Name] = t
256
257
258 typeof[fn.Name.Obj] = t
259 }
260 }
261
262
263 for _, decl := range f.Decls {
264 d, ok := decl.(*ast.GenDecl)
265 if ok {
266 for _, s := range d.Specs {
267 switch s := s.(type) {
268 case *ast.TypeSpec:
269 if cfg1.Type[s.Name.Name] != nil {
270 break
271 }
272 if !copied {
273 copied = true
274
275 cfg1.Type = make(map[string]*Type)
276 for k, v := range cfg.Type {
277 cfg1.Type[k] = v
278 }
279 }
280 t := &Type{Field: map[string]string{}}
281 cfg1.Type[s.Name.Name] = t
282 switch st := s.Type.(type) {
283 case *ast.StructType:
284 for _, f := range st.Fields.List {
285 for _, n := range f.Names {
286 t.Field[n.Name] = gofmt(f.Type)
287 }
288 }
289 case *ast.ArrayType, *ast.StarExpr, *ast.MapType:
290 t.Def = gofmt(st)
291 }
292 }
293 }
294 }
295 }
296
297 typecheck1(cfg1, f, typeof, assign)
298 return typeof, assign
299 }
300
301
302
303 var reportCgoError func(err error)
304
305 func makeExprList(a []*ast.Ident) []ast.Expr {
306 var b []ast.Expr
307 for _, x := range a {
308 b = append(b, x)
309 }
310 return b
311 }
312
313
314
315
316 func typecheck1(cfg *TypeConfig, f any, typeof map[any]string, assign map[string][]any) {
317
318
319 set := func(n ast.Expr, typ string, isDecl bool) {
320 if typeof[n] != "" || typ == "" {
321 if typeof[n] != typ {
322 assign[typ] = append(assign[typ], n)
323 }
324 return
325 }
326 typeof[n] = typ
327
328
329
330
331
332
333
334 if id, ok := n.(*ast.Ident); ok && id.Obj != nil && (isDecl || typeof[id.Obj] == "") {
335 typeof[id.Obj] = typ
336 }
337 }
338
339
340
341
342 typecheckAssign := func(lhs, rhs []ast.Expr, isDecl bool) {
343 if len(lhs) > 1 && len(rhs) == 1 {
344 if _, ok := rhs[0].(*ast.CallExpr); ok {
345 t := split(typeof[rhs[0]])
346
347 for i := 0; i < len(lhs) && i < len(t); i++ {
348 set(lhs[i], t[i], isDecl)
349 }
350 return
351 }
352 }
353 if len(lhs) == 1 && len(rhs) == 2 {
354
355 rhs = rhs[:1]
356 } else if len(lhs) == 2 && len(rhs) == 1 {
357
358 lhs = lhs[:1]
359 }
360
361
362 for i := 0; i < len(lhs) && i < len(rhs); i++ {
363 x, y := lhs[i], rhs[i]
364 if typeof[y] != "" {
365 set(x, typeof[y], isDecl)
366 } else {
367 set(y, typeof[x], false)
368 }
369 }
370 }
371
372 expand := func(s string) string {
373 typ := cfg.Type[s]
374 if typ != nil && typ.Def != "" {
375 return typ.Def
376 }
377 return s
378 }
379
380
381
382
383
384
385
386 var curfn []*ast.FuncType
387
388 before := func(n any) {
389
390 switch n := n.(type) {
391 case *ast.FuncDecl:
392 curfn = append(curfn, n.Type)
393 case *ast.FuncLit:
394 curfn = append(curfn, n.Type)
395 }
396 }
397
398
399 after := func(n any) {
400 if n == nil {
401 return
402 }
403 if false && reflect.TypeOf(n).Kind() == reflect.Pointer {
404 defer func() {
405 if t := typeof[n]; t != "" {
406 pos := fset.Position(n.(ast.Node).Pos())
407 fmt.Fprintf(os.Stderr, "%s: typeof[%s] = %s\n", pos, gofmt(n), t)
408 }
409 }()
410 }
411
412 switch n := n.(type) {
413 case *ast.FuncDecl, *ast.FuncLit:
414
415 curfn = curfn[:len(curfn)-1]
416
417 case *ast.FuncType:
418 typeof[n] = mkType(joinFunc(split(typeof[n.Params]), split(typeof[n.Results])))
419
420 case *ast.FieldList:
421
422 t := ""
423 for _, field := range n.List {
424 if t != "" {
425 t += ", "
426 }
427 t += typeof[field]
428 }
429 typeof[n] = t
430
431 case *ast.Field:
432
433 all := ""
434 t := typeof[n.Type]
435 if !isType(t) {
436
437
438 t = mkType(gofmt(n.Type))
439 typeof[n.Type] = t
440 }
441 t = getType(t)
442 if len(n.Names) == 0 {
443 all = t
444 } else {
445 for _, id := range n.Names {
446 if all != "" {
447 all += ", "
448 }
449 all += t
450 typeof[id.Obj] = t
451 typeof[id] = t
452 }
453 }
454 typeof[n] = all
455
456 case *ast.ValueSpec:
457
458 if n.Type != nil {
459 t := typeof[n.Type]
460 if !isType(t) {
461 t = mkType(gofmt(n.Type))
462 typeof[n.Type] = t
463 }
464 t = getType(t)
465 for _, id := range n.Names {
466 set(id, t, true)
467 }
468 }
469
470 typecheckAssign(makeExprList(n.Names), n.Values, true)
471
472 case *ast.AssignStmt:
473 typecheckAssign(n.Lhs, n.Rhs, n.Tok == token.DEFINE)
474
475 case *ast.Ident:
476
477 if t := typeof[n.Obj]; t != "" {
478 typeof[n] = t
479 }
480
481 case *ast.SelectorExpr:
482
483 name := n.Sel.Name
484 if t := typeof[n.X]; t != "" {
485 t = strings.TrimPrefix(t, "*")
486 if typ := cfg.Type[t]; typ != nil {
487 if t := typ.dot(cfg, name); t != "" {
488 typeof[n] = t
489 return
490 }
491 }
492 tt := typeof[t+"."+name]
493 if isType(tt) {
494 typeof[n] = getType(tt)
495 return
496 }
497 }
498
499 if x, ok := n.X.(*ast.Ident); ok && x.Obj == nil {
500 str := x.Name + "." + name
501 if cfg.Type[str] != nil {
502 typeof[n] = mkType(str)
503 return
504 }
505 if t := cfg.typeof(x.Name + "." + name); t != "" {
506 typeof[n] = t
507 return
508 }
509 }
510
511 case *ast.CallExpr:
512
513 if isTopName(n.Fun, "make") && len(n.Args) >= 1 {
514 typeof[n] = gofmt(n.Args[0])
515 return
516 }
517
518 if isTopName(n.Fun, "new") && len(n.Args) == 1 {
519 typeof[n] = "*" + gofmt(n.Args[0])
520 return
521 }
522
523 t := typeof[n.Fun]
524 if t == "" {
525 t = cfg.External[gofmt(n.Fun)]
526 }
527 in, out := splitFunc(t)
528 if in == nil && out == nil {
529 return
530 }
531 typeof[n] = join(out)
532 for i, arg := range n.Args {
533 if i >= len(in) {
534 break
535 }
536 if typeof[arg] == "" {
537 typeof[arg] = in[i]
538 }
539 }
540
541 case *ast.TypeAssertExpr:
542
543 if n.Type == nil {
544 typeof[n] = typeof[n.X]
545 return
546 }
547
548 if t := typeof[n.Type]; isType(t) {
549 typeof[n] = getType(t)
550 } else {
551 typeof[n] = gofmt(n.Type)
552 }
553
554 case *ast.SliceExpr:
555
556 typeof[n] = typeof[n.X]
557
558 case *ast.IndexExpr:
559
560 t := expand(typeof[n.X])
561 if strings.HasPrefix(t, "[") || strings.HasPrefix(t, "map[") {
562
563
564 if _, elem, ok := strings.Cut(t, "]"); ok {
565 typeof[n] = elem
566 }
567 }
568
569 case *ast.StarExpr:
570
571
572
573 t := expand(typeof[n.X])
574 if isType(t) {
575 typeof[n] = "type *" + getType(t)
576 } else if strings.HasPrefix(t, "*") {
577 typeof[n] = t[len("*"):]
578 }
579
580 case *ast.UnaryExpr:
581
582 t := typeof[n.X]
583 if t != "" && n.Op == token.AND {
584 typeof[n] = "*" + t
585 }
586
587 case *ast.CompositeLit:
588
589 typeof[n] = gofmt(n.Type)
590
591
592 t := expand(typeof[n])
593 if strings.HasPrefix(t, "[") {
594
595 if _, et, ok := strings.Cut(t, "]"); ok {
596 for _, e := range n.Elts {
597 if kv, ok := e.(*ast.KeyValueExpr); ok {
598 e = kv.Value
599 }
600 if typeof[e] == "" {
601 typeof[e] = et
602 }
603 }
604 }
605 }
606 if strings.HasPrefix(t, "map[") {
607
608 if kt, vt, ok := strings.Cut(t[len("map["):], "]"); ok {
609 for _, e := range n.Elts {
610 if kv, ok := e.(*ast.KeyValueExpr); ok {
611 if typeof[kv.Key] == "" {
612 typeof[kv.Key] = kt
613 }
614 if typeof[kv.Value] == "" {
615 typeof[kv.Value] = vt
616 }
617 }
618 }
619 }
620 }
621 if typ := cfg.Type[t]; typ != nil && len(typ.Field) > 0 {
622 for _, e := range n.Elts {
623 if kv, ok := e.(*ast.KeyValueExpr); ok {
624 if ft := typ.Field[fmt.Sprintf("%s", kv.Key)]; ft != "" {
625 if typeof[kv.Value] == "" {
626 typeof[kv.Value] = ft
627 }
628 }
629 }
630 }
631 }
632
633 case *ast.ParenExpr:
634
635 typeof[n] = typeof[n.X]
636
637 case *ast.RangeStmt:
638 t := expand(typeof[n.X])
639 if t == "" {
640 return
641 }
642 var key, value string
643 if t == "string" {
644 key, value = "int", "rune"
645 } else if strings.HasPrefix(t, "[") {
646 key = "int"
647 _, value, _ = strings.Cut(t, "]")
648 } else if strings.HasPrefix(t, "map[") {
649 if k, v, ok := strings.Cut(t[len("map["):], "]"); ok {
650 key, value = k, v
651 }
652 }
653 changed := false
654 if n.Key != nil && key != "" {
655 changed = true
656 set(n.Key, key, n.Tok == token.DEFINE)
657 }
658 if n.Value != nil && value != "" {
659 changed = true
660 set(n.Value, value, n.Tok == token.DEFINE)
661 }
662
663
664 if changed {
665 typecheck1(cfg, n.Body, typeof, assign)
666 }
667
668 case *ast.TypeSwitchStmt:
669
670
671
672
673 as, ok := n.Assign.(*ast.AssignStmt)
674 if !ok {
675 return
676 }
677 varx, ok := as.Lhs[0].(*ast.Ident)
678 if !ok {
679 return
680 }
681 t := typeof[varx]
682 for _, cas := range n.Body.List {
683 cas := cas.(*ast.CaseClause)
684 if len(cas.List) == 1 {
685
686
687 if tt := typeof[cas.List[0]]; isType(tt) {
688 tt = getType(tt)
689 typeof[varx] = tt
690 typeof[varx.Obj] = tt
691 typecheck1(cfg, cas.Body, typeof, assign)
692 }
693 }
694 }
695
696 typeof[varx] = t
697 typeof[varx.Obj] = t
698
699 case *ast.ReturnStmt:
700 if len(curfn) == 0 {
701
702 return
703 }
704 f := curfn[len(curfn)-1]
705 res := n.Results
706 if f.Results != nil {
707 t := split(typeof[f.Results])
708 for i := 0; i < len(res) && i < len(t); i++ {
709 set(res[i], t[i], false)
710 }
711 }
712
713 case *ast.BinaryExpr:
714
715 switch n.Op {
716 case token.EQL, token.NEQ:
717 if typeof[n.X] != "" && typeof[n.Y] == "" {
718 typeof[n.Y] = typeof[n.X]
719 }
720 if typeof[n.X] == "" && typeof[n.Y] != "" {
721 typeof[n.X] = typeof[n.Y]
722 }
723 }
724 }
725 }
726 walkBeforeAfter(f, before, after)
727 }
728
729
730
731
732
733
734
735 func splitFunc(s string) (in, out []string) {
736 if !strings.HasPrefix(s, "func(") {
737 return nil, nil
738 }
739
740 i := len("func(")
741 nparen := 0
742 for j := i; j < len(s); j++ {
743 switch s[j] {
744 case '(':
745 nparen++
746 case ')':
747 nparen--
748 if nparen < 0 {
749
750 out := strings.TrimSpace(s[j+1:])
751 if len(out) >= 2 && out[0] == '(' && out[len(out)-1] == ')' {
752 out = out[1 : len(out)-1]
753 }
754 return split(s[i:j]), split(out)
755 }
756 }
757 }
758 return nil, nil
759 }
760
761
762 func joinFunc(in, out []string) string {
763 outs := ""
764 if len(out) == 1 {
765 outs = " " + out[0]
766 } else if len(out) > 1 {
767 outs = " (" + join(out) + ")"
768 }
769 return "func(" + join(in) + ")" + outs
770 }
771
772
773 func split(s string) []string {
774 out := []string{}
775 i := 0
776 nparen := 0
777 for j := 0; j < len(s); j++ {
778 switch s[j] {
779 case ' ':
780 if i == j {
781 i++
782 }
783 case '(':
784 nparen++
785 case ')':
786 nparen--
787 if nparen < 0 {
788
789 return nil
790 }
791 case ',':
792 if nparen == 0 {
793 if i < j {
794 out = append(out, s[i:j])
795 }
796 i = j + 1
797 }
798 }
799 }
800 if nparen != 0 {
801
802 return nil
803 }
804 if i < len(s) {
805 out = append(out, s[i:])
806 }
807 return out
808 }
809
810
811 func join(x []string) string {
812 return strings.Join(x, ", ")
813 }
814
View as plain text