Black Lives Matter. Support the Equal Justice Initiative.

Source file src/go/types/infer.go

Documentation: go/types

     1  // Copyright 2018 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // This file implements type parameter inference given
     6  // a list of concrete arguments and a parameter list.
     7  
     8  package types
     9  
    10  import (
    11  	"go/token"
    12  	"strings"
    13  )
    14  
    15  // infer attempts to infer the complete set of type arguments for generic function instantiation/call
    16  // based on the given type parameters tparams, type arguments targs, function parameters params, and
    17  // function arguments args, if any. There must be at least one type parameter, no more type arguments
    18  // than type parameters, and params and args must match in number (incl. zero).
    19  // If successful, infer returns the complete list of type arguments, one for each type parameter.
    20  // Otherwise the result is nil and appropriate errors will be reported unless report is set to false.
    21  //
    22  // Inference proceeds in 3 steps:
    23  //
    24  //   1) Start with given type arguments.
    25  //   2) Infer type arguments from typed function arguments.
    26  //   3) Infer type arguments from untyped function arguments.
    27  //
    28  // Constraint type inference is used after each step to expand the set of type arguments.
    29  //
    30  func (check *Checker) infer(posn positioner, tparams []*TypeName, targs []Type, params *Tuple, args []*operand, report bool) (result []Type) {
    31  	if debug {
    32  		defer func() {
    33  			assert(result == nil || len(result) == len(tparams))
    34  			for _, targ := range result {
    35  				assert(targ != nil)
    36  			}
    37  			//check.dump("### inferred targs = %s", result)
    38  		}()
    39  	}
    40  
    41  	// There must be at least one type parameter, and no more type arguments than type parameters.
    42  	n := len(tparams)
    43  	assert(n > 0 && len(targs) <= n)
    44  
    45  	// Function parameters and arguments must match in number.
    46  	assert(params.Len() == len(args))
    47  
    48  	// --- 0 ---
    49  	// If we already have all type arguments, we're done.
    50  	if len(targs) == n {
    51  		return targs
    52  	}
    53  	// len(targs) < n
    54  
    55  	// --- 1 ---
    56  	// Explicitly provided type arguments take precedence over any inferred types;
    57  	// and types inferred via constraint type inference take precedence over types
    58  	// inferred from function arguments.
    59  	// If we have type arguments, see how far we get with constraint type inference.
    60  	if len(targs) > 0 {
    61  		var index int
    62  		targs, index = check.inferB(tparams, targs, report)
    63  		if targs == nil || index < 0 {
    64  			return targs
    65  		}
    66  	}
    67  
    68  	// Continue with the type arguments we have now. Avoid matching generic
    69  	// parameters that already have type arguments against function arguments:
    70  	// It may fail because matching uses type identity while parameter passing
    71  	// uses assignment rules. Instantiate the parameter list with the type
    72  	// arguments we have, and continue with that parameter list.
    73  
    74  	// First, make sure we have a "full" list of type arguments, so of which
    75  	// may be nil (unknown).
    76  	if len(targs) < n {
    77  		targs2 := make([]Type, n)
    78  		copy(targs2, targs)
    79  		targs = targs2
    80  	}
    81  	// len(targs) == n
    82  
    83  	// Substitute type arguments for their respective type parameters in params,
    84  	// if any. Note that nil targs entries are ignored by check.subst.
    85  	// TODO(gri) Can we avoid this (we're setting known type argumemts below,
    86  	//           but that doesn't impact the isParameterized check for now).
    87  	if params.Len() > 0 {
    88  		smap := makeSubstMap(tparams, targs)
    89  		params = check.subst(token.NoPos, params, smap).(*Tuple)
    90  	}
    91  
    92  	// --- 2 ---
    93  	// Unify parameter and argument types for generic parameters with typed arguments
    94  	// and collect the indices of generic parameters with untyped arguments.
    95  	// Terminology: generic parameter = function parameter with a type-parameterized type
    96  	u := newUnifier(check, false)
    97  	u.x.init(tparams)
    98  
    99  	// Set the type arguments which we know already.
   100  	for i, targ := range targs {
   101  		if targ != nil {
   102  			u.x.set(i, targ)
   103  		}
   104  	}
   105  
   106  	errorf := func(kind string, tpar, targ Type, arg *operand) {
   107  		if !report {
   108  			return
   109  		}
   110  		// provide a better error message if we can
   111  		targs, index := u.x.types()
   112  		if index == 0 {
   113  			// The first type parameter couldn't be inferred.
   114  			// If none of them could be inferred, don't try
   115  			// to provide the inferred type in the error msg.
   116  			allFailed := true
   117  			for _, targ := range targs {
   118  				if targ != nil {
   119  					allFailed = false
   120  					break
   121  				}
   122  			}
   123  			if allFailed {
   124  				check.errorf(arg, _Todo, "%s %s of %s does not match %s (cannot infer %s)", kind, targ, arg.expr, tpar, typeNamesString(tparams))
   125  				return
   126  			}
   127  		}
   128  		smap := makeSubstMap(tparams, targs)
   129  		// TODO(rFindley): pass a positioner here, rather than arg.Pos().
   130  		inferred := check.subst(arg.Pos(), tpar, smap)
   131  		if inferred != tpar {
   132  			check.errorf(arg, _Todo, "%s %s of %s does not match inferred type %s for %s", kind, targ, arg.expr, inferred, tpar)
   133  		} else {
   134  			check.errorf(arg, 0, "%s %s of %s does not match %s", kind, targ, arg.expr, tpar)
   135  		}
   136  	}
   137  
   138  	// indices of the generic parameters with untyped arguments - save for later
   139  	var indices []int
   140  	for i, arg := range args {
   141  		par := params.At(i)
   142  		// If we permit bidirectional unification, this conditional code needs to be
   143  		// executed even if par.typ is not parameterized since the argument may be a
   144  		// generic function (for which we want to infer its type arguments).
   145  		if isParameterized(tparams, par.typ) {
   146  			if arg.mode == invalid {
   147  				// An error was reported earlier. Ignore this targ
   148  				// and continue, we may still be able to infer all
   149  				// targs resulting in fewer follon-on errors.
   150  				continue
   151  			}
   152  			if targ := arg.typ; isTyped(targ) {
   153  				// If we permit bidirectional unification, and targ is
   154  				// a generic function, we need to initialize u.y with
   155  				// the respective type parameters of targ.
   156  				if !u.unify(par.typ, targ) {
   157  					errorf("type", par.typ, targ, arg)
   158  					return nil
   159  				}
   160  			} else {
   161  				indices = append(indices, i)
   162  			}
   163  		}
   164  	}
   165  
   166  	// If we've got all type arguments, we're done.
   167  	var index int
   168  	targs, index = u.x.types()
   169  	if index < 0 {
   170  		return targs
   171  	}
   172  
   173  	// See how far we get with constraint type inference.
   174  	// Note that even if we don't have any type arguments, constraint type inference
   175  	// may produce results for constraints that explicitly specify a type.
   176  	targs, index = check.inferB(tparams, targs, report)
   177  	if targs == nil || index < 0 {
   178  		return targs
   179  	}
   180  
   181  	// --- 3 ---
   182  	// Use any untyped arguments to infer additional type arguments.
   183  	// Some generic parameters with untyped arguments may have been given
   184  	// a type by now, we can ignore them.
   185  	for _, i := range indices {
   186  		par := params.At(i)
   187  		// Since untyped types are all basic (i.e., non-composite) types, an
   188  		// untyped argument will never match a composite parameter type; the
   189  		// only parameter type it can possibly match against is a *TypeParam.
   190  		// Thus, only consider untyped arguments for generic parameters that
   191  		// are not of composite types and which don't have a type inferred yet.
   192  		if tpar, _ := par.typ.(*_TypeParam); tpar != nil && targs[tpar.index] == nil {
   193  			arg := args[i]
   194  			targ := Default(arg.typ)
   195  			// The default type for an untyped nil is untyped nil. We must not
   196  			// infer an untyped nil type as type parameter type. Ignore untyped
   197  			// nil by making sure all default argument types are typed.
   198  			if isTyped(targ) && !u.unify(par.typ, targ) {
   199  				errorf("default type", par.typ, targ, arg)
   200  				return nil
   201  			}
   202  		}
   203  	}
   204  
   205  	// If we've got all type arguments, we're done.
   206  	targs, index = u.x.types()
   207  	if index < 0 {
   208  		return targs
   209  	}
   210  
   211  	// Again, follow up with constraint type inference.
   212  	targs, index = check.inferB(tparams, targs, report)
   213  	if targs == nil || index < 0 {
   214  		return targs
   215  	}
   216  
   217  	// At least one type argument couldn't be inferred.
   218  	assert(index >= 0 && targs[index] == nil)
   219  	tpar := tparams[index]
   220  	if report {
   221  		check.errorf(posn, _Todo, "cannot infer %s (%v) (%v)", tpar.name, tpar.pos, targs)
   222  	}
   223  	return nil
   224  }
   225  
   226  // typeNamesString produces a string containing all the
   227  // type names in list suitable for human consumption.
   228  func typeNamesString(list []*TypeName) string {
   229  	// common cases
   230  	n := len(list)
   231  	switch n {
   232  	case 0:
   233  		return ""
   234  	case 1:
   235  		return list[0].name
   236  	case 2:
   237  		return list[0].name + " and " + list[1].name
   238  	}
   239  
   240  	// general case (n > 2)
   241  	var b strings.Builder
   242  	for i, tname := range list[:n-1] {
   243  		if i > 0 {
   244  			b.WriteString(", ")
   245  		}
   246  		b.WriteString(tname.name)
   247  	}
   248  	b.WriteString(", and ")
   249  	b.WriteString(list[n-1].name)
   250  	return b.String()
   251  }
   252  
   253  // IsParameterized reports whether typ contains any of the type parameters of tparams.
   254  func isParameterized(tparams []*TypeName, typ Type) bool {
   255  	w := tpWalker{
   256  		seen:    make(map[Type]bool),
   257  		tparams: tparams,
   258  	}
   259  	return w.isParameterized(typ)
   260  }
   261  
   262  type tpWalker struct {
   263  	seen    map[Type]bool
   264  	tparams []*TypeName
   265  }
   266  
   267  func (w *tpWalker) isParameterized(typ Type) (res bool) {
   268  	// detect cycles
   269  	if x, ok := w.seen[typ]; ok {
   270  		return x
   271  	}
   272  	w.seen[typ] = false
   273  	defer func() {
   274  		w.seen[typ] = res
   275  	}()
   276  
   277  	switch t := typ.(type) {
   278  	case nil, *Basic: // TODO(gri) should nil be handled here?
   279  		break
   280  
   281  	case *Array:
   282  		return w.isParameterized(t.elem)
   283  
   284  	case *Slice:
   285  		return w.isParameterized(t.elem)
   286  
   287  	case *Struct:
   288  		for _, fld := range t.fields {
   289  			if w.isParameterized(fld.typ) {
   290  				return true
   291  			}
   292  		}
   293  
   294  	case *Pointer:
   295  		return w.isParameterized(t.base)
   296  
   297  	case *Tuple:
   298  		n := t.Len()
   299  		for i := 0; i < n; i++ {
   300  			if w.isParameterized(t.At(i).typ) {
   301  				return true
   302  			}
   303  		}
   304  
   305  	case *_Sum:
   306  		return w.isParameterizedList(t.types)
   307  
   308  	case *Signature:
   309  		// t.tparams may not be nil if we are looking at a signature
   310  		// of a generic function type (or an interface method) that is
   311  		// part of the type we're testing. We don't care about these type
   312  		// parameters.
   313  		// Similarly, the receiver of a method may declare (rather then
   314  		// use) type parameters, we don't care about those either.
   315  		// Thus, we only need to look at the input and result parameters.
   316  		return w.isParameterized(t.params) || w.isParameterized(t.results)
   317  
   318  	case *Interface:
   319  		if t.allMethods != nil {
   320  			// TODO(rFindley) at some point we should enforce completeness here
   321  			for _, m := range t.allMethods {
   322  				if w.isParameterized(m.typ) {
   323  					return true
   324  				}
   325  			}
   326  			return w.isParameterizedList(unpackType(t.allTypes))
   327  		}
   328  
   329  		return t.iterate(func(t *Interface) bool {
   330  			for _, m := range t.methods {
   331  				if w.isParameterized(m.typ) {
   332  					return true
   333  				}
   334  			}
   335  			return w.isParameterizedList(unpackType(t.types))
   336  		}, nil)
   337  
   338  	case *Map:
   339  		return w.isParameterized(t.key) || w.isParameterized(t.elem)
   340  
   341  	case *Chan:
   342  		return w.isParameterized(t.elem)
   343  
   344  	case *Named:
   345  		return w.isParameterizedList(t.targs)
   346  
   347  	case *_TypeParam:
   348  		// t must be one of w.tparams
   349  		return t.index < len(w.tparams) && w.tparams[t.index].typ == t
   350  
   351  	case *instance:
   352  		return w.isParameterizedList(t.targs)
   353  
   354  	default:
   355  		unreachable()
   356  	}
   357  
   358  	return false
   359  }
   360  
   361  func (w *tpWalker) isParameterizedList(list []Type) bool {
   362  	for _, t := range list {
   363  		if w.isParameterized(t) {
   364  			return true
   365  		}
   366  	}
   367  	return false
   368  }
   369  
   370  // inferB returns the list of actual type arguments inferred from the type parameters'
   371  // bounds and an initial set of type arguments. If type inference is impossible because
   372  // unification fails, an error is reported if report is set to true, the resulting types
   373  // list is nil, and index is 0.
   374  // Otherwise, types is the list of inferred type arguments, and index is the index of the
   375  // first type argument in that list that couldn't be inferred (and thus is nil). If all
   376  // type arguments were inferred successfully, index is < 0. The number of type arguments
   377  // provided may be less than the number of type parameters, but there must be at least one.
   378  func (check *Checker) inferB(tparams []*TypeName, targs []Type, report bool) (types []Type, index int) {
   379  	assert(len(tparams) >= len(targs) && len(targs) > 0)
   380  
   381  	// Setup bidirectional unification between those structural bounds
   382  	// and the corresponding type arguments (which may be nil!).
   383  	u := newUnifier(check, false)
   384  	u.x.init(tparams)
   385  	u.y = u.x // type parameters between LHS and RHS of unification are identical
   386  
   387  	// Set the type arguments which we know already.
   388  	for i, targ := range targs {
   389  		if targ != nil {
   390  			u.x.set(i, targ)
   391  		}
   392  	}
   393  
   394  	// Unify type parameters with their structural constraints, if any.
   395  	for _, tpar := range tparams {
   396  		typ := tpar.typ.(*_TypeParam)
   397  		sbound := check.structuralType(typ.bound)
   398  		if sbound != nil {
   399  			if !u.unify(typ, sbound) {
   400  				if report {
   401  					check.errorf(tpar, _Todo, "%s does not match %s", tpar, sbound)
   402  				}
   403  				return nil, 0
   404  			}
   405  		}
   406  	}
   407  
   408  	// u.x.types() now contains the incoming type arguments plus any additional type
   409  	// arguments for which there were structural constraints. The newly inferred non-
   410  	// nil entries may still contain references to other type parameters. For instance,
   411  	// for [A any, B interface{type []C}, C interface{type *A}], if A == int
   412  	// was given, unification produced the type list [int, []C, *A]. We eliminate the
   413  	// remaining type parameters by substituting the type parameters in this type list
   414  	// until nothing changes anymore.
   415  	types, _ = u.x.types()
   416  	if debug {
   417  		for i, targ := range targs {
   418  			assert(targ == nil || types[i] == targ)
   419  		}
   420  	}
   421  
   422  	// dirty tracks the indices of all types that may still contain type parameters.
   423  	// We know that nil type entries and entries corresponding to provided (non-nil)
   424  	// type arguments are clean, so exclude them from the start.
   425  	var dirty []int
   426  	for i, typ := range types {
   427  		if typ != nil && (i >= len(targs) || targs[i] == nil) {
   428  			dirty = append(dirty, i)
   429  		}
   430  	}
   431  
   432  	for len(dirty) > 0 {
   433  		// TODO(gri) Instead of creating a new substMap for each iteration,
   434  		// provide an update operation for substMaps and only change when
   435  		// needed. Optimization.
   436  		smap := makeSubstMap(tparams, types)
   437  		n := 0
   438  		for _, index := range dirty {
   439  			t0 := types[index]
   440  			if t1 := check.subst(token.NoPos, t0, smap); t1 != t0 {
   441  				types[index] = t1
   442  				dirty[n] = index
   443  				n++
   444  			}
   445  		}
   446  		dirty = dirty[:n]
   447  	}
   448  
   449  	// Once nothing changes anymore, we may still have type parameters left;
   450  	// e.g., a structural constraint *P may match a type parameter Q but we
   451  	// don't have any type arguments to fill in for *P or Q (issue #45548).
   452  	// Don't let such inferences escape, instead nil them out.
   453  	for i, typ := range types {
   454  		if typ != nil && isParameterized(tparams, typ) {
   455  			types[i] = nil
   456  		}
   457  	}
   458  
   459  	// update index
   460  	index = -1
   461  	for i, typ := range types {
   462  		if typ == nil {
   463  			index = i
   464  			break
   465  		}
   466  	}
   467  
   468  	return
   469  }
   470  
   471  // structuralType returns the structural type of a constraint, if any.
   472  func (check *Checker) structuralType(constraint Type) Type {
   473  	if iface, _ := under(constraint).(*Interface); iface != nil {
   474  		check.completeInterface(token.NoPos, iface)
   475  		types := unpackType(iface.allTypes)
   476  		if len(types) == 1 {
   477  			return types[0]
   478  		}
   479  		return nil
   480  	}
   481  	return constraint
   482  }
   483  

View as plain text