Source file
src/cmd/gofmt/gofmt_test.go
1
2
3
4
5 package main
6
7 import (
8 "bytes"
9 "flag"
10 "internal/diff"
11 "os"
12 "path/filepath"
13 "strings"
14 "testing"
15 "text/scanner"
16 )
17
18 var update = flag.Bool("update", false, "update .golden files")
19
20
21
22
23
24
25
26
27 func gofmtFlags(filename string, maxLines int) string {
28 f, err := os.Open(filename)
29 if err != nil {
30 return ""
31 }
32 defer f.Close()
33
34
35 var s scanner.Scanner
36 s.Init(f)
37 s.Error = func(*scanner.Scanner, string) {}
38 s.Mode = scanner.GoTokens &^ scanner.SkipComments
39
40
41 for s.Line <= maxLines {
42 switch s.Scan() {
43 case scanner.Comment:
44 const prefix = "//gofmt "
45 if t := s.TokenText(); strings.HasPrefix(t, prefix) {
46 return strings.TrimSpace(t[len(prefix):])
47 }
48 case scanner.EOF:
49 return ""
50 }
51 }
52
53 return ""
54 }
55
56
57 func resetFlags() {
58 *list = false
59 *write = false
60 *rewriteRule = ""
61 *simplifyAST = false
62 *doDiff = false
63 *allErrors = false
64 *cpuprofile = ""
65 }
66
67 func runTest(t *testing.T, in, out string) {
68 resetFlags()
69 info, err := os.Lstat(in)
70 if err != nil {
71 t.Error(err)
72 return
73 }
74 for _, flag := range strings.Split(gofmtFlags(in, 20), " ") {
75 elts := strings.SplitN(flag, "=", 2)
76 name := elts[0]
77 value := ""
78 if len(elts) == 2 {
79 value = elts[1]
80 }
81 switch name {
82 case "":
83
84 case "-r":
85 *rewriteRule = value
86 case "-s":
87 *simplifyAST = true
88 case "-stdin":
89
90 info = nil
91 default:
92 t.Errorf("unrecognized flag name: %s", name)
93 }
94 }
95
96 initParserMode()
97 initRewrite()
98
99 const maxWeight = 2 << 20
100 var buf, errBuf bytes.Buffer
101 s := newSequencer(maxWeight, &buf, &errBuf)
102 s.Add(fileWeight(in, info), func(r *reporter) error {
103 return processFile(in, info, nil, r)
104 })
105 if errBuf.Len() > 0 {
106 t.Logf("%q", errBuf.Bytes())
107 }
108 if s.GetExitCode() != 0 {
109 t.Fail()
110 }
111
112 expected, err := os.ReadFile(out)
113 if err != nil {
114 t.Error(err)
115 return
116 }
117
118 if got := buf.Bytes(); !bytes.Equal(got, expected) {
119 if *update {
120 if in != out {
121 if err := os.WriteFile(out, got, 0666); err != nil {
122 t.Error(err)
123 }
124 return
125 }
126
127 t.Errorf("WARNING: -update did not rewrite input file %s", in)
128 }
129
130 t.Errorf("(gofmt %s) != %s (see %s.gofmt)\n%s", in, out, in,
131 diff.Diff("expected", expected, "got", got))
132 if err := os.WriteFile(in+".gofmt", got, 0666); err != nil {
133 t.Error(err)
134 }
135 }
136 }
137
138
139
140
141
142
143
144
145 func TestRewrite(t *testing.T) {
146
147 match, err := filepath.Glob("testdata/*.input")
148 if err != nil {
149 t.Fatal(err)
150 }
151
152
153 match = append(match, "gofmt.go", "gofmt_test.go")
154
155 for _, in := range match {
156 name := filepath.Base(in)
157 t.Run(name, func(t *testing.T) {
158 out := in
159 if strings.HasSuffix(in, ".input") {
160 out = in[:len(in)-len(".input")] + ".golden"
161 }
162 runTest(t, in, out)
163 if in != out && !t.Failed() {
164
165 runTest(t, out, out)
166 }
167 })
168 }
169 }
170
171
172
173 func TestDiff(t *testing.T) {
174 tests := []struct {
175 in string
176 exitCode int
177 }{
178 {in: "testdata/exitcode.input", exitCode: 1},
179 {in: "testdata/exitcode.golden", exitCode: 0},
180 }
181
182 for _, tt := range tests {
183 resetFlags()
184 *doDiff = true
185
186 initParserMode()
187 initRewrite()
188
189 info, err := os.Lstat(tt.in)
190 if err != nil {
191 t.Error(err)
192 return
193 }
194
195 const maxWeight = 2 << 20
196 var buf, errBuf bytes.Buffer
197 s := newSequencer(maxWeight, &buf, &errBuf)
198 s.Add(fileWeight(tt.in, info), func(r *reporter) error {
199 return processFile(tt.in, info, nil, r)
200 })
201 if errBuf.Len() > 0 {
202 t.Logf("%q", errBuf.Bytes())
203 }
204
205 if s.GetExitCode() != tt.exitCode {
206 t.Errorf("%s: expected exit code %d, got %d", tt.in, tt.exitCode, s.GetExitCode())
207 }
208 }
209 }
210
211
212 func TestCRLF(t *testing.T) {
213 const input = "testdata/crlf.input"
214 const golden = "testdata/crlf.golden"
215
216 data, err := os.ReadFile(input)
217 if err != nil {
218 t.Error(err)
219 }
220 if !bytes.Contains(data, []byte("\r\n")) {
221 t.Errorf("%s contains no CR/LF's", input)
222 }
223
224 data, err = os.ReadFile(golden)
225 if err != nil {
226 t.Error(err)
227 }
228 if bytes.Contains(data, []byte("\r")) {
229 t.Errorf("%s contains CR's", golden)
230 }
231 }
232
233 func TestBackupFile(t *testing.T) {
234 dir, err := os.MkdirTemp("", "gofmt_test")
235 if err != nil {
236 t.Fatal(err)
237 }
238 defer os.RemoveAll(dir)
239 name, err := backupFile(filepath.Join(dir, "foo.go"), []byte(" package main"), 0644)
240 if err != nil {
241 t.Fatal(err)
242 }
243 t.Logf("Created: %s", name)
244 }
245
View as plain text