mirror of
https://github.com/matrix-org/sliding-sync.git
synced 2025-03-10 13:37:11 +00:00
Merge pull request #242 from matrix-org/dmr/purge-inactive-pollers
This commit is contained in:
commit
ff7120245a
@ -3,6 +3,7 @@ package sync2
|
||||
import (
|
||||
"github.com/jmoiron/sqlx"
|
||||
_ "github.com/lib/pq"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Device struct {
|
||||
@ -45,3 +46,22 @@ func (t *DevicesTable) UpdateDeviceSince(userID, deviceID, since string) error {
|
||||
_, err := t.db.Exec(`UPDATE syncv3_sync2_devices SET since = $1 WHERE user_id = $2 AND device_id = $3`, since, userID, deviceID)
|
||||
return err
|
||||
}
|
||||
|
||||
// FindOldDevices fetches the user_id and device_id of all devices which haven't /synced
|
||||
// for at least as long as the given inactivityPeriod. Such devices are returned in
|
||||
// no particular order.
|
||||
//
|
||||
// This is determined using the syncv3_sync2_tokens.last_seen column, which is updated
|
||||
// at most once per day to save DB throughtput (see TokensTable.MaybeUpdateLastSeen).
|
||||
// The caller should therefore use an inactivityPeriod of at least two days to avoid
|
||||
// considering a recently-used device as old.
|
||||
func (t *DevicesTable) FindOldDevices(inactivityPeriod time.Duration) (devices []Device, err error) {
|
||||
err = t.db.Select(&devices, `
|
||||
SELECT user_id, device_id
|
||||
FROM syncv3_sync2_devices JOIN syncv3_sync2_tokens USING(user_id, device_id)
|
||||
GROUP BY (user_id, device_id)
|
||||
HAVING MAX(last_seen) < $1
|
||||
`, time.Now().Add(-inactivityPeriod),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
@ -1,9 +1,11 @@
|
||||
package sync2
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/matrix-org/sliding-sync/sqlutil"
|
||||
"os"
|
||||
"reflect"
|
||||
"sort"
|
||||
"testing"
|
||||
"time"
|
||||
@ -100,7 +102,7 @@ func TestTokenForEachDevice(t *testing.T) {
|
||||
|
||||
// HACK: discard rows inserted by other tests. We don't normally need to do this,
|
||||
// but this is testing a query that scans the entire devices table.
|
||||
db.MustExec("TRUNCATE syncv3_sync2_devices, syncv3_sync2_tokens;")
|
||||
db.Exec("TRUNCATE syncv3_sync2_devices, syncv3_sync2_tokens;")
|
||||
|
||||
tokens := NewTokensTable(db, "my_secret")
|
||||
devices := NewDevicesTable(db)
|
||||
@ -184,3 +186,69 @@ func TestTokenForEachDevice(t *testing.T) {
|
||||
assertEqual(t, gotTokens[i].AccessToken, wantTokens[i].AccessToken, "Token.AccessToken mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDevicesTable_FindOldDevices(t *testing.T) {
|
||||
db, close := connectToDB(t)
|
||||
defer close()
|
||||
|
||||
// HACK: discard rows inserted by other tests. We don't normally need to do this,
|
||||
// but this is testing a query that scans the entire devices table.
|
||||
db.Exec("TRUNCATE syncv3_sync2_devices, syncv3_sync2_tokens;")
|
||||
|
||||
tokens := NewTokensTable(db, "my_secret")
|
||||
devices := NewDevicesTable(db)
|
||||
|
||||
tcs := []struct {
|
||||
UserID string
|
||||
DeviceID string
|
||||
tokenAges []time.Duration
|
||||
}{
|
||||
{UserID: "@alice:test", DeviceID: "no_tokens", tokenAges: nil},
|
||||
{UserID: "@bob:test", DeviceID: "one_active_token", tokenAges: []time.Duration{time.Hour}},
|
||||
{UserID: "@bob:test", DeviceID: "one_old_token", tokenAges: []time.Duration{7 * 24 * time.Hour}},
|
||||
{UserID: "@chris:test", DeviceID: "one_old_one_active", tokenAges: []time.Duration{time.Hour, 7 * 24 * time.Hour}},
|
||||
{UserID: "@delia:test", DeviceID: "two_old_tokens", tokenAges: []time.Duration{7 * 24 * time.Hour, 14 * 24 * time.Hour}},
|
||||
}
|
||||
|
||||
txn, err := db.Beginx()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
numTokens := 0
|
||||
for _, tc := range tcs {
|
||||
err = devices.InsertDevice(txn, tc.UserID, tc.DeviceID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
for _, age := range tc.tokenAges {
|
||||
numTokens++
|
||||
_, err = tokens.Insert(
|
||||
txn,
|
||||
fmt.Sprintf("token-%d", numTokens),
|
||||
tc.UserID,
|
||||
tc.DeviceID,
|
||||
time.Now().Add(-age),
|
||||
)
|
||||
}
|
||||
}
|
||||
err = txn.Commit()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
oldDevices, err := devices.FindOldDevices(24 * time.Hour)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
sort.Slice(oldDevices, func(i, j int) bool {
|
||||
return oldDevices[i].UserID < oldDevices[j].UserID
|
||||
})
|
||||
expectedDevices := []Device{
|
||||
{UserID: "@bob:test", DeviceID: "one_old_token"},
|
||||
{UserID: "@delia:test", DeviceID: "two_old_tokens"},
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(oldDevices, expectedDevices) {
|
||||
t.Errorf("Got %+v, but expected %v+", oldDevices, expectedDevices)
|
||||
}
|
||||
}
|
||||
|
@ -48,8 +48,9 @@ type Handler struct {
|
||||
typingMu *sync.Mutex
|
||||
PendingTxnIDs *sync2.PendingTransactionIDs
|
||||
|
||||
deviceDataTicker *sync2.DeviceDataTicker
|
||||
e2eeWorkerPool *internal.WorkerPool
|
||||
deviceDataTicker *sync2.DeviceDataTicker
|
||||
pollerExpiryTicker *time.Ticker
|
||||
e2eeWorkerPool *internal.WorkerPool
|
||||
|
||||
numPollers prometheus.Gauge
|
||||
subSystem string
|
||||
@ -111,6 +112,9 @@ func (h *Handler) Teardown() {
|
||||
h.v2Store.Teardown()
|
||||
h.pMap.Terminate()
|
||||
h.deviceDataTicker.Stop()
|
||||
if h.pollerExpiryTicker != nil {
|
||||
h.pollerExpiryTicker.Stop()
|
||||
}
|
||||
if h.numPollers != nil {
|
||||
prometheus.Unregister(h.numPollers)
|
||||
}
|
||||
@ -163,6 +167,7 @@ func (h *Handler) StartV2Pollers() {
|
||||
wg.Wait()
|
||||
logger.Info().Msg("StartV2Pollers finished")
|
||||
h.updateMetrics()
|
||||
h.startPollerExpiryTicker()
|
||||
}
|
||||
|
||||
func (h *Handler) updateMetrics() {
|
||||
@ -567,6 +572,40 @@ func (h *Handler) EnsurePolling(p *pubsub.V3EnsurePolling) {
|
||||
}()
|
||||
}
|
||||
|
||||
func (h *Handler) startPollerExpiryTicker() {
|
||||
if h.pollerExpiryTicker != nil {
|
||||
return
|
||||
}
|
||||
h.pollerExpiryTicker = time.NewTicker(time.Hour)
|
||||
go func() {
|
||||
for range h.pollerExpiryTicker.C {
|
||||
h.ExpireOldPollers()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// ExpireOldPollers looks for pollers whose devices have not made a sliding sync query
|
||||
// in the last 30 days, and asks the poller map to expire their corresponding pollers.
|
||||
// This function does not normally need to be called manually (StartV2Pollers queues it
|
||||
// up to run hourly); we expose it publicly only for testing purposes.
|
||||
func (h *Handler) ExpireOldPollers() {
|
||||
devices, err := h.v2Store.DevicesTable.FindOldDevices(30 * 24 * time.Hour)
|
||||
if err != nil {
|
||||
logger.Err(err).Msg("Error fetching old devices")
|
||||
sentry.CaptureException(err)
|
||||
return
|
||||
}
|
||||
pids := make([]sync2.PollerID, len(devices))
|
||||
for i := range devices {
|
||||
pids[i].UserID = devices[i].UserID
|
||||
pids[i].DeviceID = devices[i].DeviceID
|
||||
}
|
||||
numExpired := h.pMap.ExpirePollers(pids)
|
||||
if len(devices) > 0 {
|
||||
logger.Info().Int("old", len(devices)).Int("expired", numExpired).Msg("poller cleanup old devices")
|
||||
}
|
||||
}
|
||||
|
||||
func fnvHash(event json.RawMessage) uint64 {
|
||||
h := fnv.New64a()
|
||||
h.Write(event)
|
||||
|
@ -48,13 +48,18 @@ func (p *mockPollerMap) DeviceIDs(userID string) []string {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *mockPollerMap) EnsurePolling(pid sync2.PollerID, accessToken, v2since string, isStartup bool, logger zerolog.Logger) {
|
||||
func (p *mockPollerMap) ExpirePollers([]sync2.PollerID) int {
|
||||
return 0
|
||||
}
|
||||
|
||||
func (p *mockPollerMap) EnsurePolling(pid sync2.PollerID, accessToken, v2since string, isStartup bool, logger zerolog.Logger) bool {
|
||||
p.calls = append(p.calls, pollInfo{
|
||||
pid: pid,
|
||||
accessToken: accessToken,
|
||||
v2since: v2since,
|
||||
isStartup: isStartup,
|
||||
})
|
||||
return false
|
||||
}
|
||||
|
||||
func (p *mockPollerMap) assertCallExists(t *testing.T, pi pollInfo) {
|
||||
|
@ -68,10 +68,13 @@ type V2DataReceiver interface {
|
||||
}
|
||||
|
||||
type IPollerMap interface {
|
||||
EnsurePolling(pid PollerID, accessToken, v2since string, isStartup bool, logger zerolog.Logger)
|
||||
EnsurePolling(pid PollerID, accessToken, v2since string, isStartup bool, logger zerolog.Logger) (created bool)
|
||||
NumPollers() int
|
||||
Terminate()
|
||||
DeviceIDs(userID string) []string
|
||||
// ExpirePollers requests that the given pollers are terminated as if their access
|
||||
// tokens had expired. Returns the number of pollers successfully terminated.
|
||||
ExpirePollers(ids []PollerID) int
|
||||
}
|
||||
|
||||
// PollerMap is a map of device ID to Poller
|
||||
@ -217,6 +220,24 @@ func (h *PollerMap) DeviceIDs(userID string) []string {
|
||||
return devices
|
||||
}
|
||||
|
||||
func (h *PollerMap) ExpirePollers(pids []PollerID) int {
|
||||
h.pollerMu.Lock()
|
||||
defer h.pollerMu.Unlock()
|
||||
numTerminated := 0
|
||||
for _, pid := range pids {
|
||||
p, ok := h.Pollers[pid]
|
||||
if !ok || p.terminated.Load() {
|
||||
continue
|
||||
}
|
||||
p.Terminate()
|
||||
// Ensure that we won't recreate this poller on startup. If it reappears later,
|
||||
// we'll make another EnsurePolling call which will recreate the poller.
|
||||
h.callbacks.OnExpiredToken(context.Background(), hashToken(p.accessToken), p.userID, p.deviceID)
|
||||
numTerminated++
|
||||
}
|
||||
return numTerminated
|
||||
}
|
||||
|
||||
// EnsurePolling makes sure there is a poller for this device, making one if need be.
|
||||
// Blocks until at least 1 sync is done if and only if the poller was just created.
|
||||
// This ensures that calls to the database will return data.
|
||||
@ -224,7 +245,7 @@ func (h *PollerMap) DeviceIDs(userID string) []string {
|
||||
// Note that we will immediately return if there is a poller for the same user but a different device.
|
||||
// We do this to allow for logins on clients to be snappy fast, even though they won't yet have the
|
||||
// to-device msgs to decrypt E2EE rooms.
|
||||
func (h *PollerMap) EnsurePolling(pid PollerID, accessToken, v2since string, isStartup bool, logger zerolog.Logger) {
|
||||
func (h *PollerMap) EnsurePolling(pid PollerID, accessToken, v2since string, isStartup bool, logger zerolog.Logger) bool {
|
||||
h.pollerMu.Lock()
|
||||
if !h.executorRunning {
|
||||
h.executorRunning = true
|
||||
@ -240,7 +261,7 @@ func (h *PollerMap) EnsurePolling(pid PollerID, accessToken, v2since string, isS
|
||||
// this existing poller may not have completed the initial sync yet, so we need to make sure
|
||||
// it has before we return.
|
||||
poller.WaitUntilInitialSync()
|
||||
return
|
||||
return false
|
||||
}
|
||||
// check if we need to wait at all: we don't need to if this user is already syncing on a different device
|
||||
// This is O(n) so we may want to map this if we get a lot of users...
|
||||
@ -274,6 +295,7 @@ func (h *PollerMap) EnsurePolling(pid PollerID, accessToken, v2since string, isS
|
||||
} else {
|
||||
logger.Info().Str("user", poller.userID).Msg("a poller exists for this user; not waiting for this device to do an initial sync")
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (h *PollerMap) execute() {
|
||||
|
@ -198,6 +198,72 @@ func TestPollerMapEnsurePollingIdempotent(t *testing.T) {
|
||||
t.Logf("EnsurePolling unblocked")
|
||||
}
|
||||
|
||||
func TestPollerMap_ExpirePollers(t *testing.T) {
|
||||
receiver, client := newMocks(func(authHeader, since string) (*SyncResponse, int, error) {
|
||||
r := SyncResponse{
|
||||
NextBatch: "batchy-mc-batchface",
|
||||
}
|
||||
return &r, 200, nil
|
||||
})
|
||||
pm := NewPollerMap(client, false)
|
||||
pm.SetCallbacks(receiver)
|
||||
|
||||
// Start 5 pollers.
|
||||
pollerSpecs := []struct {
|
||||
UserID string
|
||||
DeviceID string
|
||||
Token string
|
||||
}{
|
||||
{UserID: "alice", DeviceID: "a_device", Token: "a_token"},
|
||||
{UserID: "bob", DeviceID: "b_device2", Token: "b_token1"},
|
||||
{UserID: "bob", DeviceID: "b_device1", Token: "b_token2"},
|
||||
{UserID: "chris", DeviceID: "phone", Token: "c_token"},
|
||||
{UserID: "delia", DeviceID: "phone", Token: "d_token"},
|
||||
}
|
||||
|
||||
for _, spec := range pollerSpecs {
|
||||
created := pm.EnsurePolling(
|
||||
PollerID{UserID: spec.UserID, DeviceID: spec.DeviceID},
|
||||
spec.Token, "", true, logger,
|
||||
)
|
||||
if !created {
|
||||
t.Errorf("Poller for %v was not newly created", spec)
|
||||
}
|
||||
}
|
||||
|
||||
// Expire some of them. This tests that:
|
||||
pm.ExpirePollers([]PollerID{
|
||||
// - Easy mode: if you have one poller and ask it to be deleted, it is deleted.
|
||||
{"alice", "a_device"},
|
||||
// - If you have two devices and ask for one of their pollers to be expired,
|
||||
// only that poller is terminated.
|
||||
{"bob", "b_device1"},
|
||||
// - If there is a device ID clash, only the specified user's poller is expired.
|
||||
// I.e. Delia unaffected
|
||||
{"chris", "phone"},
|
||||
})
|
||||
|
||||
// Try to recreate each poller. EnsurePolling should only report having to create a
|
||||
// poller for the pollers we asked to be deleted.
|
||||
expectDeleted := []bool{
|
||||
true,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
}
|
||||
|
||||
for i, spec := range pollerSpecs {
|
||||
created := pm.EnsurePolling(
|
||||
PollerID{UserID: spec.UserID, DeviceID: spec.DeviceID},
|
||||
spec.Token, "", true, logger,
|
||||
)
|
||||
if created != expectDeleted[i] {
|
||||
t.Errorf("Poller #%d (%v): created=%t, expected %t", i, spec, created, expectDeleted[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check that a call to Poll starts polling and accumulating, and terminates on 401s.
|
||||
func TestPollerPollFromNothing(t *testing.T) {
|
||||
nextSince := "next"
|
||||
|
@ -3,7 +3,10 @@ package syncv3
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/matrix-org/sliding-sync/sqlutil"
|
||||
"net/http"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@ -311,3 +314,110 @@ func TestPollerUpdatesRoomMemberTrackerOnGappySyncStateBlock(t *testing.T) {
|
||||
m.MatchRoomSubscription(roomID, m.MatchRoomTimelineMostRecent(1, []json.RawMessage{bobLeave})),
|
||||
)
|
||||
}
|
||||
|
||||
func TestPollersCanBeResumedAfterExpiry(t *testing.T) {
|
||||
pqString := testutils.PrepareDBConnectionString()
|
||||
|
||||
// Start the mock sync v2 server and add a device for alice and for bob.
|
||||
v2 := runTestV2Server(t)
|
||||
defer v2.close()
|
||||
const aliceDevice = "alice_phone"
|
||||
const bobDevice = "bob_desktop"
|
||||
v2.addAccountWithDeviceID(alice, aliceDevice, aliceToken)
|
||||
v2.addAccountWithDeviceID(bob, bobDevice, bobToken)
|
||||
|
||||
// Queue up a sync v2 response for both Alice and Bob.
|
||||
v2.queueResponse(aliceToken, sync2.SyncResponse{NextBatch: "alice_response_1"})
|
||||
v2.queueResponse(bobToken, sync2.SyncResponse{NextBatch: "bob_response_1"})
|
||||
|
||||
// Inject an old token from Alice and a new token from Bob into the DB.
|
||||
v2Store := sync2.NewStore(pqString, os.Getenv("SYNCV3_SECRET"))
|
||||
err := sqlutil.WithTransaction(v2Store.DB, func(txn *sqlx.Tx) (err error) {
|
||||
err = v2Store.DevicesTable.InsertDevice(txn, alice, aliceDevice)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = v2Store.DevicesTable.InsertDevice(txn, bob, bobDevice)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, err = v2Store.TokensTable.Insert(txn, aliceToken, alice, aliceDevice, time.UnixMicro(0))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, err = v2Store.TokensTable.Insert(txn, bobToken, bob, bobDevice, time.Now())
|
||||
return
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("Start the v3 server and its pollers.")
|
||||
v3 := runTestServer(t, v2, pqString)
|
||||
go v3.h2.StartV2Pollers()
|
||||
defer v3.close()
|
||||
|
||||
t.Log("Alice's poller should be active.")
|
||||
v2.waitUntilEmpty(t, aliceToken)
|
||||
t.Log("Bob's poller should be active.")
|
||||
v2.waitUntilEmpty(t, bobToken)
|
||||
|
||||
t.Log("Manually trigger a poller cleanup.")
|
||||
v3.h2.ExpireOldPollers()
|
||||
|
||||
t.Log("Queue up a sync v2 response for both Alice and Bob. Alice's response includes account data.")
|
||||
accdata := testutils.NewAccountData(t, "dummytype", map[string]any{})
|
||||
v2.queueResponse(aliceToken, sync2.SyncResponse{
|
||||
NextBatch: "alice_response_2",
|
||||
AccountData: sync2.EventsResponse{
|
||||
Events: []json.RawMessage{
|
||||
accdata,
|
||||
},
|
||||
},
|
||||
})
|
||||
v2.queueResponse(bobToken, sync2.SyncResponse{NextBatch: "bob_response_2"})
|
||||
|
||||
t.Log("Wait for Bob's poller to poll")
|
||||
v2.waitUntilEmpty(t, bobToken)
|
||||
|
||||
// Alice's poller has likely already made an HTTP response. But her poller should
|
||||
// have been terminated before the request was received, so its since token
|
||||
// should not have been persisted to the DB.
|
||||
t.Log("Alice's since token in the DB should not have advanced.")
|
||||
// TODO: surprising that there isn't a function to get the since token for a device!
|
||||
var since string
|
||||
err = v2Store.DB.Get(&since, `SELECT since FROM syncv3_sync2_devices WHERE user_id = $1 AND device_id = $2`, alice, aliceDevice)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if since != "alice_response_1" {
|
||||
t.Errorf("Alice's sync token in DB was %s, expected alice_response_1", since)
|
||||
}
|
||||
|
||||
t.Log("Requeue the same response for Alice's restarted poller to consume.")
|
||||
v2.queueResponse(aliceToken, sync2.SyncResponse{
|
||||
NextBatch: "alice_response_2",
|
||||
AccountData: sync2.EventsResponse{
|
||||
Events: []json.RawMessage{
|
||||
accdata,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
t.Log("Alice makes a new sliding sync request")
|
||||
res := v3.mustDoV3Request(t, aliceToken, sync3.Request{
|
||||
Extensions: extensions.Request{
|
||||
AccountData: &extensions.AccountDataRequest{
|
||||
extensions.Core{
|
||||
Enabled: &boolTrue,
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
t.Log("Alice's poller should have been polled.")
|
||||
v2.waitUntilEmpty(t, aliceToken)
|
||||
|
||||
t.Log("Alice should see her account data")
|
||||
m.MatchResponse(t, res, m.MatchAccountData([]json.RawMessage{accdata}, nil))
|
||||
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user