mirror of
https://github.com/matrix-org/sliding-sync.git
synced 2025-03-10 13:37:11 +00:00
Rate limit pubsub.V2DeviceData updates to be at most 1 per second
The db writes are still instant, but the notifications are now delayed by up to 1 second, in order to not swamp the pubsub channels.
This commit is contained in:
parent
a78612e64a
commit
f36c038cf8
@ -91,9 +91,7 @@ type V2InitialSyncComplete struct {
|
|||||||
func (*V2InitialSyncComplete) Type() string { return "V2InitialSyncComplete" }
|
func (*V2InitialSyncComplete) Type() string { return "V2InitialSyncComplete" }
|
||||||
|
|
||||||
type V2DeviceData struct {
|
type V2DeviceData struct {
|
||||||
UserID string
|
UserIDToDeviceIDs map[string][]string
|
||||||
DeviceID string
|
|
||||||
Pos int64
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (*V2DeviceData) Type() string { return "V2DeviceData" }
|
func (*V2DeviceData) Type() string { return "V2DeviceData" }
|
||||||
|
90
sync2/device_data_ticker.go
Normal file
90
sync2/device_data_ticker.go
Normal file
@ -0,0 +1,90 @@
|
|||||||
|
package sync2
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/matrix-org/sliding-sync/pubsub"
|
||||||
|
)
|
||||||
|
|
||||||
|
// This struct remembers user+device IDs to notify for then periodically
|
||||||
|
// emits them all to the caller. Use to rate limit the frequency of device list
|
||||||
|
// updates.
|
||||||
|
type DeviceDataTicker struct {
|
||||||
|
// data structures to periodically notify downstream about device data updates
|
||||||
|
// The ticker controls the frequency of updates. The done channel is used to stop ticking
|
||||||
|
// and clean up the goroutine. The notify map contains the values to notify for.
|
||||||
|
ticker *time.Ticker
|
||||||
|
done chan struct{}
|
||||||
|
notifyMap *sync.Map // map of PollerID to bools, unwrapped when notifying
|
||||||
|
fn func(payload *pubsub.V2DeviceData)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a new device data ticker, which batches calls to Remember and invokes a callback every
|
||||||
|
// d duration. If d is 0, no batching is performed and the callback is invoked synchronously, which
|
||||||
|
// is useful for testing.
|
||||||
|
func NewDeviceDataTicker(d time.Duration) *DeviceDataTicker {
|
||||||
|
ddt := &DeviceDataTicker{
|
||||||
|
done: make(chan struct{}),
|
||||||
|
notifyMap: &sync.Map{},
|
||||||
|
}
|
||||||
|
if d != 0 {
|
||||||
|
ddt.ticker = time.NewTicker(d)
|
||||||
|
}
|
||||||
|
return ddt
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop ticking.
|
||||||
|
func (t *DeviceDataTicker) Stop() {
|
||||||
|
if t.ticker != nil {
|
||||||
|
t.ticker.Stop()
|
||||||
|
}
|
||||||
|
close(t.done)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set the function which should be called when the tick happens.
|
||||||
|
func (t *DeviceDataTicker) SetCallback(fn func(payload *pubsub.V2DeviceData)) {
|
||||||
|
t.fn = fn
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remember this user/device ID, and emit it later on.
|
||||||
|
func (t *DeviceDataTicker) Remember(pid PollerID) {
|
||||||
|
t.notifyMap.Store(pid, true)
|
||||||
|
if t.ticker == nil {
|
||||||
|
t.emitUpdate()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *DeviceDataTicker) emitUpdate() {
|
||||||
|
var p pubsub.V2DeviceData
|
||||||
|
p.UserIDToDeviceIDs = make(map[string][]string)
|
||||||
|
// populate the pubsub payload
|
||||||
|
t.notifyMap.Range(func(key, value any) bool {
|
||||||
|
pid := key.(PollerID)
|
||||||
|
devices := p.UserIDToDeviceIDs[pid.UserID]
|
||||||
|
devices = append(devices, pid.DeviceID)
|
||||||
|
p.UserIDToDeviceIDs[pid.UserID] = devices
|
||||||
|
// clear the map of this value
|
||||||
|
t.notifyMap.Delete(key)
|
||||||
|
return true // keep enumerating
|
||||||
|
})
|
||||||
|
// notify if we have entries
|
||||||
|
if len(p.UserIDToDeviceIDs) > 0 {
|
||||||
|
t.fn(&p)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Blocks forever, ticking until Stop() is called.
|
||||||
|
func (t *DeviceDataTicker) Run() {
|
||||||
|
if t.ticker == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-t.done:
|
||||||
|
return
|
||||||
|
case <-t.ticker.C:
|
||||||
|
t.emitUpdate()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
124
sync2/device_data_ticker_test.go
Normal file
124
sync2/device_data_ticker_test.go
Normal file
@ -0,0 +1,124 @@
|
|||||||
|
package sync2
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
"sort"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/matrix-org/sliding-sync/pubsub"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDeviceTickerBasic(t *testing.T) {
|
||||||
|
duration := time.Millisecond
|
||||||
|
ticker := NewDeviceDataTicker(duration)
|
||||||
|
var payloads []*pubsub.V2DeviceData
|
||||||
|
ticker.SetCallback(func(payload *pubsub.V2DeviceData) {
|
||||||
|
payloads = append(payloads, payload)
|
||||||
|
})
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
t.Log("starting the ticker")
|
||||||
|
ticker.Run()
|
||||||
|
wg.Done()
|
||||||
|
}()
|
||||||
|
time.Sleep(duration * 2) // wait until the ticker is consuming
|
||||||
|
t.Log("remembering a poller")
|
||||||
|
ticker.Remember(PollerID{
|
||||||
|
UserID: "a",
|
||||||
|
DeviceID: "b",
|
||||||
|
})
|
||||||
|
time.Sleep(duration * 2)
|
||||||
|
if len(payloads) != 1 {
|
||||||
|
t.Fatalf("expected 1 callback, got %d", len(payloads))
|
||||||
|
}
|
||||||
|
want := map[string][]string{
|
||||||
|
"a": {"b"},
|
||||||
|
}
|
||||||
|
assertPayloadEqual(t, payloads[0].UserIDToDeviceIDs, want)
|
||||||
|
// check stopping works
|
||||||
|
payloads = []*pubsub.V2DeviceData{}
|
||||||
|
ticker.Stop()
|
||||||
|
wg.Wait()
|
||||||
|
time.Sleep(duration * 2)
|
||||||
|
if len(payloads) != 0 {
|
||||||
|
t.Fatalf("got extra payloads: %+v", payloads)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeviceTickerBatchesCorrectly(t *testing.T) {
|
||||||
|
duration := 100 * time.Millisecond
|
||||||
|
ticker := NewDeviceDataTicker(duration)
|
||||||
|
var payloads []*pubsub.V2DeviceData
|
||||||
|
ticker.SetCallback(func(payload *pubsub.V2DeviceData) {
|
||||||
|
payloads = append(payloads, payload)
|
||||||
|
})
|
||||||
|
go ticker.Run()
|
||||||
|
defer ticker.Stop()
|
||||||
|
ticker.Remember(PollerID{
|
||||||
|
UserID: "a",
|
||||||
|
DeviceID: "b",
|
||||||
|
})
|
||||||
|
ticker.Remember(PollerID{
|
||||||
|
UserID: "a",
|
||||||
|
DeviceID: "bb", // different device, same user
|
||||||
|
})
|
||||||
|
ticker.Remember(PollerID{
|
||||||
|
UserID: "a",
|
||||||
|
DeviceID: "b", // dupe poller ID
|
||||||
|
})
|
||||||
|
ticker.Remember(PollerID{
|
||||||
|
UserID: "x",
|
||||||
|
DeviceID: "y", // new device and user
|
||||||
|
})
|
||||||
|
time.Sleep(duration * 2)
|
||||||
|
if len(payloads) != 1 {
|
||||||
|
t.Fatalf("expected 1 callback, got %d", len(payloads))
|
||||||
|
}
|
||||||
|
want := map[string][]string{
|
||||||
|
"a": {"b", "bb"},
|
||||||
|
"x": {"y"},
|
||||||
|
}
|
||||||
|
assertPayloadEqual(t, payloads[0].UserIDToDeviceIDs, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeviceTickerForgetsAfterEmitting(t *testing.T) {
|
||||||
|
duration := time.Millisecond
|
||||||
|
ticker := NewDeviceDataTicker(duration)
|
||||||
|
var payloads []*pubsub.V2DeviceData
|
||||||
|
ticker.SetCallback(func(payload *pubsub.V2DeviceData) {
|
||||||
|
payloads = append(payloads, payload)
|
||||||
|
})
|
||||||
|
ticker.Remember(PollerID{
|
||||||
|
UserID: "a",
|
||||||
|
DeviceID: "b",
|
||||||
|
})
|
||||||
|
|
||||||
|
go ticker.Run()
|
||||||
|
defer ticker.Stop()
|
||||||
|
ticker.Remember(PollerID{
|
||||||
|
UserID: "a",
|
||||||
|
DeviceID: "b",
|
||||||
|
})
|
||||||
|
time.Sleep(10 * duration)
|
||||||
|
if len(payloads) != 1 {
|
||||||
|
t.Fatalf("got %d payloads, want 1", len(payloads))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func assertPayloadEqual(t *testing.T, got, want map[string][]string) {
|
||||||
|
t.Helper()
|
||||||
|
if len(got) != len(want) {
|
||||||
|
t.Fatalf("got %+v\nwant %+v\n", got, want)
|
||||||
|
}
|
||||||
|
for userID, wantDeviceIDs := range want {
|
||||||
|
gotDeviceIDs := got[userID]
|
||||||
|
sort.Strings(wantDeviceIDs)
|
||||||
|
sort.Strings(gotDeviceIDs)
|
||||||
|
if !reflect.DeepEqual(gotDeviceIDs, wantDeviceIDs) {
|
||||||
|
t.Errorf("user %v got devices %v want %v", userID, gotDeviceIDs, wantDeviceIDs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -4,11 +4,13 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/jmoiron/sqlx"
|
|
||||||
"github.com/matrix-org/sliding-sync/sqlutil"
|
|
||||||
"hash/fnv"
|
"hash/fnv"
|
||||||
"os"
|
"os"
|
||||||
"sync"
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/jmoiron/sqlx"
|
||||||
|
"github.com/matrix-org/sliding-sync/sqlutil"
|
||||||
|
|
||||||
"github.com/getsentry/sentry-go"
|
"github.com/getsentry/sentry-go"
|
||||||
|
|
||||||
@ -43,13 +45,15 @@ type Handler struct {
|
|||||||
// room_id => fnv_hash([typing user ids])
|
// room_id => fnv_hash([typing user ids])
|
||||||
typingMap map[string]uint64
|
typingMap map[string]uint64
|
||||||
|
|
||||||
|
deviceDataTicker *sync2.DeviceDataTicker
|
||||||
|
|
||||||
numPollers prometheus.Gauge
|
numPollers prometheus.Gauge
|
||||||
subSystem string
|
subSystem string
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewHandler(
|
func NewHandler(
|
||||||
connStr string, pMap *sync2.PollerMap, v2Store *sync2.Storage, store *state.Storage, client sync2.Client,
|
connStr string, pMap *sync2.PollerMap, v2Store *sync2.Storage, store *state.Storage, client sync2.Client,
|
||||||
pub pubsub.Notifier, sub pubsub.Listener, enablePrometheus bool,
|
pub pubsub.Notifier, sub pubsub.Listener, enablePrometheus bool, deviceDataUpdateDuration time.Duration,
|
||||||
) (*Handler, error) {
|
) (*Handler, error) {
|
||||||
h := &Handler{
|
h := &Handler{
|
||||||
pMap: pMap,
|
pMap: pMap,
|
||||||
@ -61,7 +65,8 @@ func NewHandler(
|
|||||||
Highlight int
|
Highlight int
|
||||||
Notif int
|
Notif int
|
||||||
}),
|
}),
|
||||||
typingMap: make(map[string]uint64),
|
typingMap: make(map[string]uint64),
|
||||||
|
deviceDataTicker: sync2.NewDeviceDataTicker(deviceDataUpdateDuration),
|
||||||
}
|
}
|
||||||
|
|
||||||
if enablePrometheus {
|
if enablePrometheus {
|
||||||
@ -86,6 +91,8 @@ func (h *Handler) Listen() {
|
|||||||
sentry.CaptureException(err)
|
sentry.CaptureException(err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
h.deviceDataTicker.SetCallback(h.OnBulkDeviceDataUpdate)
|
||||||
|
go h.deviceDataTicker.Run()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) Teardown() {
|
func (h *Handler) Teardown() {
|
||||||
@ -95,6 +102,7 @@ func (h *Handler) Teardown() {
|
|||||||
h.Store.Teardown()
|
h.Store.Teardown()
|
||||||
h.v2Store.Teardown()
|
h.v2Store.Teardown()
|
||||||
h.pMap.Terminate()
|
h.pMap.Terminate()
|
||||||
|
h.deviceDataTicker.Stop()
|
||||||
if h.numPollers != nil {
|
if h.numPollers != nil {
|
||||||
prometheus.Unregister(h.numPollers)
|
prometheus.Unregister(h.numPollers)
|
||||||
}
|
}
|
||||||
@ -203,19 +211,24 @@ func (h *Handler) OnE2EEData(ctx context.Context, userID, deviceID string, otkCo
|
|||||||
New: deviceListChanges,
|
New: deviceListChanges,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
nextPos, err := h.Store.DeviceDataTable.Upsert(&partialDD)
|
_, err := h.Store.DeviceDataTable.Upsert(&partialDD)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Err(err).Str("user", userID).Msg("failed to upsert device data")
|
logger.Err(err).Str("user", userID).Msg("failed to upsert device data")
|
||||||
internal.GetSentryHubFromContextOrDefault(ctx).CaptureException(err)
|
internal.GetSentryHubFromContextOrDefault(ctx).CaptureException(err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
h.v2Pub.Notify(pubsub.ChanV2, &pubsub.V2DeviceData{
|
// remember this to notify on pubsub later
|
||||||
|
h.deviceDataTicker.Remember(sync2.PollerID{
|
||||||
UserID: userID,
|
UserID: userID,
|
||||||
DeviceID: deviceID,
|
DeviceID: deviceID,
|
||||||
Pos: nextPos,
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Called periodically by deviceDataTicker, contains many updates
|
||||||
|
func (h *Handler) OnBulkDeviceDataUpdate(payload *pubsub.V2DeviceData) {
|
||||||
|
h.v2Pub.Notify(pubsub.ChanV2, payload)
|
||||||
|
}
|
||||||
|
|
||||||
func (h *Handler) Accumulate(ctx context.Context, userID, deviceID, roomID, prevBatch string, timeline []json.RawMessage) {
|
func (h *Handler) Accumulate(ctx context.Context, userID, deviceID, roomID, prevBatch string, timeline []json.RawMessage) {
|
||||||
// Remember any transaction IDs that may be unique to this user
|
// Remember any transaction IDs that may be unique to this user
|
||||||
eventIDsWithTxns := make([]string, 0, len(timeline)) // in timeline order
|
eventIDsWithTxns := make([]string, 0, len(timeline)) // in timeline order
|
||||||
|
@ -664,9 +664,14 @@ func (h *SyncLiveHandler) OnUnreadCounts(p *pubsub.V2UnreadCounts) {
|
|||||||
func (h *SyncLiveHandler) OnDeviceData(p *pubsub.V2DeviceData) {
|
func (h *SyncLiveHandler) OnDeviceData(p *pubsub.V2DeviceData) {
|
||||||
ctx, task := internal.StartTask(context.Background(), "OnDeviceData")
|
ctx, task := internal.StartTask(context.Background(), "OnDeviceData")
|
||||||
defer task.End()
|
defer task.End()
|
||||||
conns := h.ConnMap.Conns(p.UserID, p.DeviceID)
|
internal.Logf(ctx, "device_data", fmt.Sprintf("%v users to notify", len(p.UserIDToDeviceIDs)))
|
||||||
for _, conn := range conns {
|
for userID, deviceIDs := range p.UserIDToDeviceIDs {
|
||||||
conn.OnUpdate(ctx, caches.DeviceDataUpdate{})
|
for _, deviceID := range deviceIDs {
|
||||||
|
conns := h.ConnMap.Conns(userID, deviceID)
|
||||||
|
for _, conn := range conns {
|
||||||
|
conn.OnUpdate(ctx, caches.DeviceDataUpdate{})
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
4
v3.go
4
v3.go
@ -76,8 +76,10 @@ func Setup(destHomeserver, postgresURI, secret string, opts Opts) (*handler2.Han
|
|||||||
store := state.NewStorage(postgresURI)
|
store := state.NewStorage(postgresURI)
|
||||||
storev2 := sync2.NewStore(postgresURI, secret)
|
storev2 := sync2.NewStore(postgresURI, secret)
|
||||||
bufferSize := 50
|
bufferSize := 50
|
||||||
|
deviceDataUpdateFrequency := time.Second
|
||||||
if opts.TestingSynchronousPubsub {
|
if opts.TestingSynchronousPubsub {
|
||||||
bufferSize = 0
|
bufferSize = 0
|
||||||
|
deviceDataUpdateFrequency = 0 // don't batch
|
||||||
}
|
}
|
||||||
if opts.MaxPendingEventUpdates == 0 {
|
if opts.MaxPendingEventUpdates == 0 {
|
||||||
opts.MaxPendingEventUpdates = 2000
|
opts.MaxPendingEventUpdates = 2000
|
||||||
@ -86,7 +88,7 @@ func Setup(destHomeserver, postgresURI, secret string, opts Opts) (*handler2.Han
|
|||||||
|
|
||||||
pMap := sync2.NewPollerMap(v2Client, opts.AddPrometheusMetrics)
|
pMap := sync2.NewPollerMap(v2Client, opts.AddPrometheusMetrics)
|
||||||
// create v2 handler
|
// create v2 handler
|
||||||
h2, err := handler2.NewHandler(postgresURI, pMap, storev2, store, v2Client, pubSub, pubSub, opts.AddPrometheusMetrics)
|
h2, err := handler2.NewHandler(postgresURI, pMap, storev2, store, v2Client, pubSub, pubSub, opts.AddPrometheusMetrics, deviceDataUpdateFrequency)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user