From a7eed937228cf6376da82de98dda9e4bf67c52ba Mon Sep 17 00:00:00 2001 From: Kegan Dougal Date: Wed, 18 Jan 2023 14:54:26 +0000 Subject: [PATCH] Add comprehensive regression test for GlobalSnapshot(); ensure we clear db conns when tests end --- state/account_data_test.go | 13 +-- state/accumulator_test.go | 53 +++------ state/device_data_table_test.go | 17 +-- state/event_table_test.go | 199 +++++++++++++++----------------- state/invites_table_test.go | 8 +- state/main_test.go | 11 ++ state/receipt_table_test.go | 7 +- state/rooms_table_test.go | 8 +- state/snapshot_table_test.go | 7 +- state/spaces_table_test.go | 14 +-- state/storage_test.go | 145 +++++++++++++++++++++++ state/to_device_table_test.go | 30 ++--- state/txn_table_test.go | 8 +- state/typing_table_test.go | 8 +- state/unread_table_test.go | 8 +- 15 files changed, 306 insertions(+), 230 deletions(-) diff --git a/state/account_data_test.go b/state/account_data_test.go index 0c325e4..d0af472 100644 --- a/state/account_data_test.go +++ b/state/account_data_test.go @@ -5,7 +5,6 @@ import ( "sort" "testing" - "github.com/jmoiron/sqlx" "github.com/matrix-org/sliding-sync/sync2" ) @@ -43,10 +42,8 @@ func assertAccountDatasEqual(t *testing.T, msg string, gots, wants []AccountData } func TestAccountData(t *testing.T) { - db, err := sqlx.Open("postgres", postgresConnectionString) - if err != nil { - t.Fatalf("failed to open SQL db: %s", err) - } + db, close := connectToDB(t) + defer close() txn, err := db.Beginx() if err != nil { t.Fatalf("failed to start txn: %s", err) @@ -175,10 +172,8 @@ func TestAccountData(t *testing.T) { } func TestAccountDataIDIncrements(t *testing.T) { - db, err := sqlx.Open("postgres", postgresConnectionString) - if err != nil { - t.Fatalf("failed to open SQL db: %s", err) - } + db, close := connectToDB(t) + defer close() txn, err := db.Beginx() if err != nil { t.Fatalf("failed to start txn: %s", err) diff --git a/state/accumulator_test.go b/state/accumulator_test.go index e1bb476..5338646 100644 --- a/state/accumulator_test.go +++ b/state/accumulator_test.go @@ -7,7 +7,6 @@ import ( "sort" "testing" - "github.com/jmoiron/sqlx" "github.com/matrix-org/sliding-sync/sync2" "github.com/matrix-org/sliding-sync/testutils" "github.com/tidwall/gjson" @@ -21,10 +20,8 @@ func TestAccumulatorInitialise(t *testing.T) { []byte(`{"event_id":"C", "type":"m.room.join_rules", "state_key":"", "content":{"join_rule":"public"}}`), } roomEventIDs := []string{"A", "B", "C"} - db, err := sqlx.Open("postgres", postgresConnectionString) - if err != nil { - t.Fatalf("failed to open SQL db: %s", err) - } + db, close := connectToDB(t) + defer close() accumulator := NewAccumulator(db) added, initSnapID, err := accumulator.Initialise(roomID, roomEvents) if err != nil { @@ -99,12 +96,10 @@ func TestAccumulatorAccumulate(t *testing.T) { []byte(`{"event_id":"E", "type":"m.room.member", "state_key":"@me:localhost", "content":{"membership":"join"}}`), []byte(`{"event_id":"F", "type":"m.room.join_rules", "state_key":"", "content":{"join_rule":"public"}}`), } - db, err := sqlx.Open("postgres", postgresConnectionString) - if err != nil { - t.Fatalf("failed to open SQL db: %s", err) - } + db, close := connectToDB(t) + defer close() accumulator := NewAccumulator(db) - _, _, err = accumulator.Initialise(roomID, roomEvents) + _, _, err := accumulator.Initialise(roomID, roomEvents) if err != nil { t.Fatalf("failed to Initialise accumulator: %s", err) } @@ -197,12 +192,10 @@ func TestAccumulatorAccumulate(t *testing.T) { func TestAccumulatorDelta(t *testing.T) { roomID := "!TestAccumulatorDelta:localhost" - db, err := sqlx.Open("postgres", postgresConnectionString) - if err != nil { - t.Fatalf("failed to open SQL db: %s", err) - } + db, close := connectToDB(t) + defer close() accumulator := NewAccumulator(db) - _, _, err = accumulator.Initialise(roomID, nil) + _, _, err := accumulator.Initialise(roomID, nil) if err != nil { t.Fatalf("failed to Initialise accumulator: %s", err) } @@ -248,12 +241,10 @@ func TestAccumulatorDelta(t *testing.T) { func TestAccumulatorMembershipLogs(t *testing.T) { roomID := "!TestAccumulatorMembershipLogs:localhost" - db, err := sqlx.Open("postgres", postgresConnectionString) - if err != nil { - t.Fatalf("failed to open SQL db: %s", err) - } + db, close := connectToDB(t) + defer close() accumulator := NewAccumulator(db) - _, _, err = accumulator.Initialise(roomID, nil) + _, _, err := accumulator.Initialise(roomID, nil) if err != nil { t.Fatalf("failed to Initialise accumulator: %s", err) } @@ -389,13 +380,11 @@ func TestAccumulatorDupeEvents(t *testing.T) { t.Fatalf("failed to unmarshal: %s", err) } - db, err := sqlx.Open("postgres", postgresConnectionString) - if err != nil { - t.Fatalf("failed to open SQL db: %s", err) - } + db, close := connectToDB(t) + defer close() accumulator := NewAccumulator(db) roomID := "!buggy:localhost" - _, _, err = accumulator.Initialise(roomID, joinRoom.State.Events) + _, _, err := accumulator.Initialise(roomID, joinRoom.State.Events) if err != nil { t.Fatalf("failed to Initialise accumulator: %s", err) } @@ -432,14 +421,12 @@ func TestAccumulatorMisorderedGraceful(t *testing.T) { ) t.Logf("A=member-alice, B=msg, C=create, D=member-bob") - db, err := sqlx.Open("postgres", postgresConnectionString) - if err != nil { - t.Fatalf("failed to open SQL db: %s", err) - } + db, close := connectToDB(t) + defer close() accumulator := NewAccumulator(db) roomID := "!TestAccumulatorStateReset:localhost" // Create a room with initial state A,C - _, _, err = accumulator.Initialise(roomID, []json.RawMessage{ + _, _, err := accumulator.Initialise(roomID, []json.RawMessage{ eventA, eventC, }) if err != nil { @@ -509,10 +496,8 @@ func TestCalculateNewSnapshotDupe(t *testing.T) { t.Errorf("assertNIDsEqual: got %v want %v", a, b) } } - db, err := sqlx.Open("postgres", postgresConnectionString) - if err != nil { - t.Fatalf("failed to open SQL db: %s", err) - } + db, close := connectToDB(t) + defer close() testCases := []struct { input StrippedEvents inputEvent Event diff --git a/state/device_data_table_test.go b/state/device_data_table_test.go index 60c3435..c49c4bd 100644 --- a/state/device_data_table_test.go +++ b/state/device_data_table_test.go @@ -4,7 +4,6 @@ import ( "reflect" "testing" - "github.com/jmoiron/sqlx" "github.com/matrix-org/sliding-sync/internal" ) @@ -25,10 +24,8 @@ func assertDeviceData(t *testing.T, g, w internal.DeviceData) { } func TestDeviceDataTableSwaps(t *testing.T) { - db, err := sqlx.Open("postgres", postgresConnectionString) - if err != nil { - t.Fatalf("failed to open SQL db: %s", err) - } + db, close := connectToDB(t) + defer close() table := NewDeviceDataTable(db) userID := "@bob" deviceID := "BOB" @@ -67,7 +64,7 @@ func TestDeviceDataTableSwaps(t *testing.T) { }, } for _, dd := range deltas { - _, err = table.Upsert(&dd) + _, err := table.Upsert(&dd) assertNoError(t, err) } @@ -173,10 +170,8 @@ func TestDeviceDataTableSwaps(t *testing.T) { } func TestDeviceDataTableBitset(t *testing.T) { - db, err := sqlx.Open("postgres", postgresConnectionString) - if err != nil { - t.Fatalf("failed to open SQL db: %s", err) - } + db, close := connectToDB(t) + defer close() table := NewDeviceDataTable(db) userID := "@bobTestDeviceDataTableBitset" deviceID := "BOBTestDeviceDataTableBitset" @@ -202,7 +197,7 @@ func TestDeviceDataTableBitset(t *testing.T) { }, } - _, err = table.Upsert(&otkUpdate) + _, err := table.Upsert(&otkUpdate) assertNoError(t, err) got, err := table.Select(userID, deviceID, true) assertNoError(t, err) diff --git a/state/event_table_test.go b/state/event_table_test.go index 9a0781a..5e7c7e2 100644 --- a/state/event_table_test.go +++ b/state/event_table_test.go @@ -6,7 +6,6 @@ import ( "fmt" "testing" - "github.com/jmoiron/sqlx" "github.com/tidwall/gjson" "github.com/matrix-org/sliding-sync/sqlutil" @@ -14,10 +13,8 @@ import ( ) func TestEventTable(t *testing.T) { - db, err := sqlx.Open("postgres", postgresConnectionString) - if err != nil { - t.Fatalf("failed to open SQL db: %s", err) - } + db, close := connectToDB(t) + defer close() txn, err := db.Beginx() if err != nil { t.Fatalf("failed to start txn: %s", err) @@ -199,10 +196,8 @@ func TestEventTable(t *testing.T) { } func TestEventTableNullValue(t *testing.T) { - db, err := sqlx.Open("postgres", postgresConnectionString) - if err != nil { - t.Fatalf("failed to open SQL db: %s", err) - } + db, close := connectToDB(t) + defer close() txn, err := db.Beginx() if err != nil { t.Fatalf("failed to start txn: %s", err) @@ -241,10 +236,8 @@ func TestEventTableNullValue(t *testing.T) { } func TestEventTableDupeInsert(t *testing.T) { - db, err := sqlx.Open("postgres", postgresConnectionString) - if err != nil { - t.Fatalf("failed to open SQL db: %s", err) - } + db, close := connectToDB(t) + defer close() // first insert txn, err := db.Beginx() if err != nil { @@ -303,10 +296,8 @@ func TestEventTableDupeInsert(t *testing.T) { } func TestEventTableSelectEventsBetween(t *testing.T) { - db, err := sqlx.Open("postgres", postgresConnectionString) - if err != nil { - t.Fatalf("failed to open SQL db: %s", err) - } + db, close := connectToDB(t) + defer close() txn, err := db.Beginx() if err != nil { t.Fatalf("failed to start txn: %s", err) @@ -347,85 +338,85 @@ func TestEventTableSelectEventsBetween(t *testing.T) { } txn.Commit() - t.Run("selecting multiple events known lower bound", func(t *testing.T) { - t.Parallel() - txn2, err := db.Beginx() - if err != nil { - t.Fatalf("failed to start txn: %s", err) - } - defer txn2.Rollback() - events, err := table.SelectByIDs(txn2, true, []string{eventIDs[0]}) - if err != nil || len(events) == 0 { - t.Fatalf("failed to extract event for lower bound: %s", err) - } - events, err = table.SelectEventsBetween(txn2, searchRoomID, int64(events[0].NID), EventsEnd, 1000) - if err != nil { - t.Fatalf("failed to SelectEventsBetween: %s", err) - } - // 3 as 1 is from a different room - if len(events) != 3 { - t.Fatalf("wanted 3 events, got %d", len(events)) - } - }) - t.Run("selecting multiple events known lower and upper bound", func(t *testing.T) { - t.Parallel() - txn3, err := db.Beginx() - if err != nil { - t.Fatalf("failed to start txn: %s", err) - } - defer txn3.Rollback() - events, err := table.SelectByIDs(txn3, true, []string{eventIDs[0], eventIDs[2]}) - if err != nil || len(events) == 0 { - t.Fatalf("failed to extract event for lower/upper bound: %s", err) - } - events, err = table.SelectEventsBetween(txn3, searchRoomID, int64(events[0].NID), int64(events[1].NID), 1000) - if err != nil { - t.Fatalf("failed to SelectEventsBetween: %s", err) - } - // eventIDs[1] and eventIDs[2] - if len(events) != 2 { - t.Fatalf("wanted 2 events, got %d", len(events)) - } - }) - t.Run("selecting multiple events unknown bounds (all events)", func(t *testing.T) { - t.Parallel() - txn4, err := db.Beginx() - if err != nil { - t.Fatalf("failed to start txn: %s", err) - } - defer txn4.Rollback() - gotEvents, err := table.SelectEventsBetween(txn4, searchRoomID, EventsStart, EventsEnd, 1000) - if err != nil { - t.Fatalf("failed to SelectEventsBetween: %s", err) - } - // one less as one event is for a different room - if len(gotEvents) != (len(events) - 1) { - t.Fatalf("wanted %d events, got %d", len(events)-1, len(gotEvents)) - } - }) - t.Run("selecting multiple events hitting the limit", func(t *testing.T) { - t.Parallel() - txn5, err := db.Beginx() - if err != nil { - t.Fatalf("failed to start txn: %s", err) - } - defer txn5.Rollback() - limit := 2 - gotEvents, err := table.SelectEventsBetween(txn5, searchRoomID, EventsStart, EventsEnd, limit) - if err != nil { - t.Fatalf("failed to SelectEventsBetween: %s", err) - } - if len(gotEvents) != limit { - t.Fatalf("wanted %d events, got %d", limit, len(gotEvents)) - } + t.Run("subgroup", func(t *testing.T) { + t.Run("selecting multiple events known lower bound", func(t *testing.T) { + t.Parallel() + txn2, err := db.Beginx() + if err != nil { + t.Fatalf("failed to start txn: %s", err) + } + defer txn2.Rollback() + events, err := table.SelectByIDs(txn2, true, []string{eventIDs[0]}) + if err != nil || len(events) == 0 { + t.Fatalf("failed to extract event for lower bound: %s", err) + } + events, err = table.SelectEventsBetween(txn2, searchRoomID, int64(events[0].NID), EventsEnd, 1000) + if err != nil { + t.Fatalf("failed to SelectEventsBetween: %s", err) + } + // 3 as 1 is from a different room + if len(events) != 3 { + t.Fatalf("wanted 3 events, got %d", len(events)) + } + }) + t.Run("selecting multiple events known lower and upper bound", func(t *testing.T) { + t.Parallel() + txn3, err := db.Beginx() + if err != nil { + t.Fatalf("failed to start txn: %s", err) + } + defer txn3.Rollback() + events, err := table.SelectByIDs(txn3, true, []string{eventIDs[0], eventIDs[2]}) + if err != nil || len(events) == 0 { + t.Fatalf("failed to extract event for lower/upper bound: %s", err) + } + events, err = table.SelectEventsBetween(txn3, searchRoomID, int64(events[0].NID), int64(events[1].NID), 1000) + if err != nil { + t.Fatalf("failed to SelectEventsBetween: %s", err) + } + // eventIDs[1] and eventIDs[2] + if len(events) != 2 { + t.Fatalf("wanted 2 events, got %d", len(events)) + } + }) + t.Run("selecting multiple events unknown bounds (all events)", func(t *testing.T) { + t.Parallel() + txn4, err := db.Beginx() + if err != nil { + t.Fatalf("failed to start txn: %s", err) + } + defer txn4.Rollback() + gotEvents, err := table.SelectEventsBetween(txn4, searchRoomID, EventsStart, EventsEnd, 1000) + if err != nil { + t.Fatalf("failed to SelectEventsBetween: %s", err) + } + // one less as one event is for a different room + if len(gotEvents) != (len(events) - 1) { + t.Fatalf("wanted %d events, got %d", len(events)-1, len(gotEvents)) + } + }) + t.Run("selecting multiple events hitting the limit", func(t *testing.T) { + t.Parallel() + txn5, err := db.Beginx() + if err != nil { + t.Fatalf("failed to start txn: %s", err) + } + defer txn5.Rollback() + limit := 2 + gotEvents, err := table.SelectEventsBetween(txn5, searchRoomID, EventsStart, EventsEnd, limit) + if err != nil { + t.Fatalf("failed to SelectEventsBetween: %s", err) + } + if len(gotEvents) != limit { + t.Fatalf("wanted %d events, got %d", limit, len(gotEvents)) + } + }) }) } func TestEventTableMembershipDetection(t *testing.T) { - db, err := sqlx.Open("postgres", postgresConnectionString) - if err != nil { - t.Fatalf("failed to open SQL db: %s", err) - } + db, close := connectToDB(t) + defer close() txn, err := db.Beginx() if err != nil { t.Fatalf("failed to start txn: %s", err) @@ -546,10 +537,8 @@ func TestChunkify(t *testing.T) { } func TestEventTableSelectEventsWithTypeStateKey(t *testing.T) { - db, err := sqlx.Open("postgres", postgresConnectionString) - if err != nil { - t.Fatalf("failed to open SQL db: %s", err) - } + db, close := connectToDB(t) + defer close() txn, err := db.Beginx() if err != nil { t.Fatalf("failed to start txn: %s", err) @@ -645,10 +634,8 @@ func TestEventTableSelectEventsWithTypeStateKey(t *testing.T) { // Do a massive insert/select for event IDs (greater than postgres limit) and ensure it works. func TestTortureEventTable(t *testing.T) { - db, err := sqlx.Open("postgres", postgresConnectionString) - if err != nil { - t.Fatalf("failed to open SQL db: %s", err) - } + db, close := connectToDB(t) + defer close() txn, err := db.Beginx() if err != nil { t.Fatalf("failed to start txn: %s", err) @@ -699,10 +686,8 @@ func TestTortureEventTable(t *testing.T) { // 3: SelectClosestPrevBatch with an event without a prev_batch returns the next newest (stream order) event with a prev_batch // 4: SelectClosestPrevBatch with an event without a prev_batch returns nothing if there are no newer events with a prev_batch func TestEventTablePrevBatch(t *testing.T) { - db, err := sqlx.Open("postgres", postgresConnectionString) - if err != nil { - t.Fatalf("failed to open SQL db: %s", err) - } + db, close := connectToDB(t) + defer close() txn, err := db.Beginx() if err != nil { t.Fatalf("failed to start txn: %s", err) @@ -827,10 +812,8 @@ func TestEventTablePrevBatch(t *testing.T) { } func TestRemoveUnsignedTXNID(t *testing.T) { - db, err := sqlx.Open("postgres", postgresConnectionString) - if err != nil { - t.Fatalf("failed to open SQL db: %s", err) - } + db, close := connectToDB(t) + defer close() txn, err := db.Beginx() if err != nil { t.Fatalf("failed to start txn: %s", err) diff --git a/state/invites_table_test.go b/state/invites_table_test.go index a81428e..92c0077 100644 --- a/state/invites_table_test.go +++ b/state/invites_table_test.go @@ -4,15 +4,11 @@ import ( "encoding/json" "reflect" "testing" - - "github.com/jmoiron/sqlx" ) func TestInviteTable(t *testing.T) { - db, err := sqlx.Open("postgres", postgresConnectionString) - if err != nil { - t.Fatalf("failed to open SQL db: %s", err) - } + db, close := connectToDB(t) + defer close() table := NewInvitesTable(db) alice := "@alice:localhost" bob := "@bob:localhost" diff --git a/state/main_test.go b/state/main_test.go index 4727bbe..30938c6 100644 --- a/state/main_test.go +++ b/state/main_test.go @@ -4,6 +4,7 @@ import ( "os" "testing" + "github.com/jmoiron/sqlx" "github.com/matrix-org/sliding-sync/testutils" ) @@ -14,3 +15,13 @@ func TestMain(m *testing.M) { exitCode := m.Run() os.Exit(exitCode) } + +func connectToDB(t *testing.T) (*sqlx.DB, func()) { + db, err := sqlx.Open("postgres", postgresConnectionString) + if err != nil { + t.Fatalf("failed to open SQL db: %s", err) + } + return db, func() { + db.Close() + } +} diff --git a/state/receipt_table_test.go b/state/receipt_table_test.go index e9d8ab6..c5234c9 100644 --- a/state/receipt_table_test.go +++ b/state/receipt_table_test.go @@ -6,7 +6,6 @@ import ( "sort" "testing" - "github.com/jmoiron/sqlx" "github.com/matrix-org/sliding-sync/internal" ) @@ -33,10 +32,8 @@ func parsedReceiptsEqual(t *testing.T, got, want []internal.Receipt) { } func TestReceiptTable(t *testing.T) { - db, err := sqlx.Open("postgres", postgresConnectionString) - if err != nil { - t.Fatalf("failed to open SQL db: %s", err) - } + db, close := connectToDB(t) + defer close() roomA := "!A:ReceiptTable" roomB := "!B:ReceiptTable" edu := json.RawMessage(`{ diff --git a/state/rooms_table_test.go b/state/rooms_table_test.go index 167d076..88acf04 100644 --- a/state/rooms_table_test.go +++ b/state/rooms_table_test.go @@ -7,11 +7,9 @@ import ( ) func TestRoomsTable(t *testing.T) { - db, err := sqlx.Open("postgres", postgresConnectionString) - if err != nil { - t.Fatalf("failed to open SQL db: %s", err) - } - _, err = db.Exec(`DROP TABLE IF EXISTS syncv3_rooms`) + db, close := connectToDB(t) + defer close() + _, err := db.Exec(`DROP TABLE IF EXISTS syncv3_rooms`) if err != nil { t.Fatalf("failed to drop rooms table: %s", err) } diff --git a/state/snapshot_table_test.go b/state/snapshot_table_test.go index ea13b79a..884551a 100644 --- a/state/snapshot_table_test.go +++ b/state/snapshot_table_test.go @@ -4,15 +4,12 @@ import ( "reflect" "testing" - "github.com/jmoiron/sqlx" "github.com/lib/pq" ) func TestSnapshotTable(t *testing.T) { - db, err := sqlx.Open("postgres", postgresConnectionString) - if err != nil { - t.Fatalf("failed to open SQL db: %s", err) - } + db, close := connectToDB(t) + defer close() txn, err := db.Beginx() if err != nil { t.Fatalf("failed to start txn: %s", err) diff --git a/state/spaces_table_test.go b/state/spaces_table_test.go index cab1b48..50801f6 100644 --- a/state/spaces_table_test.go +++ b/state/spaces_table_test.go @@ -6,8 +6,6 @@ import ( "reflect" "sort" "testing" - - "github.com/jmoiron/sqlx" ) func matchAnyOrder(t *testing.T, gots, wants []SpaceRelation) { @@ -35,10 +33,8 @@ func noError(t *testing.T, err error) { } func TestSpacesTable(t *testing.T) { - db, err := sqlx.Open("postgres", postgresConnectionString) - if err != nil { - t.Fatalf("failed to open SQL db: %s", err) - } + db, close := connectToDB(t) + defer close() txn, err := db.Beginx() if err != nil { t.Fatalf("failed to start txn: %s", err) @@ -220,10 +216,8 @@ func TestNewSpaceRelationFromEvent(t *testing.T) { } func TestHandleSpaceUpdates(t *testing.T) { - db, err := sqlx.Open("postgres", postgresConnectionString) - if err != nil { - t.Fatalf("failed to open SQL db: %s", err) - } + db, close := connectToDB(t) + defer close() txn, err := db.Beginx() if err != nil { t.Fatalf("failed to start txn: %s", err) diff --git a/state/storage_test.go b/state/storage_test.go index 64c781e..8363cdd 100644 --- a/state/storage_test.go +++ b/state/storage_test.go @@ -5,6 +5,7 @@ import ( "context" "encoding/json" "reflect" + "sort" "testing" "time" @@ -19,6 +20,7 @@ import ( func TestStorageRoomStateBeforeAndAfterEventPosition(t *testing.T) { ctx := context.Background() store := NewStorage(postgresConnectionString) + defer store.Teardown() roomID := "!TestStorageRoomStateAfterEventPosition:localhost" alice := "@alice:localhost" bob := "@bob:localhost" @@ -112,6 +114,7 @@ func TestStorageRoomStateBeforeAndAfterEventPosition(t *testing.T) { func TestStorageJoinedRoomsAfterPosition(t *testing.T) { store := NewStorage(postgresConnectionString) + defer store.Teardown() joinedRoomID := "!joined:bar" invitedRoomID := "!invited:bar" leftRoomID := "!left:bar" @@ -245,6 +248,7 @@ func TestStorageJoinedRoomsAfterPosition(t *testing.T) { // Test the examples on VisibleEventNIDsBetween docs func TestVisibleEventNIDsBetween(t *testing.T) { store := NewStorage(postgresConnectionString) + defer store.Teardown() roomA := "!a:localhost" roomB := "!b:localhost" roomC := "!c:localhost" @@ -474,6 +478,7 @@ func TestVisibleEventNIDsBetween(t *testing.T) { func TestStorageLatestEventsInRoomsPrevBatch(t *testing.T) { store := NewStorage(postgresConnectionString) + defer store.Teardown() roomID := "!joined:bar" alice := "@alice_TestStorageLatestEventsInRoomsPrevBatch:localhost" stateEvents := []json.RawMessage{ @@ -557,6 +562,146 @@ func TestStorageLatestEventsInRoomsPrevBatch(t *testing.T) { } } +func TestGlobalSnapshot(t *testing.T) { + alice := "@TestGlobalSnapshot_alice:localhost" + bob := "@TestGlobalSnapshot_bob:localhost" + roomAlice := "!alice" + roomBob := "!bob" + roomAliceBob := "!alicebob" + roomSpace := "!space" + oldRoomID := "!old" + newRoomID := "!new" + roomType := "room_type_here" + spaceRoomType := "m.space" + roomIDToEventMap := map[string][]json.RawMessage{ + roomAlice: { + testutils.NewStateEvent(t, "m.room.create", "", alice, map[string]interface{}{"creator": alice, "predecessor": map[string]string{ + "room_id": oldRoomID, + "event_id": "$something", + }}), + testutils.NewJoinEvent(t, alice), + testutils.NewStateEvent(t, "m.room.encryption", "", alice, map[string]interface{}{"algorithm": "m.megolm.v1.aes-sha2"}), + }, + roomBob: { + testutils.NewStateEvent(t, "m.room.create", "", bob, map[string]interface{}{"creator": bob, "type": roomType}), + testutils.NewJoinEvent(t, bob), + testutils.NewStateEvent(t, "m.room.name", "", alice, map[string]interface{}{"name": "My Room"}), + }, + roomAliceBob: { + testutils.NewStateEvent(t, "m.room.create", "", bob, map[string]interface{}{"creator": bob}), + testutils.NewJoinEvent(t, bob), + testutils.NewJoinEvent(t, alice), + testutils.NewStateEvent(t, "m.room.canonical_alias", "", alice, map[string]interface{}{"alias": "#alias"}), + testutils.NewStateEvent(t, "m.room.tombstone", "", alice, map[string]interface{}{"replacement_room": newRoomID, "body": "yep"}), + }, + roomSpace: { + testutils.NewStateEvent(t, "m.room.create", "", bob, map[string]interface{}{"creator": bob, "type": spaceRoomType}), + testutils.NewJoinEvent(t, bob), + testutils.NewStateEvent(t, "m.space.child", newRoomID, bob, map[string]interface{}{"via": []string{"somewhere"}}), + testutils.NewStateEvent(t, "m.space.child", "!no_via", bob, map[string]interface{}{}), + testutils.NewStateEvent(t, "m.room.member", alice, bob, map[string]interface{}{"membership": "invite"}), + }, + } + // make a fresh DB which is unpolluted from other tests + db, close := connectToDB(t) + _, err := db.Exec(` + DROP TABLE IF EXISTS syncv3_rooms; + DROP TABLE IF EXISTS syncv3_invites; + DROP TABLE IF EXISTS syncv3_snapshots; + DROP TABLE IF EXISTS syncv3_spaces;`) + if err != nil { + t.Fatalf("failed to wipe DB: %s", err) + } + close() + store := NewStorage(postgresConnectionString) + defer store.Teardown() + for roomID, stateEvents := range roomIDToEventMap { + _, _, err := store.Initialise(roomID, stateEvents) + assertNoError(t, err) + } + snapshot, err := store.GlobalSnapshot() + assertNoError(t, err) + wantJoinedMembers := map[string][]string{ + roomAlice: {alice}, + roomBob: {bob}, + roomAliceBob: {bob, alice}, // user IDs are ordered by event nid, and bob joined first so he is first + roomSpace: {bob}, + } + if !reflect.DeepEqual(snapshot.AllJoinedMembers, wantJoinedMembers) { + t.Errorf("Snapshot.AllJoinedMembers:\ngot: %+v\nwant: %+v", snapshot.AllJoinedMembers, wantJoinedMembers) + } + wantMetadata := map[string]internal.RoomMetadata{ + roomAlice: { + RoomID: roomAlice, + JoinCount: 1, + LastMessageTimestamp: gjson.ParseBytes(roomIDToEventMap[roomAlice][len(roomIDToEventMap[roomAlice])-1]).Get("origin_server_ts").Uint(), + Heroes: []internal.Hero{{ID: alice}}, + Encrypted: true, + PredecessorRoomID: &oldRoomID, + }, + roomBob: { + RoomID: roomBob, + JoinCount: 1, + LastMessageTimestamp: gjson.ParseBytes(roomIDToEventMap[roomBob][len(roomIDToEventMap[roomBob])-1]).Get("origin_server_ts").Uint(), + Heroes: []internal.Hero{{ID: bob}}, + NameEvent: "My Room", + RoomType: &roomType, + }, + roomAliceBob: { + RoomID: roomAliceBob, + JoinCount: 2, + LastMessageTimestamp: gjson.ParseBytes(roomIDToEventMap[roomAliceBob][len(roomIDToEventMap[roomAliceBob])-1]).Get("origin_server_ts").Uint(), + Heroes: []internal.Hero{{ID: bob}, {ID: alice}}, + CanonicalAlias: "#alias", + UpgradedRoomID: &newRoomID, + }, + roomSpace: { + RoomID: roomSpace, + JoinCount: 1, + InviteCount: 1, + LastMessageTimestamp: gjson.ParseBytes(roomIDToEventMap[roomSpace][len(roomIDToEventMap[roomSpace])-1]).Get("origin_server_ts").Uint(), + Heroes: []internal.Hero{{ID: bob}, {ID: alice}}, + RoomType: &spaceRoomType, + ChildSpaceRooms: map[string]struct{}{ + newRoomID: {}, + }, + }, + } + for roomID, want := range wantMetadata { + assertRoomMetadata(t, snapshot.GlobalMetadata[roomID], want) + } +} + +func assertRoomMetadata(t *testing.T, got, want internal.RoomMetadata) { + t.Helper() + assertValue(t, "CanonicalAlias", got.CanonicalAlias, want.CanonicalAlias) + assertValue(t, "ChildSpaceRooms", got.ChildSpaceRooms, want.ChildSpaceRooms) + assertValue(t, "Encrypted", got.Encrypted, want.Encrypted) + assertValue(t, "Heroes", sortHeroes(got.Heroes), sortHeroes(want.Heroes)) + assertValue(t, "InviteCount", got.InviteCount, want.InviteCount) + assertValue(t, "JoinCount", got.JoinCount, want.JoinCount) + assertValue(t, "LastMessageTimestamp", got.LastMessageTimestamp, want.LastMessageTimestamp) + assertValue(t, "NameEvent", got.NameEvent, want.NameEvent) + assertValue(t, "PredecessorRoomID", got.PredecessorRoomID, want.PredecessorRoomID) + assertValue(t, "RoomID", got.RoomID, want.RoomID) + assertValue(t, "RoomType", got.RoomType, want.RoomType) + assertValue(t, "TypingEvent", got.TypingEvent, want.TypingEvent) + assertValue(t, "UpgradedRoomID", got.UpgradedRoomID, want.UpgradedRoomID) +} + +func assertValue(t *testing.T, msg string, got, want interface{}) { + if !reflect.DeepEqual(got, want) { + t.Errorf("%s: got %v want %v", msg, got, want) + } +} + +func sortHeroes(heroes []internal.Hero) []internal.Hero { + sort.Slice(heroes, func(i, j int) bool { + return heroes[i].ID < heroes[j].ID + }) + return heroes +} + func verifyRange(t *testing.T, result map[string][][2]int64, roomID string, wantRanges [][2]int64) { t.Helper() gotRanges := result[roomID] diff --git a/state/to_device_table_test.go b/state/to_device_table_test.go index 174891f..7c9096c 100644 --- a/state/to_device_table_test.go +++ b/state/to_device_table_test.go @@ -5,15 +5,12 @@ import ( "encoding/json" "testing" - "github.com/jmoiron/sqlx" "github.com/tidwall/gjson" ) func TestToDeviceTable(t *testing.T) { - db, err := sqlx.Open("postgres", postgresConnectionString) - if err != nil { - t.Fatalf("failed to open SQL db: %s", err) - } + db, close := connectToDB(t) + defer close() table := NewToDeviceTable(db) deviceID := "FOO" var limit int64 = 999 @@ -22,6 +19,7 @@ func TestToDeviceTable(t *testing.T) { json.RawMessage(`{"sender":"bob","type":"something","content":{"foo":"bar2"}}`), } var lastPos int64 + var err error if lastPos, err = table.InsertMessages(deviceID, msgs); err != nil { t.Fatalf("InsertMessages: %s", err) } @@ -119,10 +117,8 @@ func TestToDeviceTable(t *testing.T) { // Test that https://github.com/uhoreg/matrix-doc/blob/drop-stale-to-device/proposals/3944-drop-stale-to-device.md works for m.room_key_request func TestToDeviceTableDeleteCancels(t *testing.T) { - db, err := sqlx.Open("postgres", postgresConnectionString) - if err != nil { - t.Fatalf("failed to open SQL db: %s", err) - } + db, close := connectToDB(t) + defer close() sender := "SENDER" destination := "DEST" table := NewToDeviceTable(db) @@ -130,7 +126,7 @@ func TestToDeviceTableDeleteCancels(t *testing.T) { reqEv1 := newRoomKeyEvent(t, "request", "1", sender, map[string]interface{}{ "foo": "bar", }) - _, err = table.InsertMessages(destination, []json.RawMessage{reqEv1}) + _, err := table.InsertMessages(destination, []json.RawMessage{reqEv1}) assertNoError(t, err) gotMsgs, _, err := table.Messages(destination, 0, 10) assertNoError(t, err) @@ -189,10 +185,8 @@ func TestToDeviceTableDeleteCancels(t *testing.T) { // Test that unacked events are safe from deletion func TestToDeviceTableNoDeleteUnacks(t *testing.T) { - db, err := sqlx.Open("postgres", postgresConnectionString) - if err != nil { - t.Fatalf("failed to open SQL db: %s", err) - } + db, close := connectToDB(t) + defer close() sender := "SENDER2" destination := "DEST2" table := NewToDeviceTable(db) @@ -238,10 +232,8 @@ func TestToDeviceTableNoDeleteUnacks(t *testing.T) { // Guard against possible message truncation? func TestToDeviceTableBytesInEqualBytesOut(t *testing.T) { - db, err := sqlx.Open("postgres", postgresConnectionString) - if err != nil { - t.Fatalf("failed to open SQL db: %s", err) - } + db, close := connectToDB(t) + defer close() table := NewToDeviceTable(db) testCases := []json.RawMessage{ json.RawMessage(`{}`), @@ -264,7 +256,7 @@ func TestToDeviceTableBytesInEqualBytesOut(t *testing.T) { pos = nextPos } // and all at once - _, err = table.InsertMessages("B", testCases) + _, err := table.InsertMessages("B", testCases) if err != nil { t.Fatalf("InsertMessages: %s", err) } diff --git a/state/txn_table_test.go b/state/txn_table_test.go index a6460fe..42e22e0 100644 --- a/state/txn_table_test.go +++ b/state/txn_table_test.go @@ -3,8 +3,6 @@ package state import ( "testing" "time" - - "github.com/jmoiron/sqlx" ) func assertTxns(t *testing.T, gotEventToTxn map[string]string, wantEventToTxn map[string]string) { @@ -25,10 +23,8 @@ func assertTxns(t *testing.T, gotEventToTxn map[string]string, wantEventToTxn ma } func TestTransactionTable(t *testing.T) { - db, err := sqlx.Open("postgres", postgresConnectionString) - if err != nil { - t.Fatalf("failed to open SQL db: %s", err) - } + db, close := connectToDB(t) + defer close() userID := "@alice:txns" eventA := "$A" eventB := "$B" diff --git a/state/typing_table_test.go b/state/typing_table_test.go index 0521bdf..0db295c 100644 --- a/state/typing_table_test.go +++ b/state/typing_table_test.go @@ -3,15 +3,11 @@ package state import ( "reflect" "testing" - - "github.com/jmoiron/sqlx" ) func TestTypingTable(t *testing.T) { - db, err := sqlx.Open("postgres", postgresConnectionString) - if err != nil { - t.Fatalf("failed to open SQL db: %s", err) - } + db, close := connectToDB(t) + defer close() userIDs := []string{ "@alice:localhost", "@bob:localhost", diff --git a/state/unread_table_test.go b/state/unread_table_test.go index ff3a068..ef256b8 100644 --- a/state/unread_table_test.go +++ b/state/unread_table_test.go @@ -2,15 +2,11 @@ package state import ( "testing" - - "github.com/jmoiron/sqlx" ) func TestUnreadTable(t *testing.T) { - db, err := sqlx.Open("postgres", postgresConnectionString) - if err != nil { - t.Fatalf("failed to open SQL db: %s", err) - } + db, close := connectToDB(t) + defer close() table := NewUnreadTable(db) userID := "@alice:localhost" roomA := "!TestUnreadTableA:localhost"