Source file src/iter/pull_test.go

     1  // Copyright 2023 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package iter_test
     6  
     7  import (
     8  	"fmt"
     9  	. "iter"
    10  	"runtime"
    11  	"testing"
    12  )
    13  
    14  func count(n int) Seq[int] {
    15  	return func(yield func(int) bool) {
    16  		for i := range n {
    17  			if !yield(i) {
    18  				break
    19  			}
    20  		}
    21  	}
    22  }
    23  
    24  func squares(n int) Seq2[int, int64] {
    25  	return func(yield func(int, int64) bool) {
    26  		for i := range n {
    27  			if !yield(i, int64(i)*int64(i)) {
    28  				break
    29  			}
    30  		}
    31  	}
    32  }
    33  
    34  func TestPull(t *testing.T) {
    35  	for end := 0; end <= 3; end++ {
    36  		t.Run(fmt.Sprint(end), func(t *testing.T) {
    37  			ng := stableNumGoroutine()
    38  			wantNG := func(want int) {
    39  				if xg := runtime.NumGoroutine() - ng; xg != want {
    40  					t.Helper()
    41  					t.Errorf("have %d extra goroutines, want %d", xg, want)
    42  				}
    43  			}
    44  			wantNG(0)
    45  			next, stop := Pull(count(3))
    46  			wantNG(1)
    47  			for i := range end {
    48  				v, ok := next()
    49  				if v != i || ok != true {
    50  					t.Fatalf("next() = %d, %v, want %d, %v", v, ok, i, true)
    51  				}
    52  				wantNG(1)
    53  			}
    54  			wantNG(1)
    55  			if end < 3 {
    56  				stop()
    57  				wantNG(0)
    58  			}
    59  			for range 2 {
    60  				v, ok := next()
    61  				if v != 0 || ok != false {
    62  					t.Fatalf("next() = %d, %v, want %d, %v", v, ok, 0, false)
    63  				}
    64  				wantNG(0)
    65  			}
    66  			wantNG(0)
    67  
    68  			stop()
    69  			stop()
    70  			stop()
    71  			wantNG(0)
    72  		})
    73  	}
    74  }
    75  
    76  func TestPull2(t *testing.T) {
    77  	for end := 0; end <= 3; end++ {
    78  		t.Run(fmt.Sprint(end), func(t *testing.T) {
    79  			ng := stableNumGoroutine()
    80  			wantNG := func(want int) {
    81  				if xg := runtime.NumGoroutine() - ng; xg != want {
    82  					t.Helper()
    83  					t.Errorf("have %d extra goroutines, want %d", xg, want)
    84  				}
    85  			}
    86  			wantNG(0)
    87  			next, stop := Pull2(squares(3))
    88  			wantNG(1)
    89  			for i := range end {
    90  				k, v, ok := next()
    91  				if k != i || v != int64(i*i) || ok != true {
    92  					t.Fatalf("next() = %d, %d, %v, want %d, %d, %v", k, v, ok, i, i*i, true)
    93  				}
    94  				wantNG(1)
    95  			}
    96  			wantNG(1)
    97  			if end < 3 {
    98  				stop()
    99  				wantNG(0)
   100  			}
   101  			for range 2 {
   102  				k, v, ok := next()
   103  				if v != 0 || ok != false {
   104  					t.Fatalf("next() = %d, %d, %v, want %d, %d, %v", k, v, ok, 0, 0, false)
   105  				}
   106  				wantNG(0)
   107  			}
   108  			wantNG(0)
   109  
   110  			stop()
   111  			stop()
   112  			stop()
   113  			wantNG(0)
   114  		})
   115  	}
   116  }
   117  
   118  // stableNumGoroutine is like NumGoroutine but tries to ensure stability of
   119  // the value by letting any exiting goroutines finish exiting.
   120  func stableNumGoroutine() int {
   121  	// The idea behind stablizing the value of NumGoroutine is to
   122  	// see the same value enough times in a row in between calls to
   123  	// runtime.Gosched. With GOMAXPROCS=1, we're trying to make sure
   124  	// that other goroutines run, so that they reach a stable point.
   125  	// It's not guaranteed, because it is still possible for a goroutine
   126  	// to Gosched back into itself, so we require NumGoroutine to be
   127  	// the same 100 times in a row. This should be more than enough to
   128  	// ensure all goroutines get a chance to run to completion (or to
   129  	// some block point) for a small group of test goroutines.
   130  	defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(1))
   131  
   132  	c := 0
   133  	ng := runtime.NumGoroutine()
   134  	for i := 0; i < 1000; i++ {
   135  		nng := runtime.NumGoroutine()
   136  		if nng == ng {
   137  			c++
   138  		} else {
   139  			c = 0
   140  			ng = nng
   141  		}
   142  		if c >= 100 {
   143  			// The same value 100 times in a row is good enough.
   144  			return ng
   145  		}
   146  		runtime.Gosched()
   147  	}
   148  	panic("failed to stabilize NumGoroutine after 1000 iterations")
   149  }
   150  
   151  func TestPullDoubleNext(t *testing.T) {
   152  	next, _ := Pull(doDoubleNext())
   153  	nextSlot = next
   154  	next()
   155  	if nextSlot != nil {
   156  		t.Fatal("double next did not fail")
   157  	}
   158  }
   159  
   160  var nextSlot func() (int, bool)
   161  
   162  func doDoubleNext() Seq[int] {
   163  	return func(_ func(int) bool) {
   164  		defer func() {
   165  			if recover() != nil {
   166  				nextSlot = nil
   167  			}
   168  		}()
   169  		nextSlot()
   170  	}
   171  }
   172  
   173  func TestPullDoubleNext2(t *testing.T) {
   174  	next, _ := Pull2(doDoubleNext2())
   175  	nextSlot2 = next
   176  	next()
   177  	if nextSlot2 != nil {
   178  		t.Fatal("double next did not fail")
   179  	}
   180  }
   181  
   182  var nextSlot2 func() (int, int, bool)
   183  
   184  func doDoubleNext2() Seq2[int, int] {
   185  	return func(_ func(int, int) bool) {
   186  		defer func() {
   187  			if recover() != nil {
   188  				nextSlot2 = nil
   189  			}
   190  		}()
   191  		nextSlot2()
   192  	}
   193  }
   194  
   195  func TestPullDoubleYield(t *testing.T) {
   196  	_, stop := Pull(storeYield())
   197  	defer func() {
   198  		if recover() != nil {
   199  			yieldSlot = nil
   200  		}
   201  		stop()
   202  	}()
   203  	yieldSlot(5)
   204  	if yieldSlot != nil {
   205  		t.Fatal("double yield did not fail")
   206  	}
   207  }
   208  
   209  func storeYield() Seq[int] {
   210  	return func(yield func(int) bool) {
   211  		yieldSlot = yield
   212  		if !yield(5) {
   213  			return
   214  		}
   215  	}
   216  }
   217  
   218  var yieldSlot func(int) bool
   219  
   220  func TestPullDoubleYield2(t *testing.T) {
   221  	_, stop := Pull2(storeYield2())
   222  	defer func() {
   223  		if recover() != nil {
   224  			yieldSlot2 = nil
   225  		}
   226  		stop()
   227  	}()
   228  	yieldSlot2(23, 77)
   229  	if yieldSlot2 != nil {
   230  		t.Fatal("double yield did not fail")
   231  	}
   232  }
   233  
   234  func storeYield2() Seq2[int, int] {
   235  	return func(yield func(int, int) bool) {
   236  		yieldSlot2 = yield
   237  		if !yield(23, 77) {
   238  			return
   239  		}
   240  	}
   241  }
   242  
   243  var yieldSlot2 func(int, int) bool
   244  
   245  func TestPullPanic(t *testing.T) {
   246  	t.Run("next", func(t *testing.T) {
   247  		next, stop := Pull(panicSeq())
   248  		if !panicsWith("boom", func() { next() }) {
   249  			t.Fatal("failed to propagate panic on first next")
   250  		}
   251  		// Make sure we don't panic again if we try to call next or stop.
   252  		if _, ok := next(); ok {
   253  			t.Fatal("next returned true after iterator panicked")
   254  		}
   255  		// Calling stop again should be a no-op.
   256  		stop()
   257  	})
   258  	t.Run("stop", func(t *testing.T) {
   259  		next, stop := Pull(panicCleanupSeq())
   260  		x, ok := next()
   261  		if !ok || x != 55 {
   262  			t.Fatalf("expected (55, true) from next, got (%d, %t)", x, ok)
   263  		}
   264  		if !panicsWith("boom", func() { stop() }) {
   265  			t.Fatal("failed to propagate panic on stop")
   266  		}
   267  		// Make sure we don't panic again if we try to call next or stop.
   268  		if _, ok := next(); ok {
   269  			t.Fatal("next returned true after iterator panicked")
   270  		}
   271  		// Calling stop again should be a no-op.
   272  		stop()
   273  	})
   274  }
   275  
   276  func panicSeq() Seq[int] {
   277  	return func(yield func(int) bool) {
   278  		panic("boom")
   279  	}
   280  }
   281  
   282  func panicCleanupSeq() Seq[int] {
   283  	return func(yield func(int) bool) {
   284  		for {
   285  			if !yield(55) {
   286  				panic("boom")
   287  			}
   288  		}
   289  	}
   290  }
   291  
   292  func TestPull2Panic(t *testing.T) {
   293  	t.Run("next", func(t *testing.T) {
   294  		next, stop := Pull2(panicSeq2())
   295  		if !panicsWith("boom", func() { next() }) {
   296  			t.Fatal("failed to propagate panic on first next")
   297  		}
   298  		// Make sure we don't panic again if we try to call next or stop.
   299  		if _, _, ok := next(); ok {
   300  			t.Fatal("next returned true after iterator panicked")
   301  		}
   302  		// Calling stop again should be a no-op.
   303  		stop()
   304  	})
   305  	t.Run("stop", func(t *testing.T) {
   306  		next, stop := Pull2(panicCleanupSeq2())
   307  		x, y, ok := next()
   308  		if !ok || x != 55 || y != 100 {
   309  			t.Fatalf("expected (55, 100, true) from next, got (%d, %d, %t)", x, y, ok)
   310  		}
   311  		if !panicsWith("boom", func() { stop() }) {
   312  			t.Fatal("failed to propagate panic on stop")
   313  		}
   314  		// Make sure we don't panic again if we try to call next or stop.
   315  		if _, _, ok := next(); ok {
   316  			t.Fatal("next returned true after iterator panicked")
   317  		}
   318  		// Calling stop again should be a no-op.
   319  		stop()
   320  	})
   321  }
   322  
   323  func panicSeq2() Seq2[int, int] {
   324  	return func(yield func(int, int) bool) {
   325  		panic("boom")
   326  	}
   327  }
   328  
   329  func panicCleanupSeq2() Seq2[int, int] {
   330  	return func(yield func(int, int) bool) {
   331  		for {
   332  			if !yield(55, 100) {
   333  				panic("boom")
   334  			}
   335  		}
   336  	}
   337  }
   338  
   339  func panicsWith(v any, f func()) (panicked bool) {
   340  	defer func() {
   341  		if r := recover(); r != nil {
   342  			if r != v {
   343  				panic(r)
   344  			}
   345  			panicked = true
   346  		}
   347  	}()
   348  	f()
   349  	return
   350  }
   351  
   352  func TestPullGoexit(t *testing.T) {
   353  	t.Run("next", func(t *testing.T) {
   354  		var next func() (int, bool)
   355  		var stop func()
   356  		if !goexits(t, func() {
   357  			next, stop = Pull(goexitSeq())
   358  			next()
   359  		}) {
   360  			t.Fatal("failed to Goexit from next")
   361  		}
   362  		if x, ok := next(); x != 0 || ok {
   363  			t.Fatal("iterator returned valid value after iterator Goexited")
   364  		}
   365  		stop()
   366  	})
   367  	t.Run("stop", func(t *testing.T) {
   368  		next, stop := Pull(goexitCleanupSeq())
   369  		x, ok := next()
   370  		if !ok || x != 55 {
   371  			t.Fatalf("expected (55, true) from next, got (%d, %t)", x, ok)
   372  		}
   373  		if !goexits(t, func() {
   374  			stop()
   375  		}) {
   376  			t.Fatal("failed to Goexit from stop")
   377  		}
   378  		// Make sure we don't panic again if we try to call next or stop.
   379  		if x, ok := next(); x != 0 || ok {
   380  			t.Fatal("next returned true or non-zero value after iterator Goexited")
   381  		}
   382  		// Calling stop again should be a no-op.
   383  		stop()
   384  	})
   385  }
   386  
   387  func goexitSeq() Seq[int] {
   388  	return func(yield func(int) bool) {
   389  		runtime.Goexit()
   390  	}
   391  }
   392  
   393  func goexitCleanupSeq() Seq[int] {
   394  	return func(yield func(int) bool) {
   395  		for {
   396  			if !yield(55) {
   397  				runtime.Goexit()
   398  			}
   399  		}
   400  	}
   401  }
   402  
   403  func TestPull2Goexit(t *testing.T) {
   404  	t.Run("next", func(t *testing.T) {
   405  		var next func() (int, int, bool)
   406  		var stop func()
   407  		if !goexits(t, func() {
   408  			next, stop = Pull2(goexitSeq2())
   409  			next()
   410  		}) {
   411  			t.Fatal("failed to Goexit from next")
   412  		}
   413  		if x, y, ok := next(); x != 0 || y != 0 || ok {
   414  			t.Fatal("iterator returned valid value after iterator Goexited")
   415  		}
   416  		stop()
   417  	})
   418  	t.Run("stop", func(t *testing.T) {
   419  		next, stop := Pull2(goexitCleanupSeq2())
   420  		x, y, ok := next()
   421  		if !ok || x != 55 || y != 100 {
   422  			t.Fatalf("expected (55, 100, true) from next, got (%d, %d, %t)", x, y, ok)
   423  		}
   424  		if !goexits(t, func() {
   425  			stop()
   426  		}) {
   427  			t.Fatal("failed to Goexit from stop")
   428  		}
   429  		// Make sure we don't panic again if we try to call next or stop.
   430  		if x, y, ok := next(); x != 0 || y != 0 || ok {
   431  			t.Fatal("next returned true or non-zero after iterator Goexited")
   432  		}
   433  		// Calling stop again should be a no-op.
   434  		stop()
   435  	})
   436  }
   437  
   438  func goexitSeq2() Seq2[int, int] {
   439  	return func(yield func(int, int) bool) {
   440  		runtime.Goexit()
   441  	}
   442  }
   443  
   444  func goexitCleanupSeq2() Seq2[int, int] {
   445  	return func(yield func(int, int) bool) {
   446  		for {
   447  			if !yield(55, 100) {
   448  				runtime.Goexit()
   449  			}
   450  		}
   451  	}
   452  }
   453  
   454  func goexits(t *testing.T, f func()) bool {
   455  	t.Helper()
   456  
   457  	exit := make(chan bool)
   458  	go func() {
   459  		cleanExit := false
   460  		defer func() {
   461  			exit <- recover() == nil && !cleanExit
   462  		}()
   463  		f()
   464  		cleanExit = true
   465  	}()
   466  	return <-exit
   467  }
   468  
   469  func TestPullImmediateStop(t *testing.T) {
   470  	next, stop := Pull(panicSeq())
   471  	stop()
   472  	// Make sure we don't panic if we try to call next or stop.
   473  	if _, ok := next(); ok {
   474  		t.Fatal("next returned true after iterator was stopped")
   475  	}
   476  }
   477  
   478  func TestPull2ImmediateStop(t *testing.T) {
   479  	next, stop := Pull2(panicSeq2())
   480  	stop()
   481  	// Make sure we don't panic if we try to call next or stop.
   482  	if _, _, ok := next(); ok {
   483  		t.Fatal("next returned true after iterator was stopped")
   484  	}
   485  }
   486  

View as plain text