Source file src/cmd/vendor/golang.org/x/tools/go/analysis/passes/modernize/slicescontains.go

     1  // Copyright 2024 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package modernize
     6  
     7  import (
     8  	"fmt"
     9  	"go/ast"
    10  	"go/token"
    11  	"go/types"
    12  
    13  	"golang.org/x/tools/go/analysis"
    14  	"golang.org/x/tools/go/analysis/passes/inspect"
    15  	"golang.org/x/tools/go/ast/inspector"
    16  	"golang.org/x/tools/go/types/typeutil"
    17  	"golang.org/x/tools/internal/analysis/analyzerutil"
    18  	typeindexanalyzer "golang.org/x/tools/internal/analysis/typeindex"
    19  	"golang.org/x/tools/internal/astutil"
    20  	"golang.org/x/tools/internal/refactor"
    21  	"golang.org/x/tools/internal/typeparams"
    22  	"golang.org/x/tools/internal/typesinternal"
    23  	"golang.org/x/tools/internal/typesinternal/typeindex"
    24  	"golang.org/x/tools/internal/versions"
    25  )
    26  
    27  var SlicesContainsAnalyzer = &analysis.Analyzer{
    28  	Name: "slicescontains",
    29  	Doc:  analyzerutil.MustExtractDoc(doc, "slicescontains"),
    30  	Requires: []*analysis.Analyzer{
    31  		inspect.Analyzer,
    32  		typeindexanalyzer.Analyzer,
    33  	},
    34  	Run: slicescontains,
    35  	URL: "https://pkg.go.dev/golang.org/x/tools/go/analysis/passes/modernize#slicescontains",
    36  }
    37  
    38  // The slicescontains pass identifies loops that can be replaced by a
    39  // call to slices.Contains{,Func}. For example:
    40  //
    41  //	for i, elem := range s {
    42  //		if elem == needle {
    43  //			...
    44  //			break
    45  //		}
    46  //	}
    47  //
    48  // =>
    49  //
    50  //	if slices.Contains(s, needle) { ... }
    51  //
    52  // Variants:
    53  //   - if the if-condition is f(elem), the replacement
    54  //     uses slices.ContainsFunc(s, f).
    55  //   - if the if-body is "return true" and the fallthrough
    56  //     statement is "return false" (or vice versa), the
    57  //     loop becomes "return [!]slices.Contains(...)".
    58  //   - if the if-body is "found = true" and the previous
    59  //     statement is "found = false" (or vice versa), the
    60  //     loop becomes "found = [!]slices.Contains(...)".
    61  //
    62  // It rejects candidates whose needle/predicate expression from the if-statement
    63  // has side effects to avoid changes in program behavior.
    64  func slicescontains(pass *analysis.Pass) (any, error) {
    65  	// Skip the analyzer in packages where its
    66  	// fixes would create an import cycle.
    67  	if within(pass, "slices", "runtime") {
    68  		return nil, nil
    69  	}
    70  
    71  	var (
    72  		index = pass.ResultOf[typeindexanalyzer.Analyzer].(*typeindex.Index)
    73  		info  = pass.TypesInfo
    74  	)
    75  
    76  	// check is called for each RangeStmt of this form:
    77  	//   for i, elem := range s { if cond { ... } }
    78  	check := func(file *ast.File, curRange inspector.Cursor) {
    79  		rng := curRange.Node().(*ast.RangeStmt)
    80  		ifStmt := rng.Body.List[0].(*ast.IfStmt)
    81  
    82  		// isSliceElem reports whether e denotes the
    83  		// current slice element (elem or s[i]).
    84  		isSliceElem := func(e ast.Expr) bool {
    85  			if rng.Value != nil && astutil.EqualSyntax(e, rng.Value) {
    86  				return true // "elem"
    87  			}
    88  			if x, ok := e.(*ast.IndexExpr); ok &&
    89  				astutil.EqualSyntax(x.X, rng.X) &&
    90  				astutil.EqualSyntax(x.Index, rng.Key) {
    91  				return true // "s[i]"
    92  			}
    93  			return false
    94  		}
    95  
    96  		// Examine the condition for one of these forms:
    97  		//
    98  		// - if elem or s[i] == needle  { ... } => Contains
    99  		// - if predicate(s[i] or elem) { ... } => ContainsFunc
   100  		var (
   101  			funcName string   // "Contains" or "ContainsFunc"
   102  			arg2     ast.Expr // second argument to func (needle or predicate)
   103  		)
   104  		switch cond := ifStmt.Cond.(type) {
   105  		case *ast.BinaryExpr:
   106  			if cond.Op == token.EQL {
   107  				var elem ast.Expr
   108  				if isSliceElem(cond.X) {
   109  					funcName = "Contains"
   110  					elem = cond.X
   111  					arg2 = cond.Y // "if elem == needle"
   112  				} else if isSliceElem(cond.Y) {
   113  					funcName = "Contains"
   114  					elem = cond.Y
   115  					arg2 = cond.X // "if needle == elem"
   116  				}
   117  
   118  				// Reject if elem and needle have different types.
   119  				if elem != nil {
   120  					tElem := info.TypeOf(elem)
   121  					tNeedle := info.TypeOf(arg2)
   122  					if !types.Identical(tElem, tNeedle) {
   123  						// Avoid ill-typed slices.Contains([]error, any).
   124  						if !types.AssignableTo(tNeedle, tElem) {
   125  							return
   126  						}
   127  						// TODO(adonovan): relax this check to allow
   128  						//   slices.Contains([]error, error(any)),
   129  						// inserting an explicit widening conversion
   130  						// around the needle.
   131  						return
   132  					}
   133  				}
   134  			}
   135  
   136  		case *ast.CallExpr:
   137  			if len(cond.Args) == 1 &&
   138  				isSliceElem(cond.Args[0]) &&
   139  				typeutil.Callee(info, cond) != nil { // not a conversion
   140  
   141  				// Attempt to get signature
   142  				sig, isSignature := info.TypeOf(cond.Fun).(*types.Signature)
   143  				if isSignature {
   144  					// skip variadic functions
   145  					if sig.Variadic() {
   146  						return
   147  					}
   148  
   149  					// Slice element type must match function parameter type.
   150  					var (
   151  						tElem  = typeparams.CoreType(info.TypeOf(rng.X)).(*types.Slice).Elem()
   152  						tParam = sig.Params().At(0).Type()
   153  					)
   154  					if !types.Identical(tElem, tParam) {
   155  						return
   156  					}
   157  				}
   158  
   159  				funcName = "ContainsFunc"
   160  				arg2 = cond.Fun // "if predicate(elem)"
   161  			}
   162  		}
   163  		if funcName == "" {
   164  			return // not a candidate for Contains{,Func}
   165  		}
   166  
   167  		// body is the "true" body.
   168  		body := ifStmt.Body
   169  		if len(body.List) == 0 {
   170  			// (We could perhaps delete the loop entirely.)
   171  			return
   172  		}
   173  
   174  		// Reject if needle/predicate expression has side effects.
   175  		if !typesinternal.NoEffects(info, arg2) {
   176  			return
   177  		}
   178  
   179  		// Reject if the body, needle or predicate references either range variable.
   180  		usesRangeVar := func(n ast.Node) bool {
   181  			cur, ok := curRange.FindNode(n)
   182  			if !ok {
   183  				panic(fmt.Sprintf("FindNode(%T) failed", n))
   184  			}
   185  			return uses(index, cur, info.Defs[rng.Key.(*ast.Ident)]) ||
   186  				rng.Value != nil && uses(index, cur, info.Defs[rng.Value.(*ast.Ident)])
   187  		}
   188  		if usesRangeVar(body) {
   189  			// Body uses range var "i" or "elem".
   190  			//
   191  			// (The check for "i" could be relaxed when we
   192  			// generalize this to support slices.Index;
   193  			// and the check for "elem" could be relaxed
   194  			// if "elem" can safely be replaced in the
   195  			// body by "needle".)
   196  			return
   197  		}
   198  		if usesRangeVar(arg2) {
   199  			return
   200  		}
   201  
   202  		// Prepare slices.Contains{,Func} call.
   203  		prefix, importEdits := refactor.AddImport(info, file, "slices", "slices", funcName, rng.Pos())
   204  		contains := fmt.Sprintf("%s%s(%s, %s)",
   205  			prefix,
   206  			funcName,
   207  			astutil.Format(pass.Fset, rng.X),
   208  			astutil.Format(pass.Fset, arg2))
   209  
   210  		report := func(edits []analysis.TextEdit) {
   211  			pass.Report(analysis.Diagnostic{
   212  				Pos:     rng.Pos(),
   213  				End:     rng.End(),
   214  				Message: fmt.Sprintf("Loop can be simplified using slices.%s", funcName),
   215  				SuggestedFixes: []analysis.SuggestedFix{{
   216  					Message:   "Replace loop by call to slices." + funcName,
   217  					TextEdits: append(edits, importEdits...),
   218  				}},
   219  			})
   220  		}
   221  
   222  		// Last statement of body must return/break out of the loop.
   223  		//
   224  		// TODO(adonovan): opt:consider avoiding FindNode with new API of form:
   225  		//    curRange.Get(edge.RangeStmt_Body, -1).
   226  		//             Get(edge.BodyStmt_List, 0).
   227  		//             Get(edge.IfStmt_Body)
   228  		curBody, _ := curRange.FindNode(body)
   229  		curLastStmt, _ := curBody.LastChild()
   230  
   231  		// Reject if any statement in the body except the
   232  		// last has a free continuation (continue or break)
   233  		// that might affected by melting down the loop.
   234  		//
   235  		// TODO(adonovan): relax check by analyzing branch target.
   236  		for curBodyStmt := range curBody.Children() {
   237  			if curBodyStmt != curLastStmt {
   238  				for range curBodyStmt.Preorder((*ast.BranchStmt)(nil), (*ast.ReturnStmt)(nil)) {
   239  					return
   240  				}
   241  			}
   242  		}
   243  
   244  		switch lastStmt := curLastStmt.Node().(type) {
   245  		case *ast.ReturnStmt:
   246  			// Have: for ... range seq { if ... { stmts; return x } }
   247  
   248  			// Special case:
   249  			// body={ return true } next="return false"   (or negation)
   250  			// => return [!]slices.Contains(...)
   251  			if curNext, ok := curRange.NextSibling(); ok {
   252  				nextStmt := curNext.Node().(ast.Stmt)
   253  				tval := isReturnTrueOrFalse(info, lastStmt)
   254  				fval := isReturnTrueOrFalse(info, nextStmt)
   255  				if len(body.List) == 1 && tval*fval < 0 {
   256  					//    for ... { if ... { return true/false } }
   257  					// => return [!]slices.Contains(...)
   258  					report([]analysis.TextEdit{
   259  						// Delete the range statement and following space.
   260  						{
   261  							Pos: rng.Pos(),
   262  							End: nextStmt.Pos(),
   263  						},
   264  						// Change return to [!]slices.Contains(...).
   265  						{
   266  							Pos: nextStmt.Pos(),
   267  							End: nextStmt.End(),
   268  							NewText: fmt.Appendf(nil, "return %s%s",
   269  								cond(tval > 0, "", "!"),
   270  								contains),
   271  						},
   272  					})
   273  					return
   274  				}
   275  			}
   276  
   277  			// General case:
   278  			// => if slices.Contains(...) { stmts; return x }
   279  			report([]analysis.TextEdit{
   280  				// Replace "for ... { if ... " with "if slices.Contains(...)".
   281  				{
   282  					Pos:     rng.Pos(),
   283  					End:     ifStmt.Body.Pos(),
   284  					NewText: fmt.Appendf(nil, "if %s ", contains),
   285  				},
   286  				// Delete '}' of range statement and preceding space.
   287  				{
   288  					Pos: ifStmt.Body.End(),
   289  					End: rng.End(),
   290  				},
   291  			})
   292  			return
   293  
   294  		case *ast.BranchStmt:
   295  			if lastStmt.Tok == token.BREAK && lastStmt.Label == nil { // unlabeled break
   296  				// Have: for ... { if ... { stmts; break } }
   297  
   298  				var prevStmt ast.Stmt // previous statement to range (if any)
   299  				if curPrev, ok := curRange.PrevSibling(); ok {
   300  					// If the RangeStmt's previous sibling is a Stmt,
   301  					// the RangeStmt must be among the Body list of
   302  					// a BlockStmt, CauseClause, or CommClause.
   303  					// In all cases, the prevStmt is the immediate
   304  					// predecessor of the RangeStmt during execution.
   305  					//
   306  					// (This is not true for Stmts in general;
   307  					// see [Cursor.Children] and #71074.)
   308  					prevStmt, _ = curPrev.Node().(ast.Stmt)
   309  				}
   310  
   311  				// Special case:
   312  				// prev="lhs = false" body={ lhs = true; break }
   313  				// => lhs = slices.Contains(...) (or its negation)
   314  				if assign, ok := body.List[0].(*ast.AssignStmt); ok &&
   315  					len(body.List) == 2 &&
   316  					assign.Tok == token.ASSIGN &&
   317  					len(assign.Lhs) == 1 &&
   318  					len(assign.Rhs) == 1 {
   319  
   320  					// Have: body={ lhs = rhs; break }
   321  					if prevAssign, ok := prevStmt.(*ast.AssignStmt); ok &&
   322  						len(prevAssign.Lhs) == 1 &&
   323  						len(prevAssign.Rhs) == 1 &&
   324  						astutil.EqualSyntax(prevAssign.Lhs[0], assign.Lhs[0]) &&
   325  						isTrueOrFalse(info, assign.Rhs[0]) ==
   326  							-isTrueOrFalse(info, prevAssign.Rhs[0]) {
   327  
   328  						// Have:
   329  						//    lhs = false
   330  						//    for ... { if ... { lhs = true; break } }
   331  						//  =>
   332  						//    lhs = slices.Contains(...)
   333  						//
   334  						// TODO(adonovan):
   335  						// - support "var lhs bool = false" and variants.
   336  						// - allow the break to be omitted.
   337  						neg := cond(isTrueOrFalse(info, assign.Rhs[0]) < 0, "!", "")
   338  						report([]analysis.TextEdit{
   339  							// Replace "rhs" of previous assignment by [!]slices.Contains(...)
   340  							{
   341  								Pos:     prevAssign.Rhs[0].Pos(),
   342  								End:     prevAssign.Rhs[0].End(),
   343  								NewText: []byte(neg + contains),
   344  							},
   345  							// Delete the loop and preceding space.
   346  							{
   347  								Pos: prevAssign.Rhs[0].End(),
   348  								End: rng.End(),
   349  							},
   350  						})
   351  						return
   352  					}
   353  				}
   354  
   355  				// General case:
   356  				//    for ... { if ...        { stmts; break } }
   357  				// => if slices.Contains(...) { stmts        }
   358  				report([]analysis.TextEdit{
   359  					// Replace "for ... { if ... " with "if slices.Contains(...)".
   360  					{
   361  						Pos:     rng.Pos(),
   362  						End:     ifStmt.Body.Pos(),
   363  						NewText: fmt.Appendf(nil, "if %s ", contains),
   364  					},
   365  					// Delete break statement and preceding space.
   366  					{
   367  						Pos: func() token.Pos {
   368  							if len(body.List) > 1 {
   369  								beforeBreak, _ := curLastStmt.PrevSibling()
   370  								return beforeBreak.Node().End()
   371  							}
   372  							return lastStmt.Pos()
   373  						}(),
   374  						End: lastStmt.End(),
   375  					},
   376  					// Delete '}' of range statement and preceding space.
   377  					{
   378  						Pos: ifStmt.Body.End(),
   379  						End: rng.End(),
   380  					},
   381  				})
   382  				return
   383  			}
   384  		}
   385  	}
   386  
   387  	for curFile := range filesUsingGoVersion(pass, versions.Go1_21) {
   388  		file := curFile.Node().(*ast.File)
   389  
   390  		for curRange := range curFile.Preorder((*ast.RangeStmt)(nil)) {
   391  			rng := curRange.Node().(*ast.RangeStmt)
   392  
   393  			if is[*ast.Ident](rng.Key) &&
   394  				rng.Tok == token.DEFINE &&
   395  				len(rng.Body.List) == 1 &&
   396  				is[*types.Slice](typeparams.CoreType(info.TypeOf(rng.X))) {
   397  
   398  				// Have:
   399  				// - for _, elem := range s { S }
   400  				// - for i       := range s { S }
   401  
   402  				if ifStmt, ok := rng.Body.List[0].(*ast.IfStmt); ok &&
   403  					ifStmt.Init == nil && ifStmt.Else == nil {
   404  
   405  					// Have: for i, elem := range s { if cond { ... } }
   406  					check(file, curRange)
   407  				}
   408  			}
   409  		}
   410  	}
   411  	return nil, nil
   412  }
   413  
   414  // -- helpers --
   415  
   416  // isReturnTrueOrFalse returns nonzero if stmt returns true (+1) or false (-1).
   417  func isReturnTrueOrFalse(info *types.Info, stmt ast.Stmt) int {
   418  	if ret, ok := stmt.(*ast.ReturnStmt); ok && len(ret.Results) == 1 {
   419  		return isTrueOrFalse(info, ret.Results[0])
   420  	}
   421  	return 0
   422  }
   423  
   424  // isTrueOrFalse returns nonzero if expr is literally true (+1) or false (-1).
   425  func isTrueOrFalse(info *types.Info, expr ast.Expr) int {
   426  	if id, ok := expr.(*ast.Ident); ok {
   427  		switch info.Uses[id] {
   428  		case builtinTrue:
   429  			return +1
   430  		case builtinFalse:
   431  			return -1
   432  		}
   433  	}
   434  	return 0
   435  }
   436  

View as plain text