Add function to query for old devices

This commit is contained in:
David Robertson 2023-08-08 18:48:16 +01:00
parent a3ba6cb724
commit 188e10e77b
No known key found for this signature in database
GPG Key ID: 903ECE108A39DEDD
2 changed files with 88 additions and 0 deletions

View File

@ -3,6 +3,7 @@ package sync2
import (
"github.com/jmoiron/sqlx"
_ "github.com/lib/pq"
"time"
)
type Device struct {
@ -45,3 +46,22 @@ func (t *DevicesTable) UpdateDeviceSince(userID, deviceID, since string) error {
_, err := t.db.Exec(`UPDATE syncv3_sync2_devices SET since = $1 WHERE user_id = $2 AND device_id = $3`, since, userID, deviceID)
return err
}
// FindOldDevices fetches the user_id and device_id of all devices which haven't /synced
// for at least as long as the given inactivityPeriod. Such devices are returned in
// no particular order.
//
// This is determined using the syncv3_sync2_tokens.last_seen column, which is updated
// at most once per day to save DB throughtput (see MaybeUpdateLastSeen). The caller
// should therefore use an inactivityPeriod of at least two days to avoid considering
// a recently-used device as old.
func (t *DevicesTable) FindOldDevices(inactivityPeriod time.Duration) (devices []Device, err error) {
err = t.db.Select(&devices, `
SELECT user_id, device_id
FROM syncv3_sync2_devices JOIN syncv3_sync2_tokens USING(user_id, device_id)
GROUP BY (user_id, device_id)
HAVING MAX(last_seen) < $1
`, time.Now().Add(-inactivityPeriod),
)
return
}

View File

@ -1,9 +1,11 @@
package sync2
import (
"fmt"
"github.com/jmoiron/sqlx"
"github.com/matrix-org/sliding-sync/sqlutil"
"os"
"reflect"
"sort"
"testing"
"time"
@ -184,3 +186,69 @@ func TestTokenForEachDevice(t *testing.T) {
assertEqual(t, gotTokens[i].AccessToken, wantTokens[i].AccessToken, "Token.AccessToken mismatch")
}
}
func TestDevicesTable_FindOldDevices(t *testing.T) {
db, close := connectToDB(t)
defer close()
// HACK: discard rows inserted by other tests. We don't normally need to do this,
// but this is testing a query that scans the entire devices table.
db.Exec("TRUNCATE syncv3_sync2_devices, syncv3_sync2_tokens;")
tokens := NewTokensTable(db, "my_secret")
devices := NewDevicesTable(db)
tcs := []struct {
UserID string
DeviceID string
tokenAges []time.Duration
}{
{UserID: "@alice:test", DeviceID: "no_tokens", tokenAges: nil},
{UserID: "@bob:test", DeviceID: "one_active_token", tokenAges: []time.Duration{time.Hour}},
{UserID: "@bob:test", DeviceID: "one_old_token", tokenAges: []time.Duration{7 * 24 * time.Hour}},
{UserID: "@chris:test", DeviceID: "one_old_one_active", tokenAges: []time.Duration{time.Hour, 7 * 24 * time.Hour}},
{UserID: "@delia:test", DeviceID: "two_old_tokens", tokenAges: []time.Duration{7 * 24 * time.Hour, 14 * 24 * time.Hour}},
}
txn, err := db.Beginx()
if err != nil {
t.Fatal(err)
}
numTokens := 0
for _, tc := range tcs {
err = devices.InsertDevice(txn, tc.UserID, tc.DeviceID)
if err != nil {
t.Fatal(err)
}
for _, age := range tc.tokenAges {
numTokens++
_, err = tokens.Insert(
txn,
fmt.Sprintf("token-%d", numTokens),
tc.UserID,
tc.DeviceID,
time.Now().Add(-age),
)
}
}
err = txn.Commit()
if err != nil {
t.Fatal(err)
}
oldDevices, err := devices.FindOldDevices(24 * time.Hour)
if err != nil {
t.Fatal(err)
}
sort.Slice(oldDevices, func(i, j int) bool {
return oldDevices[i].UserID < oldDevices[j].UserID
})
expectedDevices := []Device{
{UserID: "@bob:test", DeviceID: "one_old_token"},
{UserID: "@delia:test", DeviceID: "two_old_tokens"},
}
if !reflect.DeepEqual(oldDevices, expectedDevices) {
t.Errorf("Got %+v, but expected %v+", oldDevices, expectedDevices)
}
}