From b9bc83d93f1e834e7302efe6de1b769a8b69e680 Mon Sep 17 00:00:00 2001 From: Kegan Dougal Date: Wed, 28 Jun 2023 16:32:23 -0500 Subject: [PATCH] Add WorkerPool and use it for OnE2EEData - Allowing unlimited concurrency on OnE2EEData causes huge spikes in DB conns when device lists change. - Using a high, bounded amount of concurrency ensure we don't breach DB conn limits. With unit tests. --- internal/pool.go | 67 ++++++++++++++ internal/pool_test.go | 186 ++++++++++++++++++++++++++++++++++++++ sync2/handler2/handler.go | 49 ++++++---- 3 files changed, 282 insertions(+), 20 deletions(-) create mode 100644 internal/pool.go create mode 100644 internal/pool_test.go diff --git a/internal/pool.go b/internal/pool.go new file mode 100644 index 0000000..27c6bb0 --- /dev/null +++ b/internal/pool.go @@ -0,0 +1,67 @@ +package internal + +type WorkerPool struct { + N int + ch chan func() +} + +// Create a new worker pool of size N. Up to N work can be done concurrently. +// The size of N depends on the expected frequency of work and contention for +// shared resources. Large values of N allow more frequent work at the cost of +// more contention for shared resources like cpu, memory and fds. Small values +// of N allow less frequent work but control the amount of shared resource contention. +// Ideally this value will be derived from whatever shared resource constraints you +// are hitting up against, rather than set to a fixed value. For example, if you have +// a database connection limit of 100, then setting N to some fraction of the limit is +// preferred to setting this to an arbitrary number < 100. If more than N work is requested, +// eventually WorkerPool.Queue will block until some work is done. +// +// The larger N is, the larger the up front memory costs are due to the implementation of WorkerPool. +func NewWorkerPool(n int) *WorkerPool { + return &WorkerPool{ + N: n, + // If we have N workers, we can process N work concurrently. + // If we have >N work, we need to apply backpressure to stop us + // making more and more work which takes up more and more memory. + // By setting the channel size to N, we ensure that backpressure is + // being applied on the producer, stopping it from creating more work, + // and hence bounding memory consumption. Work is still being produced + // upstream on the homeserver, but we will consume it when we're ready + // rather than gobble it all at once. + // + // Note: we aren't forced to set this to N, it just serves as a useful + // metric which scales on the number of workers. The amount of in-flight + // work is N, so it makes sense to allow up to N work to be queued up before + // applying backpressure. If the channel buffer is < N then the channel can + // become the bottleneck in the case where we have lots of instantaneous work + // to do. If the channel buffer is too large, we needlessly consume memory as + // make() will allocate a backing array of whatever size you give it up front (sad face) + ch: make(chan func(), n), + } +} + +// Start the workers. Only call this once. +func (wp *WorkerPool) Start() { + for i := 0; i < wp.N; i++ { + go wp.worker() + } +} + +// Stop the worker pool. Only really useful for tests as a worker pool should be started once +// and persist for the lifetime of the process, else it causes needless goroutine churn. +// Only call this once. +func (wp *WorkerPool) Stop() { + close(wp.ch) +} + +// Queue some work on the pool. May or may not block until some work is processed. +func (wp *WorkerPool) Queue(fn func()) { + wp.ch <- fn +} + +// worker impl +func (wp *WorkerPool) worker() { + for fn := range wp.ch { + fn() + } +} diff --git a/internal/pool_test.go b/internal/pool_test.go new file mode 100644 index 0000000..3422207 --- /dev/null +++ b/internal/pool_test.go @@ -0,0 +1,186 @@ +package internal + +import ( + "sync" + "testing" + "time" +) + +// Test basic functions of WorkerPool +func TestWorkerPool(t *testing.T) { + wp := NewWorkerPool(2) + wp.Start() + defer wp.Stop() + + // we should process this concurrently as N=2 so it should take 1s not 2s + var wg sync.WaitGroup + wg.Add(2) + start := time.Now() + wp.Queue(func() { + time.Sleep(time.Second) + wg.Done() + }) + wp.Queue(func() { + time.Sleep(time.Second) + wg.Done() + }) + wg.Wait() + took := time.Since(start) + if took > 2*time.Second { + t.Fatalf("took %v for queued work, it should have been faster than 2s", took) + } +} + +func TestWorkerPoolDoesWorkPriorToStart(t *testing.T) { + wp := NewWorkerPool(2) + + // return channel to use to see when work is done + ch := make(chan int, 2) + wp.Queue(func() { + ch <- 1 + }) + wp.Queue(func() { + ch <- 2 + }) + + // the work should not be done yet + time.Sleep(100 * time.Millisecond) + if len(ch) > 0 { + t.Fatalf("Queued work was done before Start()") + } + + // the work should be starting now + wp.Start() + defer wp.Stop() + + sum := 0 + for { + select { + case <-time.After(time.Second): + t.Fatalf("timed out waiting for work to be done") + case val := <-ch: + sum += val + } + if sum == 3 { // 2 + 1 + break + } + } +} + +type workerState struct { + id int + state int // not running, queued, running, finished + unblock *sync.WaitGroup // decrement to unblock this worker +} + +func TestWorkerPoolBackpressure(t *testing.T) { + // this test assumes backpressure starts at n*2+1 due to a chan buffer of size n, and n in-flight work. + n := 2 + wp := NewWorkerPool(n) + wp.Start() + defer wp.Stop() + + var mu sync.Mutex + stateNotRunning := 0 + stateQueued := 1 + stateRunning := 2 + stateFinished := 3 + size := (2 * n) + 1 + running := make([]*workerState, size) + + go func() { + // we test backpressure by scheduling (n*2)+1 work and ensuring that we see the following running states: + // [2,2,1,1,0] <-- 2 running, 2 queued, 1 blocked <-- THIS IS BACKPRESSURE + // [3,2,2,1,1] <-- 1 finished, 2 running, 2 queued + // [3,3,2,2,1] <-- 2 finished, 2 running , 1 queued + // [3,3,3,2,2] <-- 3 finished, 2 running + for i := 0; i < size; i++ { + // set initial state of this piece of work + wg := &sync.WaitGroup{} + wg.Add(1) + state := &workerState{ + id: i, + state: stateNotRunning, + unblock: wg, + } + mu.Lock() + running[i] = state + mu.Unlock() + + // queue the work on the pool. The final piece of work will block here and remain in + // stateNotRunning and not transition to stateQueued until the first piece of work is done. + wp.Queue(func() { + mu.Lock() + if running[state.id].state != stateQueued { + // we ran work in the worker faster than the code underneath .Queue, so let it catch up + mu.Unlock() + time.Sleep(10 * time.Millisecond) + mu.Lock() + } + running[state.id].state = stateRunning + mu.Unlock() + + running[state.id].unblock.Wait() + mu.Lock() + running[state.id].state = stateFinished + mu.Unlock() + }) + + // mark this work as queued + mu.Lock() + running[i].state = stateQueued + mu.Unlock() + } + }() + + // wait for the workers to be doing work and assert the states of each task + time.Sleep(time.Second) + + assertStates(t, &mu, running, []int{ + stateRunning, stateRunning, stateQueued, stateQueued, stateNotRunning, + }) + + // now let the first task complete + running[0].unblock.Done() + // wait for the pool to grab more work + time.Sleep(100 * time.Millisecond) + // assert new states + assertStates(t, &mu, running, []int{ + stateFinished, stateRunning, stateRunning, stateQueued, stateQueued, + }) + + // now let the second task complete + running[1].unblock.Done() + // wait for the pool to grab more work + time.Sleep(100 * time.Millisecond) + // assert new states + assertStates(t, &mu, running, []int{ + stateFinished, stateFinished, stateRunning, stateRunning, stateQueued, + }) + + // now let the third task complete + running[2].unblock.Done() + // wait for the pool to grab more work + time.Sleep(100 * time.Millisecond) + // assert new states + assertStates(t, &mu, running, []int{ + stateFinished, stateFinished, stateFinished, stateRunning, stateRunning, + }) + +} + +func assertStates(t *testing.T, mu *sync.Mutex, running []*workerState, wantStates []int) { + t.Helper() + mu.Lock() + defer mu.Unlock() + if len(running) != len(wantStates) { + t.Fatalf("assertStates: bad wantStates length, got %d want %d", len(wantStates), len(running)) + } + for i := range running { + state := running[i] + wantVal := wantStates[i] + if state.state != wantVal { + t.Errorf("work[%d] got state %d want %d", i, state.state, wantVal) + } + } +} diff --git a/sync2/handler2/handler.go b/sync2/handler2/handler.go index f554300..37b068c 100644 --- a/sync2/handler2/handler.go +++ b/sync2/handler2/handler.go @@ -46,6 +46,7 @@ type Handler struct { typingMap map[string]uint64 deviceDataTicker *sync2.DeviceDataTicker + e2eeWorkerPool *internal.WorkerPool numPollers prometheus.Gauge subSystem string @@ -67,6 +68,7 @@ func NewHandler( }), typingMap: make(map[string]uint64), deviceDataTicker: sync2.NewDeviceDataTicker(deviceDataUpdateDuration), + e2eeWorkerPool: internal.NewWorkerPool(500), // TODO: assign as fraction of db max conns, not hardcoded } if enablePrometheus { @@ -91,6 +93,7 @@ func (h *Handler) Listen() { sentry.CaptureException(err) } }() + h.e2eeWorkerPool.Start() h.deviceDataTicker.SetCallback(h.OnBulkDeviceDataUpdate) go h.deviceDataTicker.Run() } @@ -201,27 +204,33 @@ func (h *Handler) UpdateDeviceSince(ctx context.Context, userID, deviceID, since } func (h *Handler) OnE2EEData(ctx context.Context, userID, deviceID string, otkCounts map[string]int, fallbackKeyTypes []string, deviceListChanges map[string]int) { - // some of these fields may be set - partialDD := internal.DeviceData{ - UserID: userID, - DeviceID: deviceID, - OTKCounts: otkCounts, - FallbackKeyTypes: fallbackKeyTypes, - DeviceLists: internal.DeviceLists{ - New: deviceListChanges, - }, - } - _, err := h.Store.DeviceDataTable.Upsert(&partialDD) - if err != nil { - logger.Err(err).Str("user", userID).Msg("failed to upsert device data") - internal.GetSentryHubFromContextOrDefault(ctx).CaptureException(err) - return - } - // remember this to notify on pubsub later - h.deviceDataTicker.Remember(sync2.PollerID{ - UserID: userID, - DeviceID: deviceID, + var wg sync.WaitGroup + wg.Add(1) + h.e2eeWorkerPool.Queue(func() { + defer wg.Done() + // some of these fields may be set + partialDD := internal.DeviceData{ + UserID: userID, + DeviceID: deviceID, + OTKCounts: otkCounts, + FallbackKeyTypes: fallbackKeyTypes, + DeviceLists: internal.DeviceLists{ + New: deviceListChanges, + }, + } + _, err := h.Store.DeviceDataTable.Upsert(&partialDD) + if err != nil { + logger.Err(err).Str("user", userID).Msg("failed to upsert device data") + internal.GetSentryHubFromContextOrDefault(ctx).CaptureException(err) + return + } + // remember this to notify on pubsub later + h.deviceDataTicker.Remember(sync2.PollerID{ + UserID: userID, + DeviceID: deviceID, + }) }) + wg.Wait() } // Called periodically by deviceDataTicker, contains many updates