Source file src/cmd/fix/fix.go

     1  // Copyright 2011 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package main
     6  
     7  import (
     8  	"fmt"
     9  	"go/ast"
    10  	"go/token"
    11  	"path"
    12  	"strconv"
    13  )
    14  
    15  type fix struct {
    16  	name     string
    17  	date     string // date that fix was introduced, in YYYY-MM-DD format
    18  	f        func(*ast.File) bool
    19  	desc     string
    20  	disabled bool // whether this fix should be disabled by default
    21  }
    22  
    23  // main runs sort.Sort(byName(fixes)) before printing list of fixes.
    24  type byName []fix
    25  
    26  func (f byName) Len() int           { return len(f) }
    27  func (f byName) Swap(i, j int)      { f[i], f[j] = f[j], f[i] }
    28  func (f byName) Less(i, j int) bool { return f[i].name < f[j].name }
    29  
    30  // main runs sort.Sort(byDate(fixes)) before applying fixes.
    31  type byDate []fix
    32  
    33  func (f byDate) Len() int           { return len(f) }
    34  func (f byDate) Swap(i, j int)      { f[i], f[j] = f[j], f[i] }
    35  func (f byDate) Less(i, j int) bool { return f[i].date < f[j].date }
    36  
    37  var fixes []fix
    38  
    39  func register(f fix) {
    40  	fixes = append(fixes, f)
    41  }
    42  
    43  // walk traverses the AST x, calling visit(y) for each node y in the tree but
    44  // also with a pointer to each ast.Expr, ast.Stmt, and *ast.BlockStmt,
    45  // in a bottom-up traversal.
    46  func walk(x any, visit func(any)) {
    47  	walkBeforeAfter(x, nop, visit)
    48  }
    49  
    50  func nop(any) {}
    51  
    52  // walkBeforeAfter is like walk but calls before(x) before traversing
    53  // x's children and after(x) afterward.
    54  func walkBeforeAfter(x any, before, after func(any)) {
    55  	before(x)
    56  
    57  	switch n := x.(type) {
    58  	default:
    59  		panic(fmt.Errorf("unexpected type %T in walkBeforeAfter", x))
    60  
    61  	case nil:
    62  
    63  	// pointers to interfaces
    64  	case *ast.Decl:
    65  		walkBeforeAfter(*n, before, after)
    66  	case *ast.Expr:
    67  		walkBeforeAfter(*n, before, after)
    68  	case *ast.Spec:
    69  		walkBeforeAfter(*n, before, after)
    70  	case *ast.Stmt:
    71  		walkBeforeAfter(*n, before, after)
    72  
    73  	// pointers to struct pointers
    74  	case **ast.BlockStmt:
    75  		walkBeforeAfter(*n, before, after)
    76  	case **ast.CallExpr:
    77  		walkBeforeAfter(*n, before, after)
    78  	case **ast.FieldList:
    79  		walkBeforeAfter(*n, before, after)
    80  	case **ast.FuncType:
    81  		walkBeforeAfter(*n, before, after)
    82  	case **ast.Ident:
    83  		walkBeforeAfter(*n, before, after)
    84  	case **ast.BasicLit:
    85  		walkBeforeAfter(*n, before, after)
    86  
    87  	// pointers to slices
    88  	case *[]ast.Decl:
    89  		walkBeforeAfter(*n, before, after)
    90  	case *[]ast.Expr:
    91  		walkBeforeAfter(*n, before, after)
    92  	case *[]*ast.File:
    93  		walkBeforeAfter(*n, before, after)
    94  	case *[]*ast.Ident:
    95  		walkBeforeAfter(*n, before, after)
    96  	case *[]ast.Spec:
    97  		walkBeforeAfter(*n, before, after)
    98  	case *[]ast.Stmt:
    99  		walkBeforeAfter(*n, before, after)
   100  
   101  	// These are ordered and grouped to match ../../go/ast/ast.go
   102  	case *ast.Field:
   103  		walkBeforeAfter(&n.Names, before, after)
   104  		walkBeforeAfter(&n.Type, before, after)
   105  		walkBeforeAfter(&n.Tag, before, after)
   106  	case *ast.FieldList:
   107  		for _, field := range n.List {
   108  			walkBeforeAfter(field, before, after)
   109  		}
   110  	case *ast.BadExpr:
   111  	case *ast.Ident:
   112  	case *ast.Ellipsis:
   113  		walkBeforeAfter(&n.Elt, before, after)
   114  	case *ast.BasicLit:
   115  	case *ast.FuncLit:
   116  		walkBeforeAfter(&n.Type, before, after)
   117  		walkBeforeAfter(&n.Body, before, after)
   118  	case *ast.CompositeLit:
   119  		walkBeforeAfter(&n.Type, before, after)
   120  		walkBeforeAfter(&n.Elts, before, after)
   121  	case *ast.ParenExpr:
   122  		walkBeforeAfter(&n.X, before, after)
   123  	case *ast.SelectorExpr:
   124  		walkBeforeAfter(&n.X, before, after)
   125  	case *ast.IndexExpr:
   126  		walkBeforeAfter(&n.X, before, after)
   127  		walkBeforeAfter(&n.Index, before, after)
   128  	case *ast.IndexListExpr:
   129  		walkBeforeAfter(&n.X, before, after)
   130  		walkBeforeAfter(&n.Indices, before, after)
   131  	case *ast.SliceExpr:
   132  		walkBeforeAfter(&n.X, before, after)
   133  		if n.Low != nil {
   134  			walkBeforeAfter(&n.Low, before, after)
   135  		}
   136  		if n.High != nil {
   137  			walkBeforeAfter(&n.High, before, after)
   138  		}
   139  	case *ast.TypeAssertExpr:
   140  		walkBeforeAfter(&n.X, before, after)
   141  		walkBeforeAfter(&n.Type, before, after)
   142  	case *ast.CallExpr:
   143  		walkBeforeAfter(&n.Fun, before, after)
   144  		walkBeforeAfter(&n.Args, before, after)
   145  	case *ast.StarExpr:
   146  		walkBeforeAfter(&n.X, before, after)
   147  	case *ast.UnaryExpr:
   148  		walkBeforeAfter(&n.X, before, after)
   149  	case *ast.BinaryExpr:
   150  		walkBeforeAfter(&n.X, before, after)
   151  		walkBeforeAfter(&n.Y, before, after)
   152  	case *ast.KeyValueExpr:
   153  		walkBeforeAfter(&n.Key, before, after)
   154  		walkBeforeAfter(&n.Value, before, after)
   155  
   156  	case *ast.ArrayType:
   157  		walkBeforeAfter(&n.Len, before, after)
   158  		walkBeforeAfter(&n.Elt, before, after)
   159  	case *ast.StructType:
   160  		walkBeforeAfter(&n.Fields, before, after)
   161  	case *ast.FuncType:
   162  		if n.TypeParams != nil {
   163  			walkBeforeAfter(&n.TypeParams, before, after)
   164  		}
   165  		walkBeforeAfter(&n.Params, before, after)
   166  		if n.Results != nil {
   167  			walkBeforeAfter(&n.Results, before, after)
   168  		}
   169  	case *ast.InterfaceType:
   170  		walkBeforeAfter(&n.Methods, before, after)
   171  	case *ast.MapType:
   172  		walkBeforeAfter(&n.Key, before, after)
   173  		walkBeforeAfter(&n.Value, before, after)
   174  	case *ast.ChanType:
   175  		walkBeforeAfter(&n.Value, before, after)
   176  
   177  	case *ast.BadStmt:
   178  	case *ast.DeclStmt:
   179  		walkBeforeAfter(&n.Decl, before, after)
   180  	case *ast.EmptyStmt:
   181  	case *ast.LabeledStmt:
   182  		walkBeforeAfter(&n.Stmt, before, after)
   183  	case *ast.ExprStmt:
   184  		walkBeforeAfter(&n.X, before, after)
   185  	case *ast.SendStmt:
   186  		walkBeforeAfter(&n.Chan, before, after)
   187  		walkBeforeAfter(&n.Value, before, after)
   188  	case *ast.IncDecStmt:
   189  		walkBeforeAfter(&n.X, before, after)
   190  	case *ast.AssignStmt:
   191  		walkBeforeAfter(&n.Lhs, before, after)
   192  		walkBeforeAfter(&n.Rhs, before, after)
   193  	case *ast.GoStmt:
   194  		walkBeforeAfter(&n.Call, before, after)
   195  	case *ast.DeferStmt:
   196  		walkBeforeAfter(&n.Call, before, after)
   197  	case *ast.ReturnStmt:
   198  		walkBeforeAfter(&n.Results, before, after)
   199  	case *ast.BranchStmt:
   200  	case *ast.BlockStmt:
   201  		walkBeforeAfter(&n.List, before, after)
   202  	case *ast.IfStmt:
   203  		walkBeforeAfter(&n.Init, before, after)
   204  		walkBeforeAfter(&n.Cond, before, after)
   205  		walkBeforeAfter(&n.Body, before, after)
   206  		walkBeforeAfter(&n.Else, before, after)
   207  	case *ast.CaseClause:
   208  		walkBeforeAfter(&n.List, before, after)
   209  		walkBeforeAfter(&n.Body, before, after)
   210  	case *ast.SwitchStmt:
   211  		walkBeforeAfter(&n.Init, before, after)
   212  		walkBeforeAfter(&n.Tag, before, after)
   213  		walkBeforeAfter(&n.Body, before, after)
   214  	case *ast.TypeSwitchStmt:
   215  		walkBeforeAfter(&n.Init, before, after)
   216  		walkBeforeAfter(&n.Assign, before, after)
   217  		walkBeforeAfter(&n.Body, before, after)
   218  	case *ast.CommClause:
   219  		walkBeforeAfter(&n.Comm, before, after)
   220  		walkBeforeAfter(&n.Body, before, after)
   221  	case *ast.SelectStmt:
   222  		walkBeforeAfter(&n.Body, before, after)
   223  	case *ast.ForStmt:
   224  		walkBeforeAfter(&n.Init, before, after)
   225  		walkBeforeAfter(&n.Cond, before, after)
   226  		walkBeforeAfter(&n.Post, before, after)
   227  		walkBeforeAfter(&n.Body, before, after)
   228  	case *ast.RangeStmt:
   229  		walkBeforeAfter(&n.Key, before, after)
   230  		walkBeforeAfter(&n.Value, before, after)
   231  		walkBeforeAfter(&n.X, before, after)
   232  		walkBeforeAfter(&n.Body, before, after)
   233  
   234  	case *ast.ImportSpec:
   235  	case *ast.ValueSpec:
   236  		walkBeforeAfter(&n.Type, before, after)
   237  		walkBeforeAfter(&n.Values, before, after)
   238  		walkBeforeAfter(&n.Names, before, after)
   239  	case *ast.TypeSpec:
   240  		if n.TypeParams != nil {
   241  			walkBeforeAfter(&n.TypeParams, before, after)
   242  		}
   243  		walkBeforeAfter(&n.Type, before, after)
   244  
   245  	case *ast.BadDecl:
   246  	case *ast.GenDecl:
   247  		walkBeforeAfter(&n.Specs, before, after)
   248  	case *ast.FuncDecl:
   249  		if n.Recv != nil {
   250  			walkBeforeAfter(&n.Recv, before, after)
   251  		}
   252  		walkBeforeAfter(&n.Type, before, after)
   253  		if n.Body != nil {
   254  			walkBeforeAfter(&n.Body, before, after)
   255  		}
   256  
   257  	case *ast.File:
   258  		walkBeforeAfter(&n.Decls, before, after)
   259  
   260  	case *ast.Package:
   261  		walkBeforeAfter(&n.Files, before, after)
   262  
   263  	case []*ast.File:
   264  		for i := range n {
   265  			walkBeforeAfter(&n[i], before, after)
   266  		}
   267  	case []ast.Decl:
   268  		for i := range n {
   269  			walkBeforeAfter(&n[i], before, after)
   270  		}
   271  	case []ast.Expr:
   272  		for i := range n {
   273  			walkBeforeAfter(&n[i], before, after)
   274  		}
   275  	case []*ast.Ident:
   276  		for i := range n {
   277  			walkBeforeAfter(&n[i], before, after)
   278  		}
   279  	case []ast.Stmt:
   280  		for i := range n {
   281  			walkBeforeAfter(&n[i], before, after)
   282  		}
   283  	case []ast.Spec:
   284  		for i := range n {
   285  			walkBeforeAfter(&n[i], before, after)
   286  		}
   287  	}
   288  	after(x)
   289  }
   290  
   291  // imports reports whether f imports path.
   292  func imports(f *ast.File, path string) bool {
   293  	return importSpec(f, path) != nil
   294  }
   295  
   296  // importSpec returns the import spec if f imports path,
   297  // or nil otherwise.
   298  func importSpec(f *ast.File, path string) *ast.ImportSpec {
   299  	for _, s := range f.Imports {
   300  		if importPath(s) == path {
   301  			return s
   302  		}
   303  	}
   304  	return nil
   305  }
   306  
   307  // importPath returns the unquoted import path of s,
   308  // or "" if the path is not properly quoted.
   309  func importPath(s *ast.ImportSpec) string {
   310  	t, err := strconv.Unquote(s.Path.Value)
   311  	if err == nil {
   312  		return t
   313  	}
   314  	return ""
   315  }
   316  
   317  // declImports reports whether gen contains an import of path.
   318  func declImports(gen *ast.GenDecl, path string) bool {
   319  	if gen.Tok != token.IMPORT {
   320  		return false
   321  	}
   322  	for _, spec := range gen.Specs {
   323  		impspec := spec.(*ast.ImportSpec)
   324  		if importPath(impspec) == path {
   325  			return true
   326  		}
   327  	}
   328  	return false
   329  }
   330  
   331  // isTopName reports whether n is a top-level unresolved identifier with the given name.
   332  func isTopName(n ast.Expr, name string) bool {
   333  	id, ok := n.(*ast.Ident)
   334  	return ok && id.Name == name && id.Obj == nil
   335  }
   336  
   337  // renameTop renames all references to the top-level name old.
   338  // It reports whether it makes any changes.
   339  func renameTop(f *ast.File, old, new string) bool {
   340  	var fixed bool
   341  
   342  	// Rename any conflicting imports
   343  	// (assuming package name is last element of path).
   344  	for _, s := range f.Imports {
   345  		if s.Name != nil {
   346  			if s.Name.Name == old {
   347  				s.Name.Name = new
   348  				fixed = true
   349  			}
   350  		} else {
   351  			_, thisName := path.Split(importPath(s))
   352  			if thisName == old {
   353  				s.Name = ast.NewIdent(new)
   354  				fixed = true
   355  			}
   356  		}
   357  	}
   358  
   359  	// Rename any top-level declarations.
   360  	for _, d := range f.Decls {
   361  		switch d := d.(type) {
   362  		case *ast.FuncDecl:
   363  			if d.Recv == nil && d.Name.Name == old {
   364  				d.Name.Name = new
   365  				d.Name.Obj.Name = new
   366  				fixed = true
   367  			}
   368  		case *ast.GenDecl:
   369  			for _, s := range d.Specs {
   370  				switch s := s.(type) {
   371  				case *ast.TypeSpec:
   372  					if s.Name.Name == old {
   373  						s.Name.Name = new
   374  						s.Name.Obj.Name = new
   375  						fixed = true
   376  					}
   377  				case *ast.ValueSpec:
   378  					for _, n := range s.Names {
   379  						if n.Name == old {
   380  							n.Name = new
   381  							n.Obj.Name = new
   382  							fixed = true
   383  						}
   384  					}
   385  				}
   386  			}
   387  		}
   388  	}
   389  
   390  	// Rename top-level old to new, both unresolved names
   391  	// (probably defined in another file) and names that resolve
   392  	// to a declaration we renamed.
   393  	walk(f, func(n any) {
   394  		id, ok := n.(*ast.Ident)
   395  		if ok && isTopName(id, old) {
   396  			id.Name = new
   397  			fixed = true
   398  		}
   399  		if ok && id.Obj != nil && id.Name == old && id.Obj.Name == new {
   400  			id.Name = id.Obj.Name
   401  			fixed = true
   402  		}
   403  	})
   404  
   405  	return fixed
   406  }
   407  
   408  // matchLen returns the length of the longest prefix shared by x and y.
   409  func matchLen(x, y string) int {
   410  	i := 0
   411  	for i < len(x) && i < len(y) && x[i] == y[i] {
   412  		i++
   413  	}
   414  	return i
   415  }
   416  
   417  // addImport adds the import path to the file f, if absent.
   418  func addImport(f *ast.File, ipath string) (added bool) {
   419  	if imports(f, ipath) {
   420  		return false
   421  	}
   422  
   423  	// Determine name of import.
   424  	// Assume added imports follow convention of using last element.
   425  	_, name := path.Split(ipath)
   426  
   427  	// Rename any conflicting top-level references from name to name_.
   428  	renameTop(f, name, name+"_")
   429  
   430  	newImport := &ast.ImportSpec{
   431  		Path: &ast.BasicLit{
   432  			Kind:  token.STRING,
   433  			Value: strconv.Quote(ipath),
   434  		},
   435  	}
   436  
   437  	// Find an import decl to add to.
   438  	var (
   439  		bestMatch  = -1
   440  		lastImport = -1
   441  		impDecl    *ast.GenDecl
   442  		impIndex   = -1
   443  	)
   444  	for i, decl := range f.Decls {
   445  		gen, ok := decl.(*ast.GenDecl)
   446  		if ok && gen.Tok == token.IMPORT {
   447  			lastImport = i
   448  			// Do not add to import "C", to avoid disrupting the
   449  			// association with its doc comment, breaking cgo.
   450  			if declImports(gen, "C") {
   451  				continue
   452  			}
   453  
   454  			// Compute longest shared prefix with imports in this block.
   455  			for j, spec := range gen.Specs {
   456  				impspec := spec.(*ast.ImportSpec)
   457  				n := matchLen(importPath(impspec), ipath)
   458  				if n > bestMatch {
   459  					bestMatch = n
   460  					impDecl = gen
   461  					impIndex = j
   462  				}
   463  			}
   464  		}
   465  	}
   466  
   467  	// If no import decl found, add one after the last import.
   468  	if impDecl == nil {
   469  		impDecl = &ast.GenDecl{
   470  			Tok: token.IMPORT,
   471  		}
   472  		f.Decls = append(f.Decls, nil)
   473  		copy(f.Decls[lastImport+2:], f.Decls[lastImport+1:])
   474  		f.Decls[lastImport+1] = impDecl
   475  	}
   476  
   477  	// Ensure the import decl has parentheses, if needed.
   478  	if len(impDecl.Specs) > 0 && !impDecl.Lparen.IsValid() {
   479  		impDecl.Lparen = impDecl.Pos()
   480  	}
   481  
   482  	insertAt := impIndex + 1
   483  	if insertAt == 0 {
   484  		insertAt = len(impDecl.Specs)
   485  	}
   486  	impDecl.Specs = append(impDecl.Specs, nil)
   487  	copy(impDecl.Specs[insertAt+1:], impDecl.Specs[insertAt:])
   488  	impDecl.Specs[insertAt] = newImport
   489  	if insertAt > 0 {
   490  		// Assign same position as the previous import,
   491  		// so that the sorter sees it as being in the same block.
   492  		prev := impDecl.Specs[insertAt-1]
   493  		newImport.Path.ValuePos = prev.Pos()
   494  		newImport.EndPos = prev.Pos()
   495  	}
   496  
   497  	f.Imports = append(f.Imports, newImport)
   498  	return true
   499  }
   500  
   501  // deleteImport deletes the import path from the file f, if present.
   502  func deleteImport(f *ast.File, path string) (deleted bool) {
   503  	oldImport := importSpec(f, path)
   504  
   505  	// Find the import node that imports path, if any.
   506  	for i, decl := range f.Decls {
   507  		gen, ok := decl.(*ast.GenDecl)
   508  		if !ok || gen.Tok != token.IMPORT {
   509  			continue
   510  		}
   511  		for j, spec := range gen.Specs {
   512  			impspec := spec.(*ast.ImportSpec)
   513  			if oldImport != impspec {
   514  				continue
   515  			}
   516  
   517  			// We found an import spec that imports path.
   518  			// Delete it.
   519  			deleted = true
   520  			copy(gen.Specs[j:], gen.Specs[j+1:])
   521  			gen.Specs = gen.Specs[:len(gen.Specs)-1]
   522  
   523  			// If this was the last import spec in this decl,
   524  			// delete the decl, too.
   525  			if len(gen.Specs) == 0 {
   526  				copy(f.Decls[i:], f.Decls[i+1:])
   527  				f.Decls = f.Decls[:len(f.Decls)-1]
   528  			} else if len(gen.Specs) == 1 {
   529  				gen.Lparen = token.NoPos // drop parens
   530  			}
   531  			if j > 0 {
   532  				// We deleted an entry but now there will be
   533  				// a blank line-sized hole where the import was.
   534  				// Close the hole by making the previous
   535  				// import appear to "end" where this one did.
   536  				gen.Specs[j-1].(*ast.ImportSpec).EndPos = impspec.End()
   537  			}
   538  			break
   539  		}
   540  	}
   541  
   542  	// Delete it from f.Imports.
   543  	for i, imp := range f.Imports {
   544  		if imp == oldImport {
   545  			copy(f.Imports[i:], f.Imports[i+1:])
   546  			f.Imports = f.Imports[:len(f.Imports)-1]
   547  			break
   548  		}
   549  	}
   550  
   551  	return
   552  }
   553  
   554  // rewriteImport rewrites any import of path oldPath to path newPath.
   555  func rewriteImport(f *ast.File, oldPath, newPath string) (rewrote bool) {
   556  	for _, imp := range f.Imports {
   557  		if importPath(imp) == oldPath {
   558  			rewrote = true
   559  			// record old End, because the default is to compute
   560  			// it using the length of imp.Path.Value.
   561  			imp.EndPos = imp.End()
   562  			imp.Path.Value = strconv.Quote(newPath)
   563  		}
   564  	}
   565  	return
   566  }
   567  

View as plain text