mirror of
https://github.com/matrix-org/sliding-sync.git
synced 2025-03-10 13:37:11 +00:00
Add WorkerPool and use it for OnE2EEData
- Allowing unlimited concurrency on OnE2EEData causes huge spikes in DB conns when device lists change. - Using a high, bounded amount of concurrency ensure we don't breach DB conn limits. With unit tests.
This commit is contained in:
parent
82c21e6d5a
commit
b9bc83d93f
67
internal/pool.go
Normal file
67
internal/pool.go
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
package internal
|
||||||
|
|
||||||
|
type WorkerPool struct {
|
||||||
|
N int
|
||||||
|
ch chan func()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a new worker pool of size N. Up to N work can be done concurrently.
|
||||||
|
// The size of N depends on the expected frequency of work and contention for
|
||||||
|
// shared resources. Large values of N allow more frequent work at the cost of
|
||||||
|
// more contention for shared resources like cpu, memory and fds. Small values
|
||||||
|
// of N allow less frequent work but control the amount of shared resource contention.
|
||||||
|
// Ideally this value will be derived from whatever shared resource constraints you
|
||||||
|
// are hitting up against, rather than set to a fixed value. For example, if you have
|
||||||
|
// a database connection limit of 100, then setting N to some fraction of the limit is
|
||||||
|
// preferred to setting this to an arbitrary number < 100. If more than N work is requested,
|
||||||
|
// eventually WorkerPool.Queue will block until some work is done.
|
||||||
|
//
|
||||||
|
// The larger N is, the larger the up front memory costs are due to the implementation of WorkerPool.
|
||||||
|
func NewWorkerPool(n int) *WorkerPool {
|
||||||
|
return &WorkerPool{
|
||||||
|
N: n,
|
||||||
|
// If we have N workers, we can process N work concurrently.
|
||||||
|
// If we have >N work, we need to apply backpressure to stop us
|
||||||
|
// making more and more work which takes up more and more memory.
|
||||||
|
// By setting the channel size to N, we ensure that backpressure is
|
||||||
|
// being applied on the producer, stopping it from creating more work,
|
||||||
|
// and hence bounding memory consumption. Work is still being produced
|
||||||
|
// upstream on the homeserver, but we will consume it when we're ready
|
||||||
|
// rather than gobble it all at once.
|
||||||
|
//
|
||||||
|
// Note: we aren't forced to set this to N, it just serves as a useful
|
||||||
|
// metric which scales on the number of workers. The amount of in-flight
|
||||||
|
// work is N, so it makes sense to allow up to N work to be queued up before
|
||||||
|
// applying backpressure. If the channel buffer is < N then the channel can
|
||||||
|
// become the bottleneck in the case where we have lots of instantaneous work
|
||||||
|
// to do. If the channel buffer is too large, we needlessly consume memory as
|
||||||
|
// make() will allocate a backing array of whatever size you give it up front (sad face)
|
||||||
|
ch: make(chan func(), n),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start the workers. Only call this once.
|
||||||
|
func (wp *WorkerPool) Start() {
|
||||||
|
for i := 0; i < wp.N; i++ {
|
||||||
|
go wp.worker()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop the worker pool. Only really useful for tests as a worker pool should be started once
|
||||||
|
// and persist for the lifetime of the process, else it causes needless goroutine churn.
|
||||||
|
// Only call this once.
|
||||||
|
func (wp *WorkerPool) Stop() {
|
||||||
|
close(wp.ch)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Queue some work on the pool. May or may not block until some work is processed.
|
||||||
|
func (wp *WorkerPool) Queue(fn func()) {
|
||||||
|
wp.ch <- fn
|
||||||
|
}
|
||||||
|
|
||||||
|
// worker impl
|
||||||
|
func (wp *WorkerPool) worker() {
|
||||||
|
for fn := range wp.ch {
|
||||||
|
fn()
|
||||||
|
}
|
||||||
|
}
|
186
internal/pool_test.go
Normal file
186
internal/pool_test.go
Normal file
@ -0,0 +1,186 @@
|
|||||||
|
package internal
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Test basic functions of WorkerPool
|
||||||
|
func TestWorkerPool(t *testing.T) {
|
||||||
|
wp := NewWorkerPool(2)
|
||||||
|
wp.Start()
|
||||||
|
defer wp.Stop()
|
||||||
|
|
||||||
|
// we should process this concurrently as N=2 so it should take 1s not 2s
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(2)
|
||||||
|
start := time.Now()
|
||||||
|
wp.Queue(func() {
|
||||||
|
time.Sleep(time.Second)
|
||||||
|
wg.Done()
|
||||||
|
})
|
||||||
|
wp.Queue(func() {
|
||||||
|
time.Sleep(time.Second)
|
||||||
|
wg.Done()
|
||||||
|
})
|
||||||
|
wg.Wait()
|
||||||
|
took := time.Since(start)
|
||||||
|
if took > 2*time.Second {
|
||||||
|
t.Fatalf("took %v for queued work, it should have been faster than 2s", took)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWorkerPoolDoesWorkPriorToStart(t *testing.T) {
|
||||||
|
wp := NewWorkerPool(2)
|
||||||
|
|
||||||
|
// return channel to use to see when work is done
|
||||||
|
ch := make(chan int, 2)
|
||||||
|
wp.Queue(func() {
|
||||||
|
ch <- 1
|
||||||
|
})
|
||||||
|
wp.Queue(func() {
|
||||||
|
ch <- 2
|
||||||
|
})
|
||||||
|
|
||||||
|
// the work should not be done yet
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
if len(ch) > 0 {
|
||||||
|
t.Fatalf("Queued work was done before Start()")
|
||||||
|
}
|
||||||
|
|
||||||
|
// the work should be starting now
|
||||||
|
wp.Start()
|
||||||
|
defer wp.Stop()
|
||||||
|
|
||||||
|
sum := 0
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Fatalf("timed out waiting for work to be done")
|
||||||
|
case val := <-ch:
|
||||||
|
sum += val
|
||||||
|
}
|
||||||
|
if sum == 3 { // 2 + 1
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type workerState struct {
|
||||||
|
id int
|
||||||
|
state int // not running, queued, running, finished
|
||||||
|
unblock *sync.WaitGroup // decrement to unblock this worker
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWorkerPoolBackpressure(t *testing.T) {
|
||||||
|
// this test assumes backpressure starts at n*2+1 due to a chan buffer of size n, and n in-flight work.
|
||||||
|
n := 2
|
||||||
|
wp := NewWorkerPool(n)
|
||||||
|
wp.Start()
|
||||||
|
defer wp.Stop()
|
||||||
|
|
||||||
|
var mu sync.Mutex
|
||||||
|
stateNotRunning := 0
|
||||||
|
stateQueued := 1
|
||||||
|
stateRunning := 2
|
||||||
|
stateFinished := 3
|
||||||
|
size := (2 * n) + 1
|
||||||
|
running := make([]*workerState, size)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
// we test backpressure by scheduling (n*2)+1 work and ensuring that we see the following running states:
|
||||||
|
// [2,2,1,1,0] <-- 2 running, 2 queued, 1 blocked <-- THIS IS BACKPRESSURE
|
||||||
|
// [3,2,2,1,1] <-- 1 finished, 2 running, 2 queued
|
||||||
|
// [3,3,2,2,1] <-- 2 finished, 2 running , 1 queued
|
||||||
|
// [3,3,3,2,2] <-- 3 finished, 2 running
|
||||||
|
for i := 0; i < size; i++ {
|
||||||
|
// set initial state of this piece of work
|
||||||
|
wg := &sync.WaitGroup{}
|
||||||
|
wg.Add(1)
|
||||||
|
state := &workerState{
|
||||||
|
id: i,
|
||||||
|
state: stateNotRunning,
|
||||||
|
unblock: wg,
|
||||||
|
}
|
||||||
|
mu.Lock()
|
||||||
|
running[i] = state
|
||||||
|
mu.Unlock()
|
||||||
|
|
||||||
|
// queue the work on the pool. The final piece of work will block here and remain in
|
||||||
|
// stateNotRunning and not transition to stateQueued until the first piece of work is done.
|
||||||
|
wp.Queue(func() {
|
||||||
|
mu.Lock()
|
||||||
|
if running[state.id].state != stateQueued {
|
||||||
|
// we ran work in the worker faster than the code underneath .Queue, so let it catch up
|
||||||
|
mu.Unlock()
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
mu.Lock()
|
||||||
|
}
|
||||||
|
running[state.id].state = stateRunning
|
||||||
|
mu.Unlock()
|
||||||
|
|
||||||
|
running[state.id].unblock.Wait()
|
||||||
|
mu.Lock()
|
||||||
|
running[state.id].state = stateFinished
|
||||||
|
mu.Unlock()
|
||||||
|
})
|
||||||
|
|
||||||
|
// mark this work as queued
|
||||||
|
mu.Lock()
|
||||||
|
running[i].state = stateQueued
|
||||||
|
mu.Unlock()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// wait for the workers to be doing work and assert the states of each task
|
||||||
|
time.Sleep(time.Second)
|
||||||
|
|
||||||
|
assertStates(t, &mu, running, []int{
|
||||||
|
stateRunning, stateRunning, stateQueued, stateQueued, stateNotRunning,
|
||||||
|
})
|
||||||
|
|
||||||
|
// now let the first task complete
|
||||||
|
running[0].unblock.Done()
|
||||||
|
// wait for the pool to grab more work
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
// assert new states
|
||||||
|
assertStates(t, &mu, running, []int{
|
||||||
|
stateFinished, stateRunning, stateRunning, stateQueued, stateQueued,
|
||||||
|
})
|
||||||
|
|
||||||
|
// now let the second task complete
|
||||||
|
running[1].unblock.Done()
|
||||||
|
// wait for the pool to grab more work
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
// assert new states
|
||||||
|
assertStates(t, &mu, running, []int{
|
||||||
|
stateFinished, stateFinished, stateRunning, stateRunning, stateQueued,
|
||||||
|
})
|
||||||
|
|
||||||
|
// now let the third task complete
|
||||||
|
running[2].unblock.Done()
|
||||||
|
// wait for the pool to grab more work
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
// assert new states
|
||||||
|
assertStates(t, &mu, running, []int{
|
||||||
|
stateFinished, stateFinished, stateFinished, stateRunning, stateRunning,
|
||||||
|
})
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func assertStates(t *testing.T, mu *sync.Mutex, running []*workerState, wantStates []int) {
|
||||||
|
t.Helper()
|
||||||
|
mu.Lock()
|
||||||
|
defer mu.Unlock()
|
||||||
|
if len(running) != len(wantStates) {
|
||||||
|
t.Fatalf("assertStates: bad wantStates length, got %d want %d", len(wantStates), len(running))
|
||||||
|
}
|
||||||
|
for i := range running {
|
||||||
|
state := running[i]
|
||||||
|
wantVal := wantStates[i]
|
||||||
|
if state.state != wantVal {
|
||||||
|
t.Errorf("work[%d] got state %d want %d", i, state.state, wantVal)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -46,6 +46,7 @@ type Handler struct {
|
|||||||
typingMap map[string]uint64
|
typingMap map[string]uint64
|
||||||
|
|
||||||
deviceDataTicker *sync2.DeviceDataTicker
|
deviceDataTicker *sync2.DeviceDataTicker
|
||||||
|
e2eeWorkerPool *internal.WorkerPool
|
||||||
|
|
||||||
numPollers prometheus.Gauge
|
numPollers prometheus.Gauge
|
||||||
subSystem string
|
subSystem string
|
||||||
@ -67,6 +68,7 @@ func NewHandler(
|
|||||||
}),
|
}),
|
||||||
typingMap: make(map[string]uint64),
|
typingMap: make(map[string]uint64),
|
||||||
deviceDataTicker: sync2.NewDeviceDataTicker(deviceDataUpdateDuration),
|
deviceDataTicker: sync2.NewDeviceDataTicker(deviceDataUpdateDuration),
|
||||||
|
e2eeWorkerPool: internal.NewWorkerPool(500), // TODO: assign as fraction of db max conns, not hardcoded
|
||||||
}
|
}
|
||||||
|
|
||||||
if enablePrometheus {
|
if enablePrometheus {
|
||||||
@ -91,6 +93,7 @@ func (h *Handler) Listen() {
|
|||||||
sentry.CaptureException(err)
|
sentry.CaptureException(err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
h.e2eeWorkerPool.Start()
|
||||||
h.deviceDataTicker.SetCallback(h.OnBulkDeviceDataUpdate)
|
h.deviceDataTicker.SetCallback(h.OnBulkDeviceDataUpdate)
|
||||||
go h.deviceDataTicker.Run()
|
go h.deviceDataTicker.Run()
|
||||||
}
|
}
|
||||||
@ -201,27 +204,33 @@ func (h *Handler) UpdateDeviceSince(ctx context.Context, userID, deviceID, since
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) OnE2EEData(ctx context.Context, userID, deviceID string, otkCounts map[string]int, fallbackKeyTypes []string, deviceListChanges map[string]int) {
|
func (h *Handler) OnE2EEData(ctx context.Context, userID, deviceID string, otkCounts map[string]int, fallbackKeyTypes []string, deviceListChanges map[string]int) {
|
||||||
// some of these fields may be set
|
var wg sync.WaitGroup
|
||||||
partialDD := internal.DeviceData{
|
wg.Add(1)
|
||||||
UserID: userID,
|
h.e2eeWorkerPool.Queue(func() {
|
||||||
DeviceID: deviceID,
|
defer wg.Done()
|
||||||
OTKCounts: otkCounts,
|
// some of these fields may be set
|
||||||
FallbackKeyTypes: fallbackKeyTypes,
|
partialDD := internal.DeviceData{
|
||||||
DeviceLists: internal.DeviceLists{
|
UserID: userID,
|
||||||
New: deviceListChanges,
|
DeviceID: deviceID,
|
||||||
},
|
OTKCounts: otkCounts,
|
||||||
}
|
FallbackKeyTypes: fallbackKeyTypes,
|
||||||
_, err := h.Store.DeviceDataTable.Upsert(&partialDD)
|
DeviceLists: internal.DeviceLists{
|
||||||
if err != nil {
|
New: deviceListChanges,
|
||||||
logger.Err(err).Str("user", userID).Msg("failed to upsert device data")
|
},
|
||||||
internal.GetSentryHubFromContextOrDefault(ctx).CaptureException(err)
|
}
|
||||||
return
|
_, err := h.Store.DeviceDataTable.Upsert(&partialDD)
|
||||||
}
|
if err != nil {
|
||||||
// remember this to notify on pubsub later
|
logger.Err(err).Str("user", userID).Msg("failed to upsert device data")
|
||||||
h.deviceDataTicker.Remember(sync2.PollerID{
|
internal.GetSentryHubFromContextOrDefault(ctx).CaptureException(err)
|
||||||
UserID: userID,
|
return
|
||||||
DeviceID: deviceID,
|
}
|
||||||
|
// remember this to notify on pubsub later
|
||||||
|
h.deviceDataTicker.Remember(sync2.PollerID{
|
||||||
|
UserID: userID,
|
||||||
|
DeviceID: deviceID,
|
||||||
|
})
|
||||||
})
|
})
|
||||||
|
wg.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Called periodically by deviceDataTicker, contains many updates
|
// Called periodically by deviceDataTicker, contains many updates
|
||||||
|
Loading…
x
Reference in New Issue
Block a user