sliding-sync/internal/pool_test.go
Kegan Dougal b9bc83d93f Add WorkerPool and use it for OnE2EEData
- Allowing unlimited concurrency on OnE2EEData causes huge spikes in DB conns
  when device lists change.
- Using a high, bounded amount of concurrency ensure we don't breach DB conn limits.

With unit tests.
2023-06-28 16:32:23 -05:00

187 lines
4.5 KiB
Go

package internal
import (
"sync"
"testing"
"time"
)
// Test basic functions of WorkerPool
func TestWorkerPool(t *testing.T) {
wp := NewWorkerPool(2)
wp.Start()
defer wp.Stop()
// we should process this concurrently as N=2 so it should take 1s not 2s
var wg sync.WaitGroup
wg.Add(2)
start := time.Now()
wp.Queue(func() {
time.Sleep(time.Second)
wg.Done()
})
wp.Queue(func() {
time.Sleep(time.Second)
wg.Done()
})
wg.Wait()
took := time.Since(start)
if took > 2*time.Second {
t.Fatalf("took %v for queued work, it should have been faster than 2s", took)
}
}
func TestWorkerPoolDoesWorkPriorToStart(t *testing.T) {
wp := NewWorkerPool(2)
// return channel to use to see when work is done
ch := make(chan int, 2)
wp.Queue(func() {
ch <- 1
})
wp.Queue(func() {
ch <- 2
})
// the work should not be done yet
time.Sleep(100 * time.Millisecond)
if len(ch) > 0 {
t.Fatalf("Queued work was done before Start()")
}
// the work should be starting now
wp.Start()
defer wp.Stop()
sum := 0
for {
select {
case <-time.After(time.Second):
t.Fatalf("timed out waiting for work to be done")
case val := <-ch:
sum += val
}
if sum == 3 { // 2 + 1
break
}
}
}
type workerState struct {
id int
state int // not running, queued, running, finished
unblock *sync.WaitGroup // decrement to unblock this worker
}
func TestWorkerPoolBackpressure(t *testing.T) {
// this test assumes backpressure starts at n*2+1 due to a chan buffer of size n, and n in-flight work.
n := 2
wp := NewWorkerPool(n)
wp.Start()
defer wp.Stop()
var mu sync.Mutex
stateNotRunning := 0
stateQueued := 1
stateRunning := 2
stateFinished := 3
size := (2 * n) + 1
running := make([]*workerState, size)
go func() {
// we test backpressure by scheduling (n*2)+1 work and ensuring that we see the following running states:
// [2,2,1,1,0] <-- 2 running, 2 queued, 1 blocked <-- THIS IS BACKPRESSURE
// [3,2,2,1,1] <-- 1 finished, 2 running, 2 queued
// [3,3,2,2,1] <-- 2 finished, 2 running , 1 queued
// [3,3,3,2,2] <-- 3 finished, 2 running
for i := 0; i < size; i++ {
// set initial state of this piece of work
wg := &sync.WaitGroup{}
wg.Add(1)
state := &workerState{
id: i,
state: stateNotRunning,
unblock: wg,
}
mu.Lock()
running[i] = state
mu.Unlock()
// queue the work on the pool. The final piece of work will block here and remain in
// stateNotRunning and not transition to stateQueued until the first piece of work is done.
wp.Queue(func() {
mu.Lock()
if running[state.id].state != stateQueued {
// we ran work in the worker faster than the code underneath .Queue, so let it catch up
mu.Unlock()
time.Sleep(10 * time.Millisecond)
mu.Lock()
}
running[state.id].state = stateRunning
mu.Unlock()
running[state.id].unblock.Wait()
mu.Lock()
running[state.id].state = stateFinished
mu.Unlock()
})
// mark this work as queued
mu.Lock()
running[i].state = stateQueued
mu.Unlock()
}
}()
// wait for the workers to be doing work and assert the states of each task
time.Sleep(time.Second)
assertStates(t, &mu, running, []int{
stateRunning, stateRunning, stateQueued, stateQueued, stateNotRunning,
})
// now let the first task complete
running[0].unblock.Done()
// wait for the pool to grab more work
time.Sleep(100 * time.Millisecond)
// assert new states
assertStates(t, &mu, running, []int{
stateFinished, stateRunning, stateRunning, stateQueued, stateQueued,
})
// now let the second task complete
running[1].unblock.Done()
// wait for the pool to grab more work
time.Sleep(100 * time.Millisecond)
// assert new states
assertStates(t, &mu, running, []int{
stateFinished, stateFinished, stateRunning, stateRunning, stateQueued,
})
// now let the third task complete
running[2].unblock.Done()
// wait for the pool to grab more work
time.Sleep(100 * time.Millisecond)
// assert new states
assertStates(t, &mu, running, []int{
stateFinished, stateFinished, stateFinished, stateRunning, stateRunning,
})
}
func assertStates(t *testing.T, mu *sync.Mutex, running []*workerState, wantStates []int) {
t.Helper()
mu.Lock()
defer mu.Unlock()
if len(running) != len(wantStates) {
t.Fatalf("assertStates: bad wantStates length, got %d want %d", len(wantStates), len(running))
}
for i := range running {
state := running[i]
wantVal := wantStates[i]
if state.state != wantVal {
t.Errorf("work[%d] got state %d want %d", i, state.state, wantVal)
}
}
}