Source file src/cmd/gofmt/gofmt_test.go

     1  // Copyright 2011 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package main
     6  
     7  import (
     8  	"bytes"
     9  	"flag"
    10  	"internal/diff"
    11  	"os"
    12  	"path/filepath"
    13  	"strings"
    14  	"testing"
    15  	"text/scanner"
    16  )
    17  
    18  var update = flag.Bool("update", false, "update .golden files")
    19  
    20  // gofmtFlags looks for a comment of the form
    21  //
    22  //	//gofmt flags
    23  //
    24  // within the first maxLines lines of the given file,
    25  // and returns the flags string, if any. Otherwise it
    26  // returns the empty string.
    27  func gofmtFlags(filename string, maxLines int) string {
    28  	f, err := os.Open(filename)
    29  	if err != nil {
    30  		return "" // ignore errors - they will be found later
    31  	}
    32  	defer f.Close()
    33  
    34  	// initialize scanner
    35  	var s scanner.Scanner
    36  	s.Init(f)
    37  	s.Error = func(*scanner.Scanner, string) {}       // ignore errors
    38  	s.Mode = scanner.GoTokens &^ scanner.SkipComments // want comments
    39  
    40  	// look for //gofmt comment
    41  	for s.Line <= maxLines {
    42  		switch s.Scan() {
    43  		case scanner.Comment:
    44  			const prefix = "//gofmt "
    45  			if t := s.TokenText(); strings.HasPrefix(t, prefix) {
    46  				return strings.TrimSpace(t[len(prefix):])
    47  			}
    48  		case scanner.EOF:
    49  			return ""
    50  		}
    51  	}
    52  
    53  	return ""
    54  }
    55  
    56  // Reset global variables for all flags to their default value.
    57  func resetFlags() {
    58  	*list = false
    59  	*write = false
    60  	*rewriteRule = ""
    61  	*simplifyAST = false
    62  	*doDiff = false
    63  	*allErrors = false
    64  	*cpuprofile = ""
    65  }
    66  
    67  func runTest(t *testing.T, in, out string) {
    68  	resetFlags()
    69  	info, err := os.Lstat(in)
    70  	if err != nil {
    71  		t.Error(err)
    72  		return
    73  	}
    74  	for _, flag := range strings.Split(gofmtFlags(in, 20), " ") {
    75  		elts := strings.SplitN(flag, "=", 2)
    76  		name := elts[0]
    77  		value := ""
    78  		if len(elts) == 2 {
    79  			value = elts[1]
    80  		}
    81  		switch name {
    82  		case "":
    83  			// no flags
    84  		case "-r":
    85  			*rewriteRule = value
    86  		case "-s":
    87  			*simplifyAST = true
    88  		case "-stdin":
    89  			// fake flag - pretend input is from stdin
    90  			info = nil
    91  		default:
    92  			t.Errorf("unrecognized flag name: %s", name)
    93  		}
    94  	}
    95  
    96  	initParserMode()
    97  	initRewrite()
    98  
    99  	const maxWeight = 2 << 20
   100  	var buf, errBuf bytes.Buffer
   101  	s := newSequencer(maxWeight, &buf, &errBuf)
   102  	s.Add(fileWeight(in, info), func(r *reporter) error {
   103  		return processFile(in, info, nil, r)
   104  	})
   105  	if errBuf.Len() > 0 {
   106  		t.Logf("%q", errBuf.Bytes())
   107  	}
   108  	if s.GetExitCode() != 0 {
   109  		t.Fail()
   110  	}
   111  
   112  	expected, err := os.ReadFile(out)
   113  	if err != nil {
   114  		t.Error(err)
   115  		return
   116  	}
   117  
   118  	if got := buf.Bytes(); !bytes.Equal(got, expected) {
   119  		if *update {
   120  			if in != out {
   121  				if err := os.WriteFile(out, got, 0666); err != nil {
   122  					t.Error(err)
   123  				}
   124  				return
   125  			}
   126  			// in == out: don't accidentally destroy input
   127  			t.Errorf("WARNING: -update did not rewrite input file %s", in)
   128  		}
   129  
   130  		t.Errorf("(gofmt %s) != %s (see %s.gofmt)\n%s", in, out, in,
   131  			diff.Diff("expected", expected, "got", got))
   132  		if err := os.WriteFile(in+".gofmt", got, 0666); err != nil {
   133  			t.Error(err)
   134  		}
   135  	}
   136  }
   137  
   138  // TestRewrite processes testdata/*.input files and compares them to the
   139  // corresponding testdata/*.golden files. The gofmt flags used to process
   140  // a file must be provided via a comment of the form
   141  //
   142  //	//gofmt flags
   143  //
   144  // in the processed file within the first 20 lines, if any.
   145  func TestRewrite(t *testing.T) {
   146  	// determine input files
   147  	match, err := filepath.Glob("testdata/*.input")
   148  	if err != nil {
   149  		t.Fatal(err)
   150  	}
   151  
   152  	// add larger examples
   153  	match = append(match, "gofmt.go", "gofmt_test.go")
   154  
   155  	for _, in := range match {
   156  		name := filepath.Base(in)
   157  		t.Run(name, func(t *testing.T) {
   158  			out := in // for files where input and output are identical
   159  			if strings.HasSuffix(in, ".input") {
   160  				out = in[:len(in)-len(".input")] + ".golden"
   161  			}
   162  			runTest(t, in, out)
   163  			if in != out && !t.Failed() {
   164  				// Check idempotence.
   165  				runTest(t, out, out)
   166  			}
   167  		})
   168  	}
   169  }
   170  
   171  // TestDiff runs gofmt with the -d flag on the input files and checks that the
   172  // expected exit code is set.
   173  func TestDiff(t *testing.T) {
   174  	tests := []struct {
   175  		in       string
   176  		exitCode int
   177  	}{
   178  		{in: "testdata/exitcode.input", exitCode: 1},
   179  		{in: "testdata/exitcode.golden", exitCode: 0},
   180  	}
   181  
   182  	for _, tt := range tests {
   183  		resetFlags()
   184  		*doDiff = true
   185  
   186  		initParserMode()
   187  		initRewrite()
   188  
   189  		info, err := os.Lstat(tt.in)
   190  		if err != nil {
   191  			t.Error(err)
   192  			return
   193  		}
   194  
   195  		const maxWeight = 2 << 20
   196  		var buf, errBuf bytes.Buffer
   197  		s := newSequencer(maxWeight, &buf, &errBuf)
   198  		s.Add(fileWeight(tt.in, info), func(r *reporter) error {
   199  			return processFile(tt.in, info, nil, r)
   200  		})
   201  		if errBuf.Len() > 0 {
   202  			t.Logf("%q", errBuf.Bytes())
   203  		}
   204  
   205  		if s.GetExitCode() != tt.exitCode {
   206  			t.Errorf("%s: expected exit code %d, got %d", tt.in, tt.exitCode, s.GetExitCode())
   207  		}
   208  	}
   209  }
   210  
   211  // Test case for issue 3961.
   212  func TestCRLF(t *testing.T) {
   213  	const input = "testdata/crlf.input"   // must contain CR/LF's
   214  	const golden = "testdata/crlf.golden" // must not contain any CR's
   215  
   216  	data, err := os.ReadFile(input)
   217  	if err != nil {
   218  		t.Error(err)
   219  	}
   220  	if !bytes.Contains(data, []byte("\r\n")) {
   221  		t.Errorf("%s contains no CR/LF's", input)
   222  	}
   223  
   224  	data, err = os.ReadFile(golden)
   225  	if err != nil {
   226  		t.Error(err)
   227  	}
   228  	if bytes.Contains(data, []byte("\r")) {
   229  		t.Errorf("%s contains CR's", golden)
   230  	}
   231  }
   232  
   233  func TestBackupFile(t *testing.T) {
   234  	dir, err := os.MkdirTemp("", "gofmt_test")
   235  	if err != nil {
   236  		t.Fatal(err)
   237  	}
   238  	defer os.RemoveAll(dir)
   239  	name, err := backupFile(filepath.Join(dir, "foo.go"), []byte("  package main"), 0644)
   240  	if err != nil {
   241  		t.Fatal(err)
   242  	}
   243  	t.Logf("Created: %s", name)
   244  }
   245  

View as plain text