Merge pull request #242 from matrix-org/dmr/purge-inactive-pollers

This commit is contained in:
David Robertson 2023-08-16 13:43:46 +01:00 committed by GitHub
commit ff7120245a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 337 additions and 7 deletions

View File

@ -3,6 +3,7 @@ package sync2
import (
"github.com/jmoiron/sqlx"
_ "github.com/lib/pq"
"time"
)
type Device struct {
@ -45,3 +46,22 @@ func (t *DevicesTable) UpdateDeviceSince(userID, deviceID, since string) error {
_, err := t.db.Exec(`UPDATE syncv3_sync2_devices SET since = $1 WHERE user_id = $2 AND device_id = $3`, since, userID, deviceID)
return err
}
// FindOldDevices fetches the user_id and device_id of all devices which haven't /synced
// for at least as long as the given inactivityPeriod. Such devices are returned in
// no particular order.
//
// This is determined using the syncv3_sync2_tokens.last_seen column, which is updated
// at most once per day to save DB throughtput (see TokensTable.MaybeUpdateLastSeen).
// The caller should therefore use an inactivityPeriod of at least two days to avoid
// considering a recently-used device as old.
func (t *DevicesTable) FindOldDevices(inactivityPeriod time.Duration) (devices []Device, err error) {
err = t.db.Select(&devices, `
SELECT user_id, device_id
FROM syncv3_sync2_devices JOIN syncv3_sync2_tokens USING(user_id, device_id)
GROUP BY (user_id, device_id)
HAVING MAX(last_seen) < $1
`, time.Now().Add(-inactivityPeriod),
)
return
}

View File

@ -1,9 +1,11 @@
package sync2
import (
"fmt"
"github.com/jmoiron/sqlx"
"github.com/matrix-org/sliding-sync/sqlutil"
"os"
"reflect"
"sort"
"testing"
"time"
@ -100,7 +102,7 @@ func TestTokenForEachDevice(t *testing.T) {
// HACK: discard rows inserted by other tests. We don't normally need to do this,
// but this is testing a query that scans the entire devices table.
db.MustExec("TRUNCATE syncv3_sync2_devices, syncv3_sync2_tokens;")
db.Exec("TRUNCATE syncv3_sync2_devices, syncv3_sync2_tokens;")
tokens := NewTokensTable(db, "my_secret")
devices := NewDevicesTable(db)
@ -184,3 +186,69 @@ func TestTokenForEachDevice(t *testing.T) {
assertEqual(t, gotTokens[i].AccessToken, wantTokens[i].AccessToken, "Token.AccessToken mismatch")
}
}
func TestDevicesTable_FindOldDevices(t *testing.T) {
db, close := connectToDB(t)
defer close()
// HACK: discard rows inserted by other tests. We don't normally need to do this,
// but this is testing a query that scans the entire devices table.
db.Exec("TRUNCATE syncv3_sync2_devices, syncv3_sync2_tokens;")
tokens := NewTokensTable(db, "my_secret")
devices := NewDevicesTable(db)
tcs := []struct {
UserID string
DeviceID string
tokenAges []time.Duration
}{
{UserID: "@alice:test", DeviceID: "no_tokens", tokenAges: nil},
{UserID: "@bob:test", DeviceID: "one_active_token", tokenAges: []time.Duration{time.Hour}},
{UserID: "@bob:test", DeviceID: "one_old_token", tokenAges: []time.Duration{7 * 24 * time.Hour}},
{UserID: "@chris:test", DeviceID: "one_old_one_active", tokenAges: []time.Duration{time.Hour, 7 * 24 * time.Hour}},
{UserID: "@delia:test", DeviceID: "two_old_tokens", tokenAges: []time.Duration{7 * 24 * time.Hour, 14 * 24 * time.Hour}},
}
txn, err := db.Beginx()
if err != nil {
t.Fatal(err)
}
numTokens := 0
for _, tc := range tcs {
err = devices.InsertDevice(txn, tc.UserID, tc.DeviceID)
if err != nil {
t.Fatal(err)
}
for _, age := range tc.tokenAges {
numTokens++
_, err = tokens.Insert(
txn,
fmt.Sprintf("token-%d", numTokens),
tc.UserID,
tc.DeviceID,
time.Now().Add(-age),
)
}
}
err = txn.Commit()
if err != nil {
t.Fatal(err)
}
oldDevices, err := devices.FindOldDevices(24 * time.Hour)
if err != nil {
t.Fatal(err)
}
sort.Slice(oldDevices, func(i, j int) bool {
return oldDevices[i].UserID < oldDevices[j].UserID
})
expectedDevices := []Device{
{UserID: "@bob:test", DeviceID: "one_old_token"},
{UserID: "@delia:test", DeviceID: "two_old_tokens"},
}
if !reflect.DeepEqual(oldDevices, expectedDevices) {
t.Errorf("Got %+v, but expected %v+", oldDevices, expectedDevices)
}
}

View File

@ -48,8 +48,9 @@ type Handler struct {
typingMu *sync.Mutex
PendingTxnIDs *sync2.PendingTransactionIDs
deviceDataTicker *sync2.DeviceDataTicker
e2eeWorkerPool *internal.WorkerPool
deviceDataTicker *sync2.DeviceDataTicker
pollerExpiryTicker *time.Ticker
e2eeWorkerPool *internal.WorkerPool
numPollers prometheus.Gauge
subSystem string
@ -111,6 +112,9 @@ func (h *Handler) Teardown() {
h.v2Store.Teardown()
h.pMap.Terminate()
h.deviceDataTicker.Stop()
if h.pollerExpiryTicker != nil {
h.pollerExpiryTicker.Stop()
}
if h.numPollers != nil {
prometheus.Unregister(h.numPollers)
}
@ -163,6 +167,7 @@ func (h *Handler) StartV2Pollers() {
wg.Wait()
logger.Info().Msg("StartV2Pollers finished")
h.updateMetrics()
h.startPollerExpiryTicker()
}
func (h *Handler) updateMetrics() {
@ -567,6 +572,40 @@ func (h *Handler) EnsurePolling(p *pubsub.V3EnsurePolling) {
}()
}
func (h *Handler) startPollerExpiryTicker() {
if h.pollerExpiryTicker != nil {
return
}
h.pollerExpiryTicker = time.NewTicker(time.Hour)
go func() {
for range h.pollerExpiryTicker.C {
h.ExpireOldPollers()
}
}()
}
// ExpireOldPollers looks for pollers whose devices have not made a sliding sync query
// in the last 30 days, and asks the poller map to expire their corresponding pollers.
// This function does not normally need to be called manually (StartV2Pollers queues it
// up to run hourly); we expose it publicly only for testing purposes.
func (h *Handler) ExpireOldPollers() {
devices, err := h.v2Store.DevicesTable.FindOldDevices(30 * 24 * time.Hour)
if err != nil {
logger.Err(err).Msg("Error fetching old devices")
sentry.CaptureException(err)
return
}
pids := make([]sync2.PollerID, len(devices))
for i := range devices {
pids[i].UserID = devices[i].UserID
pids[i].DeviceID = devices[i].DeviceID
}
numExpired := h.pMap.ExpirePollers(pids)
if len(devices) > 0 {
logger.Info().Int("old", len(devices)).Int("expired", numExpired).Msg("poller cleanup old devices")
}
}
func fnvHash(event json.RawMessage) uint64 {
h := fnv.New64a()
h.Write(event)

View File

@ -48,13 +48,18 @@ func (p *mockPollerMap) DeviceIDs(userID string) []string {
return nil
}
func (p *mockPollerMap) EnsurePolling(pid sync2.PollerID, accessToken, v2since string, isStartup bool, logger zerolog.Logger) {
func (p *mockPollerMap) ExpirePollers([]sync2.PollerID) int {
return 0
}
func (p *mockPollerMap) EnsurePolling(pid sync2.PollerID, accessToken, v2since string, isStartup bool, logger zerolog.Logger) bool {
p.calls = append(p.calls, pollInfo{
pid: pid,
accessToken: accessToken,
v2since: v2since,
isStartup: isStartup,
})
return false
}
func (p *mockPollerMap) assertCallExists(t *testing.T, pi pollInfo) {

View File

@ -68,10 +68,13 @@ type V2DataReceiver interface {
}
type IPollerMap interface {
EnsurePolling(pid PollerID, accessToken, v2since string, isStartup bool, logger zerolog.Logger)
EnsurePolling(pid PollerID, accessToken, v2since string, isStartup bool, logger zerolog.Logger) (created bool)
NumPollers() int
Terminate()
DeviceIDs(userID string) []string
// ExpirePollers requests that the given pollers are terminated as if their access
// tokens had expired. Returns the number of pollers successfully terminated.
ExpirePollers(ids []PollerID) int
}
// PollerMap is a map of device ID to Poller
@ -217,6 +220,24 @@ func (h *PollerMap) DeviceIDs(userID string) []string {
return devices
}
func (h *PollerMap) ExpirePollers(pids []PollerID) int {
h.pollerMu.Lock()
defer h.pollerMu.Unlock()
numTerminated := 0
for _, pid := range pids {
p, ok := h.Pollers[pid]
if !ok || p.terminated.Load() {
continue
}
p.Terminate()
// Ensure that we won't recreate this poller on startup. If it reappears later,
// we'll make another EnsurePolling call which will recreate the poller.
h.callbacks.OnExpiredToken(context.Background(), hashToken(p.accessToken), p.userID, p.deviceID)
numTerminated++
}
return numTerminated
}
// EnsurePolling makes sure there is a poller for this device, making one if need be.
// Blocks until at least 1 sync is done if and only if the poller was just created.
// This ensures that calls to the database will return data.
@ -224,7 +245,7 @@ func (h *PollerMap) DeviceIDs(userID string) []string {
// Note that we will immediately return if there is a poller for the same user but a different device.
// We do this to allow for logins on clients to be snappy fast, even though they won't yet have the
// to-device msgs to decrypt E2EE rooms.
func (h *PollerMap) EnsurePolling(pid PollerID, accessToken, v2since string, isStartup bool, logger zerolog.Logger) {
func (h *PollerMap) EnsurePolling(pid PollerID, accessToken, v2since string, isStartup bool, logger zerolog.Logger) bool {
h.pollerMu.Lock()
if !h.executorRunning {
h.executorRunning = true
@ -240,7 +261,7 @@ func (h *PollerMap) EnsurePolling(pid PollerID, accessToken, v2since string, isS
// this existing poller may not have completed the initial sync yet, so we need to make sure
// it has before we return.
poller.WaitUntilInitialSync()
return
return false
}
// check if we need to wait at all: we don't need to if this user is already syncing on a different device
// This is O(n) so we may want to map this if we get a lot of users...
@ -274,6 +295,7 @@ func (h *PollerMap) EnsurePolling(pid PollerID, accessToken, v2since string, isS
} else {
logger.Info().Str("user", poller.userID).Msg("a poller exists for this user; not waiting for this device to do an initial sync")
}
return true
}
func (h *PollerMap) execute() {

View File

@ -198,6 +198,72 @@ func TestPollerMapEnsurePollingIdempotent(t *testing.T) {
t.Logf("EnsurePolling unblocked")
}
func TestPollerMap_ExpirePollers(t *testing.T) {
receiver, client := newMocks(func(authHeader, since string) (*SyncResponse, int, error) {
r := SyncResponse{
NextBatch: "batchy-mc-batchface",
}
return &r, 200, nil
})
pm := NewPollerMap(client, false)
pm.SetCallbacks(receiver)
// Start 5 pollers.
pollerSpecs := []struct {
UserID string
DeviceID string
Token string
}{
{UserID: "alice", DeviceID: "a_device", Token: "a_token"},
{UserID: "bob", DeviceID: "b_device2", Token: "b_token1"},
{UserID: "bob", DeviceID: "b_device1", Token: "b_token2"},
{UserID: "chris", DeviceID: "phone", Token: "c_token"},
{UserID: "delia", DeviceID: "phone", Token: "d_token"},
}
for _, spec := range pollerSpecs {
created := pm.EnsurePolling(
PollerID{UserID: spec.UserID, DeviceID: spec.DeviceID},
spec.Token, "", true, logger,
)
if !created {
t.Errorf("Poller for %v was not newly created", spec)
}
}
// Expire some of them. This tests that:
pm.ExpirePollers([]PollerID{
// - Easy mode: if you have one poller and ask it to be deleted, it is deleted.
{"alice", "a_device"},
// - If you have two devices and ask for one of their pollers to be expired,
// only that poller is terminated.
{"bob", "b_device1"},
// - If there is a device ID clash, only the specified user's poller is expired.
// I.e. Delia unaffected
{"chris", "phone"},
})
// Try to recreate each poller. EnsurePolling should only report having to create a
// poller for the pollers we asked to be deleted.
expectDeleted := []bool{
true,
false,
true,
true,
false,
}
for i, spec := range pollerSpecs {
created := pm.EnsurePolling(
PollerID{UserID: spec.UserID, DeviceID: spec.DeviceID},
spec.Token, "", true, logger,
)
if created != expectDeleted[i] {
t.Errorf("Poller #%d (%v): created=%t, expected %t", i, spec, created, expectDeleted[i])
}
}
}
// Check that a call to Poll starts polling and accumulating, and terminates on 401s.
func TestPollerPollFromNothing(t *testing.T) {
nextSince := "next"

View File

@ -3,7 +3,10 @@ package syncv3
import (
"encoding/json"
"fmt"
"github.com/jmoiron/sqlx"
"github.com/matrix-org/sliding-sync/sqlutil"
"net/http"
"os"
"testing"
"time"
@ -311,3 +314,110 @@ func TestPollerUpdatesRoomMemberTrackerOnGappySyncStateBlock(t *testing.T) {
m.MatchRoomSubscription(roomID, m.MatchRoomTimelineMostRecent(1, []json.RawMessage{bobLeave})),
)
}
func TestPollersCanBeResumedAfterExpiry(t *testing.T) {
pqString := testutils.PrepareDBConnectionString()
// Start the mock sync v2 server and add a device for alice and for bob.
v2 := runTestV2Server(t)
defer v2.close()
const aliceDevice = "alice_phone"
const bobDevice = "bob_desktop"
v2.addAccountWithDeviceID(alice, aliceDevice, aliceToken)
v2.addAccountWithDeviceID(bob, bobDevice, bobToken)
// Queue up a sync v2 response for both Alice and Bob.
v2.queueResponse(aliceToken, sync2.SyncResponse{NextBatch: "alice_response_1"})
v2.queueResponse(bobToken, sync2.SyncResponse{NextBatch: "bob_response_1"})
// Inject an old token from Alice and a new token from Bob into the DB.
v2Store := sync2.NewStore(pqString, os.Getenv("SYNCV3_SECRET"))
err := sqlutil.WithTransaction(v2Store.DB, func(txn *sqlx.Tx) (err error) {
err = v2Store.DevicesTable.InsertDevice(txn, alice, aliceDevice)
if err != nil {
return
}
err = v2Store.DevicesTable.InsertDevice(txn, bob, bobDevice)
if err != nil {
return
}
_, err = v2Store.TokensTable.Insert(txn, aliceToken, alice, aliceDevice, time.UnixMicro(0))
if err != nil {
return
}
_, err = v2Store.TokensTable.Insert(txn, bobToken, bob, bobDevice, time.Now())
return
})
if err != nil {
t.Fatal(err)
}
t.Log("Start the v3 server and its pollers.")
v3 := runTestServer(t, v2, pqString)
go v3.h2.StartV2Pollers()
defer v3.close()
t.Log("Alice's poller should be active.")
v2.waitUntilEmpty(t, aliceToken)
t.Log("Bob's poller should be active.")
v2.waitUntilEmpty(t, bobToken)
t.Log("Manually trigger a poller cleanup.")
v3.h2.ExpireOldPollers()
t.Log("Queue up a sync v2 response for both Alice and Bob. Alice's response includes account data.")
accdata := testutils.NewAccountData(t, "dummytype", map[string]any{})
v2.queueResponse(aliceToken, sync2.SyncResponse{
NextBatch: "alice_response_2",
AccountData: sync2.EventsResponse{
Events: []json.RawMessage{
accdata,
},
},
})
v2.queueResponse(bobToken, sync2.SyncResponse{NextBatch: "bob_response_2"})
t.Log("Wait for Bob's poller to poll")
v2.waitUntilEmpty(t, bobToken)
// Alice's poller has likely already made an HTTP response. But her poller should
// have been terminated before the request was received, so its since token
// should not have been persisted to the DB.
t.Log("Alice's since token in the DB should not have advanced.")
// TODO: surprising that there isn't a function to get the since token for a device!
var since string
err = v2Store.DB.Get(&since, `SELECT since FROM syncv3_sync2_devices WHERE user_id = $1 AND device_id = $2`, alice, aliceDevice)
if err != nil {
t.Fatal(err)
}
if since != "alice_response_1" {
t.Errorf("Alice's sync token in DB was %s, expected alice_response_1", since)
}
t.Log("Requeue the same response for Alice's restarted poller to consume.")
v2.queueResponse(aliceToken, sync2.SyncResponse{
NextBatch: "alice_response_2",
AccountData: sync2.EventsResponse{
Events: []json.RawMessage{
accdata,
},
},
})
t.Log("Alice makes a new sliding sync request")
res := v3.mustDoV3Request(t, aliceToken, sync3.Request{
Extensions: extensions.Request{
AccountData: &extensions.AccountDataRequest{
extensions.Core{
Enabled: &boolTrue,
},
},
},
})
t.Log("Alice's poller should have been polled.")
v2.waitUntilEmpty(t, aliceToken)
t.Log("Alice should see her account data")
m.MatchResponse(t, res, m.MatchAccountData([]json.RawMessage{accdata}, nil))
}