Source file src/go/types/struct.go

     1  // Copyright 2021 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package types
     6  
     7  import (
     8  	"go/ast"
     9  	"go/token"
    10  	. "internal/types/errors"
    11  	"strconv"
    12  )
    13  
    14  // ----------------------------------------------------------------------------
    15  // API
    16  
    17  // A Struct represents a struct type.
    18  type Struct struct {
    19  	fields []*Var   // fields != nil indicates the struct is set up (possibly with len(fields) == 0)
    20  	tags   []string // field tags; nil if there are no tags
    21  }
    22  
    23  // NewStruct returns a new struct with the given fields and corresponding field tags.
    24  // If a field with index i has a tag, tags[i] must be that tag, but len(tags) may be
    25  // only as long as required to hold the tag with the largest index i. Consequently,
    26  // if no field has a tag, tags may be nil.
    27  func NewStruct(fields []*Var, tags []string) *Struct {
    28  	var fset objset
    29  	for _, f := range fields {
    30  		if f.name != "_" && fset.insert(f) != nil {
    31  			panic("multiple fields with the same name")
    32  		}
    33  	}
    34  	if len(tags) > len(fields) {
    35  		panic("more tags than fields")
    36  	}
    37  	s := &Struct{fields: fields, tags: tags}
    38  	s.markComplete()
    39  	return s
    40  }
    41  
    42  // NumFields returns the number of fields in the struct (including blank and embedded fields).
    43  func (s *Struct) NumFields() int { return len(s.fields) }
    44  
    45  // Field returns the i'th field for 0 <= i < NumFields().
    46  func (s *Struct) Field(i int) *Var { return s.fields[i] }
    47  
    48  // Tag returns the i'th field tag for 0 <= i < NumFields().
    49  func (s *Struct) Tag(i int) string {
    50  	if i < len(s.tags) {
    51  		return s.tags[i]
    52  	}
    53  	return ""
    54  }
    55  
    56  func (t *Struct) Underlying() Type { return t }
    57  func (t *Struct) String() string   { return TypeString(t, nil) }
    58  
    59  // ----------------------------------------------------------------------------
    60  // Implementation
    61  
    62  func (s *Struct) markComplete() {
    63  	if s.fields == nil {
    64  		s.fields = make([]*Var, 0)
    65  	}
    66  }
    67  
    68  func (check *Checker) structType(styp *Struct, e *ast.StructType) {
    69  	list := e.Fields
    70  	if list == nil {
    71  		styp.markComplete()
    72  		return
    73  	}
    74  
    75  	// struct fields and tags
    76  	var fields []*Var
    77  	var tags []string
    78  
    79  	// for double-declaration checks
    80  	var fset objset
    81  
    82  	// current field typ and tag
    83  	var typ Type
    84  	var tag string
    85  	add := func(ident *ast.Ident, embedded bool, pos token.Pos) {
    86  		if tag != "" && tags == nil {
    87  			tags = make([]string, len(fields))
    88  		}
    89  		if tags != nil {
    90  			tags = append(tags, tag)
    91  		}
    92  
    93  		name := ident.Name
    94  		fld := NewField(pos, check.pkg, name, typ, embedded)
    95  		// spec: "Within a struct, non-blank field names must be unique."
    96  		if name == "_" || check.declareInSet(&fset, pos, fld) {
    97  			fields = append(fields, fld)
    98  			check.recordDef(ident, fld)
    99  		}
   100  	}
   101  
   102  	// addInvalid adds an embedded field of invalid type to the struct for
   103  	// fields with errors; this keeps the number of struct fields in sync
   104  	// with the source as long as the fields are _ or have different names
   105  	// (go.dev/issue/25627).
   106  	addInvalid := func(ident *ast.Ident, pos token.Pos) {
   107  		typ = Typ[Invalid]
   108  		tag = ""
   109  		add(ident, true, pos)
   110  	}
   111  
   112  	for _, f := range list.List {
   113  		typ = check.varType(f.Type)
   114  		tag = check.tag(f.Tag)
   115  		if len(f.Names) > 0 {
   116  			// named fields
   117  			for _, name := range f.Names {
   118  				add(name, false, name.Pos())
   119  			}
   120  		} else {
   121  			// embedded field
   122  			// spec: "An embedded type must be specified as a type name T or as a
   123  			// pointer to a non-interface type name *T, and T itself may not be a
   124  			// pointer type."
   125  			pos := f.Type.Pos() // position of type, for errors
   126  			name := embeddedFieldIdent(f.Type)
   127  			if name == nil {
   128  				check.errorf(f.Type, InvalidSyntaxTree, "embedded field type %s has no name", f.Type)
   129  				name = ast.NewIdent("_")
   130  				name.NamePos = pos
   131  				addInvalid(name, pos)
   132  				continue
   133  			}
   134  			add(name, true, name.Pos()) // struct{p.T} field has position of T
   135  
   136  			// Because we have a name, typ must be of the form T or *T, where T is the name
   137  			// of a (named or alias) type, and t (= deref(typ)) must be the type of T.
   138  			// We must delay this check to the end because we don't want to instantiate
   139  			// (via under(t)) a possibly incomplete type.
   140  
   141  			// for use in the closure below
   142  			embeddedTyp := typ
   143  			embeddedPos := f.Type
   144  
   145  			check.later(func() {
   146  				t, isPtr := deref(embeddedTyp)
   147  				switch u := under(t).(type) {
   148  				case *Basic:
   149  					if !isValid(t) {
   150  						// error was reported before
   151  						return
   152  					}
   153  					// unsafe.Pointer is treated like a regular pointer
   154  					if u.kind == UnsafePointer {
   155  						check.error(embeddedPos, InvalidPtrEmbed, "embedded field type cannot be unsafe.Pointer")
   156  					}
   157  				case *Pointer:
   158  					check.error(embeddedPos, InvalidPtrEmbed, "embedded field type cannot be a pointer")
   159  				case *Interface:
   160  					if isTypeParam(t) {
   161  						// The error code here is inconsistent with other error codes for
   162  						// invalid embedding, because this restriction may be relaxed in the
   163  						// future, and so it did not warrant a new error code.
   164  						check.error(embeddedPos, MisplacedTypeParam, "embedded field type cannot be a (pointer to a) type parameter")
   165  						break
   166  					}
   167  					if isPtr {
   168  						check.error(embeddedPos, InvalidPtrEmbed, "embedded field type cannot be a pointer to an interface")
   169  					}
   170  				}
   171  			}).describef(embeddedPos, "check embedded type %s", embeddedTyp)
   172  		}
   173  	}
   174  
   175  	styp.fields = fields
   176  	styp.tags = tags
   177  	styp.markComplete()
   178  }
   179  
   180  func embeddedFieldIdent(e ast.Expr) *ast.Ident {
   181  	switch e := e.(type) {
   182  	case *ast.Ident:
   183  		return e
   184  	case *ast.StarExpr:
   185  		// *T is valid, but **T is not
   186  		if _, ok := e.X.(*ast.StarExpr); !ok {
   187  			return embeddedFieldIdent(e.X)
   188  		}
   189  	case *ast.SelectorExpr:
   190  		return e.Sel
   191  	case *ast.IndexExpr:
   192  		return embeddedFieldIdent(e.X)
   193  	case *ast.IndexListExpr:
   194  		return embeddedFieldIdent(e.X)
   195  	}
   196  	return nil // invalid embedded field
   197  }
   198  
   199  func (check *Checker) declareInSet(oset *objset, pos token.Pos, obj Object) bool {
   200  	if alt := oset.insert(obj); alt != nil {
   201  		check.errorf(atPos(pos), DuplicateDecl, "%s redeclared", obj.Name())
   202  		check.reportAltDecl(alt)
   203  		return false
   204  	}
   205  	return true
   206  }
   207  
   208  func (check *Checker) tag(t *ast.BasicLit) string {
   209  	if t != nil {
   210  		if t.Kind == token.STRING {
   211  			if val, err := strconv.Unquote(t.Value); err == nil {
   212  				return val
   213  			}
   214  		}
   215  		check.errorf(t, InvalidSyntaxTree, "incorrect tag syntax: %q", t.Value)
   216  	}
   217  	return ""
   218  }
   219  

View as plain text