Source file
src/cmd/fix/main.go
1
2
3
4
5 package main
6
7 import (
8 "bytes"
9 "flag"
10 "fmt"
11 "go/ast"
12 "go/format"
13 "go/parser"
14 "go/scanner"
15 "go/token"
16 "go/version"
17 "internal/diff"
18 "io"
19 "io/fs"
20 "os"
21 "path/filepath"
22 "sort"
23 "strings"
24
25 "cmd/internal/telemetry/counter"
26 )
27
28 var (
29 fset = token.NewFileSet()
30 exitCode = 0
31 )
32
33 var allowedRewrites = flag.String("r", "",
34 "restrict the rewrites to this comma-separated list")
35
36 var forceRewrites = flag.String("force", "",
37 "force these fixes to run even if the code looks updated")
38
39 var allowed, force map[string]bool
40
41 var (
42 doDiff = flag.Bool("diff", false, "display diffs instead of rewriting files")
43 goVersion = flag.String("go", "", "go language version for files")
44 )
45
46
47 const debug = false
48
49 func usage() {
50 fmt.Fprintf(os.Stderr, "usage: go tool fix [-diff] [-r fixname,...] [-force fixname,...] [path ...]\n")
51 flag.PrintDefaults()
52 fmt.Fprintf(os.Stderr, "\nAvailable rewrites are:\n")
53 sort.Sort(byName(fixes))
54 for _, f := range fixes {
55 if f.disabled {
56 fmt.Fprintf(os.Stderr, "\n%s (disabled)\n", f.name)
57 } else {
58 fmt.Fprintf(os.Stderr, "\n%s\n", f.name)
59 }
60 desc := strings.TrimSpace(f.desc)
61 desc = strings.ReplaceAll(desc, "\n", "\n\t")
62 fmt.Fprintf(os.Stderr, "\t%s\n", desc)
63 }
64 os.Exit(2)
65 }
66
67 func main() {
68 counter.Open()
69 flag.Usage = usage
70 flag.Parse()
71 counter.Inc("fix/invocations")
72 counter.CountFlags("fix/flag:", *flag.CommandLine)
73
74 if !version.IsValid(*goVersion) {
75 report(fmt.Errorf("invalid -go=%s", *goVersion))
76 os.Exit(exitCode)
77 }
78
79 sort.Sort(byDate(fixes))
80
81 if *allowedRewrites != "" {
82 allowed = make(map[string]bool)
83 for _, f := range strings.Split(*allowedRewrites, ",") {
84 allowed[f] = true
85 }
86 }
87
88 if *forceRewrites != "" {
89 force = make(map[string]bool)
90 for _, f := range strings.Split(*forceRewrites, ",") {
91 force[f] = true
92 }
93 }
94
95 if flag.NArg() == 0 {
96 if err := processFile("standard input", true); err != nil {
97 report(err)
98 }
99 os.Exit(exitCode)
100 }
101
102 for i := 0; i < flag.NArg(); i++ {
103 path := flag.Arg(i)
104 switch dir, err := os.Stat(path); {
105 case err != nil:
106 report(err)
107 case dir.IsDir():
108 walkDir(path)
109 default:
110 if err := processFile(path, false); err != nil {
111 report(err)
112 }
113 }
114 }
115
116 os.Exit(exitCode)
117 }
118
119 const parserMode = parser.ParseComments
120
121 func gofmtFile(f *ast.File) ([]byte, error) {
122 var buf bytes.Buffer
123 if err := format.Node(&buf, fset, f); err != nil {
124 return nil, err
125 }
126 return buf.Bytes(), nil
127 }
128
129 func processFile(filename string, useStdin bool) error {
130 var f *os.File
131 var err error
132 var fixlog strings.Builder
133
134 if useStdin {
135 f = os.Stdin
136 } else {
137 f, err = os.Open(filename)
138 if err != nil {
139 return err
140 }
141 defer f.Close()
142 }
143
144 src, err := io.ReadAll(f)
145 if err != nil {
146 return err
147 }
148
149 file, err := parser.ParseFile(fset, filename, src, parserMode)
150 if err != nil {
151 return err
152 }
153
154
155
156 newSrc, err := gofmtFile(file)
157 if err != nil {
158 return err
159 }
160 if !bytes.Equal(newSrc, src) {
161 newFile, err := parser.ParseFile(fset, filename, newSrc, parserMode)
162 if err != nil {
163 return err
164 }
165 file = newFile
166 fmt.Fprintf(&fixlog, " fmt")
167 }
168
169
170 newFile := file
171 fixed := false
172 for _, fix := range fixes {
173 if allowed != nil && !allowed[fix.name] {
174 continue
175 }
176 if fix.disabled && !force[fix.name] {
177 continue
178 }
179 if fix.f(newFile) {
180 fixed = true
181 fmt.Fprintf(&fixlog, " %s", fix.name)
182
183
184
185
186 newSrc, err := gofmtFile(newFile)
187 if err != nil {
188 return err
189 }
190 newFile, err = parser.ParseFile(fset, filename, newSrc, parserMode)
191 if err != nil {
192 if debug {
193 fmt.Printf("%s", newSrc)
194 report(err)
195 os.Exit(exitCode)
196 }
197 return err
198 }
199 }
200 }
201 if !fixed {
202 return nil
203 }
204 fmt.Fprintf(os.Stderr, "%s: fixed %s\n", filename, fixlog.String()[1:])
205
206
207
208
209
210
211
212 newSrc, err = gofmtFile(newFile)
213 if err != nil {
214 return err
215 }
216
217 if *doDiff {
218 os.Stdout.Write(diff.Diff(filename, src, "fixed/"+filename, newSrc))
219 return nil
220 }
221
222 if useStdin {
223 os.Stdout.Write(newSrc)
224 return nil
225 }
226
227 return os.WriteFile(f.Name(), newSrc, 0)
228 }
229
230 func gofmt(n any) string {
231 var gofmtBuf strings.Builder
232 if err := format.Node(&gofmtBuf, fset, n); err != nil {
233 return "<" + err.Error() + ">"
234 }
235 return gofmtBuf.String()
236 }
237
238 func report(err error) {
239 scanner.PrintError(os.Stderr, err)
240 exitCode = 2
241 }
242
243 func walkDir(path string) {
244 filepath.WalkDir(path, visitFile)
245 }
246
247 func visitFile(path string, f fs.DirEntry, err error) error {
248 if err == nil && isGoFile(f) {
249 err = processFile(path, false)
250 }
251 if err != nil {
252 report(err)
253 }
254 return nil
255 }
256
257 func isGoFile(f fs.DirEntry) bool {
258
259 name := f.Name()
260 return !f.IsDir() && !strings.HasPrefix(name, ".") && strings.HasSuffix(name, ".go")
261 }
262
View as plain text