From 92a24a14c9d39f4837dc81d79ce36ab752f9e97c Mon Sep 17 00:00:00 2001 From: Kegan Dougal Date: Mon, 9 Jan 2023 15:37:51 +0000 Subject: [PATCH] BREAKING: add id column to account data table This will be used for optimising which events we send back. --- state/account_data.go | 13 +++-- state/account_data_test.go | 109 ++++++++++++++++++++++++++++++------- 2 files changed, 96 insertions(+), 26 deletions(-) diff --git a/state/account_data.go b/state/account_data.go index c154e1b..5ada848 100644 --- a/state/account_data.go +++ b/state/account_data.go @@ -11,6 +11,7 @@ import ( const AccountDataGlobalRoom = "" type AccountData struct { + ID int64 `db:"id"` UserID string `db:"user_id"` RoomID string `db:"room_id"` Type string `db:"type"` @@ -23,7 +24,9 @@ type AccountDataTable struct{} func NewAccountDataTable(db *sqlx.DB) *AccountDataTable { // make sure tables are made db.MustExec(` + CREATE SEQUENCE IF NOT EXISTS syncv3_account_data_seq; CREATE TABLE IF NOT EXISTS syncv3_account_data ( + id BIGINT NOT NULL DEFAULT nextval('syncv3_account_data_seq'), user_id TEXT NOT NULL, room_id TEXT NOT NULL, -- optional if global type TEXT NOT NULL, @@ -54,7 +57,7 @@ func (t *AccountDataTable) Insert(txn *sqlx.Tx, accDatas []AccountData) ([]Accou for _, chunk := range chunks { _, err := txn.NamedExec(` INSERT INTO syncv3_account_data (user_id, room_id, type, data) - VALUES (:user_id, :room_id, :type, :data) ON CONFLICT (user_id, room_id, type) DO UPDATE SET data = EXCLUDED.data`, chunk) + VALUES (:user_id, :room_id, :type, :data) ON CONFLICT (user_id, room_id, type) DO UPDATE SET data = EXCLUDED.data, id=nextval('syncv3_account_data_seq')`, chunk) if err != nil { return nil, err } @@ -63,24 +66,24 @@ func (t *AccountDataTable) Insert(txn *sqlx.Tx, accDatas []AccountData) ([]Accou } func (t *AccountDataTable) Select(txn *sqlx.Tx, userID string, eventTypes []string, roomID string) (datas []AccountData, err error) { - err = txn.Select(&datas, `SELECT user_id, room_id, type, data FROM syncv3_account_data + err = txn.Select(&datas, `SELECT id, user_id, room_id, type, data FROM syncv3_account_data WHERE user_id=$1 AND type=ANY($2) AND room_id=$3`, userID, pq.StringArray(eventTypes), roomID) return } func (t *AccountDataTable) SelectWithType(txn *sqlx.Tx, userID, evType string) (datas []AccountData, err error) { - err = txn.Select(&datas, `SELECT user_id, room_id, type, data FROM syncv3_account_data + err = txn.Select(&datas, `SELECT id, user_id, room_id, type, data FROM syncv3_account_data WHERE user_id=$1 AND type=$2 AND room_id != ''`, userID, evType) return } func (t *AccountDataTable) SelectMany(txn *sqlx.Tx, userID string, roomIDs ...string) (datas []AccountData, err error) { if len(roomIDs) == 0 { - err = txn.Select(&datas, `SELECT user_id, room_id, type, data FROM syncv3_account_data + err = txn.Select(&datas, `SELECT id, user_id, room_id, type, data FROM syncv3_account_data WHERE user_id=$1 AND room_id = $2`, userID, AccountDataGlobalRoom) return } - err = txn.Select(&datas, `SELECT user_id, room_id, type, data FROM syncv3_account_data + err = txn.Select(&datas, `SELECT id, user_id, room_id, type, data FROM syncv3_account_data WHERE user_id=$1 AND room_id=ANY($2)`, userID, pq.StringArray(roomIDs)) return } diff --git a/state/account_data_test.go b/state/account_data_test.go index de5eb33..0c325e4 100644 --- a/state/account_data_test.go +++ b/state/account_data_test.go @@ -1,7 +1,7 @@ package state import ( - "reflect" + "bytes" "sort" "testing" @@ -9,7 +9,8 @@ import ( "github.com/matrix-org/sliding-sync/sync2" ) -func accountDatasEqual(gots, wants []AccountData) bool { +func assertAccountDatasEqual(t *testing.T, msg string, gots, wants []AccountData) { + t.Helper() key := func(a AccountData) string { return a.UserID + a.RoomID + a.Type } @@ -19,7 +20,26 @@ func accountDatasEqual(gots, wants []AccountData) bool { sort.Slice(wants, func(i, j int) bool { return key(wants[i]) < key(wants[j]) }) - return reflect.DeepEqual(gots, wants) + if len(gots) != len(wants) { + t.Fatalf("%s: got %v want %v", msg, gots, wants) + } + for i := range wants { + if gots[i].RoomID != wants[i].RoomID { + t.Errorf("%s[%d]: got room id %v want %v", msg, i, gots[i].RoomID, wants[i].RoomID) + } + if gots[i].Type != wants[i].Type { + t.Errorf("%s[%d]: got type %v want %v", msg, i, gots[i].Type, wants[i].Type) + } + if gots[i].UserID != wants[i].UserID { + t.Errorf("%s[%d]: got user id %v want %v", msg, i, gots[i].UserID, wants[i].UserID) + } + if !bytes.Equal(gots[i].Data, wants[i].Data) { + t.Errorf("%s[%d]: got data %v want %v", msg, i, string(gots[i].Data), string(wants[i].Data)) + } + if wants[i].ID > 0 && gots[i].ID != wants[i].ID { + t.Errorf("%s[%d]: got id %v want %v", msg, i, gots[i].ID, wants[i].ID) + } + } } func TestAccountData(t *testing.T) { @@ -94,18 +114,14 @@ func TestAccountData(t *testing.T) { if err != nil { t.Fatalf("Select: %s", err) } - if !reflect.DeepEqual(gotData[0], accountData[len(accountData)-1]) { - t.Fatalf("Select: expected updated event to be returned but wasn't. Got %+v want %+v", gotData, accountData[len(accountData)-1]) - } + assertAccountDatasEqual(t, "Select: expected updated event to be returned but wasn't", []AccountData{gotData[0]}, []AccountData{accountData[len(accountData)-1]}) // Select the global event gotData, err = table.Select(txn, alice, []string{eventType}, sync2.AccountDataGlobalRoom) if err != nil { t.Fatalf("Select: %s", err) } - if !reflect.DeepEqual(gotData[0], accountData[len(accountData)-3]) { - t.Fatalf("Select: expected global event to be returned but wasn't. Got %+v want %+v", gotData, accountData[len(accountData)-3]) - } + assertAccountDatasEqual(t, "Select: expected global event to be returned but wasn't", []AccountData{gotData[0]}, []AccountData{accountData[len(accountData)-3]}) // Select all global events for alice wantDatas := []AccountData{ @@ -115,9 +131,7 @@ func TestAccountData(t *testing.T) { if err != nil { t.Fatalf("SelectMany: %s", err) } - if !accountDatasEqual(gotDatas, wantDatas) { - t.Fatalf("SelectMany: got %v want %v", gotDatas, wantDatas) - } + assertAccountDatasEqual(t, "SelectMany", gotDatas, wantDatas) // Select all room events for alice wantDatas = []AccountData{ @@ -127,9 +141,7 @@ func TestAccountData(t *testing.T) { if err != nil { t.Fatalf("SelectMany: %s", err) } - if !accountDatasEqual(gotDatas, wantDatas) { - t.Fatalf("SelectMany: got %v want %v", gotDatas, wantDatas) - } + assertAccountDatasEqual(t, "SelectMany", gotDatas, wantDatas) // Select all room events for unknown user gotDatas, err = table.SelectMany(txn, "@someone-else:localhost", roomA) @@ -148,9 +160,7 @@ func TestAccountData(t *testing.T) { wantDatas = []AccountData{ accountData[1], accountData[6], } - if !accountDatasEqual(gotDatas, wantDatas) { - t.Fatalf("SelectWithType: got %v want %v", gotDatas, wantDatas) - } + assertAccountDatasEqual(t, "SelectWithType", gotDatas, wantDatas) // Select all types in this room gotDatas, err = table.Select(txn, alice, []string{eventType, "dummy"}, roomB) @@ -160,8 +170,65 @@ func TestAccountData(t *testing.T) { wantDatas = []AccountData{ accountData[1], accountData[2], } - if !accountDatasEqual(gotDatas, wantDatas) { - t.Fatalf("Select(multi-types): got %v want %v", gotDatas, wantDatas) - } + assertAccountDatasEqual(t, "Select(multi-types)", gotDatas, wantDatas) } + +func TestAccountDataIDIncrements(t *testing.T) { + db, err := sqlx.Open("postgres", postgresConnectionString) + if err != nil { + t.Fatalf("failed to open SQL db: %s", err) + } + txn, err := db.Beginx() + if err != nil { + t.Fatalf("failed to start txn: %s", err) + } + alice := "@alice_TestAccountDataIDIncrements:localhost" + roomA := "!TestAccountData_A:localhost" + //roomB := "!TestAccountData_B:localhost" + eventType := "the_event_type" + data := AccountData{ + UserID: alice, + RoomID: roomA, + Type: eventType, + Data: []byte(`{"foo":"bar"}`), + } + table := NewAccountDataTable(db) + _, err = table.Insert(txn, []AccountData{ + data, + }) + assertNoError(t, err) + // make sure all selects return an id + gots, err := table.SelectWithType(txn, alice, eventType) + assertNoError(t, err) + assertAccountDatasEqual(t, "SelectWithType", gots, []AccountData{data}) + if gots[0].ID == 0 { + t.Fatalf("missing id field") + } + data.ID = gots[0].ID + gots, err = table.Select(txn, alice, []string{eventType}, roomA) + assertNoError(t, err) + assertAccountDatasEqual(t, "Select", gots, []AccountData{data}) + gots, err = table.SelectMany(txn, alice, roomA) + assertNoError(t, err) + assertAccountDatasEqual(t, "SelectMany", gots, []AccountData{data}) + // now replace the data, which should update the id + data.Data = []byte(`{"foo":"bar2"}`) + _, err = table.Insert(txn, []AccountData{ + data, + }) + assertNoError(t, err) + gots, err = table.Select(txn, alice, []string{eventType}, roomA) + assertNoError(t, err) + if gots[0].ID < data.ID { + t.Fatalf("id was not incremented, got %d want %d", gots[0].ID, data.ID) + } + data.ID = gots[0].ID + assertAccountDatasEqual(t, "Select", gots, []AccountData{data}) + gots, err = table.SelectMany(txn, alice, roomA) + assertNoError(t, err) + assertAccountDatasEqual(t, "SelectMany", gots, []AccountData{data}) + gots, err = table.SelectWithType(txn, alice, eventType) + assertNoError(t, err) + assertAccountDatasEqual(t, "SelectWithType", gots, []AccountData{data}) +}