Add comprehensive regression test for GlobalSnapshot(); ensure we clear db conns when tests end

This commit is contained in:
Kegan Dougal 2023-01-18 14:54:26 +00:00
parent 0ec1088b39
commit a7eed93722
15 changed files with 306 additions and 230 deletions

View File

@ -5,7 +5,6 @@ import (
"sort" "sort"
"testing" "testing"
"github.com/jmoiron/sqlx"
"github.com/matrix-org/sliding-sync/sync2" "github.com/matrix-org/sliding-sync/sync2"
) )
@ -43,10 +42,8 @@ func assertAccountDatasEqual(t *testing.T, msg string, gots, wants []AccountData
} }
func TestAccountData(t *testing.T) { func TestAccountData(t *testing.T) {
db, err := sqlx.Open("postgres", postgresConnectionString) db, close := connectToDB(t)
if err != nil { defer close()
t.Fatalf("failed to open SQL db: %s", err)
}
txn, err := db.Beginx() txn, err := db.Beginx()
if err != nil { if err != nil {
t.Fatalf("failed to start txn: %s", err) t.Fatalf("failed to start txn: %s", err)
@ -175,10 +172,8 @@ func TestAccountData(t *testing.T) {
} }
func TestAccountDataIDIncrements(t *testing.T) { func TestAccountDataIDIncrements(t *testing.T) {
db, err := sqlx.Open("postgres", postgresConnectionString) db, close := connectToDB(t)
if err != nil { defer close()
t.Fatalf("failed to open SQL db: %s", err)
}
txn, err := db.Beginx() txn, err := db.Beginx()
if err != nil { if err != nil {
t.Fatalf("failed to start txn: %s", err) t.Fatalf("failed to start txn: %s", err)

View File

@ -7,7 +7,6 @@ import (
"sort" "sort"
"testing" "testing"
"github.com/jmoiron/sqlx"
"github.com/matrix-org/sliding-sync/sync2" "github.com/matrix-org/sliding-sync/sync2"
"github.com/matrix-org/sliding-sync/testutils" "github.com/matrix-org/sliding-sync/testutils"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
@ -21,10 +20,8 @@ func TestAccumulatorInitialise(t *testing.T) {
[]byte(`{"event_id":"C", "type":"m.room.join_rules", "state_key":"", "content":{"join_rule":"public"}}`), []byte(`{"event_id":"C", "type":"m.room.join_rules", "state_key":"", "content":{"join_rule":"public"}}`),
} }
roomEventIDs := []string{"A", "B", "C"} roomEventIDs := []string{"A", "B", "C"}
db, err := sqlx.Open("postgres", postgresConnectionString) db, close := connectToDB(t)
if err != nil { defer close()
t.Fatalf("failed to open SQL db: %s", err)
}
accumulator := NewAccumulator(db) accumulator := NewAccumulator(db)
added, initSnapID, err := accumulator.Initialise(roomID, roomEvents) added, initSnapID, err := accumulator.Initialise(roomID, roomEvents)
if err != nil { if err != nil {
@ -99,12 +96,10 @@ func TestAccumulatorAccumulate(t *testing.T) {
[]byte(`{"event_id":"E", "type":"m.room.member", "state_key":"@me:localhost", "content":{"membership":"join"}}`), []byte(`{"event_id":"E", "type":"m.room.member", "state_key":"@me:localhost", "content":{"membership":"join"}}`),
[]byte(`{"event_id":"F", "type":"m.room.join_rules", "state_key":"", "content":{"join_rule":"public"}}`), []byte(`{"event_id":"F", "type":"m.room.join_rules", "state_key":"", "content":{"join_rule":"public"}}`),
} }
db, err := sqlx.Open("postgres", postgresConnectionString) db, close := connectToDB(t)
if err != nil { defer close()
t.Fatalf("failed to open SQL db: %s", err)
}
accumulator := NewAccumulator(db) accumulator := NewAccumulator(db)
_, _, err = accumulator.Initialise(roomID, roomEvents) _, _, err := accumulator.Initialise(roomID, roomEvents)
if err != nil { if err != nil {
t.Fatalf("failed to Initialise accumulator: %s", err) t.Fatalf("failed to Initialise accumulator: %s", err)
} }
@ -197,12 +192,10 @@ func TestAccumulatorAccumulate(t *testing.T) {
func TestAccumulatorDelta(t *testing.T) { func TestAccumulatorDelta(t *testing.T) {
roomID := "!TestAccumulatorDelta:localhost" roomID := "!TestAccumulatorDelta:localhost"
db, err := sqlx.Open("postgres", postgresConnectionString) db, close := connectToDB(t)
if err != nil { defer close()
t.Fatalf("failed to open SQL db: %s", err)
}
accumulator := NewAccumulator(db) accumulator := NewAccumulator(db)
_, _, err = accumulator.Initialise(roomID, nil) _, _, err := accumulator.Initialise(roomID, nil)
if err != nil { if err != nil {
t.Fatalf("failed to Initialise accumulator: %s", err) t.Fatalf("failed to Initialise accumulator: %s", err)
} }
@ -248,12 +241,10 @@ func TestAccumulatorDelta(t *testing.T) {
func TestAccumulatorMembershipLogs(t *testing.T) { func TestAccumulatorMembershipLogs(t *testing.T) {
roomID := "!TestAccumulatorMembershipLogs:localhost" roomID := "!TestAccumulatorMembershipLogs:localhost"
db, err := sqlx.Open("postgres", postgresConnectionString) db, close := connectToDB(t)
if err != nil { defer close()
t.Fatalf("failed to open SQL db: %s", err)
}
accumulator := NewAccumulator(db) accumulator := NewAccumulator(db)
_, _, err = accumulator.Initialise(roomID, nil) _, _, err := accumulator.Initialise(roomID, nil)
if err != nil { if err != nil {
t.Fatalf("failed to Initialise accumulator: %s", err) t.Fatalf("failed to Initialise accumulator: %s", err)
} }
@ -389,13 +380,11 @@ func TestAccumulatorDupeEvents(t *testing.T) {
t.Fatalf("failed to unmarshal: %s", err) t.Fatalf("failed to unmarshal: %s", err)
} }
db, err := sqlx.Open("postgres", postgresConnectionString) db, close := connectToDB(t)
if err != nil { defer close()
t.Fatalf("failed to open SQL db: %s", err)
}
accumulator := NewAccumulator(db) accumulator := NewAccumulator(db)
roomID := "!buggy:localhost" roomID := "!buggy:localhost"
_, _, err = accumulator.Initialise(roomID, joinRoom.State.Events) _, _, err := accumulator.Initialise(roomID, joinRoom.State.Events)
if err != nil { if err != nil {
t.Fatalf("failed to Initialise accumulator: %s", err) t.Fatalf("failed to Initialise accumulator: %s", err)
} }
@ -432,14 +421,12 @@ func TestAccumulatorMisorderedGraceful(t *testing.T) {
) )
t.Logf("A=member-alice, B=msg, C=create, D=member-bob") t.Logf("A=member-alice, B=msg, C=create, D=member-bob")
db, err := sqlx.Open("postgres", postgresConnectionString) db, close := connectToDB(t)
if err != nil { defer close()
t.Fatalf("failed to open SQL db: %s", err)
}
accumulator := NewAccumulator(db) accumulator := NewAccumulator(db)
roomID := "!TestAccumulatorStateReset:localhost" roomID := "!TestAccumulatorStateReset:localhost"
// Create a room with initial state A,C // Create a room with initial state A,C
_, _, err = accumulator.Initialise(roomID, []json.RawMessage{ _, _, err := accumulator.Initialise(roomID, []json.RawMessage{
eventA, eventC, eventA, eventC,
}) })
if err != nil { if err != nil {
@ -509,10 +496,8 @@ func TestCalculateNewSnapshotDupe(t *testing.T) {
t.Errorf("assertNIDsEqual: got %v want %v", a, b) t.Errorf("assertNIDsEqual: got %v want %v", a, b)
} }
} }
db, err := sqlx.Open("postgres", postgresConnectionString) db, close := connectToDB(t)
if err != nil { defer close()
t.Fatalf("failed to open SQL db: %s", err)
}
testCases := []struct { testCases := []struct {
input StrippedEvents input StrippedEvents
inputEvent Event inputEvent Event

View File

@ -4,7 +4,6 @@ import (
"reflect" "reflect"
"testing" "testing"
"github.com/jmoiron/sqlx"
"github.com/matrix-org/sliding-sync/internal" "github.com/matrix-org/sliding-sync/internal"
) )
@ -25,10 +24,8 @@ func assertDeviceData(t *testing.T, g, w internal.DeviceData) {
} }
func TestDeviceDataTableSwaps(t *testing.T) { func TestDeviceDataTableSwaps(t *testing.T) {
db, err := sqlx.Open("postgres", postgresConnectionString) db, close := connectToDB(t)
if err != nil { defer close()
t.Fatalf("failed to open SQL db: %s", err)
}
table := NewDeviceDataTable(db) table := NewDeviceDataTable(db)
userID := "@bob" userID := "@bob"
deviceID := "BOB" deviceID := "BOB"
@ -67,7 +64,7 @@ func TestDeviceDataTableSwaps(t *testing.T) {
}, },
} }
for _, dd := range deltas { for _, dd := range deltas {
_, err = table.Upsert(&dd) _, err := table.Upsert(&dd)
assertNoError(t, err) assertNoError(t, err)
} }
@ -173,10 +170,8 @@ func TestDeviceDataTableSwaps(t *testing.T) {
} }
func TestDeviceDataTableBitset(t *testing.T) { func TestDeviceDataTableBitset(t *testing.T) {
db, err := sqlx.Open("postgres", postgresConnectionString) db, close := connectToDB(t)
if err != nil { defer close()
t.Fatalf("failed to open SQL db: %s", err)
}
table := NewDeviceDataTable(db) table := NewDeviceDataTable(db)
userID := "@bobTestDeviceDataTableBitset" userID := "@bobTestDeviceDataTableBitset"
deviceID := "BOBTestDeviceDataTableBitset" deviceID := "BOBTestDeviceDataTableBitset"
@ -202,7 +197,7 @@ func TestDeviceDataTableBitset(t *testing.T) {
}, },
} }
_, err = table.Upsert(&otkUpdate) _, err := table.Upsert(&otkUpdate)
assertNoError(t, err) assertNoError(t, err)
got, err := table.Select(userID, deviceID, true) got, err := table.Select(userID, deviceID, true)
assertNoError(t, err) assertNoError(t, err)

View File

@ -6,7 +6,6 @@ import (
"fmt" "fmt"
"testing" "testing"
"github.com/jmoiron/sqlx"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
"github.com/matrix-org/sliding-sync/sqlutil" "github.com/matrix-org/sliding-sync/sqlutil"
@ -14,10 +13,8 @@ import (
) )
func TestEventTable(t *testing.T) { func TestEventTable(t *testing.T) {
db, err := sqlx.Open("postgres", postgresConnectionString) db, close := connectToDB(t)
if err != nil { defer close()
t.Fatalf("failed to open SQL db: %s", err)
}
txn, err := db.Beginx() txn, err := db.Beginx()
if err != nil { if err != nil {
t.Fatalf("failed to start txn: %s", err) t.Fatalf("failed to start txn: %s", err)
@ -199,10 +196,8 @@ func TestEventTable(t *testing.T) {
} }
func TestEventTableNullValue(t *testing.T) { func TestEventTableNullValue(t *testing.T) {
db, err := sqlx.Open("postgres", postgresConnectionString) db, close := connectToDB(t)
if err != nil { defer close()
t.Fatalf("failed to open SQL db: %s", err)
}
txn, err := db.Beginx() txn, err := db.Beginx()
if err != nil { if err != nil {
t.Fatalf("failed to start txn: %s", err) t.Fatalf("failed to start txn: %s", err)
@ -241,10 +236,8 @@ func TestEventTableNullValue(t *testing.T) {
} }
func TestEventTableDupeInsert(t *testing.T) { func TestEventTableDupeInsert(t *testing.T) {
db, err := sqlx.Open("postgres", postgresConnectionString) db, close := connectToDB(t)
if err != nil { defer close()
t.Fatalf("failed to open SQL db: %s", err)
}
// first insert // first insert
txn, err := db.Beginx() txn, err := db.Beginx()
if err != nil { if err != nil {
@ -303,10 +296,8 @@ func TestEventTableDupeInsert(t *testing.T) {
} }
func TestEventTableSelectEventsBetween(t *testing.T) { func TestEventTableSelectEventsBetween(t *testing.T) {
db, err := sqlx.Open("postgres", postgresConnectionString) db, close := connectToDB(t)
if err != nil { defer close()
t.Fatalf("failed to open SQL db: %s", err)
}
txn, err := db.Beginx() txn, err := db.Beginx()
if err != nil { if err != nil {
t.Fatalf("failed to start txn: %s", err) t.Fatalf("failed to start txn: %s", err)
@ -347,85 +338,85 @@ func TestEventTableSelectEventsBetween(t *testing.T) {
} }
txn.Commit() txn.Commit()
t.Run("selecting multiple events known lower bound", func(t *testing.T) { t.Run("subgroup", func(t *testing.T) {
t.Parallel() t.Run("selecting multiple events known lower bound", func(t *testing.T) {
txn2, err := db.Beginx() t.Parallel()
if err != nil { txn2, err := db.Beginx()
t.Fatalf("failed to start txn: %s", err) if err != nil {
} t.Fatalf("failed to start txn: %s", err)
defer txn2.Rollback() }
events, err := table.SelectByIDs(txn2, true, []string{eventIDs[0]}) defer txn2.Rollback()
if err != nil || len(events) == 0 { events, err := table.SelectByIDs(txn2, true, []string{eventIDs[0]})
t.Fatalf("failed to extract event for lower bound: %s", err) if err != nil || len(events) == 0 {
} t.Fatalf("failed to extract event for lower bound: %s", err)
events, err = table.SelectEventsBetween(txn2, searchRoomID, int64(events[0].NID), EventsEnd, 1000) }
if err != nil { events, err = table.SelectEventsBetween(txn2, searchRoomID, int64(events[0].NID), EventsEnd, 1000)
t.Fatalf("failed to SelectEventsBetween: %s", err) if err != nil {
} t.Fatalf("failed to SelectEventsBetween: %s", err)
// 3 as 1 is from a different room }
if len(events) != 3 { // 3 as 1 is from a different room
t.Fatalf("wanted 3 events, got %d", len(events)) if len(events) != 3 {
} t.Fatalf("wanted 3 events, got %d", len(events))
}) }
t.Run("selecting multiple events known lower and upper bound", func(t *testing.T) { })
t.Parallel() t.Run("selecting multiple events known lower and upper bound", func(t *testing.T) {
txn3, err := db.Beginx() t.Parallel()
if err != nil { txn3, err := db.Beginx()
t.Fatalf("failed to start txn: %s", err) if err != nil {
} t.Fatalf("failed to start txn: %s", err)
defer txn3.Rollback() }
events, err := table.SelectByIDs(txn3, true, []string{eventIDs[0], eventIDs[2]}) defer txn3.Rollback()
if err != nil || len(events) == 0 { events, err := table.SelectByIDs(txn3, true, []string{eventIDs[0], eventIDs[2]})
t.Fatalf("failed to extract event for lower/upper bound: %s", err) if err != nil || len(events) == 0 {
} t.Fatalf("failed to extract event for lower/upper bound: %s", err)
events, err = table.SelectEventsBetween(txn3, searchRoomID, int64(events[0].NID), int64(events[1].NID), 1000) }
if err != nil { events, err = table.SelectEventsBetween(txn3, searchRoomID, int64(events[0].NID), int64(events[1].NID), 1000)
t.Fatalf("failed to SelectEventsBetween: %s", err) if err != nil {
} t.Fatalf("failed to SelectEventsBetween: %s", err)
// eventIDs[1] and eventIDs[2] }
if len(events) != 2 { // eventIDs[1] and eventIDs[2]
t.Fatalf("wanted 2 events, got %d", len(events)) if len(events) != 2 {
} t.Fatalf("wanted 2 events, got %d", len(events))
}) }
t.Run("selecting multiple events unknown bounds (all events)", func(t *testing.T) { })
t.Parallel() t.Run("selecting multiple events unknown bounds (all events)", func(t *testing.T) {
txn4, err := db.Beginx() t.Parallel()
if err != nil { txn4, err := db.Beginx()
t.Fatalf("failed to start txn: %s", err) if err != nil {
} t.Fatalf("failed to start txn: %s", err)
defer txn4.Rollback() }
gotEvents, err := table.SelectEventsBetween(txn4, searchRoomID, EventsStart, EventsEnd, 1000) defer txn4.Rollback()
if err != nil { gotEvents, err := table.SelectEventsBetween(txn4, searchRoomID, EventsStart, EventsEnd, 1000)
t.Fatalf("failed to SelectEventsBetween: %s", err) if err != nil {
} t.Fatalf("failed to SelectEventsBetween: %s", err)
// one less as one event is for a different room }
if len(gotEvents) != (len(events) - 1) { // one less as one event is for a different room
t.Fatalf("wanted %d events, got %d", len(events)-1, len(gotEvents)) if len(gotEvents) != (len(events) - 1) {
} t.Fatalf("wanted %d events, got %d", len(events)-1, len(gotEvents))
}) }
t.Run("selecting multiple events hitting the limit", func(t *testing.T) { })
t.Parallel() t.Run("selecting multiple events hitting the limit", func(t *testing.T) {
txn5, err := db.Beginx() t.Parallel()
if err != nil { txn5, err := db.Beginx()
t.Fatalf("failed to start txn: %s", err) if err != nil {
} t.Fatalf("failed to start txn: %s", err)
defer txn5.Rollback() }
limit := 2 defer txn5.Rollback()
gotEvents, err := table.SelectEventsBetween(txn5, searchRoomID, EventsStart, EventsEnd, limit) limit := 2
if err != nil { gotEvents, err := table.SelectEventsBetween(txn5, searchRoomID, EventsStart, EventsEnd, limit)
t.Fatalf("failed to SelectEventsBetween: %s", err) if err != nil {
} t.Fatalf("failed to SelectEventsBetween: %s", err)
if len(gotEvents) != limit { }
t.Fatalf("wanted %d events, got %d", limit, len(gotEvents)) if len(gotEvents) != limit {
} t.Fatalf("wanted %d events, got %d", limit, len(gotEvents))
}
})
}) })
} }
func TestEventTableMembershipDetection(t *testing.T) { func TestEventTableMembershipDetection(t *testing.T) {
db, err := sqlx.Open("postgres", postgresConnectionString) db, close := connectToDB(t)
if err != nil { defer close()
t.Fatalf("failed to open SQL db: %s", err)
}
txn, err := db.Beginx() txn, err := db.Beginx()
if err != nil { if err != nil {
t.Fatalf("failed to start txn: %s", err) t.Fatalf("failed to start txn: %s", err)
@ -546,10 +537,8 @@ func TestChunkify(t *testing.T) {
} }
func TestEventTableSelectEventsWithTypeStateKey(t *testing.T) { func TestEventTableSelectEventsWithTypeStateKey(t *testing.T) {
db, err := sqlx.Open("postgres", postgresConnectionString) db, close := connectToDB(t)
if err != nil { defer close()
t.Fatalf("failed to open SQL db: %s", err)
}
txn, err := db.Beginx() txn, err := db.Beginx()
if err != nil { if err != nil {
t.Fatalf("failed to start txn: %s", err) t.Fatalf("failed to start txn: %s", err)
@ -645,10 +634,8 @@ func TestEventTableSelectEventsWithTypeStateKey(t *testing.T) {
// Do a massive insert/select for event IDs (greater than postgres limit) and ensure it works. // Do a massive insert/select for event IDs (greater than postgres limit) and ensure it works.
func TestTortureEventTable(t *testing.T) { func TestTortureEventTable(t *testing.T) {
db, err := sqlx.Open("postgres", postgresConnectionString) db, close := connectToDB(t)
if err != nil { defer close()
t.Fatalf("failed to open SQL db: %s", err)
}
txn, err := db.Beginx() txn, err := db.Beginx()
if err != nil { if err != nil {
t.Fatalf("failed to start txn: %s", err) t.Fatalf("failed to start txn: %s", err)
@ -699,10 +686,8 @@ func TestTortureEventTable(t *testing.T) {
// 3: SelectClosestPrevBatch with an event without a prev_batch returns the next newest (stream order) event with a prev_batch // 3: SelectClosestPrevBatch with an event without a prev_batch returns the next newest (stream order) event with a prev_batch
// 4: SelectClosestPrevBatch with an event without a prev_batch returns nothing if there are no newer events with a prev_batch // 4: SelectClosestPrevBatch with an event without a prev_batch returns nothing if there are no newer events with a prev_batch
func TestEventTablePrevBatch(t *testing.T) { func TestEventTablePrevBatch(t *testing.T) {
db, err := sqlx.Open("postgres", postgresConnectionString) db, close := connectToDB(t)
if err != nil { defer close()
t.Fatalf("failed to open SQL db: %s", err)
}
txn, err := db.Beginx() txn, err := db.Beginx()
if err != nil { if err != nil {
t.Fatalf("failed to start txn: %s", err) t.Fatalf("failed to start txn: %s", err)
@ -827,10 +812,8 @@ func TestEventTablePrevBatch(t *testing.T) {
} }
func TestRemoveUnsignedTXNID(t *testing.T) { func TestRemoveUnsignedTXNID(t *testing.T) {
db, err := sqlx.Open("postgres", postgresConnectionString) db, close := connectToDB(t)
if err != nil { defer close()
t.Fatalf("failed to open SQL db: %s", err)
}
txn, err := db.Beginx() txn, err := db.Beginx()
if err != nil { if err != nil {
t.Fatalf("failed to start txn: %s", err) t.Fatalf("failed to start txn: %s", err)

View File

@ -4,15 +4,11 @@ import (
"encoding/json" "encoding/json"
"reflect" "reflect"
"testing" "testing"
"github.com/jmoiron/sqlx"
) )
func TestInviteTable(t *testing.T) { func TestInviteTable(t *testing.T) {
db, err := sqlx.Open("postgres", postgresConnectionString) db, close := connectToDB(t)
if err != nil { defer close()
t.Fatalf("failed to open SQL db: %s", err)
}
table := NewInvitesTable(db) table := NewInvitesTable(db)
alice := "@alice:localhost" alice := "@alice:localhost"
bob := "@bob:localhost" bob := "@bob:localhost"

View File

@ -4,6 +4,7 @@ import (
"os" "os"
"testing" "testing"
"github.com/jmoiron/sqlx"
"github.com/matrix-org/sliding-sync/testutils" "github.com/matrix-org/sliding-sync/testutils"
) )
@ -14,3 +15,13 @@ func TestMain(m *testing.M) {
exitCode := m.Run() exitCode := m.Run()
os.Exit(exitCode) os.Exit(exitCode)
} }
func connectToDB(t *testing.T) (*sqlx.DB, func()) {
db, err := sqlx.Open("postgres", postgresConnectionString)
if err != nil {
t.Fatalf("failed to open SQL db: %s", err)
}
return db, func() {
db.Close()
}
}

View File

@ -6,7 +6,6 @@ import (
"sort" "sort"
"testing" "testing"
"github.com/jmoiron/sqlx"
"github.com/matrix-org/sliding-sync/internal" "github.com/matrix-org/sliding-sync/internal"
) )
@ -33,10 +32,8 @@ func parsedReceiptsEqual(t *testing.T, got, want []internal.Receipt) {
} }
func TestReceiptTable(t *testing.T) { func TestReceiptTable(t *testing.T) {
db, err := sqlx.Open("postgres", postgresConnectionString) db, close := connectToDB(t)
if err != nil { defer close()
t.Fatalf("failed to open SQL db: %s", err)
}
roomA := "!A:ReceiptTable" roomA := "!A:ReceiptTable"
roomB := "!B:ReceiptTable" roomB := "!B:ReceiptTable"
edu := json.RawMessage(`{ edu := json.RawMessage(`{

View File

@ -7,11 +7,9 @@ import (
) )
func TestRoomsTable(t *testing.T) { func TestRoomsTable(t *testing.T) {
db, err := sqlx.Open("postgres", postgresConnectionString) db, close := connectToDB(t)
if err != nil { defer close()
t.Fatalf("failed to open SQL db: %s", err) _, err := db.Exec(`DROP TABLE IF EXISTS syncv3_rooms`)
}
_, err = db.Exec(`DROP TABLE IF EXISTS syncv3_rooms`)
if err != nil { if err != nil {
t.Fatalf("failed to drop rooms table: %s", err) t.Fatalf("failed to drop rooms table: %s", err)
} }

View File

@ -4,15 +4,12 @@ import (
"reflect" "reflect"
"testing" "testing"
"github.com/jmoiron/sqlx"
"github.com/lib/pq" "github.com/lib/pq"
) )
func TestSnapshotTable(t *testing.T) { func TestSnapshotTable(t *testing.T) {
db, err := sqlx.Open("postgres", postgresConnectionString) db, close := connectToDB(t)
if err != nil { defer close()
t.Fatalf("failed to open SQL db: %s", err)
}
txn, err := db.Beginx() txn, err := db.Beginx()
if err != nil { if err != nil {
t.Fatalf("failed to start txn: %s", err) t.Fatalf("failed to start txn: %s", err)

View File

@ -6,8 +6,6 @@ import (
"reflect" "reflect"
"sort" "sort"
"testing" "testing"
"github.com/jmoiron/sqlx"
) )
func matchAnyOrder(t *testing.T, gots, wants []SpaceRelation) { func matchAnyOrder(t *testing.T, gots, wants []SpaceRelation) {
@ -35,10 +33,8 @@ func noError(t *testing.T, err error) {
} }
func TestSpacesTable(t *testing.T) { func TestSpacesTable(t *testing.T) {
db, err := sqlx.Open("postgres", postgresConnectionString) db, close := connectToDB(t)
if err != nil { defer close()
t.Fatalf("failed to open SQL db: %s", err)
}
txn, err := db.Beginx() txn, err := db.Beginx()
if err != nil { if err != nil {
t.Fatalf("failed to start txn: %s", err) t.Fatalf("failed to start txn: %s", err)
@ -220,10 +216,8 @@ func TestNewSpaceRelationFromEvent(t *testing.T) {
} }
func TestHandleSpaceUpdates(t *testing.T) { func TestHandleSpaceUpdates(t *testing.T) {
db, err := sqlx.Open("postgres", postgresConnectionString) db, close := connectToDB(t)
if err != nil { defer close()
t.Fatalf("failed to open SQL db: %s", err)
}
txn, err := db.Beginx() txn, err := db.Beginx()
if err != nil { if err != nil {
t.Fatalf("failed to start txn: %s", err) t.Fatalf("failed to start txn: %s", err)

View File

@ -5,6 +5,7 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"reflect" "reflect"
"sort"
"testing" "testing"
"time" "time"
@ -19,6 +20,7 @@ import (
func TestStorageRoomStateBeforeAndAfterEventPosition(t *testing.T) { func TestStorageRoomStateBeforeAndAfterEventPosition(t *testing.T) {
ctx := context.Background() ctx := context.Background()
store := NewStorage(postgresConnectionString) store := NewStorage(postgresConnectionString)
defer store.Teardown()
roomID := "!TestStorageRoomStateAfterEventPosition:localhost" roomID := "!TestStorageRoomStateAfterEventPosition:localhost"
alice := "@alice:localhost" alice := "@alice:localhost"
bob := "@bob:localhost" bob := "@bob:localhost"
@ -112,6 +114,7 @@ func TestStorageRoomStateBeforeAndAfterEventPosition(t *testing.T) {
func TestStorageJoinedRoomsAfterPosition(t *testing.T) { func TestStorageJoinedRoomsAfterPosition(t *testing.T) {
store := NewStorage(postgresConnectionString) store := NewStorage(postgresConnectionString)
defer store.Teardown()
joinedRoomID := "!joined:bar" joinedRoomID := "!joined:bar"
invitedRoomID := "!invited:bar" invitedRoomID := "!invited:bar"
leftRoomID := "!left:bar" leftRoomID := "!left:bar"
@ -245,6 +248,7 @@ func TestStorageJoinedRoomsAfterPosition(t *testing.T) {
// Test the examples on VisibleEventNIDsBetween docs // Test the examples on VisibleEventNIDsBetween docs
func TestVisibleEventNIDsBetween(t *testing.T) { func TestVisibleEventNIDsBetween(t *testing.T) {
store := NewStorage(postgresConnectionString) store := NewStorage(postgresConnectionString)
defer store.Teardown()
roomA := "!a:localhost" roomA := "!a:localhost"
roomB := "!b:localhost" roomB := "!b:localhost"
roomC := "!c:localhost" roomC := "!c:localhost"
@ -474,6 +478,7 @@ func TestVisibleEventNIDsBetween(t *testing.T) {
func TestStorageLatestEventsInRoomsPrevBatch(t *testing.T) { func TestStorageLatestEventsInRoomsPrevBatch(t *testing.T) {
store := NewStorage(postgresConnectionString) store := NewStorage(postgresConnectionString)
defer store.Teardown()
roomID := "!joined:bar" roomID := "!joined:bar"
alice := "@alice_TestStorageLatestEventsInRoomsPrevBatch:localhost" alice := "@alice_TestStorageLatestEventsInRoomsPrevBatch:localhost"
stateEvents := []json.RawMessage{ stateEvents := []json.RawMessage{
@ -557,6 +562,146 @@ func TestStorageLatestEventsInRoomsPrevBatch(t *testing.T) {
} }
} }
func TestGlobalSnapshot(t *testing.T) {
alice := "@TestGlobalSnapshot_alice:localhost"
bob := "@TestGlobalSnapshot_bob:localhost"
roomAlice := "!alice"
roomBob := "!bob"
roomAliceBob := "!alicebob"
roomSpace := "!space"
oldRoomID := "!old"
newRoomID := "!new"
roomType := "room_type_here"
spaceRoomType := "m.space"
roomIDToEventMap := map[string][]json.RawMessage{
roomAlice: {
testutils.NewStateEvent(t, "m.room.create", "", alice, map[string]interface{}{"creator": alice, "predecessor": map[string]string{
"room_id": oldRoomID,
"event_id": "$something",
}}),
testutils.NewJoinEvent(t, alice),
testutils.NewStateEvent(t, "m.room.encryption", "", alice, map[string]interface{}{"algorithm": "m.megolm.v1.aes-sha2"}),
},
roomBob: {
testutils.NewStateEvent(t, "m.room.create", "", bob, map[string]interface{}{"creator": bob, "type": roomType}),
testutils.NewJoinEvent(t, bob),
testutils.NewStateEvent(t, "m.room.name", "", alice, map[string]interface{}{"name": "My Room"}),
},
roomAliceBob: {
testutils.NewStateEvent(t, "m.room.create", "", bob, map[string]interface{}{"creator": bob}),
testutils.NewJoinEvent(t, bob),
testutils.NewJoinEvent(t, alice),
testutils.NewStateEvent(t, "m.room.canonical_alias", "", alice, map[string]interface{}{"alias": "#alias"}),
testutils.NewStateEvent(t, "m.room.tombstone", "", alice, map[string]interface{}{"replacement_room": newRoomID, "body": "yep"}),
},
roomSpace: {
testutils.NewStateEvent(t, "m.room.create", "", bob, map[string]interface{}{"creator": bob, "type": spaceRoomType}),
testutils.NewJoinEvent(t, bob),
testutils.NewStateEvent(t, "m.space.child", newRoomID, bob, map[string]interface{}{"via": []string{"somewhere"}}),
testutils.NewStateEvent(t, "m.space.child", "!no_via", bob, map[string]interface{}{}),
testutils.NewStateEvent(t, "m.room.member", alice, bob, map[string]interface{}{"membership": "invite"}),
},
}
// make a fresh DB which is unpolluted from other tests
db, close := connectToDB(t)
_, err := db.Exec(`
DROP TABLE IF EXISTS syncv3_rooms;
DROP TABLE IF EXISTS syncv3_invites;
DROP TABLE IF EXISTS syncv3_snapshots;
DROP TABLE IF EXISTS syncv3_spaces;`)
if err != nil {
t.Fatalf("failed to wipe DB: %s", err)
}
close()
store := NewStorage(postgresConnectionString)
defer store.Teardown()
for roomID, stateEvents := range roomIDToEventMap {
_, _, err := store.Initialise(roomID, stateEvents)
assertNoError(t, err)
}
snapshot, err := store.GlobalSnapshot()
assertNoError(t, err)
wantJoinedMembers := map[string][]string{
roomAlice: {alice},
roomBob: {bob},
roomAliceBob: {bob, alice}, // user IDs are ordered by event nid, and bob joined first so he is first
roomSpace: {bob},
}
if !reflect.DeepEqual(snapshot.AllJoinedMembers, wantJoinedMembers) {
t.Errorf("Snapshot.AllJoinedMembers:\ngot: %+v\nwant: %+v", snapshot.AllJoinedMembers, wantJoinedMembers)
}
wantMetadata := map[string]internal.RoomMetadata{
roomAlice: {
RoomID: roomAlice,
JoinCount: 1,
LastMessageTimestamp: gjson.ParseBytes(roomIDToEventMap[roomAlice][len(roomIDToEventMap[roomAlice])-1]).Get("origin_server_ts").Uint(),
Heroes: []internal.Hero{{ID: alice}},
Encrypted: true,
PredecessorRoomID: &oldRoomID,
},
roomBob: {
RoomID: roomBob,
JoinCount: 1,
LastMessageTimestamp: gjson.ParseBytes(roomIDToEventMap[roomBob][len(roomIDToEventMap[roomBob])-1]).Get("origin_server_ts").Uint(),
Heroes: []internal.Hero{{ID: bob}},
NameEvent: "My Room",
RoomType: &roomType,
},
roomAliceBob: {
RoomID: roomAliceBob,
JoinCount: 2,
LastMessageTimestamp: gjson.ParseBytes(roomIDToEventMap[roomAliceBob][len(roomIDToEventMap[roomAliceBob])-1]).Get("origin_server_ts").Uint(),
Heroes: []internal.Hero{{ID: bob}, {ID: alice}},
CanonicalAlias: "#alias",
UpgradedRoomID: &newRoomID,
},
roomSpace: {
RoomID: roomSpace,
JoinCount: 1,
InviteCount: 1,
LastMessageTimestamp: gjson.ParseBytes(roomIDToEventMap[roomSpace][len(roomIDToEventMap[roomSpace])-1]).Get("origin_server_ts").Uint(),
Heroes: []internal.Hero{{ID: bob}, {ID: alice}},
RoomType: &spaceRoomType,
ChildSpaceRooms: map[string]struct{}{
newRoomID: {},
},
},
}
for roomID, want := range wantMetadata {
assertRoomMetadata(t, snapshot.GlobalMetadata[roomID], want)
}
}
func assertRoomMetadata(t *testing.T, got, want internal.RoomMetadata) {
t.Helper()
assertValue(t, "CanonicalAlias", got.CanonicalAlias, want.CanonicalAlias)
assertValue(t, "ChildSpaceRooms", got.ChildSpaceRooms, want.ChildSpaceRooms)
assertValue(t, "Encrypted", got.Encrypted, want.Encrypted)
assertValue(t, "Heroes", sortHeroes(got.Heroes), sortHeroes(want.Heroes))
assertValue(t, "InviteCount", got.InviteCount, want.InviteCount)
assertValue(t, "JoinCount", got.JoinCount, want.JoinCount)
assertValue(t, "LastMessageTimestamp", got.LastMessageTimestamp, want.LastMessageTimestamp)
assertValue(t, "NameEvent", got.NameEvent, want.NameEvent)
assertValue(t, "PredecessorRoomID", got.PredecessorRoomID, want.PredecessorRoomID)
assertValue(t, "RoomID", got.RoomID, want.RoomID)
assertValue(t, "RoomType", got.RoomType, want.RoomType)
assertValue(t, "TypingEvent", got.TypingEvent, want.TypingEvent)
assertValue(t, "UpgradedRoomID", got.UpgradedRoomID, want.UpgradedRoomID)
}
func assertValue(t *testing.T, msg string, got, want interface{}) {
if !reflect.DeepEqual(got, want) {
t.Errorf("%s: got %v want %v", msg, got, want)
}
}
func sortHeroes(heroes []internal.Hero) []internal.Hero {
sort.Slice(heroes, func(i, j int) bool {
return heroes[i].ID < heroes[j].ID
})
return heroes
}
func verifyRange(t *testing.T, result map[string][][2]int64, roomID string, wantRanges [][2]int64) { func verifyRange(t *testing.T, result map[string][][2]int64, roomID string, wantRanges [][2]int64) {
t.Helper() t.Helper()
gotRanges := result[roomID] gotRanges := result[roomID]

View File

@ -5,15 +5,12 @@ import (
"encoding/json" "encoding/json"
"testing" "testing"
"github.com/jmoiron/sqlx"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
) )
func TestToDeviceTable(t *testing.T) { func TestToDeviceTable(t *testing.T) {
db, err := sqlx.Open("postgres", postgresConnectionString) db, close := connectToDB(t)
if err != nil { defer close()
t.Fatalf("failed to open SQL db: %s", err)
}
table := NewToDeviceTable(db) table := NewToDeviceTable(db)
deviceID := "FOO" deviceID := "FOO"
var limit int64 = 999 var limit int64 = 999
@ -22,6 +19,7 @@ func TestToDeviceTable(t *testing.T) {
json.RawMessage(`{"sender":"bob","type":"something","content":{"foo":"bar2"}}`), json.RawMessage(`{"sender":"bob","type":"something","content":{"foo":"bar2"}}`),
} }
var lastPos int64 var lastPos int64
var err error
if lastPos, err = table.InsertMessages(deviceID, msgs); err != nil { if lastPos, err = table.InsertMessages(deviceID, msgs); err != nil {
t.Fatalf("InsertMessages: %s", err) t.Fatalf("InsertMessages: %s", err)
} }
@ -119,10 +117,8 @@ func TestToDeviceTable(t *testing.T) {
// Test that https://github.com/uhoreg/matrix-doc/blob/drop-stale-to-device/proposals/3944-drop-stale-to-device.md works for m.room_key_request // Test that https://github.com/uhoreg/matrix-doc/blob/drop-stale-to-device/proposals/3944-drop-stale-to-device.md works for m.room_key_request
func TestToDeviceTableDeleteCancels(t *testing.T) { func TestToDeviceTableDeleteCancels(t *testing.T) {
db, err := sqlx.Open("postgres", postgresConnectionString) db, close := connectToDB(t)
if err != nil { defer close()
t.Fatalf("failed to open SQL db: %s", err)
}
sender := "SENDER" sender := "SENDER"
destination := "DEST" destination := "DEST"
table := NewToDeviceTable(db) table := NewToDeviceTable(db)
@ -130,7 +126,7 @@ func TestToDeviceTableDeleteCancels(t *testing.T) {
reqEv1 := newRoomKeyEvent(t, "request", "1", sender, map[string]interface{}{ reqEv1 := newRoomKeyEvent(t, "request", "1", sender, map[string]interface{}{
"foo": "bar", "foo": "bar",
}) })
_, err = table.InsertMessages(destination, []json.RawMessage{reqEv1}) _, err := table.InsertMessages(destination, []json.RawMessage{reqEv1})
assertNoError(t, err) assertNoError(t, err)
gotMsgs, _, err := table.Messages(destination, 0, 10) gotMsgs, _, err := table.Messages(destination, 0, 10)
assertNoError(t, err) assertNoError(t, err)
@ -189,10 +185,8 @@ func TestToDeviceTableDeleteCancels(t *testing.T) {
// Test that unacked events are safe from deletion // Test that unacked events are safe from deletion
func TestToDeviceTableNoDeleteUnacks(t *testing.T) { func TestToDeviceTableNoDeleteUnacks(t *testing.T) {
db, err := sqlx.Open("postgres", postgresConnectionString) db, close := connectToDB(t)
if err != nil { defer close()
t.Fatalf("failed to open SQL db: %s", err)
}
sender := "SENDER2" sender := "SENDER2"
destination := "DEST2" destination := "DEST2"
table := NewToDeviceTable(db) table := NewToDeviceTable(db)
@ -238,10 +232,8 @@ func TestToDeviceTableNoDeleteUnacks(t *testing.T) {
// Guard against possible message truncation? // Guard against possible message truncation?
func TestToDeviceTableBytesInEqualBytesOut(t *testing.T) { func TestToDeviceTableBytesInEqualBytesOut(t *testing.T) {
db, err := sqlx.Open("postgres", postgresConnectionString) db, close := connectToDB(t)
if err != nil { defer close()
t.Fatalf("failed to open SQL db: %s", err)
}
table := NewToDeviceTable(db) table := NewToDeviceTable(db)
testCases := []json.RawMessage{ testCases := []json.RawMessage{
json.RawMessage(`{}`), json.RawMessage(`{}`),
@ -264,7 +256,7 @@ func TestToDeviceTableBytesInEqualBytesOut(t *testing.T) {
pos = nextPos pos = nextPos
} }
// and all at once // and all at once
_, err = table.InsertMessages("B", testCases) _, err := table.InsertMessages("B", testCases)
if err != nil { if err != nil {
t.Fatalf("InsertMessages: %s", err) t.Fatalf("InsertMessages: %s", err)
} }

View File

@ -3,8 +3,6 @@ package state
import ( import (
"testing" "testing"
"time" "time"
"github.com/jmoiron/sqlx"
) )
func assertTxns(t *testing.T, gotEventToTxn map[string]string, wantEventToTxn map[string]string) { func assertTxns(t *testing.T, gotEventToTxn map[string]string, wantEventToTxn map[string]string) {
@ -25,10 +23,8 @@ func assertTxns(t *testing.T, gotEventToTxn map[string]string, wantEventToTxn ma
} }
func TestTransactionTable(t *testing.T) { func TestTransactionTable(t *testing.T) {
db, err := sqlx.Open("postgres", postgresConnectionString) db, close := connectToDB(t)
if err != nil { defer close()
t.Fatalf("failed to open SQL db: %s", err)
}
userID := "@alice:txns" userID := "@alice:txns"
eventA := "$A" eventA := "$A"
eventB := "$B" eventB := "$B"

View File

@ -3,15 +3,11 @@ package state
import ( import (
"reflect" "reflect"
"testing" "testing"
"github.com/jmoiron/sqlx"
) )
func TestTypingTable(t *testing.T) { func TestTypingTable(t *testing.T) {
db, err := sqlx.Open("postgres", postgresConnectionString) db, close := connectToDB(t)
if err != nil { defer close()
t.Fatalf("failed to open SQL db: %s", err)
}
userIDs := []string{ userIDs := []string{
"@alice:localhost", "@alice:localhost",
"@bob:localhost", "@bob:localhost",

View File

@ -2,15 +2,11 @@ package state
import ( import (
"testing" "testing"
"github.com/jmoiron/sqlx"
) )
func TestUnreadTable(t *testing.T) { func TestUnreadTable(t *testing.T) {
db, err := sqlx.Open("postgres", postgresConnectionString) db, close := connectToDB(t)
if err != nil { defer close()
t.Fatalf("failed to open SQL db: %s", err)
}
table := NewUnreadTable(db) table := NewUnreadTable(db)
userID := "@alice:localhost" userID := "@alice:localhost"
roomA := "!TestUnreadTableA:localhost" roomA := "!TestUnreadTableA:localhost"