Add migrations and refactor internal structs

This commit is contained in:
Kegan Dougal 2024-05-17 13:45:14 +01:00
parent 2cd9a81ab2
commit b383ed0d82
12 changed files with 445 additions and 223 deletions

View File

@ -1,9 +1,5 @@
package internal
import (
"sync"
)
const (
bitOTKCount int = iota
bitFallbackKeyTypes
@ -18,9 +14,22 @@ func isBitSet(n int, bit int) bool {
return val > 0
}
// DeviceData contains useful data for this user's device. This list can be expanded without prompting
// schema changes. These values are upserted into the database and persisted forever.
// DeviceData contains useful data for this user's device.
type DeviceData struct {
DeviceListChanges
DeviceKeyData
UserID string
DeviceID string
}
// This is calculated from device_lists table
type DeviceListChanges struct {
DeviceListChanged []string
DeviceListLeft []string
}
// This gets serialised as CBOR in device_data table
type DeviceKeyData struct {
// Contains the latest device_one_time_keys_count values.
// Set whenever this field arrives down the v2 poller, and it replaces what was previously there.
OTKCounts MapStringInt `json:"otk"`
@ -28,95 +37,22 @@ type DeviceData struct {
// Set whenever this field arrives down the v2 poller, and it replaces what was previously there.
// If this is a nil slice this means no change. If this is an empty slice then this means the fallback key was used up.
FallbackKeyTypes []string `json:"fallback"`
DeviceLists DeviceLists `json:"dl"`
// bitset for which device data changes are present. They accumulate until they get swapped over
// when they get reset
ChangedBits int `json:"c"`
UserID string
DeviceID string
}
func (dd *DeviceData) SetOTKCountChanged() {
func (dd *DeviceKeyData) SetOTKCountChanged() {
dd.ChangedBits = setBit(dd.ChangedBits, bitOTKCount)
}
func (dd *DeviceData) SetFallbackKeysChanged() {
func (dd *DeviceKeyData) SetFallbackKeysChanged() {
dd.ChangedBits = setBit(dd.ChangedBits, bitFallbackKeyTypes)
}
func (dd *DeviceData) OTKCountChanged() bool {
func (dd *DeviceKeyData) OTKCountChanged() bool {
return isBitSet(dd.ChangedBits, bitOTKCount)
}
func (dd *DeviceData) FallbackKeysChanged() bool {
func (dd *DeviceKeyData) FallbackKeysChanged() bool {
return isBitSet(dd.ChangedBits, bitFallbackKeyTypes)
}
type UserDeviceKey struct {
UserID string
DeviceID string
}
type DeviceDataMap struct {
deviceDataMu *sync.Mutex
deviceDataMap map[UserDeviceKey]*DeviceData
Pos int64
}
func NewDeviceDataMap(startPos int64, devices []DeviceData) *DeviceDataMap {
ddm := &DeviceDataMap{
deviceDataMu: &sync.Mutex{},
deviceDataMap: make(map[UserDeviceKey]*DeviceData),
Pos: startPos,
}
for i, dd := range devices {
ddm.deviceDataMap[UserDeviceKey{
UserID: dd.UserID,
DeviceID: dd.DeviceID,
}] = &devices[i]
}
return ddm
}
func (d *DeviceDataMap) Get(userID, deviceID string) *DeviceData {
key := UserDeviceKey{
UserID: userID,
DeviceID: deviceID,
}
d.deviceDataMu.Lock()
defer d.deviceDataMu.Unlock()
dd, ok := d.deviceDataMap[key]
if !ok {
return nil
}
return dd
}
func (d *DeviceDataMap) Update(dd DeviceData) DeviceData {
key := UserDeviceKey{
UserID: dd.UserID,
DeviceID: dd.DeviceID,
}
d.deviceDataMu.Lock()
defer d.deviceDataMu.Unlock()
existing, ok := d.deviceDataMap[key]
if !ok {
existing = &DeviceData{
UserID: dd.UserID,
DeviceID: dd.DeviceID,
}
}
if dd.OTKCounts != nil {
existing.OTKCounts = dd.OTKCounts
}
if dd.FallbackKeyTypes != nil {
existing.FallbackKeyTypes = dd.FallbackKeyTypes
}
existing.DeviceLists = existing.DeviceLists.Combine(dd.DeviceLists)
d.deviceDataMap[key] = existing
return *existing
}

View File

@ -15,9 +15,9 @@ type DeviceDataRow struct {
ID int64 `db:"id"`
UserID string `db:"user_id"`
DeviceID string `db:"device_id"`
// This will contain internal.DeviceData serialised as JSON. It's stored in a single column as we don't
// This will contain internal.DeviceKeyData serialised as JSON. It's stored in a single column as we don't
// need to perform searches on this data.
Data []byte `db:"data"`
KeyData []byte `db:"data"`
}
type DeviceDataTable struct {
@ -47,6 +47,7 @@ func NewDeviceDataTable(db *sqlx.DB) *DeviceDataTable {
// This should only be called by the v3 HTTP APIs when servicing an E2EE extension request.
func (t *DeviceDataTable) Select(userID, deviceID string, swap bool) (result *internal.DeviceData, err error) {
err = sqlutil.WithTransaction(t.db, func(txn *sqlx.Tx) error {
// grab otk counts and fallback key types
var row DeviceDataRow
err = txn.Get(&row, `SELECT data FROM syncv3_device_data WHERE user_id=$1 AND device_id=$2 FOR UPDATE`, userID, deviceID)
if err != nil {
@ -56,32 +57,38 @@ func (t *DeviceDataTable) Select(userID, deviceID string, swap bool) (result *in
}
return err
}
result = &internal.DeviceData{}
var keyData *internal.DeviceKeyData
// unmarshal to swap
opts := cbor.DecOptions{
MaxMapPairs: 1000000000, // 1 billion :(
}
decMode, err := opts.DecMode()
if err != nil {
return err
}
if err = decMode.Unmarshal(row.Data, &result); err != nil {
if err = cbor.Unmarshal(row.KeyData, &keyData); err != nil {
return err
}
result.UserID = userID
result.DeviceID = deviceID
if keyData != nil {
result.DeviceKeyData = *keyData
}
deviceListChanges, err := t.deviceListTable.SelectTx(txn, userID, deviceID, swap)
if err != nil {
return err
}
for targetUserID, targetState := range deviceListChanges {
switch targetState {
case internal.DeviceListChanged:
result.DeviceListChanged = append(result.DeviceListChanged, targetUserID)
case internal.DeviceListLeft:
result.DeviceListLeft = append(result.DeviceListLeft, targetUserID)
}
}
if !swap {
return nil // don't swap
}
// the caller will only look at sent, so make sure what is new is now in sent
result.DeviceLists.Sent = result.DeviceLists.New
// swap over the fields
writeBack := *result
writeBack.DeviceLists.Sent = result.DeviceLists.New
writeBack.DeviceLists.New = make(map[string]int)
writeBack := *keyData
writeBack.ChangedBits = 0
if reflect.DeepEqual(result, &writeBack) {
if reflect.DeepEqual(keyData, &writeBack) {
// The update to the DB would be a no-op; don't bother with it.
// This helps reduce write usage and the contention on the unique index for
// the device_data table.
@ -99,14 +106,13 @@ func (t *DeviceDataTable) Select(userID, deviceID string, swap bool) (result *in
return
}
func (t *DeviceDataTable) DeleteDevice(userID, deviceID string) error {
_, err := t.db.Exec(`DELETE FROM syncv3_device_data WHERE user_id = $1 AND device_id = $2`, userID, deviceID)
// Upsert combines what is in the database for this user|device with the partial entry `dd`
func (t *DeviceDataTable) Upsert(dd *internal.DeviceData, deviceListChanges map[string]int) (err error) {
err = sqlutil.WithTransaction(t.db, func(txn *sqlx.Tx) error {
// Update device lists
if err = t.deviceListTable.UpsertTx(txn, dd.UserID, dd.DeviceID, deviceListChanges); err != nil {
return err
}
// Upsert combines what is in the database for this user|device with the partial entry `dd`
func (t *DeviceDataTable) Upsert(dd *internal.DeviceData) (err error) {
err = sqlutil.WithTransaction(t.db, func(txn *sqlx.Tx) error {
// select what already exists
var row DeviceDataRow
err = txn.Get(&row, `SELECT data FROM syncv3_device_data WHERE user_id=$1 AND device_id=$2 FOR UPDATE`, dd.UserID, dd.DeviceID)
@ -114,30 +120,22 @@ func (t *DeviceDataTable) Upsert(dd *internal.DeviceData) (err error) {
return err
}
// unmarshal and combine
var tempDD internal.DeviceData
if len(row.Data) > 0 {
opts := cbor.DecOptions{
MaxMapPairs: 1000000000, // 1 billion :(
}
decMode, err := opts.DecMode()
if err != nil {
return err
}
if err = decMode.Unmarshal(row.Data, &tempDD); err != nil {
var keyData internal.DeviceKeyData
if len(row.KeyData) > 0 {
if err = cbor.Unmarshal(row.KeyData, &keyData); err != nil {
return err
}
}
if dd.FallbackKeyTypes != nil {
tempDD.FallbackKeyTypes = dd.FallbackKeyTypes
tempDD.SetFallbackKeysChanged()
keyData.FallbackKeyTypes = dd.FallbackKeyTypes
keyData.SetFallbackKeysChanged()
}
if dd.OTKCounts != nil {
tempDD.OTKCounts = dd.OTKCounts
tempDD.SetOTKCountChanged()
keyData.OTKCounts = dd.OTKCounts
keyData.SetOTKCountChanged()
}
tempDD.DeviceLists = tempDD.DeviceLists.Combine(dd.DeviceLists)
data, err := cbor.Marshal(tempDD)
data, err := cbor.Marshal(keyData)
if err != nil {
return err
}

View File

@ -22,9 +22,6 @@ func assertDeviceData(t *testing.T, g, w internal.DeviceData) {
assertVal(t, "FallbackKeyTypes", g.FallbackKeyTypes, w.FallbackKeyTypes)
assertVal(t, "OTKCounts", g.OTKCounts, w.OTKCounts)
assertVal(t, "ChangedBits", g.ChangedBits, w.ChangedBits)
if w.DeviceLists.Sent != nil {
assertVal(t, "DeviceLists.Sent", g.DeviceLists.Sent, w.DeviceLists.Sent)
}
}
// Tests OTKCounts and FallbackKeyTypes behaviour
@ -40,23 +37,29 @@ func TestDeviceDataTableOTKCountAndFallbackKeyTypes(t *testing.T) {
{
UserID: userID,
DeviceID: deviceID,
DeviceKeyData: internal.DeviceKeyData{
OTKCounts: map[string]int{
"foo": 100,
"bar": 92,
},
},
{
UserID: userID,
DeviceID: deviceID,
FallbackKeyTypes: []string{"foobar"},
},
{
UserID: userID,
DeviceID: deviceID,
DeviceKeyData: internal.DeviceKeyData{
FallbackKeyTypes: []string{"foobar"},
},
},
{
UserID: userID,
DeviceID: deviceID,
DeviceKeyData: internal.DeviceKeyData{
OTKCounts: map[string]int{
"foo": 99,
},
},
},
{
UserID: userID,
DeviceID: deviceID,
@ -65,7 +68,7 @@ func TestDeviceDataTableOTKCountAndFallbackKeyTypes(t *testing.T) {
// apply them
for _, dd := range deltas {
err := table.Upsert(&dd)
err := table.Upsert(&dd, nil)
assertNoError(t, err)
}
@ -79,10 +82,12 @@ func TestDeviceDataTableOTKCountAndFallbackKeyTypes(t *testing.T) {
want := internal.DeviceData{
UserID: userID,
DeviceID: deviceID,
DeviceKeyData: internal.DeviceKeyData{
OTKCounts: map[string]int{
"foo": 99,
},
FallbackKeyTypes: []string{"foobar"},
},
}
want.SetFallbackKeysChanged()
want.SetOTKCountChanged()
@ -95,10 +100,12 @@ func TestDeviceDataTableOTKCountAndFallbackKeyTypes(t *testing.T) {
want := internal.DeviceData{
UserID: userID,
DeviceID: deviceID,
DeviceKeyData: internal.DeviceKeyData{
OTKCounts: map[string]int{
"foo": 99,
},
FallbackKeyTypes: []string{"foobar"},
},
}
want.SetFallbackKeysChanged()
want.SetOTKCountChanged()
@ -110,10 +117,12 @@ func TestDeviceDataTableOTKCountAndFallbackKeyTypes(t *testing.T) {
want = internal.DeviceData{
UserID: userID,
DeviceID: deviceID,
DeviceKeyData: internal.DeviceKeyData{
OTKCounts: map[string]int{
"foo": 99,
},
FallbackKeyTypes: []string{"foobar"},
},
}
assertDeviceData(t, *got, want)
}
@ -127,29 +136,32 @@ func TestDeviceDataTableBitset(t *testing.T) {
otkUpdate := internal.DeviceData{
UserID: userID,
DeviceID: deviceID,
DeviceKeyData: internal.DeviceKeyData{
OTKCounts: map[string]int{
"foo": 100,
"bar": 92,
},
DeviceLists: internal.DeviceLists{New: map[string]int{}, Sent: map[string]int{}},
},
}
fallbakKeyUpdate := internal.DeviceData{
UserID: userID,
DeviceID: deviceID,
DeviceKeyData: internal.DeviceKeyData{
FallbackKeyTypes: []string{"foo", "bar"},
DeviceLists: internal.DeviceLists{New: map[string]int{}, Sent: map[string]int{}},
},
}
bothUpdate := internal.DeviceData{
UserID: userID,
DeviceID: deviceID,
DeviceKeyData: internal.DeviceKeyData{
FallbackKeyTypes: []string{"both"},
OTKCounts: map[string]int{
"both": 100,
},
DeviceLists: internal.DeviceLists{New: map[string]int{}, Sent: map[string]int{}},
},
}
err := table.Upsert(&otkUpdate)
err := table.Upsert(&otkUpdate, nil)
assertNoError(t, err)
got, err := table.Select(userID, deviceID, true)
assertNoError(t, err)
@ -161,7 +173,7 @@ func TestDeviceDataTableBitset(t *testing.T) {
otkUpdate.ChangedBits = 0
assertDeviceData(t, *got, otkUpdate)
// now same for fallback keys, but we won't swap them so it should return those diffs
err = table.Upsert(&fallbakKeyUpdate)
err = table.Upsert(&fallbakKeyUpdate, nil)
assertNoError(t, err)
fallbakKeyUpdate.OTKCounts = otkUpdate.OTKCounts
got, err = table.Select(userID, deviceID, false)
@ -173,7 +185,7 @@ func TestDeviceDataTableBitset(t *testing.T) {
fallbakKeyUpdate.SetFallbackKeysChanged()
assertDeviceData(t, *got, fallbakKeyUpdate)
// updating both works
err = table.Upsert(&bothUpdate)
err = table.Upsert(&bothUpdate, nil)
assertNoError(t, err)
got, err = table.Select(userID, deviceID, true)
assertNoError(t, err)

View File

@ -14,6 +14,14 @@ const (
BucketSent = 2
)
type DeviceListRow struct {
UserID string `db:"user_id"`
DeviceID string `db:"device_id"`
TargetUserID string `db:"target_user_id"`
TargetState int `db:"target_state"`
Bucket int `db:"bucket"`
}
type DeviceListTable struct {
db *sqlx.DB
}
@ -39,24 +47,9 @@ func NewDeviceListTable(db *sqlx.DB) *DeviceListTable {
}
}
// Upsert new device list changes.
func (t *DeviceListTable) Upsert(userID, deviceID string, deviceListChanges map[string]int) (err error) {
err = sqlutil.WithTransaction(t.db, func(txn *sqlx.Tx) error {
for targetUserID, targetState := range deviceListChanges {
if targetState != internal.DeviceListChanged && targetState != internal.DeviceListLeft {
sentry.CaptureException(fmt.Errorf("DeviceListTable.Upsert invalid target_state: %d this is a programming error", targetState))
continue
}
_, err = txn.Exec(
`INSERT INTO syncv3_device_list_updates(user_id, device_id, target_user_id, target_state, bucket) VALUES($1,$2,$3,$4,$5)
ON CONFLICT (user_id, device_id, target_user_id, bucket) DO UPDATE SET target_state=$4`,
userID, deviceID, targetUserID, targetState, BucketNew,
)
if err != nil {
return err
}
}
return nil
return t.UpsertTx(txn, userID, deviceID, deviceListChanges)
})
if err != nil {
sentry.CaptureException(err)
@ -64,31 +57,68 @@ func (t *DeviceListTable) Upsert(userID, deviceID string, deviceListChanges map[
return
}
// Select device list changes for this client. Returns a map of user_id => change enum.
// Upsert new device list changes.
func (t *DeviceListTable) UpsertTx(txn *sqlx.Tx, userID, deviceID string, deviceListChanges map[string]int) (err error) {
if len(deviceListChanges) == 0 {
return nil
}
var deviceListRows []DeviceListRow
for targetUserID, targetState := range deviceListChanges {
if targetState != internal.DeviceListChanged && targetState != internal.DeviceListLeft {
sentry.CaptureException(fmt.Errorf("DeviceListTable.Upsert invalid target_state: %d this is a programming error", targetState))
continue
}
deviceListRows = append(deviceListRows, DeviceListRow{
UserID: userID,
DeviceID: deviceID,
TargetUserID: targetUserID,
TargetState: targetState,
Bucket: BucketNew,
})
}
chunks := sqlutil.Chunkify(5, MaxPostgresParameters, DeviceListChunker(deviceListRows))
for _, chunk := range chunks {
_, err := txn.NamedExec(`
INSERT INTO syncv3_device_list_updates(user_id, device_id, target_user_id, target_state, bucket)
VALUES(:user_id, :device_id, :target_user_id, :target_state, :bucket)
ON CONFLICT (user_id, device_id, target_user_id, bucket) DO UPDATE SET target_state = EXCLUDED.target_state`, chunk)
if err != nil {
return err
}
}
return nil
return
}
func (t *DeviceListTable) Select(userID, deviceID string, swap bool) (result internal.MapStringInt, err error) {
err = sqlutil.WithTransaction(t.db, func(txn *sqlx.Tx) error {
result, err = t.SelectTx(txn, userID, deviceID, swap)
return err
})
return
}
// Select device list changes for this client. Returns a map of user_id => change enum.
func (t *DeviceListTable) SelectTx(txn *sqlx.Tx, userID, deviceID string, swap bool) (result internal.MapStringInt, err error) {
if !swap {
// read only view, just return what we previously sent and don't do anything else.
result, err = t.selectDeviceListChangesInBucket(txn, userID, deviceID, BucketSent)
return err
return t.selectDeviceListChangesInBucket(txn, userID, deviceID, BucketSent)
}
// delete the now acknowledged 'sent' data
_, err = txn.Exec(`DELETE FROM syncv3_device_list_updates WHERE user_id=$1 AND device_id=$2 AND bucket=$3`, userID, deviceID, BucketSent)
if err != nil {
return err
return nil, err
}
// grab any 'new' updates
result, err = t.selectDeviceListChangesInBucket(txn, userID, deviceID, BucketNew)
if err != nil {
return err
return nil, err
}
// mark these 'new' updates as 'sent'
_, err = txn.Exec(`UPDATE syncv3_device_list_updates SET bucket=$1 WHERE user_id=$2 AND device_id=$3 AND bucket=$4`, BucketSent, userID, deviceID, BucketNew)
return err
})
return
return result, err
}
func (t *DeviceListTable) selectDeviceListChangesInBucket(txn *sqlx.Tx, userID, deviceID string, bucket int) (result internal.MapStringInt, err error) {
@ -108,3 +138,12 @@ func (t *DeviceListTable) selectDeviceListChangesInBucket(txn *sqlx.Tx, userID,
}
return result, rows.Err()
}
type DeviceListChunker []DeviceListRow
func (c DeviceListChunker) Len() int {
return len(c)
}
func (c DeviceListChunker) Subslice(i, j int) sqlutil.Chunker {
return c[i:j]
}

View File

@ -1,6 +1,7 @@
package state
import (
"fmt"
"testing"
"github.com/matrix-org/sliding-sync/internal"
@ -105,4 +106,15 @@ func TestDeviceListTable(t *testing.T) {
got, err = table.Select(userID, deviceID, true)
assertNoError(t, err)
assertVal(t, "swap select did not return combined new items", got, internal.MapStringInt{})
// large updates work (chunker)
largeUpdate := internal.MapStringInt{}
for i := 0; i < 100000; i++ {
largeUpdate[fmt.Sprintf("user_%d", i)] = internal.DeviceListChanged
}
err = table.Upsert(userID, deviceID, largeUpdate)
assertNoError(t, err)
got, err = table.Select(userID, deviceID, true)
assertNoError(t, err)
assertVal(t, "swap select did not return large items", got, largeUpdate)
}

View File

@ -7,7 +7,6 @@ import (
"github.com/jmoiron/sqlx"
_ "github.com/lib/pq"
"github.com/matrix-org/sliding-sync/internal"
"github.com/matrix-org/sliding-sync/testutils"
)
@ -48,8 +47,8 @@ func TestJSONBMigration(t *testing.T) {
defer tx.Commit()
// insert some "invalid" data
dd := internal.DeviceData{
DeviceLists: internal.DeviceLists{
dd := OldDeviceData{
DeviceLists: OldDeviceLists{
New: map[string]int{"@💣:localhost": 1},
Sent: map[string]int{},
},

View File

@ -9,7 +9,6 @@ import (
"strings"
"github.com/fxamacker/cbor/v2"
"github.com/matrix-org/sliding-sync/internal"
"github.com/matrix-org/sliding-sync/sync2"
"github.com/pressly/goose/v3"
)
@ -59,7 +58,7 @@ func upCborDeviceData(ctx context.Context, tx *sql.Tx) error {
}
for dd, jsonBytes := range deviceDatas {
var data internal.DeviceData
var data OldDeviceData
if err := json.Unmarshal(jsonBytes, &data); err != nil {
return fmt.Errorf("failed to unmarshal JSON: %v -> %v", string(jsonBytes), err)
}
@ -115,7 +114,7 @@ func downCborDeviceData(ctx context.Context, tx *sql.Tx) error {
}
for dd, cborBytes := range deviceDatas {
var data internal.DeviceData
var data OldDeviceData
if err := cbor.Unmarshal(cborBytes, &data); err != nil {
return fmt.Errorf("failed to unmarshal CBOR: %v", err)
}

View File

@ -2,13 +2,15 @@ package migrations
import (
"context"
"database/sql"
"encoding/json"
"reflect"
"testing"
"github.com/fxamacker/cbor/v2"
"github.com/jmoiron/sqlx"
_ "github.com/lib/pq"
"github.com/matrix-org/sliding-sync/internal"
"github.com/matrix-org/sliding-sync/state"
"github.com/matrix-org/sliding-sync/sqlutil"
)
func TestCBORBMigration(t *testing.T) {
@ -30,9 +32,9 @@ func TestCBORBMigration(t *testing.T) {
t.Fatal(err)
}
rowData := []internal.DeviceData{
rowData := []OldDeviceData{
{
DeviceLists: internal.DeviceLists{
DeviceLists: OldDeviceLists{
New: map[string]int{"@bob:localhost": 2},
Sent: map[string]int{},
},
@ -43,7 +45,7 @@ func TestCBORBMigration(t *testing.T) {
UserID: "@alice:localhost",
},
{
DeviceLists: internal.DeviceLists{
DeviceLists: OldDeviceLists{
New: map[string]int{"@💣:localhost": 1, "@bomb:localhost": 2},
Sent: map[string]int{"@sent:localhost": 1},
},
@ -78,9 +80,8 @@ func TestCBORBMigration(t *testing.T) {
tx.Commit()
// ensure we can now select it
table := state.NewDeviceDataTable(db)
for _, want := range rowData {
got, err := table.Select(want.UserID, want.DeviceID, false)
got, err := OldDeviceDataTableSelect(db, want.UserID, want.DeviceID, false)
if err != nil {
t.Fatal(err)
}
@ -101,7 +102,7 @@ func TestCBORBMigration(t *testing.T) {
// ensure it is what we originally inserted
for _, want := range rowData {
var got internal.DeviceData
var got OldDeviceData
var gotBytes []byte
err = tx.QueryRow(`SELECT data FROM syncv3_device_data WHERE user_id=$1 AND device_id=$2`, want.UserID, want.DeviceID).Scan(&gotBytes)
if err != nil {
@ -119,3 +120,66 @@ func TestCBORBMigration(t *testing.T) {
tx.Commit()
}
type OldDeviceDataRow struct {
ID int64 `db:"id"`
UserID string `db:"user_id"`
DeviceID string `db:"device_id"`
// This will contain internal.DeviceData serialised as JSON. It's stored in a single column as we don't
// need to perform searches on this data.
Data []byte `db:"data"`
}
func OldDeviceDataTableSelect(db *sqlx.DB, userID, deviceID string, swap bool) (result *OldDeviceData, err error) {
err = sqlutil.WithTransaction(db, func(txn *sqlx.Tx) error {
var row OldDeviceDataRow
err = txn.Get(&row, `SELECT data FROM syncv3_device_data WHERE user_id=$1 AND device_id=$2 FOR UPDATE`, userID, deviceID)
if err != nil {
if err == sql.ErrNoRows {
// if there is no device data for this user, it's not an error.
return nil
}
return err
}
// unmarshal to swap
opts := cbor.DecOptions{
MaxMapPairs: 1000000000, // 1 billion :(
}
decMode, err := opts.DecMode()
if err != nil {
return err
}
if err = decMode.Unmarshal(row.Data, &result); err != nil {
return err
}
result.UserID = userID
result.DeviceID = deviceID
if !swap {
return nil // don't swap
}
// the caller will only look at sent, so make sure what is new is now in sent
result.DeviceLists.Sent = result.DeviceLists.New
// swap over the fields
writeBack := *result
writeBack.DeviceLists.Sent = result.DeviceLists.New
writeBack.DeviceLists.New = make(map[string]int)
writeBack.ChangedBits = 0
if reflect.DeepEqual(result, &writeBack) {
// The update to the DB would be a no-op; don't bother with it.
// This helps reduce write usage and the contention on the unique index for
// the device_data table.
return nil
}
// re-marshal and write
data, err := cbor.Marshal(writeBack)
if err != nil {
return err
}
_, err = txn.Exec(`UPDATE syncv3_device_data SET data=$1 WHERE user_id=$2 AND device_id=$3`, data, userID, deviceID)
return err
})
return
}

View File

@ -0,0 +1,158 @@
package migrations
import (
"context"
"database/sql"
"time"
"github.com/fxamacker/cbor/v2"
"github.com/lib/pq"
"github.com/matrix-org/sliding-sync/internal"
"github.com/matrix-org/sliding-sync/state"
"github.com/pressly/goose/v3"
)
type OldDeviceData struct {
// Contains the latest device_one_time_keys_count values.
// Set whenever this field arrives down the v2 poller, and it replaces what was previously there.
OTKCounts internal.MapStringInt `json:"otk"`
// Contains the latest device_unused_fallback_key_types value
// Set whenever this field arrives down the v2 poller, and it replaces what was previously there.
// If this is a nil slice this means no change. If this is an empty slice then this means the fallback key was used up.
FallbackKeyTypes []string `json:"fallback"`
DeviceLists OldDeviceLists `json:"dl"`
// bitset for which device data changes are present. They accumulate until they get swapped over
// when they get reset
ChangedBits int `json:"c"`
UserID string
DeviceID string
}
type OldDeviceLists struct {
// map user_id -> DeviceList enum
New internal.MapStringInt `json:"n"`
Sent internal.MapStringInt `json:"s"`
}
func init() {
goose.AddMigrationContext(upDeviceListTable, downDeviceListTable)
}
func upDeviceListTable(ctx context.Context, tx *sql.Tx) error {
// create the table. It's a bit gross we need to dupe the schema here, but this is the first migration to
// add a new table like this.
_, err := tx.Exec(`
CREATE TABLE IF NOT EXISTS syncv3_device_list_updates (
user_id TEXT NOT NULL,
device_id TEXT NOT NULL,
target_user_id TEXT NOT NULL,
target_state SMALLINT NOT NULL,
bucket SMALLINT NOT NULL,
UNIQUE(user_id, device_id, target_user_id, bucket)
);
-- make an index so selecting all the rows is faster
CREATE INDEX IF NOT EXISTS syncv3_device_list_updates_bucket_idx ON syncv3_device_list_updates(user_id, device_id, bucket);
-- Set the fillfactor to 90%, to allow for HOT updates (e.g. we only
-- change the data, not anything indexed like the id)
ALTER TABLE syncv3_device_list_updates SET (fillfactor = 90);
`)
if err != nil {
return err
}
var count int
if err = tx.QueryRow(`SELECT count(*) FROM syncv3_device_data`).Scan(&count); err != nil {
return err
}
logger.Info().Int("count", count).Msg("transferring device list data for devices")
// scan for existing CBOR (streaming as the CBOR can be large) and for each row:
rows, err := tx.Query(`SELECT user_id, device_id, data FROM syncv3_device_data`)
if err != nil {
return err
}
defer rows.Close()
var userID string
var deviceID string
var data []byte
// every N seconds log an update
updateFrequency := time.Second * 2
lastUpdate := time.Now()
i := 0
for rows.Next() {
i++
if time.Since(lastUpdate) > updateFrequency {
logger.Info().Msgf("%d/%d process device list data", i, count)
lastUpdate = time.Now()
}
// * deserialise the CBOR
if err := rows.Scan(&userID, &deviceID, &data); err != nil {
return err
}
result, err := deserialiseCBOR(data)
if err != nil {
return err
}
// * transfer the device lists to the new device lists table
// uses a bulk copy that lib/pq supports
stmt, err := tx.Prepare(pq.CopyIn("syncv3_device_list_updates", "user_id", "device_id", "target_user_id", "target_state", "bucket"))
if err != nil {
return err
}
for targetUser, targetState := range result.DeviceLists.New {
if _, err := stmt.Exec(userID, deviceID, targetUser, targetState, state.BucketNew); err != nil {
return err
}
}
for targetUser, targetState := range result.DeviceLists.Sent {
if _, err := stmt.Exec(userID, deviceID, targetUser, targetState, state.BucketSent); err != nil {
return err
}
}
if _, err = stmt.Exec(); err != nil {
return err
}
if err = stmt.Close(); err != nil {
return err
}
// * delete the device lists from the CBOR and update
result.DeviceLists = OldDeviceLists{
New: make(internal.MapStringInt),
Sent: make(internal.MapStringInt),
}
data, err := cbor.Marshal(result)
if err != nil {
return err
}
_, err = tx.Exec(`UPDATE syncv3_device_data SET data=$1 WHERE user_id=$2 AND device_id=$3`, data, userID, deviceID)
if err != nil {
return err
}
}
return rows.Err()
}
func downDeviceListTable(ctx context.Context, tx *sql.Tx) error {
// no-op: we'll drop the device list updates but still work correctly as new/sent are still in the cbor but are empty
return nil
}
func deserialiseCBOR(data []byte) (*OldDeviceData, error) {
opts := cbor.DecOptions{
MaxMapPairs: 1000000000, // 1 billion :(
}
decMode, err := opts.DecMode()
if err != nil {
return nil, err
}
var result *OldDeviceData
if err = decMode.Unmarshal(data, &result); err != nil {
return nil, err
}
return result, nil
}

View File

@ -0,0 +1,7 @@
package migrations
import "testing"
func TestDeviceListTableMigration(t *testing.T) {
}

View File

@ -236,13 +236,12 @@ func (h *Handler) OnE2EEData(ctx context.Context, userID, deviceID string, otkCo
partialDD := internal.DeviceData{
UserID: userID,
DeviceID: deviceID,
DeviceKeyData: internal.DeviceKeyData{
OTKCounts: otkCounts,
FallbackKeyTypes: fallbackKeyTypes,
DeviceLists: internal.DeviceLists{
New: deviceListChanges,
},
}
err := h.Store.DeviceDataTable.Upsert(&partialDD)
err := h.Store.DeviceDataTable.Upsert(&partialDD, deviceListChanges)
if err != nil {
logger.Err(err).Str("user", userID).Msg("failed to upsert device data")
internal.GetSentryHubFromContextOrDefault(ctx).CaptureException(err)

View File

@ -70,11 +70,10 @@ func (r *E2EERequest) ProcessInitial(ctx context.Context, res *Response, extCtx
extRes.OTKCounts = dd.OTKCounts
hasUpdates = true
}
changed, left := internal.DeviceListChangesArrays(dd.DeviceLists.Sent)
if len(changed) > 0 || len(left) > 0 {
if len(dd.DeviceListChanged) > 0 || len(dd.DeviceListLeft) > 0 {
extRes.DeviceLists = &E2EEDeviceList{
Changed: changed,
Left: left,
Changed: dd.DeviceListChanged,
Left: dd.DeviceListLeft,
}
hasUpdates = true
}