Merge pull request #432 from matrix-org/kegan/device-list-updates

Ensure device list updates are robust to race conditions and network failures
This commit is contained in:
Kegan Dougal 2024-05-10 10:35:09 +01:00 committed by GitHub
commit 0d22cf1da5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 198 additions and 79 deletions

View File

@ -46,7 +46,7 @@ func NewDeviceDataTable(db *sqlx.DB) *DeviceDataTable {
func (t *DeviceDataTable) Select(userID, deviceID string, swap bool) (result *internal.DeviceData, err error) { func (t *DeviceDataTable) Select(userID, deviceID string, swap bool) (result *internal.DeviceData, err error) {
err = sqlutil.WithTransaction(t.db, func(txn *sqlx.Tx) error { err = sqlutil.WithTransaction(t.db, func(txn *sqlx.Tx) error {
var row DeviceDataRow var row DeviceDataRow
err = txn.Get(&row, `SELECT data FROM syncv3_device_data WHERE user_id=$1 AND device_id=$2`, userID, deviceID) err = txn.Get(&row, `SELECT data FROM syncv3_device_data WHERE user_id=$1 AND device_id=$2 FOR UPDATE`, userID, deviceID)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
// if there is no device data for this user, it's not an error. // if there is no device data for this user, it's not an error.
@ -70,6 +70,9 @@ func (t *DeviceDataTable) Select(userID, deviceID string, swap bool) (result *in
if !swap { if !swap {
return nil // don't swap return nil // don't swap
} }
// the caller will only look at sent, so make sure what is new is now in sent
result.DeviceLists.Sent = result.DeviceLists.New
// swap over the fields // swap over the fields
writeBack := *result writeBack := *result
writeBack.DeviceLists.Sent = result.DeviceLists.New writeBack.DeviceLists.Sent = result.DeviceLists.New
@ -104,7 +107,7 @@ func (t *DeviceDataTable) Upsert(dd *internal.DeviceData) (err error) {
err = sqlutil.WithTransaction(t.db, func(txn *sqlx.Tx) error { err = sqlutil.WithTransaction(t.db, func(txn *sqlx.Tx) error {
// select what already exists // select what already exists
var row DeviceDataRow var row DeviceDataRow
err = txn.Get(&row, `SELECT data FROM syncv3_device_data WHERE user_id=$1 AND device_id=$2`, dd.UserID, dd.DeviceID) err = txn.Get(&row, `SELECT data FROM syncv3_device_data WHERE user_id=$1 AND device_id=$2 FOR UPDATE`, dd.UserID, dd.DeviceID)
if err != nil && err != sql.ErrNoRows { if err != nil && err != sql.ErrNoRows {
return err return err
} }

View File

@ -22,17 +22,20 @@ func assertDeviceData(t *testing.T, g, w internal.DeviceData) {
assertVal(t, "FallbackKeyTypes", g.FallbackKeyTypes, w.FallbackKeyTypes) assertVal(t, "FallbackKeyTypes", g.FallbackKeyTypes, w.FallbackKeyTypes)
assertVal(t, "OTKCounts", g.OTKCounts, w.OTKCounts) assertVal(t, "OTKCounts", g.OTKCounts, w.OTKCounts)
assertVal(t, "ChangedBits", g.ChangedBits, w.ChangedBits) assertVal(t, "ChangedBits", g.ChangedBits, w.ChangedBits)
assertVal(t, "DeviceLists", g.DeviceLists, w.DeviceLists) if w.DeviceLists.Sent != nil {
assertVal(t, "DeviceLists.Sent", g.DeviceLists.Sent, w.DeviceLists.Sent)
}
} }
func TestDeviceDataTableSwaps(t *testing.T) { // Tests OTKCounts and FallbackKeyTypes behaviour
func TestDeviceDataTableOTKCountAndFallbackKeyTypes(t *testing.T) {
db, close := connectToDB(t) db, close := connectToDB(t)
defer close() defer close()
table := NewDeviceDataTable(db) table := NewDeviceDataTable(db)
userID := "@bob" userID := "@TestDeviceDataTableOTKCountAndFallbackKeyTypes"
deviceID := "BOB" deviceID := "BOB"
// test accumulating deltas // these are individual updates from Synapse from /sync v2
deltas := []internal.DeviceData{ deltas := []internal.DeviceData{
{ {
UserID: userID, UserID: userID,
@ -46,9 +49,6 @@ func TestDeviceDataTableSwaps(t *testing.T) {
UserID: userID, UserID: userID,
DeviceID: deviceID, DeviceID: deviceID,
FallbackKeyTypes: []string{"foobar"}, FallbackKeyTypes: []string{"foobar"},
DeviceLists: internal.DeviceLists{
New: internal.ToDeviceListChangesMap([]string{"alice"}, nil),
},
}, },
{ {
UserID: userID, UserID: userID,
@ -60,16 +60,38 @@ func TestDeviceDataTableSwaps(t *testing.T) {
{ {
UserID: userID, UserID: userID,
DeviceID: deviceID, DeviceID: deviceID,
DeviceLists: internal.DeviceLists{
New: internal.ToDeviceListChangesMap([]string{"💣"}, nil),
},
}, },
} }
// apply them
for _, dd := range deltas { for _, dd := range deltas {
err := table.Upsert(&dd) err := table.Upsert(&dd)
assertNoError(t, err) assertNoError(t, err)
} }
// read them without swap, it should have replaced them correctly.
// Because sync v2 sends the complete OTK count and complete fallback key types
// every time, we always use the latest values. Because we aren't swapping, repeated
// reads produce the same result.
for i := 0; i < 3; i++ {
got, err := table.Select(userID, deviceID, false)
mustNotError(t, err)
want := internal.DeviceData{
UserID: userID,
DeviceID: deviceID,
OTKCounts: map[string]int{
"foo": 99,
},
FallbackKeyTypes: []string{"foobar"},
}
want.SetFallbackKeysChanged()
want.SetOTKCountChanged()
assertDeviceData(t, *got, want)
}
// now we swap the data. This still returns the same values, but the changed bits are no longer set
// on subsequent reads.
got, err := table.Select(userID, deviceID, true)
mustNotError(t, err)
want := internal.DeviceData{ want := internal.DeviceData{
UserID: userID, UserID: userID,
DeviceID: deviceID, DeviceID: deviceID,
@ -77,68 +99,118 @@ func TestDeviceDataTableSwaps(t *testing.T) {
"foo": 99, "foo": 99,
}, },
FallbackKeyTypes: []string{"foobar"}, FallbackKeyTypes: []string{"foobar"},
DeviceLists: internal.DeviceLists{
New: internal.ToDeviceListChangesMap([]string{"alice", "💣"}, nil),
Sent: map[string]int{},
},
} }
want.SetFallbackKeysChanged() want.SetFallbackKeysChanged()
want.SetOTKCountChanged() want.SetOTKCountChanged()
// check we can read-only select assertDeviceData(t, *got, want)
// subsequent read
got, err = table.Select(userID, deviceID, false)
mustNotError(t, err)
want = internal.DeviceData{
UserID: userID,
DeviceID: deviceID,
OTKCounts: map[string]int{
"foo": 99,
},
FallbackKeyTypes: []string{"foobar"},
}
assertDeviceData(t, *got, want)
}
// Tests the DeviceLists field
func TestDeviceDataTableDeviceList(t *testing.T) {
db, close := connectToDB(t)
defer close()
table := NewDeviceDataTable(db)
userID := "@TestDeviceDataTableDeviceList"
deviceID := "BOB"
// these are individual updates from Synapse from /sync v2
deltas := []internal.DeviceData{
{
UserID: userID,
DeviceID: deviceID,
DeviceLists: internal.DeviceLists{
New: internal.ToDeviceListChangesMap([]string{"alice"}, nil),
},
},
{
UserID: userID,
DeviceID: deviceID,
DeviceLists: internal.DeviceLists{
New: internal.ToDeviceListChangesMap([]string{"💣"}, nil),
},
},
}
// apply them
for _, dd := range deltas {
err := table.Upsert(&dd)
assertNoError(t, err)
}
// check we can read-only select. This doesn't modify any fields.
for i := 0; i < 3; i++ { for i := 0; i < 3; i++ {
got, err := table.Select(userID, deviceID, false) got, err := table.Select(userID, deviceID, false)
assertNoError(t, err) assertNoError(t, err)
assertDeviceData(t, *got, want) assertDeviceData(t, *got, internal.DeviceData{
UserID: userID,
DeviceID: deviceID,
DeviceLists: internal.DeviceLists{
Sent: internal.MapStringInt{}, // until we "swap" we don't consume the New entries
},
})
} }
// now swap-er-roo, at this point we still expect the "old" data, // now swap-er-roo, which shifts everything from New into Sent.
// as it is the first time we swap
got, err := table.Select(userID, deviceID, true) got, err := table.Select(userID, deviceID, true)
assertNoError(t, err) assertNoError(t, err)
assertDeviceData(t, *got, want) assertDeviceData(t, *got, internal.DeviceData{
UserID: userID,
// changed bits were reset when we swapped DeviceID: deviceID,
want2 := want DeviceLists: internal.DeviceLists{
want2.DeviceLists = internal.DeviceLists{ Sent: internal.ToDeviceListChangesMap([]string{"alice", "💣"}, nil),
Sent: internal.ToDeviceListChangesMap([]string{"alice", "💣"}, nil), },
New: map[string]int{}, })
}
want2.ChangedBits = 0
want.ChangedBits = 0
// this is permanent, read-only views show this too. // this is permanent, read-only views show this too.
// Since we have swapped previously, we now expect New to be empty
// and Sent to be set. Swap again to clear Sent.
got, err = table.Select(userID, deviceID, true)
assertNoError(t, err)
assertDeviceData(t, *got, want2)
// We now expect empty DeviceLists, as we swapped twice.
got, err = table.Select(userID, deviceID, false) got, err = table.Select(userID, deviceID, false)
assertNoError(t, err) assertNoError(t, err)
want3 := want2 assertDeviceData(t, *got, internal.DeviceData{
want3.DeviceLists = internal.DeviceLists{ UserID: userID,
Sent: map[string]int{}, DeviceID: deviceID,
New: map[string]int{}, DeviceLists: internal.DeviceLists{
} Sent: internal.ToDeviceListChangesMap([]string{"alice", "💣"}, nil),
assertDeviceData(t, *got, want3) },
})
// We now expect empty DeviceLists, as we swapped twice.
got, err = table.Select(userID, deviceID, true)
assertNoError(t, err)
assertDeviceData(t, *got, internal.DeviceData{
UserID: userID,
DeviceID: deviceID,
DeviceLists: internal.DeviceLists{
Sent: internal.MapStringInt{},
},
})
// get back the original state // get back the original state
//err = table.DeleteDevice(userID, deviceID)
assertNoError(t, err) assertNoError(t, err)
for _, dd := range deltas { for _, dd := range deltas {
err = table.Upsert(&dd) err = table.Upsert(&dd)
assertNoError(t, err) assertNoError(t, err)
} }
want.SetFallbackKeysChanged() // Move original state to Sent by swapping
want.SetOTKCountChanged() got, err = table.Select(userID, deviceID, true)
got, err = table.Select(userID, deviceID, false)
assertNoError(t, err)
assertDeviceData(t, *got, want)
// swap once then add once so both sent and new are populated
// Moves Alice and Bob to Sent
_, err = table.Select(userID, deviceID, true)
assertNoError(t, err) assertNoError(t, err)
assertDeviceData(t, *got, internal.DeviceData{
UserID: userID,
DeviceID: deviceID,
DeviceLists: internal.DeviceLists{
Sent: internal.ToDeviceListChangesMap([]string{"alice", "💣"}, nil),
},
})
// Add new entries to New before acknowledging Sent
err = table.Upsert(&internal.DeviceData{ err = table.Upsert(&internal.DeviceData{
UserID: userID, UserID: userID,
DeviceID: deviceID, DeviceID: deviceID,
@ -148,20 +220,18 @@ func TestDeviceDataTableSwaps(t *testing.T) {
}) })
assertNoError(t, err) assertNoError(t, err)
want.ChangedBits = 0 // Reading without swapping does not move New->Sent, so returns the previous value
want4 := want
want4.DeviceLists = internal.DeviceLists{
Sent: internal.ToDeviceListChangesMap([]string{"alice", "💣"}, nil),
New: internal.ToDeviceListChangesMap([]string{"💣"}, []string{"charlie"}),
}
// Without swapping, we expect Alice and Bob in Sent, and Bob and Charlie in New
got, err = table.Select(userID, deviceID, false) got, err = table.Select(userID, deviceID, false)
assertNoError(t, err) assertNoError(t, err)
assertDeviceData(t, *got, want4) assertDeviceData(t, *got, internal.DeviceData{
UserID: userID,
DeviceID: deviceID,
DeviceLists: internal.DeviceLists{
Sent: internal.ToDeviceListChangesMap([]string{"alice", "💣"}, nil),
},
})
// another append then consume // Append even more items to New
// This results in dave to be added to New
err = table.Upsert(&internal.DeviceData{ err = table.Upsert(&internal.DeviceData{
UserID: userID, UserID: userID,
DeviceID: deviceID, DeviceID: deviceID,
@ -170,24 +240,28 @@ func TestDeviceDataTableSwaps(t *testing.T) {
}, },
}) })
assertNoError(t, err) assertNoError(t, err)
got, err = table.Select(userID, deviceID, true)
assertNoError(t, err)
want5 := want4
want5.DeviceLists = internal.DeviceLists{
Sent: internal.ToDeviceListChangesMap([]string{"alice", "💣"}, nil),
New: internal.ToDeviceListChangesMap([]string{"💣"}, []string{"charlie", "dave"}),
}
assertDeviceData(t, *got, want5)
// Swapping again clears New // Now swap: all the combined items in New go into Sent
got, err = table.Select(userID, deviceID, true) got, err = table.Select(userID, deviceID, true)
assertNoError(t, err) assertNoError(t, err)
want5 = want4 assertDeviceData(t, *got, internal.DeviceData{
want5.DeviceLists = internal.DeviceLists{ UserID: userID,
Sent: internal.ToDeviceListChangesMap([]string{"💣"}, []string{"charlie", "dave"}), DeviceID: deviceID,
New: map[string]int{}, DeviceLists: internal.DeviceLists{
} Sent: internal.ToDeviceListChangesMap([]string{"💣", "dave"}, []string{"charlie", "dave"}),
assertDeviceData(t, *got, want5) },
})
// Swapping again clears Sent out, and since nothing is in New we get an empty list
got, err = table.Select(userID, deviceID, true)
assertNoError(t, err)
assertDeviceData(t, *got, internal.DeviceData{
UserID: userID,
DeviceID: deviceID,
DeviceLists: internal.DeviceLists{
Sent: internal.MapStringInt{},
},
})
// delete everything, no data returned // delete everything, no data returned
assertNoError(t, table.DeleteDevice(userID, deviceID)) assertNoError(t, table.DeleteDevice(userID, deviceID))

View File

@ -193,6 +193,48 @@ func TestExtensionE2EE(t *testing.T) {
if time.Since(start) >= (500 * time.Millisecond) { if time.Since(start) >= (500 * time.Millisecond) {
t.Fatalf("sync request did not return immediately with OTK counts") t.Fatalf("sync request did not return immediately with OTK counts")
} }
// check that if we lose a device list update and restart from nothing, we see the same update
v2.queueResponse(alice, sync2.SyncResponse{
DeviceLists: struct {
Changed []string `json:"changed,omitempty"`
Left []string `json:"left,omitempty"`
}{
Changed: wantChanged,
Left: wantLeft,
},
})
v2.waitUntilEmpty(t, alice)
res = v3.mustDoV3RequestWithPos(t, aliceToken, res.Pos, sync3.Request{
Lists: map[string]sync3.RequestList{"a": {
Ranges: sync3.SliceRanges{
[2]int64{0, 10}, // doesn't matter
},
}},
// enable the E2EE extension
Extensions: extensions.Request{
E2EE: &extensions.E2EERequest{
Core: extensions.Core{Enabled: &boolTrue},
},
},
})
m.MatchResponse(t, res, m.MatchDeviceLists(wantChanged, wantLeft))
// we actually lost this update: start again and we should see it.
res = v3.mustDoV3Request(t, aliceToken, sync3.Request{
Lists: map[string]sync3.RequestList{"a": {
Ranges: sync3.SliceRanges{
[2]int64{0, 10}, // doesn't matter
},
}},
// enable the E2EE extension
Extensions: extensions.Request{
E2EE: &extensions.E2EERequest{
Core: extensions.Core{Enabled: &boolTrue},
},
},
})
m.MatchResponse(t, res, m.MatchDeviceLists(wantChanged, wantLeft))
} }
// Checks that to-device messages are passed from v2 to v3 // Checks that to-device messages are passed from v2 to v3