1  
     2  
     3  
     4  
     5  
     6  
     7  package types2
     8  
     9  import (
    10  	"cmd/compile/internal/syntax"
    11  )
    12  
    13  type substMap map[*TypeParam]Type
    14  
    15  
    16  
    17  func makeSubstMap(tpars []*TypeParam, targs []Type) substMap {
    18  	assert(len(tpars) == len(targs))
    19  	proj := make(substMap, len(tpars))
    20  	for i, tpar := range tpars {
    21  		proj[tpar] = targs[i]
    22  	}
    23  	return proj
    24  }
    25  
    26  
    27  
    28  func makeRenameMap(from, to []*TypeParam) substMap {
    29  	assert(len(from) == len(to))
    30  	proj := make(substMap, len(from))
    31  	for i, tpar := range from {
    32  		proj[tpar] = to[i]
    33  	}
    34  	return proj
    35  }
    36  
    37  func (m substMap) empty() bool {
    38  	return len(m) == 0
    39  }
    40  
    41  func (m substMap) lookup(tpar *TypeParam) Type {
    42  	if t := m[tpar]; t != nil {
    43  		return t
    44  	}
    45  	return tpar
    46  }
    47  
    48  
    49  
    50  
    51  
    52  
    53  
    54  
    55  func (check *Checker) subst(pos syntax.Pos, typ Type, smap substMap, expanding *Named, ctxt *Context) Type {
    56  	assert(expanding != nil || ctxt != nil)
    57  
    58  	if smap.empty() {
    59  		return typ
    60  	}
    61  
    62  	
    63  	switch t := typ.(type) {
    64  	case *Basic:
    65  		return typ 
    66  	case *TypeParam:
    67  		return smap.lookup(t)
    68  	}
    69  
    70  	
    71  	subst := subster{
    72  		pos:       pos,
    73  		smap:      smap,
    74  		check:     check,
    75  		expanding: expanding,
    76  		ctxt:      ctxt,
    77  	}
    78  	return subst.typ(typ)
    79  }
    80  
    81  type subster struct {
    82  	pos       syntax.Pos
    83  	smap      substMap
    84  	check     *Checker 
    85  	expanding *Named   
    86  	ctxt      *Context
    87  }
    88  
    89  func (subst *subster) typ(typ Type) Type {
    90  	switch t := typ.(type) {
    91  	case nil:
    92  		
    93  		panic("nil typ")
    94  
    95  	case *Basic:
    96  		
    97  
    98  	case *Alias:
    99  		
   100  		
   101  		orig := t.Origin()
   102  		n := orig.TypeParams().Len()
   103  		if n == 0 {
   104  			return t 
   105  		}
   106  
   107  		
   108  		if t.TypeArgs().Len() != n {
   109  			return Typ[Invalid] 
   110  		}
   111  
   112  		
   113  		
   114  		
   115  		
   116  		if targs := substList(t.TypeArgs().list(), subst.typ); targs != nil {
   117  			return subst.check.newAliasInstance(subst.pos, t.orig, targs, subst.expanding, subst.ctxt)
   118  		}
   119  
   120  	case *Array:
   121  		elem := subst.typOrNil(t.elem)
   122  		if elem != t.elem {
   123  			return &Array{len: t.len, elem: elem}
   124  		}
   125  
   126  	case *Slice:
   127  		elem := subst.typOrNil(t.elem)
   128  		if elem != t.elem {
   129  			return &Slice{elem: elem}
   130  		}
   131  
   132  	case *Struct:
   133  		if fields := substList(t.fields, subst.var_); fields != nil {
   134  			s := &Struct{fields: fields, tags: t.tags}
   135  			s.markComplete()
   136  			return s
   137  		}
   138  
   139  	case *Pointer:
   140  		base := subst.typ(t.base)
   141  		if base != t.base {
   142  			return &Pointer{base: base}
   143  		}
   144  
   145  	case *Tuple:
   146  		return subst.tuple(t)
   147  
   148  	case *Signature:
   149  		
   150  		
   151  		
   152  		
   153  		
   154  		
   155  		
   156  		
   157  		
   158  		
   159  		
   160  		
   161  		
   162  		recv := t.recv
   163  
   164  		params := subst.tuple(t.params)
   165  		results := subst.tuple(t.results)
   166  		if params != t.params || results != t.results {
   167  			return &Signature{
   168  				rparams: t.rparams,
   169  				
   170  				tparams: t.tparams,
   171  				
   172  				recv:     recv,
   173  				params:   params,
   174  				results:  results,
   175  				variadic: t.variadic,
   176  			}
   177  		}
   178  
   179  	case *Union:
   180  		if terms := substList(t.terms, subst.term); terms != nil {
   181  			
   182  			
   183  			
   184  			return &Union{terms}
   185  		}
   186  
   187  	case *Interface:
   188  		methods := substList(t.methods, subst.func_)
   189  		embeddeds := substList(t.embeddeds, subst.typ)
   190  		if methods != nil || embeddeds != nil {
   191  			if methods == nil {
   192  				methods = t.methods
   193  			}
   194  			if embeddeds == nil {
   195  				embeddeds = t.embeddeds
   196  			}
   197  			iface := subst.check.newInterface()
   198  			iface.embeddeds = embeddeds
   199  			iface.embedPos = t.embedPos
   200  			iface.implicit = t.implicit
   201  			assert(t.complete) 
   202  			iface.complete = t.complete
   203  			
   204  			
   205  			
   206  			
   207  			
   208  			
   209  			
   210  			
   211  			
   212  			
   213  			
   214  			
   215  			
   216  			iface.methods, _ = replaceRecvType(methods, t, iface)
   217  
   218  			
   219  			if subst.check == nil { 
   220  				iface.typeSet()
   221  			}
   222  			return iface
   223  		}
   224  
   225  	case *Map:
   226  		key := subst.typ(t.key)
   227  		elem := subst.typ(t.elem)
   228  		if key != t.key || elem != t.elem {
   229  			return &Map{key: key, elem: elem}
   230  		}
   231  
   232  	case *Chan:
   233  		elem := subst.typ(t.elem)
   234  		if elem != t.elem {
   235  			return &Chan{dir: t.dir, elem: elem}
   236  		}
   237  
   238  	case *Named:
   239  		
   240  		
   241  		
   242  		
   243  		
   244  		orig := t.Origin()
   245  		n := orig.TypeParams().Len()
   246  		if n == 0 {
   247  			return t 
   248  		}
   249  
   250  		if t.TypeArgs().Len() != n {
   251  			return Typ[Invalid] 
   252  		}
   253  
   254  		
   255  		
   256  		
   257  		
   258  		if targs := substList(t.TypeArgs().list(), subst.typ); targs != nil {
   259  			
   260  			
   261  			
   262  			
   263  			return subst.check.instance(subst.pos, orig, targs, subst.expanding, subst.ctxt)
   264  		}
   265  
   266  	case *TypeParam:
   267  		return subst.smap.lookup(t)
   268  
   269  	default:
   270  		panic("unreachable")
   271  	}
   272  
   273  	return typ
   274  }
   275  
   276  
   277  
   278  
   279  func (subst *subster) typOrNil(typ Type) Type {
   280  	if typ == nil {
   281  		return Typ[Invalid]
   282  	}
   283  	return subst.typ(typ)
   284  }
   285  
   286  func (subst *subster) var_(v *Var) *Var {
   287  	if v != nil {
   288  		if typ := subst.typ(v.typ); typ != v.typ {
   289  			return cloneVar(v, typ)
   290  		}
   291  	}
   292  	return v
   293  }
   294  
   295  func cloneVar(v *Var, typ Type) *Var {
   296  	copy := *v
   297  	copy.typ = typ
   298  	copy.origin = v.Origin()
   299  	return ©
   300  }
   301  
   302  func (subst *subster) tuple(t *Tuple) *Tuple {
   303  	if t != nil {
   304  		if vars := substList(t.vars, subst.var_); vars != nil {
   305  			return &Tuple{vars: vars}
   306  		}
   307  	}
   308  	return t
   309  }
   310  
   311  
   312  
   313  
   314  
   315  func substList[T comparable](in []T, subst func(T) T) (out []T) {
   316  	for i, t := range in {
   317  		if u := subst(t); u != t {
   318  			if out == nil {
   319  				
   320  				out = make([]T, len(in))
   321  				copy(out, in)
   322  			}
   323  			out[i] = u
   324  		}
   325  	}
   326  	return
   327  }
   328  
   329  func (subst *subster) func_(f *Func) *Func {
   330  	if f != nil {
   331  		if typ := subst.typ(f.typ); typ != f.typ {
   332  			return cloneFunc(f, typ)
   333  		}
   334  	}
   335  	return f
   336  }
   337  
   338  func cloneFunc(f *Func, typ Type) *Func {
   339  	copy := *f
   340  	copy.typ = typ
   341  	copy.origin = f.Origin()
   342  	return ©
   343  }
   344  
   345  func (subst *subster) term(t *Term) *Term {
   346  	if typ := subst.typ(t.typ); typ != t.typ {
   347  		return NewTerm(t.tilde, typ)
   348  	}
   349  	return t
   350  }
   351  
   352  
   353  
   354  
   355  
   356  
   357  
   358  func replaceRecvType(in []*Func, old, new Type) (out []*Func, copied bool) {
   359  	out = in
   360  	for i, method := range in {
   361  		sig := method.Signature()
   362  		if sig.recv != nil && sig.recv.Type() == old {
   363  			if !copied {
   364  				
   365  				
   366  				
   367  				out = make([]*Func, len(in))
   368  				copy(out, in)
   369  				copied = true
   370  			}
   371  			newsig := *sig
   372  			newsig.recv = cloneVar(sig.recv, new)
   373  			out[i] = cloneFunc(method, &newsig)
   374  		}
   375  	}
   376  	return
   377  }
   378  
View as plain text