Rate limit pubsub.V2DeviceData updates to be at most 1 per second

The db writes are still instant, but the notifications are now delayed
by up to 1 second, in order to not swamp the pubsub channels.
This commit is contained in:
Kegan Dougal 2023-06-26 21:04:02 -07:00
parent a78612e64a
commit f36c038cf8
6 changed files with 246 additions and 14 deletions

View File

@ -91,9 +91,7 @@ type V2InitialSyncComplete struct {
func (*V2InitialSyncComplete) Type() string { return "V2InitialSyncComplete" } func (*V2InitialSyncComplete) Type() string { return "V2InitialSyncComplete" }
type V2DeviceData struct { type V2DeviceData struct {
UserID string UserIDToDeviceIDs map[string][]string
DeviceID string
Pos int64
} }
func (*V2DeviceData) Type() string { return "V2DeviceData" } func (*V2DeviceData) Type() string { return "V2DeviceData" }

View File

@ -0,0 +1,90 @@
package sync2
import (
"sync"
"time"
"github.com/matrix-org/sliding-sync/pubsub"
)
// This struct remembers user+device IDs to notify for then periodically
// emits them all to the caller. Use to rate limit the frequency of device list
// updates.
type DeviceDataTicker struct {
// data structures to periodically notify downstream about device data updates
// The ticker controls the frequency of updates. The done channel is used to stop ticking
// and clean up the goroutine. The notify map contains the values to notify for.
ticker *time.Ticker
done chan struct{}
notifyMap *sync.Map // map of PollerID to bools, unwrapped when notifying
fn func(payload *pubsub.V2DeviceData)
}
// Create a new device data ticker, which batches calls to Remember and invokes a callback every
// d duration. If d is 0, no batching is performed and the callback is invoked synchronously, which
// is useful for testing.
func NewDeviceDataTicker(d time.Duration) *DeviceDataTicker {
ddt := &DeviceDataTicker{
done: make(chan struct{}),
notifyMap: &sync.Map{},
}
if d != 0 {
ddt.ticker = time.NewTicker(d)
}
return ddt
}
// Stop ticking.
func (t *DeviceDataTicker) Stop() {
if t.ticker != nil {
t.ticker.Stop()
}
close(t.done)
}
// Set the function which should be called when the tick happens.
func (t *DeviceDataTicker) SetCallback(fn func(payload *pubsub.V2DeviceData)) {
t.fn = fn
}
// Remember this user/device ID, and emit it later on.
func (t *DeviceDataTicker) Remember(pid PollerID) {
t.notifyMap.Store(pid, true)
if t.ticker == nil {
t.emitUpdate()
}
}
func (t *DeviceDataTicker) emitUpdate() {
var p pubsub.V2DeviceData
p.UserIDToDeviceIDs = make(map[string][]string)
// populate the pubsub payload
t.notifyMap.Range(func(key, value any) bool {
pid := key.(PollerID)
devices := p.UserIDToDeviceIDs[pid.UserID]
devices = append(devices, pid.DeviceID)
p.UserIDToDeviceIDs[pid.UserID] = devices
// clear the map of this value
t.notifyMap.Delete(key)
return true // keep enumerating
})
// notify if we have entries
if len(p.UserIDToDeviceIDs) > 0 {
t.fn(&p)
}
}
// Blocks forever, ticking until Stop() is called.
func (t *DeviceDataTicker) Run() {
if t.ticker == nil {
return
}
for {
select {
case <-t.done:
return
case <-t.ticker.C:
t.emitUpdate()
}
}
}

View File

@ -0,0 +1,124 @@
package sync2
import (
"reflect"
"sort"
"sync"
"testing"
"time"
"github.com/matrix-org/sliding-sync/pubsub"
)
func TestDeviceTickerBasic(t *testing.T) {
duration := time.Millisecond
ticker := NewDeviceDataTicker(duration)
var payloads []*pubsub.V2DeviceData
ticker.SetCallback(func(payload *pubsub.V2DeviceData) {
payloads = append(payloads, payload)
})
var wg sync.WaitGroup
wg.Add(1)
go func() {
t.Log("starting the ticker")
ticker.Run()
wg.Done()
}()
time.Sleep(duration * 2) // wait until the ticker is consuming
t.Log("remembering a poller")
ticker.Remember(PollerID{
UserID: "a",
DeviceID: "b",
})
time.Sleep(duration * 2)
if len(payloads) != 1 {
t.Fatalf("expected 1 callback, got %d", len(payloads))
}
want := map[string][]string{
"a": {"b"},
}
assertPayloadEqual(t, payloads[0].UserIDToDeviceIDs, want)
// check stopping works
payloads = []*pubsub.V2DeviceData{}
ticker.Stop()
wg.Wait()
time.Sleep(duration * 2)
if len(payloads) != 0 {
t.Fatalf("got extra payloads: %+v", payloads)
}
}
func TestDeviceTickerBatchesCorrectly(t *testing.T) {
duration := 100 * time.Millisecond
ticker := NewDeviceDataTicker(duration)
var payloads []*pubsub.V2DeviceData
ticker.SetCallback(func(payload *pubsub.V2DeviceData) {
payloads = append(payloads, payload)
})
go ticker.Run()
defer ticker.Stop()
ticker.Remember(PollerID{
UserID: "a",
DeviceID: "b",
})
ticker.Remember(PollerID{
UserID: "a",
DeviceID: "bb", // different device, same user
})
ticker.Remember(PollerID{
UserID: "a",
DeviceID: "b", // dupe poller ID
})
ticker.Remember(PollerID{
UserID: "x",
DeviceID: "y", // new device and user
})
time.Sleep(duration * 2)
if len(payloads) != 1 {
t.Fatalf("expected 1 callback, got %d", len(payloads))
}
want := map[string][]string{
"a": {"b", "bb"},
"x": {"y"},
}
assertPayloadEqual(t, payloads[0].UserIDToDeviceIDs, want)
}
func TestDeviceTickerForgetsAfterEmitting(t *testing.T) {
duration := time.Millisecond
ticker := NewDeviceDataTicker(duration)
var payloads []*pubsub.V2DeviceData
ticker.SetCallback(func(payload *pubsub.V2DeviceData) {
payloads = append(payloads, payload)
})
ticker.Remember(PollerID{
UserID: "a",
DeviceID: "b",
})
go ticker.Run()
defer ticker.Stop()
ticker.Remember(PollerID{
UserID: "a",
DeviceID: "b",
})
time.Sleep(10 * duration)
if len(payloads) != 1 {
t.Fatalf("got %d payloads, want 1", len(payloads))
}
}
func assertPayloadEqual(t *testing.T, got, want map[string][]string) {
t.Helper()
if len(got) != len(want) {
t.Fatalf("got %+v\nwant %+v\n", got, want)
}
for userID, wantDeviceIDs := range want {
gotDeviceIDs := got[userID]
sort.Strings(wantDeviceIDs)
sort.Strings(gotDeviceIDs)
if !reflect.DeepEqual(gotDeviceIDs, wantDeviceIDs) {
t.Errorf("user %v got devices %v want %v", userID, gotDeviceIDs, wantDeviceIDs)
}
}
}

View File

@ -4,11 +4,13 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/jmoiron/sqlx"
"github.com/matrix-org/sliding-sync/sqlutil"
"hash/fnv" "hash/fnv"
"os" "os"
"sync" "sync"
"time"
"github.com/jmoiron/sqlx"
"github.com/matrix-org/sliding-sync/sqlutil"
"github.com/getsentry/sentry-go" "github.com/getsentry/sentry-go"
@ -43,13 +45,15 @@ type Handler struct {
// room_id => fnv_hash([typing user ids]) // room_id => fnv_hash([typing user ids])
typingMap map[string]uint64 typingMap map[string]uint64
deviceDataTicker *sync2.DeviceDataTicker
numPollers prometheus.Gauge numPollers prometheus.Gauge
subSystem string subSystem string
} }
func NewHandler( func NewHandler(
connStr string, pMap *sync2.PollerMap, v2Store *sync2.Storage, store *state.Storage, client sync2.Client, connStr string, pMap *sync2.PollerMap, v2Store *sync2.Storage, store *state.Storage, client sync2.Client,
pub pubsub.Notifier, sub pubsub.Listener, enablePrometheus bool, pub pubsub.Notifier, sub pubsub.Listener, enablePrometheus bool, deviceDataUpdateDuration time.Duration,
) (*Handler, error) { ) (*Handler, error) {
h := &Handler{ h := &Handler{
pMap: pMap, pMap: pMap,
@ -61,7 +65,8 @@ func NewHandler(
Highlight int Highlight int
Notif int Notif int
}), }),
typingMap: make(map[string]uint64), typingMap: make(map[string]uint64),
deviceDataTicker: sync2.NewDeviceDataTicker(deviceDataUpdateDuration),
} }
if enablePrometheus { if enablePrometheus {
@ -86,6 +91,8 @@ func (h *Handler) Listen() {
sentry.CaptureException(err) sentry.CaptureException(err)
} }
}() }()
h.deviceDataTicker.SetCallback(h.OnBulkDeviceDataUpdate)
go h.deviceDataTicker.Run()
} }
func (h *Handler) Teardown() { func (h *Handler) Teardown() {
@ -95,6 +102,7 @@ func (h *Handler) Teardown() {
h.Store.Teardown() h.Store.Teardown()
h.v2Store.Teardown() h.v2Store.Teardown()
h.pMap.Terminate() h.pMap.Terminate()
h.deviceDataTicker.Stop()
if h.numPollers != nil { if h.numPollers != nil {
prometheus.Unregister(h.numPollers) prometheus.Unregister(h.numPollers)
} }
@ -203,19 +211,24 @@ func (h *Handler) OnE2EEData(ctx context.Context, userID, deviceID string, otkCo
New: deviceListChanges, New: deviceListChanges,
}, },
} }
nextPos, err := h.Store.DeviceDataTable.Upsert(&partialDD) _, err := h.Store.DeviceDataTable.Upsert(&partialDD)
if err != nil { if err != nil {
logger.Err(err).Str("user", userID).Msg("failed to upsert device data") logger.Err(err).Str("user", userID).Msg("failed to upsert device data")
internal.GetSentryHubFromContextOrDefault(ctx).CaptureException(err) internal.GetSentryHubFromContextOrDefault(ctx).CaptureException(err)
return return
} }
h.v2Pub.Notify(pubsub.ChanV2, &pubsub.V2DeviceData{ // remember this to notify on pubsub later
h.deviceDataTicker.Remember(sync2.PollerID{
UserID: userID, UserID: userID,
DeviceID: deviceID, DeviceID: deviceID,
Pos: nextPos,
}) })
} }
// Called periodically by deviceDataTicker, contains many updates
func (h *Handler) OnBulkDeviceDataUpdate(payload *pubsub.V2DeviceData) {
h.v2Pub.Notify(pubsub.ChanV2, payload)
}
func (h *Handler) Accumulate(ctx context.Context, userID, deviceID, roomID, prevBatch string, timeline []json.RawMessage) { func (h *Handler) Accumulate(ctx context.Context, userID, deviceID, roomID, prevBatch string, timeline []json.RawMessage) {
// Remember any transaction IDs that may be unique to this user // Remember any transaction IDs that may be unique to this user
eventIDsWithTxns := make([]string, 0, len(timeline)) // in timeline order eventIDsWithTxns := make([]string, 0, len(timeline)) // in timeline order

View File

@ -664,9 +664,14 @@ func (h *SyncLiveHandler) OnUnreadCounts(p *pubsub.V2UnreadCounts) {
func (h *SyncLiveHandler) OnDeviceData(p *pubsub.V2DeviceData) { func (h *SyncLiveHandler) OnDeviceData(p *pubsub.V2DeviceData) {
ctx, task := internal.StartTask(context.Background(), "OnDeviceData") ctx, task := internal.StartTask(context.Background(), "OnDeviceData")
defer task.End() defer task.End()
conns := h.ConnMap.Conns(p.UserID, p.DeviceID) internal.Logf(ctx, "device_data", fmt.Sprintf("%v users to notify", len(p.UserIDToDeviceIDs)))
for _, conn := range conns { for userID, deviceIDs := range p.UserIDToDeviceIDs {
conn.OnUpdate(ctx, caches.DeviceDataUpdate{}) for _, deviceID := range deviceIDs {
conns := h.ConnMap.Conns(userID, deviceID)
for _, conn := range conns {
conn.OnUpdate(ctx, caches.DeviceDataUpdate{})
}
}
} }
} }

4
v3.go
View File

@ -76,8 +76,10 @@ func Setup(destHomeserver, postgresURI, secret string, opts Opts) (*handler2.Han
store := state.NewStorage(postgresURI) store := state.NewStorage(postgresURI)
storev2 := sync2.NewStore(postgresURI, secret) storev2 := sync2.NewStore(postgresURI, secret)
bufferSize := 50 bufferSize := 50
deviceDataUpdateFrequency := time.Second
if opts.TestingSynchronousPubsub { if opts.TestingSynchronousPubsub {
bufferSize = 0 bufferSize = 0
deviceDataUpdateFrequency = 0 // don't batch
} }
if opts.MaxPendingEventUpdates == 0 { if opts.MaxPendingEventUpdates == 0 {
opts.MaxPendingEventUpdates = 2000 opts.MaxPendingEventUpdates = 2000
@ -86,7 +88,7 @@ func Setup(destHomeserver, postgresURI, secret string, opts Opts) (*handler2.Han
pMap := sync2.NewPollerMap(v2Client, opts.AddPrometheusMetrics) pMap := sync2.NewPollerMap(v2Client, opts.AddPrometheusMetrics)
// create v2 handler // create v2 handler
h2, err := handler2.NewHandler(postgresURI, pMap, storev2, store, v2Client, pubSub, pubSub, opts.AddPrometheusMetrics) h2, err := handler2.NewHandler(postgresURI, pMap, storev2, store, v2Client, pubSub, pubSub, opts.AddPrometheusMetrics, deviceDataUpdateFrequency)
if err != nil { if err != nil {
panic(err) panic(err)
} }