From f36c038cf87994fc53229c7e7fa2f4e2b09c8560 Mon Sep 17 00:00:00 2001 From: Kegan Dougal Date: Mon, 26 Jun 2023 21:04:02 -0700 Subject: [PATCH] Rate limit pubsub.V2DeviceData updates to be at most 1 per second The db writes are still instant, but the notifications are now delayed by up to 1 second, in order to not swamp the pubsub channels. --- pubsub/v2.go | 4 +- sync2/device_data_ticker.go | 90 ++++++++++++++++++++++ sync2/device_data_ticker_test.go | 124 +++++++++++++++++++++++++++++++ sync2/handler2/handler.go | 27 +++++-- sync3/handler/handler.go | 11 ++- v3.go | 4 +- 6 files changed, 246 insertions(+), 14 deletions(-) create mode 100644 sync2/device_data_ticker.go create mode 100644 sync2/device_data_ticker_test.go diff --git a/pubsub/v2.go b/pubsub/v2.go index 2ed379f..7dfb01e 100644 --- a/pubsub/v2.go +++ b/pubsub/v2.go @@ -91,9 +91,7 @@ type V2InitialSyncComplete struct { func (*V2InitialSyncComplete) Type() string { return "V2InitialSyncComplete" } type V2DeviceData struct { - UserID string - DeviceID string - Pos int64 + UserIDToDeviceIDs map[string][]string } func (*V2DeviceData) Type() string { return "V2DeviceData" } diff --git a/sync2/device_data_ticker.go b/sync2/device_data_ticker.go new file mode 100644 index 0000000..7d77aae --- /dev/null +++ b/sync2/device_data_ticker.go @@ -0,0 +1,90 @@ +package sync2 + +import ( + "sync" + "time" + + "github.com/matrix-org/sliding-sync/pubsub" +) + +// This struct remembers user+device IDs to notify for then periodically +// emits them all to the caller. Use to rate limit the frequency of device list +// updates. +type DeviceDataTicker struct { + // data structures to periodically notify downstream about device data updates + // The ticker controls the frequency of updates. The done channel is used to stop ticking + // and clean up the goroutine. The notify map contains the values to notify for. + ticker *time.Ticker + done chan struct{} + notifyMap *sync.Map // map of PollerID to bools, unwrapped when notifying + fn func(payload *pubsub.V2DeviceData) +} + +// Create a new device data ticker, which batches calls to Remember and invokes a callback every +// d duration. If d is 0, no batching is performed and the callback is invoked synchronously, which +// is useful for testing. +func NewDeviceDataTicker(d time.Duration) *DeviceDataTicker { + ddt := &DeviceDataTicker{ + done: make(chan struct{}), + notifyMap: &sync.Map{}, + } + if d != 0 { + ddt.ticker = time.NewTicker(d) + } + return ddt +} + +// Stop ticking. +func (t *DeviceDataTicker) Stop() { + if t.ticker != nil { + t.ticker.Stop() + } + close(t.done) +} + +// Set the function which should be called when the tick happens. +func (t *DeviceDataTicker) SetCallback(fn func(payload *pubsub.V2DeviceData)) { + t.fn = fn +} + +// Remember this user/device ID, and emit it later on. +func (t *DeviceDataTicker) Remember(pid PollerID) { + t.notifyMap.Store(pid, true) + if t.ticker == nil { + t.emitUpdate() + } +} + +func (t *DeviceDataTicker) emitUpdate() { + var p pubsub.V2DeviceData + p.UserIDToDeviceIDs = make(map[string][]string) + // populate the pubsub payload + t.notifyMap.Range(func(key, value any) bool { + pid := key.(PollerID) + devices := p.UserIDToDeviceIDs[pid.UserID] + devices = append(devices, pid.DeviceID) + p.UserIDToDeviceIDs[pid.UserID] = devices + // clear the map of this value + t.notifyMap.Delete(key) + return true // keep enumerating + }) + // notify if we have entries + if len(p.UserIDToDeviceIDs) > 0 { + t.fn(&p) + } +} + +// Blocks forever, ticking until Stop() is called. +func (t *DeviceDataTicker) Run() { + if t.ticker == nil { + return + } + for { + select { + case <-t.done: + return + case <-t.ticker.C: + t.emitUpdate() + } + } +} diff --git a/sync2/device_data_ticker_test.go b/sync2/device_data_ticker_test.go new file mode 100644 index 0000000..470e537 --- /dev/null +++ b/sync2/device_data_ticker_test.go @@ -0,0 +1,124 @@ +package sync2 + +import ( + "reflect" + "sort" + "sync" + "testing" + "time" + + "github.com/matrix-org/sliding-sync/pubsub" +) + +func TestDeviceTickerBasic(t *testing.T) { + duration := time.Millisecond + ticker := NewDeviceDataTicker(duration) + var payloads []*pubsub.V2DeviceData + ticker.SetCallback(func(payload *pubsub.V2DeviceData) { + payloads = append(payloads, payload) + }) + var wg sync.WaitGroup + wg.Add(1) + go func() { + t.Log("starting the ticker") + ticker.Run() + wg.Done() + }() + time.Sleep(duration * 2) // wait until the ticker is consuming + t.Log("remembering a poller") + ticker.Remember(PollerID{ + UserID: "a", + DeviceID: "b", + }) + time.Sleep(duration * 2) + if len(payloads) != 1 { + t.Fatalf("expected 1 callback, got %d", len(payloads)) + } + want := map[string][]string{ + "a": {"b"}, + } + assertPayloadEqual(t, payloads[0].UserIDToDeviceIDs, want) + // check stopping works + payloads = []*pubsub.V2DeviceData{} + ticker.Stop() + wg.Wait() + time.Sleep(duration * 2) + if len(payloads) != 0 { + t.Fatalf("got extra payloads: %+v", payloads) + } +} + +func TestDeviceTickerBatchesCorrectly(t *testing.T) { + duration := 100 * time.Millisecond + ticker := NewDeviceDataTicker(duration) + var payloads []*pubsub.V2DeviceData + ticker.SetCallback(func(payload *pubsub.V2DeviceData) { + payloads = append(payloads, payload) + }) + go ticker.Run() + defer ticker.Stop() + ticker.Remember(PollerID{ + UserID: "a", + DeviceID: "b", + }) + ticker.Remember(PollerID{ + UserID: "a", + DeviceID: "bb", // different device, same user + }) + ticker.Remember(PollerID{ + UserID: "a", + DeviceID: "b", // dupe poller ID + }) + ticker.Remember(PollerID{ + UserID: "x", + DeviceID: "y", // new device and user + }) + time.Sleep(duration * 2) + if len(payloads) != 1 { + t.Fatalf("expected 1 callback, got %d", len(payloads)) + } + want := map[string][]string{ + "a": {"b", "bb"}, + "x": {"y"}, + } + assertPayloadEqual(t, payloads[0].UserIDToDeviceIDs, want) +} + +func TestDeviceTickerForgetsAfterEmitting(t *testing.T) { + duration := time.Millisecond + ticker := NewDeviceDataTicker(duration) + var payloads []*pubsub.V2DeviceData + ticker.SetCallback(func(payload *pubsub.V2DeviceData) { + payloads = append(payloads, payload) + }) + ticker.Remember(PollerID{ + UserID: "a", + DeviceID: "b", + }) + + go ticker.Run() + defer ticker.Stop() + ticker.Remember(PollerID{ + UserID: "a", + DeviceID: "b", + }) + time.Sleep(10 * duration) + if len(payloads) != 1 { + t.Fatalf("got %d payloads, want 1", len(payloads)) + } +} + +func assertPayloadEqual(t *testing.T, got, want map[string][]string) { + t.Helper() + if len(got) != len(want) { + t.Fatalf("got %+v\nwant %+v\n", got, want) + } + for userID, wantDeviceIDs := range want { + gotDeviceIDs := got[userID] + sort.Strings(wantDeviceIDs) + sort.Strings(gotDeviceIDs) + if !reflect.DeepEqual(gotDeviceIDs, wantDeviceIDs) { + t.Errorf("user %v got devices %v want %v", userID, gotDeviceIDs, wantDeviceIDs) + } + } +} diff --git a/sync2/handler2/handler.go b/sync2/handler2/handler.go index c95a276..f554300 100644 --- a/sync2/handler2/handler.go +++ b/sync2/handler2/handler.go @@ -4,11 +4,13 @@ import ( "context" "encoding/json" "fmt" - "github.com/jmoiron/sqlx" - "github.com/matrix-org/sliding-sync/sqlutil" "hash/fnv" "os" "sync" + "time" + + "github.com/jmoiron/sqlx" + "github.com/matrix-org/sliding-sync/sqlutil" "github.com/getsentry/sentry-go" @@ -43,13 +45,15 @@ type Handler struct { // room_id => fnv_hash([typing user ids]) typingMap map[string]uint64 + deviceDataTicker *sync2.DeviceDataTicker + numPollers prometheus.Gauge subSystem string } func NewHandler( connStr string, pMap *sync2.PollerMap, v2Store *sync2.Storage, store *state.Storage, client sync2.Client, - pub pubsub.Notifier, sub pubsub.Listener, enablePrometheus bool, + pub pubsub.Notifier, sub pubsub.Listener, enablePrometheus bool, deviceDataUpdateDuration time.Duration, ) (*Handler, error) { h := &Handler{ pMap: pMap, @@ -61,7 +65,8 @@ func NewHandler( Highlight int Notif int }), - typingMap: make(map[string]uint64), + typingMap: make(map[string]uint64), + deviceDataTicker: sync2.NewDeviceDataTicker(deviceDataUpdateDuration), } if enablePrometheus { @@ -86,6 +91,8 @@ func (h *Handler) Listen() { sentry.CaptureException(err) } }() + h.deviceDataTicker.SetCallback(h.OnBulkDeviceDataUpdate) + go h.deviceDataTicker.Run() } func (h *Handler) Teardown() { @@ -95,6 +102,7 @@ func (h *Handler) Teardown() { h.Store.Teardown() h.v2Store.Teardown() h.pMap.Terminate() + h.deviceDataTicker.Stop() if h.numPollers != nil { prometheus.Unregister(h.numPollers) } @@ -203,19 +211,24 @@ func (h *Handler) OnE2EEData(ctx context.Context, userID, deviceID string, otkCo New: deviceListChanges, }, } - nextPos, err := h.Store.DeviceDataTable.Upsert(&partialDD) + _, err := h.Store.DeviceDataTable.Upsert(&partialDD) if err != nil { logger.Err(err).Str("user", userID).Msg("failed to upsert device data") internal.GetSentryHubFromContextOrDefault(ctx).CaptureException(err) return } - h.v2Pub.Notify(pubsub.ChanV2, &pubsub.V2DeviceData{ + // remember this to notify on pubsub later + h.deviceDataTicker.Remember(sync2.PollerID{ UserID: userID, DeviceID: deviceID, - Pos: nextPos, }) } +// Called periodically by deviceDataTicker, contains many updates +func (h *Handler) OnBulkDeviceDataUpdate(payload *pubsub.V2DeviceData) { + h.v2Pub.Notify(pubsub.ChanV2, payload) +} + func (h *Handler) Accumulate(ctx context.Context, userID, deviceID, roomID, prevBatch string, timeline []json.RawMessage) { // Remember any transaction IDs that may be unique to this user eventIDsWithTxns := make([]string, 0, len(timeline)) // in timeline order diff --git a/sync3/handler/handler.go b/sync3/handler/handler.go index 6717e99..68305e8 100644 --- a/sync3/handler/handler.go +++ b/sync3/handler/handler.go @@ -664,9 +664,14 @@ func (h *SyncLiveHandler) OnUnreadCounts(p *pubsub.V2UnreadCounts) { func (h *SyncLiveHandler) OnDeviceData(p *pubsub.V2DeviceData) { ctx, task := internal.StartTask(context.Background(), "OnDeviceData") defer task.End() - conns := h.ConnMap.Conns(p.UserID, p.DeviceID) - for _, conn := range conns { - conn.OnUpdate(ctx, caches.DeviceDataUpdate{}) + internal.Logf(ctx, "device_data", fmt.Sprintf("%v users to notify", len(p.UserIDToDeviceIDs))) + for userID, deviceIDs := range p.UserIDToDeviceIDs { + for _, deviceID := range deviceIDs { + conns := h.ConnMap.Conns(userID, deviceID) + for _, conn := range conns { + conn.OnUpdate(ctx, caches.DeviceDataUpdate{}) + } + } } } diff --git a/v3.go b/v3.go index a286064..e189e97 100644 --- a/v3.go +++ b/v3.go @@ -76,8 +76,10 @@ func Setup(destHomeserver, postgresURI, secret string, opts Opts) (*handler2.Han store := state.NewStorage(postgresURI) storev2 := sync2.NewStore(postgresURI, secret) bufferSize := 50 + deviceDataUpdateFrequency := time.Second if opts.TestingSynchronousPubsub { bufferSize = 0 + deviceDataUpdateFrequency = 0 // don't batch } if opts.MaxPendingEventUpdates == 0 { opts.MaxPendingEventUpdates = 2000 @@ -86,7 +88,7 @@ func Setup(destHomeserver, postgresURI, secret string, opts Opts) (*handler2.Han pMap := sync2.NewPollerMap(v2Client, opts.AddPrometheusMetrics) // create v2 handler - h2, err := handler2.NewHandler(postgresURI, pMap, storev2, store, v2Client, pubSub, pubSub, opts.AddPrometheusMetrics) + h2, err := handler2.NewHandler(postgresURI, pMap, storev2, store, v2Client, pubSub, pubSub, opts.AddPrometheusMetrics, deviceDataUpdateFrequency) if err != nil { panic(err) }