Source file src/cmd/fix/main.go

     1  // Copyright 2011 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package main
     6  
     7  import (
     8  	"bytes"
     9  	"flag"
    10  	"fmt"
    11  	"go/ast"
    12  	"go/format"
    13  	"go/parser"
    14  	"go/scanner"
    15  	"go/token"
    16  	"go/version"
    17  	"internal/diff"
    18  	"io"
    19  	"io/fs"
    20  	"os"
    21  	"path/filepath"
    22  	"sort"
    23  	"strings"
    24  
    25  	"cmd/internal/telemetry/counter"
    26  )
    27  
    28  var (
    29  	fset     = token.NewFileSet()
    30  	exitCode = 0
    31  )
    32  
    33  var allowedRewrites = flag.String("r", "",
    34  	"restrict the rewrites to this comma-separated list")
    35  
    36  var forceRewrites = flag.String("force", "",
    37  	"force these fixes to run even if the code looks updated")
    38  
    39  var allowed, force map[string]bool
    40  
    41  var (
    42  	doDiff    = flag.Bool("diff", false, "display diffs instead of rewriting files")
    43  	goVersion = flag.String("go", "", "go language version for files")
    44  )
    45  
    46  // enable for debugging fix failures
    47  const debug = false // display incorrectly reformatted source and exit
    48  
    49  func usage() {
    50  	fmt.Fprintf(os.Stderr, "usage: go tool fix [-diff] [-r fixname,...] [-force fixname,...] [path ...]\n")
    51  	flag.PrintDefaults()
    52  	fmt.Fprintf(os.Stderr, "\nAvailable rewrites are:\n")
    53  	sort.Sort(byName(fixes))
    54  	for _, f := range fixes {
    55  		if f.disabled {
    56  			fmt.Fprintf(os.Stderr, "\n%s (disabled)\n", f.name)
    57  		} else {
    58  			fmt.Fprintf(os.Stderr, "\n%s\n", f.name)
    59  		}
    60  		desc := strings.TrimSpace(f.desc)
    61  		desc = strings.ReplaceAll(desc, "\n", "\n\t")
    62  		fmt.Fprintf(os.Stderr, "\t%s\n", desc)
    63  	}
    64  	os.Exit(2)
    65  }
    66  
    67  func main() {
    68  	counter.Open()
    69  	flag.Usage = usage
    70  	flag.Parse()
    71  	counter.Inc("fix/invocations")
    72  	counter.CountFlags("fix/flag:", *flag.CommandLine)
    73  
    74  	if !version.IsValid(*goVersion) {
    75  		report(fmt.Errorf("invalid -go=%s", *goVersion))
    76  		os.Exit(exitCode)
    77  	}
    78  
    79  	sort.Sort(byDate(fixes))
    80  
    81  	if *allowedRewrites != "" {
    82  		allowed = make(map[string]bool)
    83  		for _, f := range strings.Split(*allowedRewrites, ",") {
    84  			allowed[f] = true
    85  		}
    86  	}
    87  
    88  	if *forceRewrites != "" {
    89  		force = make(map[string]bool)
    90  		for _, f := range strings.Split(*forceRewrites, ",") {
    91  			force[f] = true
    92  		}
    93  	}
    94  
    95  	if flag.NArg() == 0 {
    96  		if err := processFile("standard input", true); err != nil {
    97  			report(err)
    98  		}
    99  		os.Exit(exitCode)
   100  	}
   101  
   102  	for i := 0; i < flag.NArg(); i++ {
   103  		path := flag.Arg(i)
   104  		switch dir, err := os.Stat(path); {
   105  		case err != nil:
   106  			report(err)
   107  		case dir.IsDir():
   108  			walkDir(path)
   109  		default:
   110  			if err := processFile(path, false); err != nil {
   111  				report(err)
   112  			}
   113  		}
   114  	}
   115  
   116  	os.Exit(exitCode)
   117  }
   118  
   119  const parserMode = parser.ParseComments
   120  
   121  func gofmtFile(f *ast.File) ([]byte, error) {
   122  	var buf bytes.Buffer
   123  	if err := format.Node(&buf, fset, f); err != nil {
   124  		return nil, err
   125  	}
   126  	return buf.Bytes(), nil
   127  }
   128  
   129  func processFile(filename string, useStdin bool) error {
   130  	var f *os.File
   131  	var err error
   132  	var fixlog strings.Builder
   133  
   134  	if useStdin {
   135  		f = os.Stdin
   136  	} else {
   137  		f, err = os.Open(filename)
   138  		if err != nil {
   139  			return err
   140  		}
   141  		defer f.Close()
   142  	}
   143  
   144  	src, err := io.ReadAll(f)
   145  	if err != nil {
   146  		return err
   147  	}
   148  
   149  	file, err := parser.ParseFile(fset, filename, src, parserMode)
   150  	if err != nil {
   151  		return err
   152  	}
   153  
   154  	// Make sure file is in canonical format.
   155  	// This "fmt" pseudo-fix cannot be disabled.
   156  	newSrc, err := gofmtFile(file)
   157  	if err != nil {
   158  		return err
   159  	}
   160  	if !bytes.Equal(newSrc, src) {
   161  		newFile, err := parser.ParseFile(fset, filename, newSrc, parserMode)
   162  		if err != nil {
   163  			return err
   164  		}
   165  		file = newFile
   166  		fmt.Fprintf(&fixlog, " fmt")
   167  	}
   168  
   169  	// Apply all fixes to file.
   170  	newFile := file
   171  	fixed := false
   172  	for _, fix := range fixes {
   173  		if allowed != nil && !allowed[fix.name] {
   174  			continue
   175  		}
   176  		if fix.disabled && !force[fix.name] {
   177  			continue
   178  		}
   179  		if fix.f(newFile) {
   180  			fixed = true
   181  			fmt.Fprintf(&fixlog, " %s", fix.name)
   182  
   183  			// AST changed.
   184  			// Print and parse, to update any missing scoping
   185  			// or position information for subsequent fixers.
   186  			newSrc, err := gofmtFile(newFile)
   187  			if err != nil {
   188  				return err
   189  			}
   190  			newFile, err = parser.ParseFile(fset, filename, newSrc, parserMode)
   191  			if err != nil {
   192  				if debug {
   193  					fmt.Printf("%s", newSrc)
   194  					report(err)
   195  					os.Exit(exitCode)
   196  				}
   197  				return err
   198  			}
   199  		}
   200  	}
   201  	if !fixed {
   202  		return nil
   203  	}
   204  	fmt.Fprintf(os.Stderr, "%s: fixed %s\n", filename, fixlog.String()[1:])
   205  
   206  	// Print AST.  We did that after each fix, so this appears
   207  	// redundant, but it is necessary to generate gofmt-compatible
   208  	// source code in a few cases. The official gofmt style is the
   209  	// output of the printer run on a standard AST generated by the parser,
   210  	// but the source we generated inside the loop above is the
   211  	// output of the printer run on a mangled AST generated by a fixer.
   212  	newSrc, err = gofmtFile(newFile)
   213  	if err != nil {
   214  		return err
   215  	}
   216  
   217  	if *doDiff {
   218  		os.Stdout.Write(diff.Diff(filename, src, "fixed/"+filename, newSrc))
   219  		return nil
   220  	}
   221  
   222  	if useStdin {
   223  		os.Stdout.Write(newSrc)
   224  		return nil
   225  	}
   226  
   227  	return os.WriteFile(f.Name(), newSrc, 0)
   228  }
   229  
   230  func gofmt(n any) string {
   231  	var gofmtBuf strings.Builder
   232  	if err := format.Node(&gofmtBuf, fset, n); err != nil {
   233  		return "<" + err.Error() + ">"
   234  	}
   235  	return gofmtBuf.String()
   236  }
   237  
   238  func report(err error) {
   239  	scanner.PrintError(os.Stderr, err)
   240  	exitCode = 2
   241  }
   242  
   243  func walkDir(path string) {
   244  	filepath.WalkDir(path, visitFile)
   245  }
   246  
   247  func visitFile(path string, f fs.DirEntry, err error) error {
   248  	if err == nil && isGoFile(f) {
   249  		err = processFile(path, false)
   250  	}
   251  	if err != nil {
   252  		report(err)
   253  	}
   254  	return nil
   255  }
   256  
   257  func isGoFile(f fs.DirEntry) bool {
   258  	// ignore non-Go files
   259  	name := f.Name()
   260  	return !f.IsDir() && !strings.HasPrefix(name, ".") && strings.HasSuffix(name, ".go")
   261  }
   262  

View as plain text