2021-09-20 18:09:28 +01:00
|
|
|
package sync2
|
|
|
|
|
|
|
|
import (
|
2023-08-08 18:48:16 +01:00
|
|
|
"fmt"
|
2023-04-27 19:14:35 +01:00
|
|
|
"github.com/jmoiron/sqlx"
|
2023-06-19 16:42:05 +01:00
|
|
|
"github.com/matrix-org/sliding-sync/sqlutil"
|
2021-09-20 18:09:28 +01:00
|
|
|
"os"
|
2023-08-08 18:48:16 +01:00
|
|
|
"reflect"
|
2022-07-14 10:48:45 +01:00
|
|
|
"sort"
|
2021-09-20 18:09:28 +01:00
|
|
|
"testing"
|
2023-04-28 16:10:43 +01:00
|
|
|
"time"
|
2021-09-20 18:09:28 +01:00
|
|
|
|
2022-12-15 11:08:50 +00:00
|
|
|
"github.com/matrix-org/sliding-sync/testutils"
|
2021-09-20 18:09:28 +01:00
|
|
|
)
|
|
|
|
|
|
|
|
var postgresConnectionString = "user=xxxxx dbname=syncv3_test sslmode=disable"
|
|
|
|
|
|
|
|
func TestMain(m *testing.M) {
|
2021-11-09 10:15:48 +00:00
|
|
|
postgresConnectionString = testutils.PrepareDBConnectionString()
|
2021-09-20 18:09:28 +01:00
|
|
|
exitCode := m.Run()
|
|
|
|
os.Exit(exitCode)
|
|
|
|
}
|
|
|
|
|
2023-05-02 16:57:11 +01:00
|
|
|
func connectToDB(t *testing.T) (*sqlx.DB, func()) {
|
2023-04-27 19:14:35 +01:00
|
|
|
db, err := sqlx.Open("postgres", postgresConnectionString)
|
|
|
|
if err != nil {
|
|
|
|
t.Fatalf("failed to open SQL db: %s", err)
|
|
|
|
}
|
2023-05-02 16:57:11 +01:00
|
|
|
return db, func() {
|
|
|
|
db.Close()
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// Note that we currently only ever read from (devices JOIN tokens), so there is some
|
|
|
|
// overlap with tokens_table_test.go here.
|
|
|
|
func TestDevicesTableSinceColumn(t *testing.T) {
|
|
|
|
db, close := connectToDB(t)
|
|
|
|
defer close()
|
2023-04-28 16:10:43 +01:00
|
|
|
tokens := NewTokensTable(db, "my_secret")
|
2023-05-02 16:57:11 +01:00
|
|
|
devices := NewDevicesTable(db)
|
2023-04-27 19:14:35 +01:00
|
|
|
|
2023-04-28 16:10:43 +01:00
|
|
|
alice := "@alice:localhost"
|
|
|
|
aliceDevice := "alice_phone"
|
|
|
|
aliceSecret1 := "mysecret1"
|
2023-05-02 16:57:11 +01:00
|
|
|
aliceSecret2 := "mysecret2"
|
2023-04-28 16:10:43 +01:00
|
|
|
|
2023-06-19 16:42:05 +01:00
|
|
|
var aliceToken, aliceToken2 *Token
|
|
|
|
_ = sqlutil.WithTransaction(db, func(txn *sqlx.Tx) (err error) {
|
|
|
|
t.Log("Insert two tokens for Alice.")
|
|
|
|
aliceToken, err = tokens.Insert(txn, aliceSecret1, alice, aliceDevice, time.Now())
|
|
|
|
if err != nil {
|
|
|
|
t.Fatalf("Failed to Insert token: %s", err)
|
|
|
|
}
|
|
|
|
aliceToken2, err = tokens.Insert(txn, aliceSecret2, alice, aliceDevice, time.Now())
|
|
|
|
if err != nil {
|
|
|
|
t.Fatalf("Failed to Insert token: %s", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
t.Log("Add a devices row for Alice")
|
|
|
|
err = devices.InsertDevice(txn, alice, aliceDevice)
|
|
|
|
if err != nil {
|
|
|
|
t.Fatalf("Failed to Insert device: %s", err)
|
|
|
|
}
|
|
|
|
return nil
|
|
|
|
})
|
2023-04-28 16:10:43 +01:00
|
|
|
|
|
|
|
t.Log("Pretend we're about to start a poller. Fetch Alice's token along with the since value tracked by the devices table.")
|
|
|
|
accessToken, since, err := tokens.GetTokenAndSince(alice, aliceDevice, aliceToken.AccessTokenHash)
|
|
|
|
if err != nil {
|
|
|
|
t.Fatalf("Failed to GetTokenAndSince: %s", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
t.Log("The since token should be empty.")
|
|
|
|
assertEqual(t, accessToken, aliceToken.AccessToken, "Token.AccessToken mismatch")
|
|
|
|
assertEqual(t, since, "", "Device.Since mismatch")
|
|
|
|
|
|
|
|
t.Log("Update the since column.")
|
|
|
|
sinceValue := "s-1-2-3-4"
|
|
|
|
err = devices.UpdateDeviceSince(alice, aliceDevice, sinceValue)
|
|
|
|
if err != nil {
|
|
|
|
t.Fatalf("Failed to update since column: %s", err)
|
|
|
|
}
|
|
|
|
|
2023-05-02 16:57:11 +01:00
|
|
|
t.Log("We should see the new since value when the poller refetches alice's token")
|
2023-04-28 16:10:43 +01:00
|
|
|
_, since, err = tokens.GetTokenAndSince(alice, aliceDevice, aliceToken.AccessTokenHash)
|
|
|
|
if err != nil {
|
|
|
|
t.Fatalf("Failed to GetTokenAndSince: %s", err)
|
|
|
|
}
|
|
|
|
assertEqual(t, since, sinceValue, "Device.Since mismatch")
|
|
|
|
|
2023-05-02 16:57:11 +01:00
|
|
|
t.Log("We should also see the new since value when the poller fetches alice's second token")
|
2023-04-28 16:10:43 +01:00
|
|
|
_, since, err = tokens.GetTokenAndSince(alice, aliceDevice, aliceToken2.AccessTokenHash)
|
|
|
|
if err != nil {
|
|
|
|
t.Fatalf("Failed to GetTokenAndSince: %s", err)
|
2021-09-20 18:09:28 +01:00
|
|
|
}
|
2023-04-28 16:10:43 +01:00
|
|
|
assertEqual(t, since, sinceValue, "Device.Since mismatch")
|
2023-05-02 16:57:11 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
func TestTokenForEachDevice(t *testing.T) {
|
|
|
|
db, close := connectToDB(t)
|
|
|
|
defer close()
|
|
|
|
|
|
|
|
// HACK: discard rows inserted by other tests. We don't normally need to do this,
|
|
|
|
// but this is testing a query that scans the entire devices table.
|
2023-08-08 18:47:52 +01:00
|
|
|
db.Exec("TRUNCATE syncv3_sync2_devices, syncv3_sync2_tokens;")
|
2023-05-02 16:57:11 +01:00
|
|
|
|
|
|
|
tokens := NewTokensTable(db, "my_secret")
|
|
|
|
devices := NewDevicesTable(db)
|
2022-07-14 10:48:45 +01:00
|
|
|
|
2023-05-02 16:57:11 +01:00
|
|
|
alice := "alice"
|
|
|
|
aliceDevice := "alice_phone"
|
2023-04-28 16:10:43 +01:00
|
|
|
bob := "bob"
|
|
|
|
bobDevice := "bob_laptop"
|
2023-05-02 16:57:11 +01:00
|
|
|
chris := "chris"
|
|
|
|
chrisDevice := "chris_desktop"
|
|
|
|
|
2023-06-19 16:42:05 +01:00
|
|
|
_ = sqlutil.WithTransaction(db, func(txn *sqlx.Tx) error {
|
|
|
|
t.Log("Add a device for Alice, Bob and Chris.")
|
|
|
|
err := devices.InsertDevice(txn, alice, aliceDevice)
|
|
|
|
if err != nil {
|
|
|
|
t.Fatalf("InsertDevice returned error: %s", err)
|
|
|
|
}
|
|
|
|
err = devices.InsertDevice(txn, bob, bobDevice)
|
|
|
|
if err != nil {
|
|
|
|
t.Fatalf("InsertDevice returned error: %s", err)
|
|
|
|
}
|
|
|
|
err = devices.InsertDevice(txn, chris, chrisDevice)
|
|
|
|
if err != nil {
|
|
|
|
t.Fatalf("InsertDevice returned error: %s", err)
|
|
|
|
}
|
|
|
|
return nil
|
|
|
|
})
|
2023-04-28 16:10:43 +01:00
|
|
|
|
2023-05-02 16:57:11 +01:00
|
|
|
t.Log("Mark Alice's device with a since token.")
|
|
|
|
sinceValue := "s-1-2-3-4"
|
2023-06-19 16:42:05 +01:00
|
|
|
err := devices.UpdateDeviceSince(alice, aliceDevice, sinceValue)
|
2023-05-02 16:57:11 +01:00
|
|
|
if err != nil {
|
2023-06-19 16:42:05 +01:00
|
|
|
t.Fatalf("UpdateDeviceSince returned error: %s", err)
|
2023-05-02 16:57:11 +01:00
|
|
|
}
|
|
|
|
|
2023-06-19 16:42:05 +01:00
|
|
|
var aliceToken2, bobToken *Token
|
|
|
|
_ = sqlutil.WithTransaction(db, func(txn *sqlx.Tx) error {
|
|
|
|
t.Log("Insert 2 tokens for Alice, one for Bob and none for Chris.")
|
|
|
|
aliceLastSeen1 := time.Now()
|
|
|
|
_, err = tokens.Insert(txn, "alice_secret", alice, aliceDevice, aliceLastSeen1)
|
|
|
|
if err != nil {
|
|
|
|
t.Fatalf("Failed to Insert token: %s", err)
|
|
|
|
}
|
|
|
|
aliceLastSeen2 := aliceLastSeen1.Add(1 * time.Minute)
|
|
|
|
aliceToken2, err = tokens.Insert(txn, "alice_secret2", alice, aliceDevice, aliceLastSeen2)
|
|
|
|
if err != nil {
|
|
|
|
t.Fatalf("Failed to Insert token: %s", err)
|
|
|
|
}
|
|
|
|
bobToken, err = tokens.Insert(txn, "bob_secret", bob, bobDevice, time.Time{})
|
|
|
|
if err != nil {
|
|
|
|
t.Fatalf("Failed to Insert token: %s", err)
|
|
|
|
}
|
|
|
|
return nil
|
|
|
|
})
|
|
|
|
|
2023-04-28 16:10:43 +01:00
|
|
|
t.Log("Fetch a token for every device")
|
2023-05-17 14:54:03 +01:00
|
|
|
gotTokens, err := tokens.TokenForEachDevice(nil)
|
2023-04-28 16:10:43 +01:00
|
|
|
if err != nil {
|
|
|
|
t.Fatalf("Failed TokenForEachDevice: %s", err)
|
2022-07-14 10:48:45 +01:00
|
|
|
}
|
2023-04-28 16:10:43 +01:00
|
|
|
|
2023-05-02 17:44:07 +01:00
|
|
|
expectAlice := TokenForPoller{Token: aliceToken2, Since: sinceValue}
|
|
|
|
expectBob := TokenForPoller{Token: bobToken, Since: ""}
|
|
|
|
wantTokens := []*TokenForPoller{&expectAlice, &expectBob}
|
2023-04-28 16:10:43 +01:00
|
|
|
|
|
|
|
if len(gotTokens) != len(wantTokens) {
|
|
|
|
t.Fatalf("AllDevices: got %d tokens, want %d", len(gotTokens), len(wantTokens))
|
2022-07-14 10:48:45 +01:00
|
|
|
}
|
2023-04-28 16:10:43 +01:00
|
|
|
|
|
|
|
sort.Slice(gotTokens, func(i, j int) bool {
|
|
|
|
if gotTokens[i].UserID < gotTokens[j].UserID {
|
|
|
|
return true
|
|
|
|
}
|
|
|
|
return gotTokens[i].DeviceID < gotTokens[j].DeviceID
|
|
|
|
})
|
|
|
|
|
|
|
|
for i := range gotTokens {
|
|
|
|
assertEqual(t, gotTokens[i].Since, wantTokens[i].Since, "Device.Since mismatch")
|
|
|
|
assertEqual(t, gotTokens[i].UserID, wantTokens[i].UserID, "Token.UserID mismatch")
|
|
|
|
assertEqual(t, gotTokens[i].DeviceID, wantTokens[i].DeviceID, "Token.DeviceID mismatch")
|
|
|
|
assertEqual(t, gotTokens[i].AccessToken, wantTokens[i].AccessToken, "Token.AccessToken mismatch")
|
2022-07-14 10:48:45 +01:00
|
|
|
}
|
2021-09-20 18:09:28 +01:00
|
|
|
}
|
2023-08-08 18:48:16 +01:00
|
|
|
|
|
|
|
func TestDevicesTable_FindOldDevices(t *testing.T) {
|
|
|
|
db, close := connectToDB(t)
|
|
|
|
defer close()
|
|
|
|
|
|
|
|
// HACK: discard rows inserted by other tests. We don't normally need to do this,
|
|
|
|
// but this is testing a query that scans the entire devices table.
|
|
|
|
db.Exec("TRUNCATE syncv3_sync2_devices, syncv3_sync2_tokens;")
|
|
|
|
|
|
|
|
tokens := NewTokensTable(db, "my_secret")
|
|
|
|
devices := NewDevicesTable(db)
|
|
|
|
|
|
|
|
tcs := []struct {
|
|
|
|
UserID string
|
|
|
|
DeviceID string
|
|
|
|
tokenAges []time.Duration
|
|
|
|
}{
|
|
|
|
{UserID: "@alice:test", DeviceID: "no_tokens", tokenAges: nil},
|
|
|
|
{UserID: "@bob:test", DeviceID: "one_active_token", tokenAges: []time.Duration{time.Hour}},
|
|
|
|
{UserID: "@bob:test", DeviceID: "one_old_token", tokenAges: []time.Duration{7 * 24 * time.Hour}},
|
|
|
|
{UserID: "@chris:test", DeviceID: "one_old_one_active", tokenAges: []time.Duration{time.Hour, 7 * 24 * time.Hour}},
|
|
|
|
{UserID: "@delia:test", DeviceID: "two_old_tokens", tokenAges: []time.Duration{7 * 24 * time.Hour, 14 * 24 * time.Hour}},
|
|
|
|
}
|
|
|
|
|
|
|
|
txn, err := db.Beginx()
|
|
|
|
if err != nil {
|
|
|
|
t.Fatal(err)
|
|
|
|
}
|
|
|
|
numTokens := 0
|
|
|
|
for _, tc := range tcs {
|
|
|
|
err = devices.InsertDevice(txn, tc.UserID, tc.DeviceID)
|
|
|
|
if err != nil {
|
|
|
|
t.Fatal(err)
|
|
|
|
}
|
|
|
|
for _, age := range tc.tokenAges {
|
|
|
|
numTokens++
|
|
|
|
_, err = tokens.Insert(
|
|
|
|
txn,
|
|
|
|
fmt.Sprintf("token-%d", numTokens),
|
|
|
|
tc.UserID,
|
|
|
|
tc.DeviceID,
|
|
|
|
time.Now().Add(-age),
|
|
|
|
)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
err = txn.Commit()
|
|
|
|
if err != nil {
|
|
|
|
t.Fatal(err)
|
|
|
|
}
|
|
|
|
|
|
|
|
oldDevices, err := devices.FindOldDevices(24 * time.Hour)
|
|
|
|
if err != nil {
|
|
|
|
t.Fatal(err)
|
|
|
|
}
|
|
|
|
sort.Slice(oldDevices, func(i, j int) bool {
|
|
|
|
return oldDevices[i].UserID < oldDevices[j].UserID
|
|
|
|
})
|
|
|
|
expectedDevices := []Device{
|
|
|
|
{UserID: "@bob:test", DeviceID: "one_old_token"},
|
|
|
|
{UserID: "@delia:test", DeviceID: "two_old_tokens"},
|
|
|
|
}
|
|
|
|
|
|
|
|
if !reflect.DeepEqual(oldDevices, expectedDevices) {
|
|
|
|
t.Errorf("Got %+v, but expected %v+", oldDevices, expectedDevices)
|
|
|
|
}
|
|
|
|
}
|