Rate limit pubsub.V2DeviceData updates to be at most 1 per second

The db writes are still instant, but the notifications are now delayed
by up to 1 second, in order to not swamp the pubsub channels.
This commit is contained in:
Kegan Dougal 2023-06-26 21:04:02 -07:00
parent a78612e64a
commit f36c038cf8
6 changed files with 246 additions and 14 deletions

View File

@ -91,9 +91,7 @@ type V2InitialSyncComplete struct {
func (*V2InitialSyncComplete) Type() string { return "V2InitialSyncComplete" }
type V2DeviceData struct {
UserID string
DeviceID string
Pos int64
UserIDToDeviceIDs map[string][]string
}
func (*V2DeviceData) Type() string { return "V2DeviceData" }

View File

@ -0,0 +1,90 @@
package sync2
import (
"sync"
"time"
"github.com/matrix-org/sliding-sync/pubsub"
)
// This struct remembers user+device IDs to notify for then periodically
// emits them all to the caller. Use to rate limit the frequency of device list
// updates.
type DeviceDataTicker struct {
// data structures to periodically notify downstream about device data updates
// The ticker controls the frequency of updates. The done channel is used to stop ticking
// and clean up the goroutine. The notify map contains the values to notify for.
ticker *time.Ticker
done chan struct{}
notifyMap *sync.Map // map of PollerID to bools, unwrapped when notifying
fn func(payload *pubsub.V2DeviceData)
}
// Create a new device data ticker, which batches calls to Remember and invokes a callback every
// d duration. If d is 0, no batching is performed and the callback is invoked synchronously, which
// is useful for testing.
func NewDeviceDataTicker(d time.Duration) *DeviceDataTicker {
ddt := &DeviceDataTicker{
done: make(chan struct{}),
notifyMap: &sync.Map{},
}
if d != 0 {
ddt.ticker = time.NewTicker(d)
}
return ddt
}
// Stop ticking.
func (t *DeviceDataTicker) Stop() {
if t.ticker != nil {
t.ticker.Stop()
}
close(t.done)
}
// Set the function which should be called when the tick happens.
func (t *DeviceDataTicker) SetCallback(fn func(payload *pubsub.V2DeviceData)) {
t.fn = fn
}
// Remember this user/device ID, and emit it later on.
func (t *DeviceDataTicker) Remember(pid PollerID) {
t.notifyMap.Store(pid, true)
if t.ticker == nil {
t.emitUpdate()
}
}
func (t *DeviceDataTicker) emitUpdate() {
var p pubsub.V2DeviceData
p.UserIDToDeviceIDs = make(map[string][]string)
// populate the pubsub payload
t.notifyMap.Range(func(key, value any) bool {
pid := key.(PollerID)
devices := p.UserIDToDeviceIDs[pid.UserID]
devices = append(devices, pid.DeviceID)
p.UserIDToDeviceIDs[pid.UserID] = devices
// clear the map of this value
t.notifyMap.Delete(key)
return true // keep enumerating
})
// notify if we have entries
if len(p.UserIDToDeviceIDs) > 0 {
t.fn(&p)
}
}
// Blocks forever, ticking until Stop() is called.
func (t *DeviceDataTicker) Run() {
if t.ticker == nil {
return
}
for {
select {
case <-t.done:
return
case <-t.ticker.C:
t.emitUpdate()
}
}
}

View File

@ -0,0 +1,124 @@
package sync2
import (
"reflect"
"sort"
"sync"
"testing"
"time"
"github.com/matrix-org/sliding-sync/pubsub"
)
func TestDeviceTickerBasic(t *testing.T) {
duration := time.Millisecond
ticker := NewDeviceDataTicker(duration)
var payloads []*pubsub.V2DeviceData
ticker.SetCallback(func(payload *pubsub.V2DeviceData) {
payloads = append(payloads, payload)
})
var wg sync.WaitGroup
wg.Add(1)
go func() {
t.Log("starting the ticker")
ticker.Run()
wg.Done()
}()
time.Sleep(duration * 2) // wait until the ticker is consuming
t.Log("remembering a poller")
ticker.Remember(PollerID{
UserID: "a",
DeviceID: "b",
})
time.Sleep(duration * 2)
if len(payloads) != 1 {
t.Fatalf("expected 1 callback, got %d", len(payloads))
}
want := map[string][]string{
"a": {"b"},
}
assertPayloadEqual(t, payloads[0].UserIDToDeviceIDs, want)
// check stopping works
payloads = []*pubsub.V2DeviceData{}
ticker.Stop()
wg.Wait()
time.Sleep(duration * 2)
if len(payloads) != 0 {
t.Fatalf("got extra payloads: %+v", payloads)
}
}
func TestDeviceTickerBatchesCorrectly(t *testing.T) {
duration := 100 * time.Millisecond
ticker := NewDeviceDataTicker(duration)
var payloads []*pubsub.V2DeviceData
ticker.SetCallback(func(payload *pubsub.V2DeviceData) {
payloads = append(payloads, payload)
})
go ticker.Run()
defer ticker.Stop()
ticker.Remember(PollerID{
UserID: "a",
DeviceID: "b",
})
ticker.Remember(PollerID{
UserID: "a",
DeviceID: "bb", // different device, same user
})
ticker.Remember(PollerID{
UserID: "a",
DeviceID: "b", // dupe poller ID
})
ticker.Remember(PollerID{
UserID: "x",
DeviceID: "y", // new device and user
})
time.Sleep(duration * 2)
if len(payloads) != 1 {
t.Fatalf("expected 1 callback, got %d", len(payloads))
}
want := map[string][]string{
"a": {"b", "bb"},
"x": {"y"},
}
assertPayloadEqual(t, payloads[0].UserIDToDeviceIDs, want)
}
func TestDeviceTickerForgetsAfterEmitting(t *testing.T) {
duration := time.Millisecond
ticker := NewDeviceDataTicker(duration)
var payloads []*pubsub.V2DeviceData
ticker.SetCallback(func(payload *pubsub.V2DeviceData) {
payloads = append(payloads, payload)
})
ticker.Remember(PollerID{
UserID: "a",
DeviceID: "b",
})
go ticker.Run()
defer ticker.Stop()
ticker.Remember(PollerID{
UserID: "a",
DeviceID: "b",
})
time.Sleep(10 * duration)
if len(payloads) != 1 {
t.Fatalf("got %d payloads, want 1", len(payloads))
}
}
func assertPayloadEqual(t *testing.T, got, want map[string][]string) {
t.Helper()
if len(got) != len(want) {
t.Fatalf("got %+v\nwant %+v\n", got, want)
}
for userID, wantDeviceIDs := range want {
gotDeviceIDs := got[userID]
sort.Strings(wantDeviceIDs)
sort.Strings(gotDeviceIDs)
if !reflect.DeepEqual(gotDeviceIDs, wantDeviceIDs) {
t.Errorf("user %v got devices %v want %v", userID, gotDeviceIDs, wantDeviceIDs)
}
}
}

View File

@ -4,11 +4,13 @@ import (
"context"
"encoding/json"
"fmt"
"github.com/jmoiron/sqlx"
"github.com/matrix-org/sliding-sync/sqlutil"
"hash/fnv"
"os"
"sync"
"time"
"github.com/jmoiron/sqlx"
"github.com/matrix-org/sliding-sync/sqlutil"
"github.com/getsentry/sentry-go"
@ -43,13 +45,15 @@ type Handler struct {
// room_id => fnv_hash([typing user ids])
typingMap map[string]uint64
deviceDataTicker *sync2.DeviceDataTicker
numPollers prometheus.Gauge
subSystem string
}
func NewHandler(
connStr string, pMap *sync2.PollerMap, v2Store *sync2.Storage, store *state.Storage, client sync2.Client,
pub pubsub.Notifier, sub pubsub.Listener, enablePrometheus bool,
pub pubsub.Notifier, sub pubsub.Listener, enablePrometheus bool, deviceDataUpdateDuration time.Duration,
) (*Handler, error) {
h := &Handler{
pMap: pMap,
@ -62,6 +66,7 @@ func NewHandler(
Notif int
}),
typingMap: make(map[string]uint64),
deviceDataTicker: sync2.NewDeviceDataTicker(deviceDataUpdateDuration),
}
if enablePrometheus {
@ -86,6 +91,8 @@ func (h *Handler) Listen() {
sentry.CaptureException(err)
}
}()
h.deviceDataTicker.SetCallback(h.OnBulkDeviceDataUpdate)
go h.deviceDataTicker.Run()
}
func (h *Handler) Teardown() {
@ -95,6 +102,7 @@ func (h *Handler) Teardown() {
h.Store.Teardown()
h.v2Store.Teardown()
h.pMap.Terminate()
h.deviceDataTicker.Stop()
if h.numPollers != nil {
prometheus.Unregister(h.numPollers)
}
@ -203,19 +211,24 @@ func (h *Handler) OnE2EEData(ctx context.Context, userID, deviceID string, otkCo
New: deviceListChanges,
},
}
nextPos, err := h.Store.DeviceDataTable.Upsert(&partialDD)
_, err := h.Store.DeviceDataTable.Upsert(&partialDD)
if err != nil {
logger.Err(err).Str("user", userID).Msg("failed to upsert device data")
internal.GetSentryHubFromContextOrDefault(ctx).CaptureException(err)
return
}
h.v2Pub.Notify(pubsub.ChanV2, &pubsub.V2DeviceData{
// remember this to notify on pubsub later
h.deviceDataTicker.Remember(sync2.PollerID{
UserID: userID,
DeviceID: deviceID,
Pos: nextPos,
})
}
// Called periodically by deviceDataTicker, contains many updates
func (h *Handler) OnBulkDeviceDataUpdate(payload *pubsub.V2DeviceData) {
h.v2Pub.Notify(pubsub.ChanV2, payload)
}
func (h *Handler) Accumulate(ctx context.Context, userID, deviceID, roomID, prevBatch string, timeline []json.RawMessage) {
// Remember any transaction IDs that may be unique to this user
eventIDsWithTxns := make([]string, 0, len(timeline)) // in timeline order

View File

@ -664,10 +664,15 @@ func (h *SyncLiveHandler) OnUnreadCounts(p *pubsub.V2UnreadCounts) {
func (h *SyncLiveHandler) OnDeviceData(p *pubsub.V2DeviceData) {
ctx, task := internal.StartTask(context.Background(), "OnDeviceData")
defer task.End()
conns := h.ConnMap.Conns(p.UserID, p.DeviceID)
internal.Logf(ctx, "device_data", fmt.Sprintf("%v users to notify", len(p.UserIDToDeviceIDs)))
for userID, deviceIDs := range p.UserIDToDeviceIDs {
for _, deviceID := range deviceIDs {
conns := h.ConnMap.Conns(userID, deviceID)
for _, conn := range conns {
conn.OnUpdate(ctx, caches.DeviceDataUpdate{})
}
}
}
}
func (h *SyncLiveHandler) OnDeviceMessages(p *pubsub.V2DeviceMessages) {

4
v3.go
View File

@ -76,8 +76,10 @@ func Setup(destHomeserver, postgresURI, secret string, opts Opts) (*handler2.Han
store := state.NewStorage(postgresURI)
storev2 := sync2.NewStore(postgresURI, secret)
bufferSize := 50
deviceDataUpdateFrequency := time.Second
if opts.TestingSynchronousPubsub {
bufferSize = 0
deviceDataUpdateFrequency = 0 // don't batch
}
if opts.MaxPendingEventUpdates == 0 {
opts.MaxPendingEventUpdates = 2000
@ -86,7 +88,7 @@ func Setup(destHomeserver, postgresURI, secret string, opts Opts) (*handler2.Han
pMap := sync2.NewPollerMap(v2Client, opts.AddPrometheusMetrics)
// create v2 handler
h2, err := handler2.NewHandler(postgresURI, pMap, storev2, store, v2Client, pubSub, pubSub, opts.AddPrometheusMetrics)
h2, err := handler2.NewHandler(postgresURI, pMap, storev2, store, v2Client, pubSub, pubSub, opts.AddPrometheusMetrics, deviceDataUpdateFrequency)
if err != nil {
panic(err)
}