mirror of
https://github.com/matrix-org/sliding-sync.git
synced 2025-03-10 13:37:11 +00:00
Rate limit pubsub.V2DeviceData updates to be at most 1 per second
The db writes are still instant, but the notifications are now delayed by up to 1 second, in order to not swamp the pubsub channels.
This commit is contained in:
parent
a78612e64a
commit
f36c038cf8
@ -91,9 +91,7 @@ type V2InitialSyncComplete struct {
|
||||
func (*V2InitialSyncComplete) Type() string { return "V2InitialSyncComplete" }
|
||||
|
||||
type V2DeviceData struct {
|
||||
UserID string
|
||||
DeviceID string
|
||||
Pos int64
|
||||
UserIDToDeviceIDs map[string][]string
|
||||
}
|
||||
|
||||
func (*V2DeviceData) Type() string { return "V2DeviceData" }
|
||||
|
90
sync2/device_data_ticker.go
Normal file
90
sync2/device_data_ticker.go
Normal file
@ -0,0 +1,90 @@
|
||||
package sync2
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/sliding-sync/pubsub"
|
||||
)
|
||||
|
||||
// This struct remembers user+device IDs to notify for then periodically
|
||||
// emits them all to the caller. Use to rate limit the frequency of device list
|
||||
// updates.
|
||||
type DeviceDataTicker struct {
|
||||
// data structures to periodically notify downstream about device data updates
|
||||
// The ticker controls the frequency of updates. The done channel is used to stop ticking
|
||||
// and clean up the goroutine. The notify map contains the values to notify for.
|
||||
ticker *time.Ticker
|
||||
done chan struct{}
|
||||
notifyMap *sync.Map // map of PollerID to bools, unwrapped when notifying
|
||||
fn func(payload *pubsub.V2DeviceData)
|
||||
}
|
||||
|
||||
// Create a new device data ticker, which batches calls to Remember and invokes a callback every
|
||||
// d duration. If d is 0, no batching is performed and the callback is invoked synchronously, which
|
||||
// is useful for testing.
|
||||
func NewDeviceDataTicker(d time.Duration) *DeviceDataTicker {
|
||||
ddt := &DeviceDataTicker{
|
||||
done: make(chan struct{}),
|
||||
notifyMap: &sync.Map{},
|
||||
}
|
||||
if d != 0 {
|
||||
ddt.ticker = time.NewTicker(d)
|
||||
}
|
||||
return ddt
|
||||
}
|
||||
|
||||
// Stop ticking.
|
||||
func (t *DeviceDataTicker) Stop() {
|
||||
if t.ticker != nil {
|
||||
t.ticker.Stop()
|
||||
}
|
||||
close(t.done)
|
||||
}
|
||||
|
||||
// Set the function which should be called when the tick happens.
|
||||
func (t *DeviceDataTicker) SetCallback(fn func(payload *pubsub.V2DeviceData)) {
|
||||
t.fn = fn
|
||||
}
|
||||
|
||||
// Remember this user/device ID, and emit it later on.
|
||||
func (t *DeviceDataTicker) Remember(pid PollerID) {
|
||||
t.notifyMap.Store(pid, true)
|
||||
if t.ticker == nil {
|
||||
t.emitUpdate()
|
||||
}
|
||||
}
|
||||
|
||||
func (t *DeviceDataTicker) emitUpdate() {
|
||||
var p pubsub.V2DeviceData
|
||||
p.UserIDToDeviceIDs = make(map[string][]string)
|
||||
// populate the pubsub payload
|
||||
t.notifyMap.Range(func(key, value any) bool {
|
||||
pid := key.(PollerID)
|
||||
devices := p.UserIDToDeviceIDs[pid.UserID]
|
||||
devices = append(devices, pid.DeviceID)
|
||||
p.UserIDToDeviceIDs[pid.UserID] = devices
|
||||
// clear the map of this value
|
||||
t.notifyMap.Delete(key)
|
||||
return true // keep enumerating
|
||||
})
|
||||
// notify if we have entries
|
||||
if len(p.UserIDToDeviceIDs) > 0 {
|
||||
t.fn(&p)
|
||||
}
|
||||
}
|
||||
|
||||
// Blocks forever, ticking until Stop() is called.
|
||||
func (t *DeviceDataTicker) Run() {
|
||||
if t.ticker == nil {
|
||||
return
|
||||
}
|
||||
for {
|
||||
select {
|
||||
case <-t.done:
|
||||
return
|
||||
case <-t.ticker.C:
|
||||
t.emitUpdate()
|
||||
}
|
||||
}
|
||||
}
|
124
sync2/device_data_ticker_test.go
Normal file
124
sync2/device_data_ticker_test.go
Normal file
@ -0,0 +1,124 @@
|
||||
package sync2
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"sort"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/sliding-sync/pubsub"
|
||||
)
|
||||
|
||||
func TestDeviceTickerBasic(t *testing.T) {
|
||||
duration := time.Millisecond
|
||||
ticker := NewDeviceDataTicker(duration)
|
||||
var payloads []*pubsub.V2DeviceData
|
||||
ticker.SetCallback(func(payload *pubsub.V2DeviceData) {
|
||||
payloads = append(payloads, payload)
|
||||
})
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
t.Log("starting the ticker")
|
||||
ticker.Run()
|
||||
wg.Done()
|
||||
}()
|
||||
time.Sleep(duration * 2) // wait until the ticker is consuming
|
||||
t.Log("remembering a poller")
|
||||
ticker.Remember(PollerID{
|
||||
UserID: "a",
|
||||
DeviceID: "b",
|
||||
})
|
||||
time.Sleep(duration * 2)
|
||||
if len(payloads) != 1 {
|
||||
t.Fatalf("expected 1 callback, got %d", len(payloads))
|
||||
}
|
||||
want := map[string][]string{
|
||||
"a": {"b"},
|
||||
}
|
||||
assertPayloadEqual(t, payloads[0].UserIDToDeviceIDs, want)
|
||||
// check stopping works
|
||||
payloads = []*pubsub.V2DeviceData{}
|
||||
ticker.Stop()
|
||||
wg.Wait()
|
||||
time.Sleep(duration * 2)
|
||||
if len(payloads) != 0 {
|
||||
t.Fatalf("got extra payloads: %+v", payloads)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeviceTickerBatchesCorrectly(t *testing.T) {
|
||||
duration := 100 * time.Millisecond
|
||||
ticker := NewDeviceDataTicker(duration)
|
||||
var payloads []*pubsub.V2DeviceData
|
||||
ticker.SetCallback(func(payload *pubsub.V2DeviceData) {
|
||||
payloads = append(payloads, payload)
|
||||
})
|
||||
go ticker.Run()
|
||||
defer ticker.Stop()
|
||||
ticker.Remember(PollerID{
|
||||
UserID: "a",
|
||||
DeviceID: "b",
|
||||
})
|
||||
ticker.Remember(PollerID{
|
||||
UserID: "a",
|
||||
DeviceID: "bb", // different device, same user
|
||||
})
|
||||
ticker.Remember(PollerID{
|
||||
UserID: "a",
|
||||
DeviceID: "b", // dupe poller ID
|
||||
})
|
||||
ticker.Remember(PollerID{
|
||||
UserID: "x",
|
||||
DeviceID: "y", // new device and user
|
||||
})
|
||||
time.Sleep(duration * 2)
|
||||
if len(payloads) != 1 {
|
||||
t.Fatalf("expected 1 callback, got %d", len(payloads))
|
||||
}
|
||||
want := map[string][]string{
|
||||
"a": {"b", "bb"},
|
||||
"x": {"y"},
|
||||
}
|
||||
assertPayloadEqual(t, payloads[0].UserIDToDeviceIDs, want)
|
||||
}
|
||||
|
||||
func TestDeviceTickerForgetsAfterEmitting(t *testing.T) {
|
||||
duration := time.Millisecond
|
||||
ticker := NewDeviceDataTicker(duration)
|
||||
var payloads []*pubsub.V2DeviceData
|
||||
ticker.SetCallback(func(payload *pubsub.V2DeviceData) {
|
||||
payloads = append(payloads, payload)
|
||||
})
|
||||
ticker.Remember(PollerID{
|
||||
UserID: "a",
|
||||
DeviceID: "b",
|
||||
})
|
||||
|
||||
go ticker.Run()
|
||||
defer ticker.Stop()
|
||||
ticker.Remember(PollerID{
|
||||
UserID: "a",
|
||||
DeviceID: "b",
|
||||
})
|
||||
time.Sleep(10 * duration)
|
||||
if len(payloads) != 1 {
|
||||
t.Fatalf("got %d payloads, want 1", len(payloads))
|
||||
}
|
||||
}
|
||||
|
||||
func assertPayloadEqual(t *testing.T, got, want map[string][]string) {
|
||||
t.Helper()
|
||||
if len(got) != len(want) {
|
||||
t.Fatalf("got %+v\nwant %+v\n", got, want)
|
||||
}
|
||||
for userID, wantDeviceIDs := range want {
|
||||
gotDeviceIDs := got[userID]
|
||||
sort.Strings(wantDeviceIDs)
|
||||
sort.Strings(gotDeviceIDs)
|
||||
if !reflect.DeepEqual(gotDeviceIDs, wantDeviceIDs) {
|
||||
t.Errorf("user %v got devices %v want %v", userID, gotDeviceIDs, wantDeviceIDs)
|
||||
}
|
||||
}
|
||||
}
|
@ -4,11 +4,13 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/matrix-org/sliding-sync/sqlutil"
|
||||
"hash/fnv"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/matrix-org/sliding-sync/sqlutil"
|
||||
|
||||
"github.com/getsentry/sentry-go"
|
||||
|
||||
@ -43,13 +45,15 @@ type Handler struct {
|
||||
// room_id => fnv_hash([typing user ids])
|
||||
typingMap map[string]uint64
|
||||
|
||||
deviceDataTicker *sync2.DeviceDataTicker
|
||||
|
||||
numPollers prometheus.Gauge
|
||||
subSystem string
|
||||
}
|
||||
|
||||
func NewHandler(
|
||||
connStr string, pMap *sync2.PollerMap, v2Store *sync2.Storage, store *state.Storage, client sync2.Client,
|
||||
pub pubsub.Notifier, sub pubsub.Listener, enablePrometheus bool,
|
||||
pub pubsub.Notifier, sub pubsub.Listener, enablePrometheus bool, deviceDataUpdateDuration time.Duration,
|
||||
) (*Handler, error) {
|
||||
h := &Handler{
|
||||
pMap: pMap,
|
||||
@ -61,7 +65,8 @@ func NewHandler(
|
||||
Highlight int
|
||||
Notif int
|
||||
}),
|
||||
typingMap: make(map[string]uint64),
|
||||
typingMap: make(map[string]uint64),
|
||||
deviceDataTicker: sync2.NewDeviceDataTicker(deviceDataUpdateDuration),
|
||||
}
|
||||
|
||||
if enablePrometheus {
|
||||
@ -86,6 +91,8 @@ func (h *Handler) Listen() {
|
||||
sentry.CaptureException(err)
|
||||
}
|
||||
}()
|
||||
h.deviceDataTicker.SetCallback(h.OnBulkDeviceDataUpdate)
|
||||
go h.deviceDataTicker.Run()
|
||||
}
|
||||
|
||||
func (h *Handler) Teardown() {
|
||||
@ -95,6 +102,7 @@ func (h *Handler) Teardown() {
|
||||
h.Store.Teardown()
|
||||
h.v2Store.Teardown()
|
||||
h.pMap.Terminate()
|
||||
h.deviceDataTicker.Stop()
|
||||
if h.numPollers != nil {
|
||||
prometheus.Unregister(h.numPollers)
|
||||
}
|
||||
@ -203,19 +211,24 @@ func (h *Handler) OnE2EEData(ctx context.Context, userID, deviceID string, otkCo
|
||||
New: deviceListChanges,
|
||||
},
|
||||
}
|
||||
nextPos, err := h.Store.DeviceDataTable.Upsert(&partialDD)
|
||||
_, err := h.Store.DeviceDataTable.Upsert(&partialDD)
|
||||
if err != nil {
|
||||
logger.Err(err).Str("user", userID).Msg("failed to upsert device data")
|
||||
internal.GetSentryHubFromContextOrDefault(ctx).CaptureException(err)
|
||||
return
|
||||
}
|
||||
h.v2Pub.Notify(pubsub.ChanV2, &pubsub.V2DeviceData{
|
||||
// remember this to notify on pubsub later
|
||||
h.deviceDataTicker.Remember(sync2.PollerID{
|
||||
UserID: userID,
|
||||
DeviceID: deviceID,
|
||||
Pos: nextPos,
|
||||
})
|
||||
}
|
||||
|
||||
// Called periodically by deviceDataTicker, contains many updates
|
||||
func (h *Handler) OnBulkDeviceDataUpdate(payload *pubsub.V2DeviceData) {
|
||||
h.v2Pub.Notify(pubsub.ChanV2, payload)
|
||||
}
|
||||
|
||||
func (h *Handler) Accumulate(ctx context.Context, userID, deviceID, roomID, prevBatch string, timeline []json.RawMessage) {
|
||||
// Remember any transaction IDs that may be unique to this user
|
||||
eventIDsWithTxns := make([]string, 0, len(timeline)) // in timeline order
|
||||
|
@ -664,9 +664,14 @@ func (h *SyncLiveHandler) OnUnreadCounts(p *pubsub.V2UnreadCounts) {
|
||||
func (h *SyncLiveHandler) OnDeviceData(p *pubsub.V2DeviceData) {
|
||||
ctx, task := internal.StartTask(context.Background(), "OnDeviceData")
|
||||
defer task.End()
|
||||
conns := h.ConnMap.Conns(p.UserID, p.DeviceID)
|
||||
for _, conn := range conns {
|
||||
conn.OnUpdate(ctx, caches.DeviceDataUpdate{})
|
||||
internal.Logf(ctx, "device_data", fmt.Sprintf("%v users to notify", len(p.UserIDToDeviceIDs)))
|
||||
for userID, deviceIDs := range p.UserIDToDeviceIDs {
|
||||
for _, deviceID := range deviceIDs {
|
||||
conns := h.ConnMap.Conns(userID, deviceID)
|
||||
for _, conn := range conns {
|
||||
conn.OnUpdate(ctx, caches.DeviceDataUpdate{})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
4
v3.go
4
v3.go
@ -76,8 +76,10 @@ func Setup(destHomeserver, postgresURI, secret string, opts Opts) (*handler2.Han
|
||||
store := state.NewStorage(postgresURI)
|
||||
storev2 := sync2.NewStore(postgresURI, secret)
|
||||
bufferSize := 50
|
||||
deviceDataUpdateFrequency := time.Second
|
||||
if opts.TestingSynchronousPubsub {
|
||||
bufferSize = 0
|
||||
deviceDataUpdateFrequency = 0 // don't batch
|
||||
}
|
||||
if opts.MaxPendingEventUpdates == 0 {
|
||||
opts.MaxPendingEventUpdates = 2000
|
||||
@ -86,7 +88,7 @@ func Setup(destHomeserver, postgresURI, secret string, opts Opts) (*handler2.Han
|
||||
|
||||
pMap := sync2.NewPollerMap(v2Client, opts.AddPrometheusMetrics)
|
||||
// create v2 handler
|
||||
h2, err := handler2.NewHandler(postgresURI, pMap, storev2, store, v2Client, pubSub, pubSub, opts.AddPrometheusMetrics)
|
||||
h2, err := handler2.NewHandler(postgresURI, pMap, storev2, store, v2Client, pubSub, pubSub, opts.AddPrometheusMetrics, deviceDataUpdateFrequency)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user