mirror of
https://github.com/matrix-org/sliding-sync.git
synced 2025-03-10 13:37:11 +00:00
Merge pull request #437 from matrix-org/kegan/device-data-table
Refactor device data
This commit is contained in:
commit
1551ccd7c9
@ -1,9 +1,5 @@
|
|||||||
package internal
|
package internal
|
||||||
|
|
||||||
import (
|
|
||||||
"sync"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
const (
|
||||||
bitOTKCount int = iota
|
bitOTKCount int = iota
|
||||||
bitFallbackKeyTypes
|
bitFallbackKeyTypes
|
||||||
@ -18,9 +14,22 @@ func isBitSet(n int, bit int) bool {
|
|||||||
return val > 0
|
return val > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeviceData contains useful data for this user's device. This list can be expanded without prompting
|
// DeviceData contains useful data for this user's device.
|
||||||
// schema changes. These values are upserted into the database and persisted forever.
|
|
||||||
type DeviceData struct {
|
type DeviceData struct {
|
||||||
|
DeviceListChanges
|
||||||
|
DeviceKeyData
|
||||||
|
UserID string
|
||||||
|
DeviceID string
|
||||||
|
}
|
||||||
|
|
||||||
|
// This is calculated from device_lists table
|
||||||
|
type DeviceListChanges struct {
|
||||||
|
DeviceListChanged []string
|
||||||
|
DeviceListLeft []string
|
||||||
|
}
|
||||||
|
|
||||||
|
// This gets serialised as CBOR in device_data table
|
||||||
|
type DeviceKeyData struct {
|
||||||
// Contains the latest device_one_time_keys_count values.
|
// Contains the latest device_one_time_keys_count values.
|
||||||
// Set whenever this field arrives down the v2 poller, and it replaces what was previously there.
|
// Set whenever this field arrives down the v2 poller, and it replaces what was previously there.
|
||||||
OTKCounts MapStringInt `json:"otk"`
|
OTKCounts MapStringInt `json:"otk"`
|
||||||
@ -28,95 +37,22 @@ type DeviceData struct {
|
|||||||
// Set whenever this field arrives down the v2 poller, and it replaces what was previously there.
|
// Set whenever this field arrives down the v2 poller, and it replaces what was previously there.
|
||||||
// If this is a nil slice this means no change. If this is an empty slice then this means the fallback key was used up.
|
// If this is a nil slice this means no change. If this is an empty slice then this means the fallback key was used up.
|
||||||
FallbackKeyTypes []string `json:"fallback"`
|
FallbackKeyTypes []string `json:"fallback"`
|
||||||
|
|
||||||
DeviceLists DeviceLists `json:"dl"`
|
|
||||||
|
|
||||||
// bitset for which device data changes are present. They accumulate until they get swapped over
|
// bitset for which device data changes are present. They accumulate until they get swapped over
|
||||||
// when they get reset
|
// when they get reset
|
||||||
ChangedBits int `json:"c"`
|
ChangedBits int `json:"c"`
|
||||||
|
|
||||||
UserID string
|
|
||||||
DeviceID string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (dd *DeviceData) SetOTKCountChanged() {
|
func (dd *DeviceKeyData) SetOTKCountChanged() {
|
||||||
dd.ChangedBits = setBit(dd.ChangedBits, bitOTKCount)
|
dd.ChangedBits = setBit(dd.ChangedBits, bitOTKCount)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (dd *DeviceData) SetFallbackKeysChanged() {
|
func (dd *DeviceKeyData) SetFallbackKeysChanged() {
|
||||||
dd.ChangedBits = setBit(dd.ChangedBits, bitFallbackKeyTypes)
|
dd.ChangedBits = setBit(dd.ChangedBits, bitFallbackKeyTypes)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (dd *DeviceData) OTKCountChanged() bool {
|
func (dd *DeviceKeyData) OTKCountChanged() bool {
|
||||||
return isBitSet(dd.ChangedBits, bitOTKCount)
|
return isBitSet(dd.ChangedBits, bitOTKCount)
|
||||||
}
|
}
|
||||||
func (dd *DeviceData) FallbackKeysChanged() bool {
|
func (dd *DeviceKeyData) FallbackKeysChanged() bool {
|
||||||
return isBitSet(dd.ChangedBits, bitFallbackKeyTypes)
|
return isBitSet(dd.ChangedBits, bitFallbackKeyTypes)
|
||||||
}
|
}
|
||||||
|
|
||||||
type UserDeviceKey struct {
|
|
||||||
UserID string
|
|
||||||
DeviceID string
|
|
||||||
}
|
|
||||||
|
|
||||||
type DeviceDataMap struct {
|
|
||||||
deviceDataMu *sync.Mutex
|
|
||||||
deviceDataMap map[UserDeviceKey]*DeviceData
|
|
||||||
Pos int64
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewDeviceDataMap(startPos int64, devices []DeviceData) *DeviceDataMap {
|
|
||||||
ddm := &DeviceDataMap{
|
|
||||||
deviceDataMu: &sync.Mutex{},
|
|
||||||
deviceDataMap: make(map[UserDeviceKey]*DeviceData),
|
|
||||||
Pos: startPos,
|
|
||||||
}
|
|
||||||
for i, dd := range devices {
|
|
||||||
ddm.deviceDataMap[UserDeviceKey{
|
|
||||||
UserID: dd.UserID,
|
|
||||||
DeviceID: dd.DeviceID,
|
|
||||||
}] = &devices[i]
|
|
||||||
}
|
|
||||||
return ddm
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *DeviceDataMap) Get(userID, deviceID string) *DeviceData {
|
|
||||||
key := UserDeviceKey{
|
|
||||||
UserID: userID,
|
|
||||||
DeviceID: deviceID,
|
|
||||||
}
|
|
||||||
d.deviceDataMu.Lock()
|
|
||||||
defer d.deviceDataMu.Unlock()
|
|
||||||
dd, ok := d.deviceDataMap[key]
|
|
||||||
if !ok {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return dd
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *DeviceDataMap) Update(dd DeviceData) DeviceData {
|
|
||||||
key := UserDeviceKey{
|
|
||||||
UserID: dd.UserID,
|
|
||||||
DeviceID: dd.DeviceID,
|
|
||||||
}
|
|
||||||
d.deviceDataMu.Lock()
|
|
||||||
defer d.deviceDataMu.Unlock()
|
|
||||||
existing, ok := d.deviceDataMap[key]
|
|
||||||
if !ok {
|
|
||||||
existing = &DeviceData{
|
|
||||||
UserID: dd.UserID,
|
|
||||||
DeviceID: dd.DeviceID,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if dd.OTKCounts != nil {
|
|
||||||
existing.OTKCounts = dd.OTKCounts
|
|
||||||
}
|
|
||||||
if dd.FallbackKeyTypes != nil {
|
|
||||||
existing.FallbackKeyTypes = dd.FallbackKeyTypes
|
|
||||||
}
|
|
||||||
existing.DeviceLists = existing.DeviceLists.Combine(dd.DeviceLists)
|
|
||||||
|
|
||||||
d.deviceDataMap[key] = existing
|
|
||||||
|
|
||||||
return *existing
|
|
||||||
}
|
|
||||||
|
@ -15,13 +15,14 @@ type DeviceDataRow struct {
|
|||||||
ID int64 `db:"id"`
|
ID int64 `db:"id"`
|
||||||
UserID string `db:"user_id"`
|
UserID string `db:"user_id"`
|
||||||
DeviceID string `db:"device_id"`
|
DeviceID string `db:"device_id"`
|
||||||
// This will contain internal.DeviceData serialised as JSON. It's stored in a single column as we don't
|
// This will contain internal.DeviceKeyData serialised as JSON. It's stored in a single column as we don't
|
||||||
// need to perform searches on this data.
|
// need to perform searches on this data.
|
||||||
Data []byte `db:"data"`
|
KeyData []byte `db:"data"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type DeviceDataTable struct {
|
type DeviceDataTable struct {
|
||||||
db *sqlx.DB
|
db *sqlx.DB
|
||||||
|
deviceListTable *DeviceListTable
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewDeviceDataTable(db *sqlx.DB) *DeviceDataTable {
|
func NewDeviceDataTable(db *sqlx.DB) *DeviceDataTable {
|
||||||
@ -37,7 +38,8 @@ func NewDeviceDataTable(db *sqlx.DB) *DeviceDataTable {
|
|||||||
ALTER TABLE syncv3_device_data SET (fillfactor = 90);
|
ALTER TABLE syncv3_device_data SET (fillfactor = 90);
|
||||||
`)
|
`)
|
||||||
return &DeviceDataTable{
|
return &DeviceDataTable{
|
||||||
db: db,
|
db: db,
|
||||||
|
deviceListTable: NewDeviceListTable(db),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -45,6 +47,7 @@ func NewDeviceDataTable(db *sqlx.DB) *DeviceDataTable {
|
|||||||
// This should only be called by the v3 HTTP APIs when servicing an E2EE extension request.
|
// This should only be called by the v3 HTTP APIs when servicing an E2EE extension request.
|
||||||
func (t *DeviceDataTable) Select(userID, deviceID string, swap bool) (result *internal.DeviceData, err error) {
|
func (t *DeviceDataTable) Select(userID, deviceID string, swap bool) (result *internal.DeviceData, err error) {
|
||||||
err = sqlutil.WithTransaction(t.db, func(txn *sqlx.Tx) error {
|
err = sqlutil.WithTransaction(t.db, func(txn *sqlx.Tx) error {
|
||||||
|
// grab otk counts and fallback key types
|
||||||
var row DeviceDataRow
|
var row DeviceDataRow
|
||||||
err = txn.Get(&row, `SELECT data FROM syncv3_device_data WHERE user_id=$1 AND device_id=$2 FOR UPDATE`, userID, deviceID)
|
err = txn.Get(&row, `SELECT data FROM syncv3_device_data WHERE user_id=$1 AND device_id=$2 FOR UPDATE`, userID, deviceID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -54,32 +57,38 @@ func (t *DeviceDataTable) Select(userID, deviceID string, swap bool) (result *in
|
|||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
result = &internal.DeviceData{}
|
||||||
|
var keyData *internal.DeviceKeyData
|
||||||
// unmarshal to swap
|
// unmarshal to swap
|
||||||
opts := cbor.DecOptions{
|
if err = cbor.Unmarshal(row.KeyData, &keyData); err != nil {
|
||||||
MaxMapPairs: 1000000000, // 1 billion :(
|
|
||||||
}
|
|
||||||
decMode, err := opts.DecMode()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if err = decMode.Unmarshal(row.Data, &result); err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
result.UserID = userID
|
result.UserID = userID
|
||||||
result.DeviceID = deviceID
|
result.DeviceID = deviceID
|
||||||
|
if keyData != nil {
|
||||||
|
result.DeviceKeyData = *keyData
|
||||||
|
}
|
||||||
|
|
||||||
|
deviceListChanges, err := t.deviceListTable.SelectTx(txn, userID, deviceID, swap)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
for targetUserID, targetState := range deviceListChanges {
|
||||||
|
switch targetState {
|
||||||
|
case internal.DeviceListChanged:
|
||||||
|
result.DeviceListChanged = append(result.DeviceListChanged, targetUserID)
|
||||||
|
case internal.DeviceListLeft:
|
||||||
|
result.DeviceListLeft = append(result.DeviceListLeft, targetUserID)
|
||||||
|
}
|
||||||
|
}
|
||||||
if !swap {
|
if !swap {
|
||||||
return nil // don't swap
|
return nil // don't swap
|
||||||
}
|
}
|
||||||
// the caller will only look at sent, so make sure what is new is now in sent
|
|
||||||
result.DeviceLists.Sent = result.DeviceLists.New
|
|
||||||
|
|
||||||
// swap over the fields
|
// swap over the fields
|
||||||
writeBack := *result
|
writeBack := *keyData
|
||||||
writeBack.DeviceLists.Sent = result.DeviceLists.New
|
|
||||||
writeBack.DeviceLists.New = make(map[string]int)
|
|
||||||
writeBack.ChangedBits = 0
|
writeBack.ChangedBits = 0
|
||||||
|
|
||||||
if reflect.DeepEqual(result, &writeBack) {
|
if reflect.DeepEqual(keyData, &writeBack) {
|
||||||
// The update to the DB would be a no-op; don't bother with it.
|
// The update to the DB would be a no-op; don't bother with it.
|
||||||
// This helps reduce write usage and the contention on the unique index for
|
// This helps reduce write usage and the contention on the unique index for
|
||||||
// the device_data table.
|
// the device_data table.
|
||||||
@ -97,52 +106,43 @@ func (t *DeviceDataTable) Select(userID, deviceID string, swap bool) (result *in
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *DeviceDataTable) DeleteDevice(userID, deviceID string) error {
|
|
||||||
_, err := t.db.Exec(`DELETE FROM syncv3_device_data WHERE user_id = $1 AND device_id = $2`, userID, deviceID)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Upsert combines what is in the database for this user|device with the partial entry `dd`
|
// Upsert combines what is in the database for this user|device with the partial entry `dd`
|
||||||
func (t *DeviceDataTable) Upsert(dd *internal.DeviceData) (err error) {
|
func (t *DeviceDataTable) Upsert(userID, deviceID string, keys internal.DeviceKeyData, deviceListChanges map[string]int) (err error) {
|
||||||
err = sqlutil.WithTransaction(t.db, func(txn *sqlx.Tx) error {
|
err = sqlutil.WithTransaction(t.db, func(txn *sqlx.Tx) error {
|
||||||
|
// Update device lists
|
||||||
|
if err = t.deviceListTable.UpsertTx(txn, userID, deviceID, deviceListChanges); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
// select what already exists
|
// select what already exists
|
||||||
var row DeviceDataRow
|
var row DeviceDataRow
|
||||||
err = txn.Get(&row, `SELECT data FROM syncv3_device_data WHERE user_id=$1 AND device_id=$2 FOR UPDATE`, dd.UserID, dd.DeviceID)
|
err = txn.Get(&row, `SELECT data FROM syncv3_device_data WHERE user_id=$1 AND device_id=$2 FOR UPDATE`, userID, deviceID)
|
||||||
if err != nil && err != sql.ErrNoRows {
|
if err != nil && err != sql.ErrNoRows {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
// unmarshal and combine
|
// unmarshal and combine
|
||||||
var tempDD internal.DeviceData
|
var keyData internal.DeviceKeyData
|
||||||
if len(row.Data) > 0 {
|
if len(row.KeyData) > 0 {
|
||||||
opts := cbor.DecOptions{
|
if err = cbor.Unmarshal(row.KeyData, &keyData); err != nil {
|
||||||
MaxMapPairs: 1000000000, // 1 billion :(
|
|
||||||
}
|
|
||||||
decMode, err := opts.DecMode()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if err = decMode.Unmarshal(row.Data, &tempDD); err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if dd.FallbackKeyTypes != nil {
|
if keys.FallbackKeyTypes != nil {
|
||||||
tempDD.FallbackKeyTypes = dd.FallbackKeyTypes
|
keyData.FallbackKeyTypes = keys.FallbackKeyTypes
|
||||||
tempDD.SetFallbackKeysChanged()
|
keyData.SetFallbackKeysChanged()
|
||||||
}
|
}
|
||||||
if dd.OTKCounts != nil {
|
if keys.OTKCounts != nil {
|
||||||
tempDD.OTKCounts = dd.OTKCounts
|
keyData.OTKCounts = keys.OTKCounts
|
||||||
tempDD.SetOTKCountChanged()
|
keyData.SetOTKCountChanged()
|
||||||
}
|
}
|
||||||
tempDD.DeviceLists = tempDD.DeviceLists.Combine(dd.DeviceLists)
|
|
||||||
|
|
||||||
data, err := cbor.Marshal(tempDD)
|
data, err := cbor.Marshal(keyData)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
_, err = txn.Exec(
|
_, err = txn.Exec(
|
||||||
`INSERT INTO syncv3_device_data(user_id, device_id, data) VALUES($1,$2,$3)
|
`INSERT INTO syncv3_device_data(user_id, device_id, data) VALUES($1,$2,$3)
|
||||||
ON CONFLICT (user_id, device_id) DO UPDATE SET data=$3`,
|
ON CONFLICT (user_id, device_id) DO UPDATE SET data=$3`,
|
||||||
dd.UserID, dd.DeviceID, data,
|
userID, deviceID, data,
|
||||||
)
|
)
|
||||||
return err
|
return err
|
||||||
})
|
})
|
||||||
|
@ -22,9 +22,6 @@ func assertDeviceData(t *testing.T, g, w internal.DeviceData) {
|
|||||||
assertVal(t, "FallbackKeyTypes", g.FallbackKeyTypes, w.FallbackKeyTypes)
|
assertVal(t, "FallbackKeyTypes", g.FallbackKeyTypes, w.FallbackKeyTypes)
|
||||||
assertVal(t, "OTKCounts", g.OTKCounts, w.OTKCounts)
|
assertVal(t, "OTKCounts", g.OTKCounts, w.OTKCounts)
|
||||||
assertVal(t, "ChangedBits", g.ChangedBits, w.ChangedBits)
|
assertVal(t, "ChangedBits", g.ChangedBits, w.ChangedBits)
|
||||||
if w.DeviceLists.Sent != nil {
|
|
||||||
assertVal(t, "DeviceLists.Sent", g.DeviceLists.Sent, w.DeviceLists.Sent)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Tests OTKCounts and FallbackKeyTypes behaviour
|
// Tests OTKCounts and FallbackKeyTypes behaviour
|
||||||
@ -40,21 +37,27 @@ func TestDeviceDataTableOTKCountAndFallbackKeyTypes(t *testing.T) {
|
|||||||
{
|
{
|
||||||
UserID: userID,
|
UserID: userID,
|
||||||
DeviceID: deviceID,
|
DeviceID: deviceID,
|
||||||
OTKCounts: map[string]int{
|
DeviceKeyData: internal.DeviceKeyData{
|
||||||
"foo": 100,
|
OTKCounts: map[string]int{
|
||||||
"bar": 92,
|
"foo": 100,
|
||||||
|
"bar": 92,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
|
||||||
UserID: userID,
|
|
||||||
DeviceID: deviceID,
|
|
||||||
FallbackKeyTypes: []string{"foobar"},
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
UserID: userID,
|
UserID: userID,
|
||||||
DeviceID: deviceID,
|
DeviceID: deviceID,
|
||||||
OTKCounts: map[string]int{
|
DeviceKeyData: internal.DeviceKeyData{
|
||||||
"foo": 99,
|
FallbackKeyTypes: []string{"foobar"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
UserID: userID,
|
||||||
|
DeviceID: deviceID,
|
||||||
|
DeviceKeyData: internal.DeviceKeyData{
|
||||||
|
OTKCounts: map[string]int{
|
||||||
|
"foo": 99,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -65,7 +68,7 @@ func TestDeviceDataTableOTKCountAndFallbackKeyTypes(t *testing.T) {
|
|||||||
|
|
||||||
// apply them
|
// apply them
|
||||||
for _, dd := range deltas {
|
for _, dd := range deltas {
|
||||||
err := table.Upsert(&dd)
|
err := table.Upsert(dd.UserID, dd.DeviceID, dd.DeviceKeyData, nil)
|
||||||
assertNoError(t, err)
|
assertNoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -79,10 +82,12 @@ func TestDeviceDataTableOTKCountAndFallbackKeyTypes(t *testing.T) {
|
|||||||
want := internal.DeviceData{
|
want := internal.DeviceData{
|
||||||
UserID: userID,
|
UserID: userID,
|
||||||
DeviceID: deviceID,
|
DeviceID: deviceID,
|
||||||
OTKCounts: map[string]int{
|
DeviceKeyData: internal.DeviceKeyData{
|
||||||
"foo": 99,
|
OTKCounts: map[string]int{
|
||||||
|
"foo": 99,
|
||||||
|
},
|
||||||
|
FallbackKeyTypes: []string{"foobar"},
|
||||||
},
|
},
|
||||||
FallbackKeyTypes: []string{"foobar"},
|
|
||||||
}
|
}
|
||||||
want.SetFallbackKeysChanged()
|
want.SetFallbackKeysChanged()
|
||||||
want.SetOTKCountChanged()
|
want.SetOTKCountChanged()
|
||||||
@ -95,10 +100,12 @@ func TestDeviceDataTableOTKCountAndFallbackKeyTypes(t *testing.T) {
|
|||||||
want := internal.DeviceData{
|
want := internal.DeviceData{
|
||||||
UserID: userID,
|
UserID: userID,
|
||||||
DeviceID: deviceID,
|
DeviceID: deviceID,
|
||||||
OTKCounts: map[string]int{
|
DeviceKeyData: internal.DeviceKeyData{
|
||||||
"foo": 99,
|
OTKCounts: map[string]int{
|
||||||
|
"foo": 99,
|
||||||
|
},
|
||||||
|
FallbackKeyTypes: []string{"foobar"},
|
||||||
},
|
},
|
||||||
FallbackKeyTypes: []string{"foobar"},
|
|
||||||
}
|
}
|
||||||
want.SetFallbackKeysChanged()
|
want.SetFallbackKeysChanged()
|
||||||
want.SetOTKCountChanged()
|
want.SetOTKCountChanged()
|
||||||
@ -110,168 +117,16 @@ func TestDeviceDataTableOTKCountAndFallbackKeyTypes(t *testing.T) {
|
|||||||
want = internal.DeviceData{
|
want = internal.DeviceData{
|
||||||
UserID: userID,
|
UserID: userID,
|
||||||
DeviceID: deviceID,
|
DeviceID: deviceID,
|
||||||
OTKCounts: map[string]int{
|
DeviceKeyData: internal.DeviceKeyData{
|
||||||
"foo": 99,
|
OTKCounts: map[string]int{
|
||||||
|
"foo": 99,
|
||||||
|
},
|
||||||
|
FallbackKeyTypes: []string{"foobar"},
|
||||||
},
|
},
|
||||||
FallbackKeyTypes: []string{"foobar"},
|
|
||||||
}
|
}
|
||||||
assertDeviceData(t, *got, want)
|
assertDeviceData(t, *got, want)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Tests the DeviceLists field
|
|
||||||
func TestDeviceDataTableDeviceList(t *testing.T) {
|
|
||||||
db, close := connectToDB(t)
|
|
||||||
defer close()
|
|
||||||
table := NewDeviceDataTable(db)
|
|
||||||
userID := "@TestDeviceDataTableDeviceList"
|
|
||||||
deviceID := "BOB"
|
|
||||||
|
|
||||||
// these are individual updates from Synapse from /sync v2
|
|
||||||
deltas := []internal.DeviceData{
|
|
||||||
{
|
|
||||||
UserID: userID,
|
|
||||||
DeviceID: deviceID,
|
|
||||||
DeviceLists: internal.DeviceLists{
|
|
||||||
New: internal.ToDeviceListChangesMap([]string{"alice"}, nil),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
UserID: userID,
|
|
||||||
DeviceID: deviceID,
|
|
||||||
DeviceLists: internal.DeviceLists{
|
|
||||||
New: internal.ToDeviceListChangesMap([]string{"💣"}, nil),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
// apply them
|
|
||||||
for _, dd := range deltas {
|
|
||||||
err := table.Upsert(&dd)
|
|
||||||
assertNoError(t, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// check we can read-only select. This doesn't modify any fields.
|
|
||||||
for i := 0; i < 3; i++ {
|
|
||||||
got, err := table.Select(userID, deviceID, false)
|
|
||||||
assertNoError(t, err)
|
|
||||||
assertDeviceData(t, *got, internal.DeviceData{
|
|
||||||
UserID: userID,
|
|
||||||
DeviceID: deviceID,
|
|
||||||
DeviceLists: internal.DeviceLists{
|
|
||||||
Sent: internal.MapStringInt{}, // until we "swap" we don't consume the New entries
|
|
||||||
},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
// now swap-er-roo, which shifts everything from New into Sent.
|
|
||||||
got, err := table.Select(userID, deviceID, true)
|
|
||||||
assertNoError(t, err)
|
|
||||||
assertDeviceData(t, *got, internal.DeviceData{
|
|
||||||
UserID: userID,
|
|
||||||
DeviceID: deviceID,
|
|
||||||
DeviceLists: internal.DeviceLists{
|
|
||||||
Sent: internal.ToDeviceListChangesMap([]string{"alice", "💣"}, nil),
|
|
||||||
},
|
|
||||||
})
|
|
||||||
|
|
||||||
// this is permanent, read-only views show this too.
|
|
||||||
got, err = table.Select(userID, deviceID, false)
|
|
||||||
assertNoError(t, err)
|
|
||||||
assertDeviceData(t, *got, internal.DeviceData{
|
|
||||||
UserID: userID,
|
|
||||||
DeviceID: deviceID,
|
|
||||||
DeviceLists: internal.DeviceLists{
|
|
||||||
Sent: internal.ToDeviceListChangesMap([]string{"alice", "💣"}, nil),
|
|
||||||
},
|
|
||||||
})
|
|
||||||
|
|
||||||
// We now expect empty DeviceLists, as we swapped twice.
|
|
||||||
got, err = table.Select(userID, deviceID, true)
|
|
||||||
assertNoError(t, err)
|
|
||||||
assertDeviceData(t, *got, internal.DeviceData{
|
|
||||||
UserID: userID,
|
|
||||||
DeviceID: deviceID,
|
|
||||||
DeviceLists: internal.DeviceLists{
|
|
||||||
Sent: internal.MapStringInt{},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
|
|
||||||
// get back the original state
|
|
||||||
assertNoError(t, err)
|
|
||||||
for _, dd := range deltas {
|
|
||||||
err = table.Upsert(&dd)
|
|
||||||
assertNoError(t, err)
|
|
||||||
}
|
|
||||||
// Move original state to Sent by swapping
|
|
||||||
got, err = table.Select(userID, deviceID, true)
|
|
||||||
assertNoError(t, err)
|
|
||||||
assertDeviceData(t, *got, internal.DeviceData{
|
|
||||||
UserID: userID,
|
|
||||||
DeviceID: deviceID,
|
|
||||||
DeviceLists: internal.DeviceLists{
|
|
||||||
Sent: internal.ToDeviceListChangesMap([]string{"alice", "💣"}, nil),
|
|
||||||
},
|
|
||||||
})
|
|
||||||
// Add new entries to New before acknowledging Sent
|
|
||||||
err = table.Upsert(&internal.DeviceData{
|
|
||||||
UserID: userID,
|
|
||||||
DeviceID: deviceID,
|
|
||||||
DeviceLists: internal.DeviceLists{
|
|
||||||
New: internal.ToDeviceListChangesMap([]string{"💣"}, []string{"charlie"}),
|
|
||||||
},
|
|
||||||
})
|
|
||||||
assertNoError(t, err)
|
|
||||||
|
|
||||||
// Reading without swapping does not move New->Sent, so returns the previous value
|
|
||||||
got, err = table.Select(userID, deviceID, false)
|
|
||||||
assertNoError(t, err)
|
|
||||||
assertDeviceData(t, *got, internal.DeviceData{
|
|
||||||
UserID: userID,
|
|
||||||
DeviceID: deviceID,
|
|
||||||
DeviceLists: internal.DeviceLists{
|
|
||||||
Sent: internal.ToDeviceListChangesMap([]string{"alice", "💣"}, nil),
|
|
||||||
},
|
|
||||||
})
|
|
||||||
|
|
||||||
// Append even more items to New
|
|
||||||
err = table.Upsert(&internal.DeviceData{
|
|
||||||
UserID: userID,
|
|
||||||
DeviceID: deviceID,
|
|
||||||
DeviceLists: internal.DeviceLists{
|
|
||||||
New: internal.ToDeviceListChangesMap([]string{"dave"}, []string{"dave"}),
|
|
||||||
},
|
|
||||||
})
|
|
||||||
assertNoError(t, err)
|
|
||||||
|
|
||||||
// Now swap: all the combined items in New go into Sent
|
|
||||||
got, err = table.Select(userID, deviceID, true)
|
|
||||||
assertNoError(t, err)
|
|
||||||
assertDeviceData(t, *got, internal.DeviceData{
|
|
||||||
UserID: userID,
|
|
||||||
DeviceID: deviceID,
|
|
||||||
DeviceLists: internal.DeviceLists{
|
|
||||||
Sent: internal.ToDeviceListChangesMap([]string{"💣", "dave"}, []string{"charlie", "dave"}),
|
|
||||||
},
|
|
||||||
})
|
|
||||||
|
|
||||||
// Swapping again clears Sent out, and since nothing is in New we get an empty list
|
|
||||||
got, err = table.Select(userID, deviceID, true)
|
|
||||||
assertNoError(t, err)
|
|
||||||
assertDeviceData(t, *got, internal.DeviceData{
|
|
||||||
UserID: userID,
|
|
||||||
DeviceID: deviceID,
|
|
||||||
DeviceLists: internal.DeviceLists{
|
|
||||||
Sent: internal.MapStringInt{},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
|
|
||||||
// delete everything, no data returned
|
|
||||||
assertNoError(t, table.DeleteDevice(userID, deviceID))
|
|
||||||
got, err = table.Select(userID, deviceID, false)
|
|
||||||
assertNoError(t, err)
|
|
||||||
if got != nil {
|
|
||||||
t.Errorf("wanted no data, got %v", got)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDeviceDataTableBitset(t *testing.T) {
|
func TestDeviceDataTableBitset(t *testing.T) {
|
||||||
db, close := connectToDB(t)
|
db, close := connectToDB(t)
|
||||||
defer close()
|
defer close()
|
||||||
@ -281,29 +136,32 @@ func TestDeviceDataTableBitset(t *testing.T) {
|
|||||||
otkUpdate := internal.DeviceData{
|
otkUpdate := internal.DeviceData{
|
||||||
UserID: userID,
|
UserID: userID,
|
||||||
DeviceID: deviceID,
|
DeviceID: deviceID,
|
||||||
OTKCounts: map[string]int{
|
DeviceKeyData: internal.DeviceKeyData{
|
||||||
"foo": 100,
|
OTKCounts: map[string]int{
|
||||||
"bar": 92,
|
"foo": 100,
|
||||||
|
"bar": 92,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
DeviceLists: internal.DeviceLists{New: map[string]int{}, Sent: map[string]int{}},
|
|
||||||
}
|
}
|
||||||
fallbakKeyUpdate := internal.DeviceData{
|
fallbakKeyUpdate := internal.DeviceData{
|
||||||
UserID: userID,
|
UserID: userID,
|
||||||
DeviceID: deviceID,
|
DeviceID: deviceID,
|
||||||
FallbackKeyTypes: []string{"foo", "bar"},
|
DeviceKeyData: internal.DeviceKeyData{
|
||||||
DeviceLists: internal.DeviceLists{New: map[string]int{}, Sent: map[string]int{}},
|
FallbackKeyTypes: []string{"foo", "bar"},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
bothUpdate := internal.DeviceData{
|
bothUpdate := internal.DeviceData{
|
||||||
UserID: userID,
|
UserID: userID,
|
||||||
DeviceID: deviceID,
|
DeviceID: deviceID,
|
||||||
FallbackKeyTypes: []string{"both"},
|
DeviceKeyData: internal.DeviceKeyData{
|
||||||
OTKCounts: map[string]int{
|
FallbackKeyTypes: []string{"both"},
|
||||||
"both": 100,
|
OTKCounts: map[string]int{
|
||||||
|
"both": 100,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
DeviceLists: internal.DeviceLists{New: map[string]int{}, Sent: map[string]int{}},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
err := table.Upsert(&otkUpdate)
|
err := table.Upsert(otkUpdate.UserID, otkUpdate.DeviceID, otkUpdate.DeviceKeyData, nil)
|
||||||
assertNoError(t, err)
|
assertNoError(t, err)
|
||||||
got, err := table.Select(userID, deviceID, true)
|
got, err := table.Select(userID, deviceID, true)
|
||||||
assertNoError(t, err)
|
assertNoError(t, err)
|
||||||
@ -315,7 +173,7 @@ func TestDeviceDataTableBitset(t *testing.T) {
|
|||||||
otkUpdate.ChangedBits = 0
|
otkUpdate.ChangedBits = 0
|
||||||
assertDeviceData(t, *got, otkUpdate)
|
assertDeviceData(t, *got, otkUpdate)
|
||||||
// now same for fallback keys, but we won't swap them so it should return those diffs
|
// now same for fallback keys, but we won't swap them so it should return those diffs
|
||||||
err = table.Upsert(&fallbakKeyUpdate)
|
err = table.Upsert(fallbakKeyUpdate.UserID, fallbakKeyUpdate.DeviceID, fallbakKeyUpdate.DeviceKeyData, nil)
|
||||||
assertNoError(t, err)
|
assertNoError(t, err)
|
||||||
fallbakKeyUpdate.OTKCounts = otkUpdate.OTKCounts
|
fallbakKeyUpdate.OTKCounts = otkUpdate.OTKCounts
|
||||||
got, err = table.Select(userID, deviceID, false)
|
got, err = table.Select(userID, deviceID, false)
|
||||||
@ -327,7 +185,7 @@ func TestDeviceDataTableBitset(t *testing.T) {
|
|||||||
fallbakKeyUpdate.SetFallbackKeysChanged()
|
fallbakKeyUpdate.SetFallbackKeysChanged()
|
||||||
assertDeviceData(t, *got, fallbakKeyUpdate)
|
assertDeviceData(t, *got, fallbakKeyUpdate)
|
||||||
// updating both works
|
// updating both works
|
||||||
err = table.Upsert(&bothUpdate)
|
err = table.Upsert(bothUpdate.UserID, bothUpdate.DeviceID, bothUpdate.DeviceKeyData, nil)
|
||||||
assertNoError(t, err)
|
assertNoError(t, err)
|
||||||
got, err = table.Select(userID, deviceID, true)
|
got, err = table.Select(userID, deviceID, true)
|
||||||
assertNoError(t, err)
|
assertNoError(t, err)
|
||||||
|
163
state/device_list_table.go
Normal file
163
state/device_list_table.go
Normal file
@ -0,0 +1,163 @@
|
|||||||
|
package state
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/getsentry/sentry-go"
|
||||||
|
"github.com/jmoiron/sqlx"
|
||||||
|
"github.com/matrix-org/sliding-sync/internal"
|
||||||
|
"github.com/matrix-org/sliding-sync/sqlutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
BucketNew = 1
|
||||||
|
BucketSent = 2
|
||||||
|
)
|
||||||
|
|
||||||
|
type DeviceListRow struct {
|
||||||
|
UserID string `db:"user_id"`
|
||||||
|
DeviceID string `db:"device_id"`
|
||||||
|
TargetUserID string `db:"target_user_id"`
|
||||||
|
TargetState int `db:"target_state"`
|
||||||
|
Bucket int `db:"bucket"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type DeviceListTable struct {
|
||||||
|
db *sqlx.DB
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewDeviceListTable(db *sqlx.DB) *DeviceListTable {
|
||||||
|
db.MustExec(`
|
||||||
|
CREATE TABLE IF NOT EXISTS syncv3_device_list_updates (
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
device_id TEXT NOT NULL,
|
||||||
|
target_user_id TEXT NOT NULL,
|
||||||
|
target_state SMALLINT NOT NULL,
|
||||||
|
bucket SMALLINT NOT NULL,
|
||||||
|
UNIQUE(user_id, device_id, target_user_id, bucket)
|
||||||
|
);
|
||||||
|
-- make an index so selecting all the rows is faster
|
||||||
|
CREATE INDEX IF NOT EXISTS syncv3_device_list_updates_bucket_idx ON syncv3_device_list_updates(user_id, device_id, bucket);
|
||||||
|
-- Set the fillfactor to 90%, to allow for HOT updates (e.g. we only
|
||||||
|
-- change the data, not anything indexed like the id)
|
||||||
|
ALTER TABLE syncv3_device_list_updates SET (fillfactor = 90);
|
||||||
|
`)
|
||||||
|
return &DeviceListTable{
|
||||||
|
db: db,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *DeviceListTable) Upsert(userID, deviceID string, deviceListChanges map[string]int) (err error) {
|
||||||
|
err = sqlutil.WithTransaction(t.db, func(txn *sqlx.Tx) error {
|
||||||
|
return t.UpsertTx(txn, userID, deviceID, deviceListChanges)
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
sentry.CaptureException(err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Upsert new device list changes.
|
||||||
|
func (t *DeviceListTable) UpsertTx(txn *sqlx.Tx, userID, deviceID string, deviceListChanges map[string]int) (err error) {
|
||||||
|
if len(deviceListChanges) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
var deviceListRows []DeviceListRow
|
||||||
|
for targetUserID, targetState := range deviceListChanges {
|
||||||
|
if targetState != internal.DeviceListChanged && targetState != internal.DeviceListLeft {
|
||||||
|
sentry.CaptureException(fmt.Errorf("DeviceListTable.Upsert invalid target_state: %d this is a programming error", targetState))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
deviceListRows = append(deviceListRows, DeviceListRow{
|
||||||
|
UserID: userID,
|
||||||
|
DeviceID: deviceID,
|
||||||
|
TargetUserID: targetUserID,
|
||||||
|
TargetState: targetState,
|
||||||
|
Bucket: BucketNew,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
chunks := sqlutil.Chunkify(5, MaxPostgresParameters, DeviceListChunker(deviceListRows))
|
||||||
|
for _, chunk := range chunks {
|
||||||
|
_, err := txn.NamedExec(`
|
||||||
|
INSERT INTO syncv3_device_list_updates(user_id, device_id, target_user_id, target_state, bucket)
|
||||||
|
VALUES(:user_id, :device_id, :target_user_id, :target_state, :bucket)
|
||||||
|
ON CONFLICT (user_id, device_id, target_user_id, bucket) DO UPDATE SET target_state = EXCLUDED.target_state`, chunk)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *DeviceListTable) Select(userID, deviceID string, swap bool) (result internal.MapStringInt, err error) {
|
||||||
|
err = sqlutil.WithTransaction(t.db, func(txn *sqlx.Tx) error {
|
||||||
|
result, err = t.SelectTx(txn, userID, deviceID, swap)
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Select device list changes for this client. Returns a map of user_id => change enum.
|
||||||
|
func (t *DeviceListTable) SelectTx(txn *sqlx.Tx, userID, deviceID string, swap bool) (result internal.MapStringInt, err error) {
|
||||||
|
if !swap {
|
||||||
|
// read only view, just return what we previously sent and don't do anything else.
|
||||||
|
return t.selectDeviceListChangesInBucket(txn, userID, deviceID, BucketSent)
|
||||||
|
}
|
||||||
|
|
||||||
|
// delete the now acknowledged 'sent' data
|
||||||
|
_, err = txn.Exec(`DELETE FROM syncv3_device_list_updates WHERE user_id=$1 AND device_id=$2 AND bucket=$3`, userID, deviceID, BucketSent)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
// grab any 'new' updates and atomically mark these as 'sent'.
|
||||||
|
// NB: we must not SELECT then UPDATE, because a 'new' row could be inserted after the SELECT and before the UPDATE, which
|
||||||
|
// would then be incorrectly moved to 'sent' without being returned to the client, dropping the data. This happens because
|
||||||
|
// the default transaction level is 'read committed', which /allows/ nonrepeatable reads which is:
|
||||||
|
// > A transaction re-reads data it has previously read and finds that data has been modified by another transaction (that committed since the initial read).
|
||||||
|
// We could change the isolation level but this incurs extra performance costs in addition to serialisation errors which
|
||||||
|
// need to be handled. It's easier to just use UPDATE .. RETURNING. Note that we don't require UPDATE .. RETURNING to be
|
||||||
|
// atomic in any way, it's just that we need to guarantee each things SELECTed is also UPDATEd (so in the scenario above,
|
||||||
|
// we don't care if the SELECT includes or excludes the 'new' row, but if it is SELECTed it MUST be UPDATEd).
|
||||||
|
rows, err := txn.Query(`UPDATE syncv3_device_list_updates SET bucket=$1 WHERE user_id=$2 AND device_id=$3 AND bucket=$4 RETURNING target_user_id, target_state`, BucketSent, userID, deviceID, BucketNew)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
result = make(internal.MapStringInt)
|
||||||
|
var targetUserID string
|
||||||
|
var targetState int
|
||||||
|
for rows.Next() {
|
||||||
|
if err := rows.Scan(&targetUserID, &targetState); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
result[targetUserID] = targetState
|
||||||
|
}
|
||||||
|
return result, rows.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *DeviceListTable) selectDeviceListChangesInBucket(txn *sqlx.Tx, userID, deviceID string, bucket int) (result internal.MapStringInt, err error) {
|
||||||
|
rows, err := txn.Query(`SELECT target_user_id, target_state FROM syncv3_device_list_updates WHERE user_id=$1 AND device_id=$2 AND bucket=$3`, userID, deviceID, bucket)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
result = make(internal.MapStringInt)
|
||||||
|
var targetUserID string
|
||||||
|
var targetState int
|
||||||
|
for rows.Next() {
|
||||||
|
if err := rows.Scan(&targetUserID, &targetState); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
result[targetUserID] = targetState
|
||||||
|
}
|
||||||
|
return result, rows.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
type DeviceListChunker []DeviceListRow
|
||||||
|
|
||||||
|
func (c DeviceListChunker) Len() int {
|
||||||
|
return len(c)
|
||||||
|
}
|
||||||
|
func (c DeviceListChunker) Subslice(i, j int) sqlutil.Chunker {
|
||||||
|
return c[i:j]
|
||||||
|
}
|
120
state/device_list_table_test.go
Normal file
120
state/device_list_table_test.go
Normal file
@ -0,0 +1,120 @@
|
|||||||
|
package state
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/matrix-org/sliding-sync/internal"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Tests the DeviceLists table
|
||||||
|
func TestDeviceListTable(t *testing.T) {
|
||||||
|
db, close := connectToDB(t)
|
||||||
|
defer close()
|
||||||
|
table := NewDeviceListTable(db)
|
||||||
|
userID := "@TestDeviceListTable"
|
||||||
|
deviceID := "BOB"
|
||||||
|
|
||||||
|
// these are individual updates from Synapse from /sync v2
|
||||||
|
deltas := []internal.MapStringInt{
|
||||||
|
{
|
||||||
|
"alice": internal.DeviceListChanged,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"💣": internal.DeviceListChanged,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
// apply them
|
||||||
|
for _, dd := range deltas {
|
||||||
|
err := table.Upsert(userID, deviceID, dd)
|
||||||
|
assertNoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// check we can read-only select. This doesn't modify any fields.
|
||||||
|
for i := 0; i < 3; i++ {
|
||||||
|
got, err := table.Select(userID, deviceID, false)
|
||||||
|
assertNoError(t, err)
|
||||||
|
// until we "swap" we don't consume the New entries
|
||||||
|
assertVal(t, "unexpected data on swapless select", got, internal.MapStringInt{})
|
||||||
|
}
|
||||||
|
// now swap-er-roo, which shifts everything from New into Sent.
|
||||||
|
got, err := table.Select(userID, deviceID, true)
|
||||||
|
assertNoError(t, err)
|
||||||
|
assertVal(t, "did not select what was upserted on swap select", got, internal.MapStringInt{
|
||||||
|
"alice": internal.DeviceListChanged,
|
||||||
|
"💣": internal.DeviceListChanged,
|
||||||
|
})
|
||||||
|
|
||||||
|
// this is permanent, read-only views show this too.
|
||||||
|
got, err = table.Select(userID, deviceID, false)
|
||||||
|
assertNoError(t, err)
|
||||||
|
assertVal(t, "swapless select did not return the same data as before", got, internal.MapStringInt{
|
||||||
|
"alice": internal.DeviceListChanged,
|
||||||
|
"💣": internal.DeviceListChanged,
|
||||||
|
})
|
||||||
|
|
||||||
|
// We now expect empty DeviceLists, as we swapped twice.
|
||||||
|
got, err = table.Select(userID, deviceID, true)
|
||||||
|
assertNoError(t, err)
|
||||||
|
assertVal(t, "swap select did not return nothing", got, internal.MapStringInt{})
|
||||||
|
|
||||||
|
// get back the original state
|
||||||
|
assertNoError(t, err)
|
||||||
|
for _, dd := range deltas {
|
||||||
|
err = table.Upsert(userID, deviceID, dd)
|
||||||
|
assertNoError(t, err)
|
||||||
|
}
|
||||||
|
// Move original state to Sent by swapping
|
||||||
|
got, err = table.Select(userID, deviceID, true)
|
||||||
|
assertNoError(t, err)
|
||||||
|
assertVal(t, "did not select what was upserted on swap select", got, internal.MapStringInt{
|
||||||
|
"alice": internal.DeviceListChanged,
|
||||||
|
"💣": internal.DeviceListChanged,
|
||||||
|
})
|
||||||
|
// Add new entries to New before acknowledging Sent
|
||||||
|
err = table.Upsert(userID, deviceID, internal.MapStringInt{
|
||||||
|
"💣": internal.DeviceListChanged,
|
||||||
|
"charlie": internal.DeviceListLeft,
|
||||||
|
})
|
||||||
|
assertNoError(t, err)
|
||||||
|
|
||||||
|
// Reading without swapping does not move New->Sent, so returns the previous value
|
||||||
|
got, err = table.Select(userID, deviceID, false)
|
||||||
|
assertNoError(t, err)
|
||||||
|
assertVal(t, "swapless select did not return the same data as before", got, internal.MapStringInt{
|
||||||
|
"alice": internal.DeviceListChanged,
|
||||||
|
"💣": internal.DeviceListChanged,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Append even more items to New
|
||||||
|
err = table.Upsert(userID, deviceID, internal.MapStringInt{
|
||||||
|
"charlie": internal.DeviceListChanged, // we previously said "left" for charlie, so as "changed" is newer, we should see "changed"
|
||||||
|
"dave": internal.DeviceListLeft,
|
||||||
|
})
|
||||||
|
assertNoError(t, err)
|
||||||
|
|
||||||
|
// Now swap: all the combined items in New go into Sent
|
||||||
|
got, err = table.Select(userID, deviceID, true)
|
||||||
|
assertNoError(t, err)
|
||||||
|
assertVal(t, "swap select did not return combined new items", got, internal.MapStringInt{
|
||||||
|
"💣": internal.DeviceListChanged,
|
||||||
|
"charlie": internal.DeviceListChanged,
|
||||||
|
"dave": internal.DeviceListLeft,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Swapping again clears Sent out, and since nothing is in New we get an empty list
|
||||||
|
got, err = table.Select(userID, deviceID, true)
|
||||||
|
assertNoError(t, err)
|
||||||
|
assertVal(t, "swap select did not return combined new items", got, internal.MapStringInt{})
|
||||||
|
|
||||||
|
// large updates work (chunker)
|
||||||
|
largeUpdate := internal.MapStringInt{}
|
||||||
|
for i := 0; i < 100000; i++ {
|
||||||
|
largeUpdate[fmt.Sprintf("user_%d", i)] = internal.DeviceListChanged
|
||||||
|
}
|
||||||
|
err = table.Upsert(userID, deviceID, largeUpdate)
|
||||||
|
assertNoError(t, err)
|
||||||
|
got, err = table.Select(userID, deviceID, true)
|
||||||
|
assertNoError(t, err)
|
||||||
|
assertVal(t, "swap select did not return large items", got, largeUpdate)
|
||||||
|
}
|
@ -7,7 +7,6 @@ import (
|
|||||||
|
|
||||||
"github.com/jmoiron/sqlx"
|
"github.com/jmoiron/sqlx"
|
||||||
_ "github.com/lib/pq"
|
_ "github.com/lib/pq"
|
||||||
"github.com/matrix-org/sliding-sync/internal"
|
|
||||||
"github.com/matrix-org/sliding-sync/testutils"
|
"github.com/matrix-org/sliding-sync/testutils"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -48,8 +47,8 @@ func TestJSONBMigration(t *testing.T) {
|
|||||||
defer tx.Commit()
|
defer tx.Commit()
|
||||||
|
|
||||||
// insert some "invalid" data
|
// insert some "invalid" data
|
||||||
dd := internal.DeviceData{
|
dd := OldDeviceData{
|
||||||
DeviceLists: internal.DeviceLists{
|
DeviceLists: OldDeviceLists{
|
||||||
New: map[string]int{"@💣:localhost": 1},
|
New: map[string]int{"@💣:localhost": 1},
|
||||||
Sent: map[string]int{},
|
Sent: map[string]int{},
|
||||||
},
|
},
|
||||||
|
@ -9,7 +9,6 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/fxamacker/cbor/v2"
|
"github.com/fxamacker/cbor/v2"
|
||||||
"github.com/matrix-org/sliding-sync/internal"
|
|
||||||
"github.com/matrix-org/sliding-sync/sync2"
|
"github.com/matrix-org/sliding-sync/sync2"
|
||||||
"github.com/pressly/goose/v3"
|
"github.com/pressly/goose/v3"
|
||||||
)
|
)
|
||||||
@ -59,7 +58,7 @@ func upCborDeviceData(ctx context.Context, tx *sql.Tx) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for dd, jsonBytes := range deviceDatas {
|
for dd, jsonBytes := range deviceDatas {
|
||||||
var data internal.DeviceData
|
var data OldDeviceData
|
||||||
if err := json.Unmarshal(jsonBytes, &data); err != nil {
|
if err := json.Unmarshal(jsonBytes, &data); err != nil {
|
||||||
return fmt.Errorf("failed to unmarshal JSON: %v -> %v", string(jsonBytes), err)
|
return fmt.Errorf("failed to unmarshal JSON: %v -> %v", string(jsonBytes), err)
|
||||||
}
|
}
|
||||||
@ -115,7 +114,7 @@ func downCborDeviceData(ctx context.Context, tx *sql.Tx) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for dd, cborBytes := range deviceDatas {
|
for dd, cborBytes := range deviceDatas {
|
||||||
var data internal.DeviceData
|
var data OldDeviceData
|
||||||
if err := cbor.Unmarshal(cborBytes, &data); err != nil {
|
if err := cbor.Unmarshal(cborBytes, &data); err != nil {
|
||||||
return fmt.Errorf("failed to unmarshal CBOR: %v", err)
|
return fmt.Errorf("failed to unmarshal CBOR: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -2,13 +2,15 @@ package migrations
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"database/sql"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"reflect"
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/fxamacker/cbor/v2"
|
||||||
|
"github.com/jmoiron/sqlx"
|
||||||
_ "github.com/lib/pq"
|
_ "github.com/lib/pq"
|
||||||
"github.com/matrix-org/sliding-sync/internal"
|
"github.com/matrix-org/sliding-sync/sqlutil"
|
||||||
"github.com/matrix-org/sliding-sync/state"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestCBORBMigration(t *testing.T) {
|
func TestCBORBMigration(t *testing.T) {
|
||||||
@ -30,9 +32,9 @@ func TestCBORBMigration(t *testing.T) {
|
|||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
rowData := []internal.DeviceData{
|
rowData := []OldDeviceData{
|
||||||
{
|
{
|
||||||
DeviceLists: internal.DeviceLists{
|
DeviceLists: OldDeviceLists{
|
||||||
New: map[string]int{"@bob:localhost": 2},
|
New: map[string]int{"@bob:localhost": 2},
|
||||||
Sent: map[string]int{},
|
Sent: map[string]int{},
|
||||||
},
|
},
|
||||||
@ -43,7 +45,7 @@ func TestCBORBMigration(t *testing.T) {
|
|||||||
UserID: "@alice:localhost",
|
UserID: "@alice:localhost",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
DeviceLists: internal.DeviceLists{
|
DeviceLists: OldDeviceLists{
|
||||||
New: map[string]int{"@💣:localhost": 1, "@bomb:localhost": 2},
|
New: map[string]int{"@💣:localhost": 1, "@bomb:localhost": 2},
|
||||||
Sent: map[string]int{"@sent:localhost": 1},
|
Sent: map[string]int{"@sent:localhost": 1},
|
||||||
},
|
},
|
||||||
@ -78,9 +80,8 @@ func TestCBORBMigration(t *testing.T) {
|
|||||||
tx.Commit()
|
tx.Commit()
|
||||||
|
|
||||||
// ensure we can now select it
|
// ensure we can now select it
|
||||||
table := state.NewDeviceDataTable(db)
|
|
||||||
for _, want := range rowData {
|
for _, want := range rowData {
|
||||||
got, err := table.Select(want.UserID, want.DeviceID, false)
|
got, err := OldDeviceDataTableSelect(db, want.UserID, want.DeviceID, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@ -101,7 +102,7 @@ func TestCBORBMigration(t *testing.T) {
|
|||||||
|
|
||||||
// ensure it is what we originally inserted
|
// ensure it is what we originally inserted
|
||||||
for _, want := range rowData {
|
for _, want := range rowData {
|
||||||
var got internal.DeviceData
|
var got OldDeviceData
|
||||||
var gotBytes []byte
|
var gotBytes []byte
|
||||||
err = tx.QueryRow(`SELECT data FROM syncv3_device_data WHERE user_id=$1 AND device_id=$2`, want.UserID, want.DeviceID).Scan(&gotBytes)
|
err = tx.QueryRow(`SELECT data FROM syncv3_device_data WHERE user_id=$1 AND device_id=$2`, want.UserID, want.DeviceID).Scan(&gotBytes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -119,3 +120,66 @@ func TestCBORBMigration(t *testing.T) {
|
|||||||
|
|
||||||
tx.Commit()
|
tx.Commit()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type OldDeviceDataRow struct {
|
||||||
|
ID int64 `db:"id"`
|
||||||
|
UserID string `db:"user_id"`
|
||||||
|
DeviceID string `db:"device_id"`
|
||||||
|
// This will contain internal.DeviceData serialised as JSON. It's stored in a single column as we don't
|
||||||
|
// need to perform searches on this data.
|
||||||
|
Data []byte `db:"data"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func OldDeviceDataTableSelect(db *sqlx.DB, userID, deviceID string, swap bool) (result *OldDeviceData, err error) {
|
||||||
|
err = sqlutil.WithTransaction(db, func(txn *sqlx.Tx) error {
|
||||||
|
var row OldDeviceDataRow
|
||||||
|
err = txn.Get(&row, `SELECT data FROM syncv3_device_data WHERE user_id=$1 AND device_id=$2 FOR UPDATE`, userID, deviceID)
|
||||||
|
if err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
// if there is no device data for this user, it's not an error.
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
// unmarshal to swap
|
||||||
|
opts := cbor.DecOptions{
|
||||||
|
MaxMapPairs: 1000000000, // 1 billion :(
|
||||||
|
}
|
||||||
|
decMode, err := opts.DecMode()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err = decMode.Unmarshal(row.Data, &result); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
result.UserID = userID
|
||||||
|
result.DeviceID = deviceID
|
||||||
|
if !swap {
|
||||||
|
return nil // don't swap
|
||||||
|
}
|
||||||
|
// the caller will only look at sent, so make sure what is new is now in sent
|
||||||
|
result.DeviceLists.Sent = result.DeviceLists.New
|
||||||
|
|
||||||
|
// swap over the fields
|
||||||
|
writeBack := *result
|
||||||
|
writeBack.DeviceLists.Sent = result.DeviceLists.New
|
||||||
|
writeBack.DeviceLists.New = make(map[string]int)
|
||||||
|
writeBack.ChangedBits = 0
|
||||||
|
|
||||||
|
if reflect.DeepEqual(result, &writeBack) {
|
||||||
|
// The update to the DB would be a no-op; don't bother with it.
|
||||||
|
// This helps reduce write usage and the contention on the unique index for
|
||||||
|
// the device_data table.
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
// re-marshal and write
|
||||||
|
data, err := cbor.Marshal(writeBack)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = txn.Exec(`UPDATE syncv3_device_data SET data=$1 WHERE user_id=$2 AND device_id=$3`, data, userID, deviceID)
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
@ -151,6 +151,7 @@ func TestClearStuckInvites(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
defer tx.Rollback()
|
||||||
|
|
||||||
// users in room B (bob) and F (doris) should be reset.
|
// users in room B (bob) and F (doris) should be reset.
|
||||||
tokens, err := tokensTable.TokenForEachDevice(tx)
|
tokens, err := tokensTable.TokenForEachDevice(tx)
|
||||||
|
190
state/migrations/20240517104423_device_list_table.go
Normal file
190
state/migrations/20240517104423_device_list_table.go
Normal file
@ -0,0 +1,190 @@
|
|||||||
|
package migrations
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/fxamacker/cbor/v2"
|
||||||
|
"github.com/matrix-org/sliding-sync/internal"
|
||||||
|
"github.com/matrix-org/sliding-sync/sqlutil"
|
||||||
|
"github.com/matrix-org/sliding-sync/state"
|
||||||
|
"github.com/pressly/goose/v3"
|
||||||
|
)
|
||||||
|
|
||||||
|
type OldDeviceData struct {
|
||||||
|
// Contains the latest device_one_time_keys_count values.
|
||||||
|
// Set whenever this field arrives down the v2 poller, and it replaces what was previously there.
|
||||||
|
OTKCounts internal.MapStringInt `json:"otk"`
|
||||||
|
// Contains the latest device_unused_fallback_key_types value
|
||||||
|
// Set whenever this field arrives down the v2 poller, and it replaces what was previously there.
|
||||||
|
// If this is a nil slice this means no change. If this is an empty slice then this means the fallback key was used up.
|
||||||
|
FallbackKeyTypes []string `json:"fallback"`
|
||||||
|
|
||||||
|
DeviceLists OldDeviceLists `json:"dl"`
|
||||||
|
|
||||||
|
// bitset for which device data changes are present. They accumulate until they get swapped over
|
||||||
|
// when they get reset
|
||||||
|
ChangedBits int `json:"c"`
|
||||||
|
|
||||||
|
UserID string
|
||||||
|
DeviceID string
|
||||||
|
}
|
||||||
|
|
||||||
|
type OldDeviceLists struct {
|
||||||
|
// map user_id -> DeviceList enum
|
||||||
|
New internal.MapStringInt `json:"n"`
|
||||||
|
Sent internal.MapStringInt `json:"s"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
goose.AddMigrationContext(upDeviceListTable, downDeviceListTable)
|
||||||
|
}
|
||||||
|
|
||||||
|
func upDeviceListTable(ctx context.Context, tx *sql.Tx) error {
|
||||||
|
// create the table. It's a bit gross we need to dupe the schema here, but this is the first migration to
|
||||||
|
// add a new table like this.
|
||||||
|
_, err := tx.Exec(`
|
||||||
|
CREATE TABLE IF NOT EXISTS syncv3_device_list_updates (
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
device_id TEXT NOT NULL,
|
||||||
|
target_user_id TEXT NOT NULL,
|
||||||
|
target_state SMALLINT NOT NULL,
|
||||||
|
bucket SMALLINT NOT NULL,
|
||||||
|
UNIQUE(user_id, device_id, target_user_id, bucket)
|
||||||
|
);`)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
var count int
|
||||||
|
if err = tx.QueryRow(`SELECT count(*) FROM syncv3_device_data`).Scan(&count); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
logger.Info().Int("count", count).Msg("transferring device list data for devices")
|
||||||
|
|
||||||
|
// scan for existing CBOR (streaming as the CBOR with cursors as it can be large)
|
||||||
|
_, err = tx.Exec(`DECLARE device_data_migration_cursor CURSOR FOR SELECT user_id, device_id, data FROM syncv3_device_data`)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer tx.Exec("CLOSE device_data_migration_cursor")
|
||||||
|
var userID string
|
||||||
|
var deviceID string
|
||||||
|
var data []byte
|
||||||
|
// every N seconds log an update
|
||||||
|
updateFrequency := time.Second * 2
|
||||||
|
lastUpdate := time.Now()
|
||||||
|
i := 0
|
||||||
|
for {
|
||||||
|
// logging
|
||||||
|
i++
|
||||||
|
if time.Since(lastUpdate) > updateFrequency {
|
||||||
|
logger.Info().Msgf("%d/%d process device list data", i, count)
|
||||||
|
lastUpdate = time.Now()
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tx.QueryRow(
|
||||||
|
`FETCH NEXT FROM device_data_migration_cursor`,
|
||||||
|
).Scan(&userID, &deviceID, &data); err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
// End of rows.
|
||||||
|
break
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// * deserialise the CBOR
|
||||||
|
result, err := deserialiseCBOR(data)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// * transfer the device lists to the new device lists table
|
||||||
|
var deviceListRows []state.DeviceListRow
|
||||||
|
for targetUser, targetState := range result.DeviceLists.New {
|
||||||
|
deviceListRows = append(deviceListRows, state.DeviceListRow{
|
||||||
|
UserID: userID,
|
||||||
|
DeviceID: deviceID,
|
||||||
|
TargetUserID: targetUser,
|
||||||
|
TargetState: targetState,
|
||||||
|
Bucket: state.BucketNew,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
for targetUser, targetState := range result.DeviceLists.Sent {
|
||||||
|
deviceListRows = append(deviceListRows, state.DeviceListRow{
|
||||||
|
UserID: userID,
|
||||||
|
DeviceID: deviceID,
|
||||||
|
TargetUserID: targetUser,
|
||||||
|
TargetState: targetState,
|
||||||
|
Bucket: state.BucketSent,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
if len(deviceListRows) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
chunks := sqlutil.Chunkify(5, state.MaxPostgresParameters, state.DeviceListChunker(deviceListRows))
|
||||||
|
for _, chunk := range chunks {
|
||||||
|
var placeholders []string
|
||||||
|
var vals []interface{}
|
||||||
|
listChunk := chunk.(state.DeviceListChunker)
|
||||||
|
for i, deviceListRow := range listChunk {
|
||||||
|
placeholders = append(placeholders, fmt.Sprintf("($%d,$%d,$%d,$%d,$%d)",
|
||||||
|
i*5+1,
|
||||||
|
i*5+2,
|
||||||
|
i*5+3,
|
||||||
|
i*5+4,
|
||||||
|
i*5+5,
|
||||||
|
))
|
||||||
|
vals = append(vals, deviceListRow.UserID, deviceListRow.DeviceID, deviceListRow.TargetUserID, deviceListRow.TargetState, deviceListRow.Bucket)
|
||||||
|
}
|
||||||
|
query := fmt.Sprintf(
|
||||||
|
`INSERT INTO syncv3_device_list_updates(user_id, device_id, target_user_id, target_state, bucket) VALUES %s`,
|
||||||
|
strings.Join(placeholders, ","),
|
||||||
|
)
|
||||||
|
_, err = tx.ExecContext(ctx, query, vals...)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to bulk insert: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// * delete the device lists from the CBOR and update
|
||||||
|
result.DeviceLists = OldDeviceLists{
|
||||||
|
New: make(internal.MapStringInt),
|
||||||
|
Sent: make(internal.MapStringInt),
|
||||||
|
}
|
||||||
|
data, err := cbor.Marshal(result)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_, err = tx.ExecContext(ctx, `UPDATE syncv3_device_data SET data=$1 WHERE user_id=$2 AND device_id=$3`, data, userID, deviceID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func downDeviceListTable(ctx context.Context, tx *sql.Tx) error {
|
||||||
|
// no-op: we'll drop the device list updates but still work correctly as new/sent are still in the cbor but are empty.
|
||||||
|
// This will lose some device list updates.
|
||||||
|
_, err := tx.Exec(`DROP TABLE IF EXISTS syncv3_device_list_updates`)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func deserialiseCBOR(data []byte) (*OldDeviceData, error) {
|
||||||
|
opts := cbor.DecOptions{
|
||||||
|
MaxMapPairs: 1000000000, // 1 billion :(
|
||||||
|
}
|
||||||
|
decMode, err := opts.DecMode()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
var result *OldDeviceData
|
||||||
|
if err = decMode.Unmarshal(data, &result); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
159
state/migrations/20240517104423_device_list_table_test.go
Normal file
159
state/migrations/20240517104423_device_list_table_test.go
Normal file
@ -0,0 +1,159 @@
|
|||||||
|
package migrations
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/fxamacker/cbor/v2"
|
||||||
|
"github.com/matrix-org/sliding-sync/internal"
|
||||||
|
"github.com/matrix-org/sliding-sync/state"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDeviceListTableMigration(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
db, close := connectToDB(t)
|
||||||
|
defer close()
|
||||||
|
|
||||||
|
// Create the table in the old format (data = JSONB instead of BYTEA)
|
||||||
|
// and insert some data: we'll make sure that this data is preserved
|
||||||
|
// after migrating.
|
||||||
|
_, err := db.Exec(`CREATE TABLE IF NOT EXISTS syncv3_device_data (
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
device_id TEXT NOT NULL,
|
||||||
|
data BYTEA NOT NULL,
|
||||||
|
UNIQUE(user_id, device_id)
|
||||||
|
);`)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create table: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// insert old data
|
||||||
|
rowData := []OldDeviceData{
|
||||||
|
{
|
||||||
|
DeviceLists: OldDeviceLists{
|
||||||
|
New: map[string]int{"@bob:localhost": 2},
|
||||||
|
Sent: map[string]int{},
|
||||||
|
},
|
||||||
|
ChangedBits: 2,
|
||||||
|
OTKCounts: map[string]int{"bar": 42},
|
||||||
|
FallbackKeyTypes: []string{"narp"},
|
||||||
|
DeviceID: "ALICE",
|
||||||
|
UserID: "@alice:localhost",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
DeviceLists: OldDeviceLists{
|
||||||
|
New: map[string]int{"@💣:localhost": 1, "@bomb:localhost": 2},
|
||||||
|
Sent: map[string]int{"@sent:localhost": 1},
|
||||||
|
},
|
||||||
|
OTKCounts: map[string]int{"foo": 100},
|
||||||
|
FallbackKeyTypes: []string{"yep"},
|
||||||
|
DeviceID: "BOB",
|
||||||
|
UserID: "@bob:localhost",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, data := range rowData {
|
||||||
|
blob, err := cbor.Marshal(data)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
_, err = db.ExecContext(ctx, `INSERT INTO syncv3_device_data (user_id, device_id, data) VALUES ($1, $2, $3)`, data.UserID, data.DeviceID, blob)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// now migrate and ensure we didn't lose any data
|
||||||
|
tx, err := db.Begin()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
err = upDeviceListTable(ctx, tx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
tx.Commit()
|
||||||
|
|
||||||
|
wantSents := []internal.DeviceData{
|
||||||
|
{
|
||||||
|
UserID: "@alice:localhost",
|
||||||
|
DeviceID: "ALICE",
|
||||||
|
DeviceKeyData: internal.DeviceKeyData{
|
||||||
|
OTKCounts: internal.MapStringInt{
|
||||||
|
"bar": 42,
|
||||||
|
},
|
||||||
|
FallbackKeyTypes: []string{"narp"},
|
||||||
|
ChangedBits: 2,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
UserID: "@bob:localhost",
|
||||||
|
DeviceID: "BOB",
|
||||||
|
DeviceListChanges: internal.DeviceListChanges{
|
||||||
|
DeviceListChanged: []string{"@sent:localhost"},
|
||||||
|
},
|
||||||
|
DeviceKeyData: internal.DeviceKeyData{
|
||||||
|
OTKCounts: internal.MapStringInt{
|
||||||
|
"foo": 100,
|
||||||
|
},
|
||||||
|
FallbackKeyTypes: []string{"yep"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
table := state.NewDeviceDataTable(db)
|
||||||
|
for _, wantSent := range wantSents {
|
||||||
|
gotSent, err := table.Select(wantSent.UserID, wantSent.DeviceID, false)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
assertVal(t, "'sent' data was corrupted during the migration", *gotSent, wantSent)
|
||||||
|
}
|
||||||
|
|
||||||
|
wantNews := []internal.DeviceData{
|
||||||
|
{
|
||||||
|
UserID: "@alice:localhost",
|
||||||
|
DeviceID: "ALICE",
|
||||||
|
DeviceListChanges: internal.DeviceListChanges{
|
||||||
|
DeviceListLeft: []string{"@bob:localhost"},
|
||||||
|
},
|
||||||
|
DeviceKeyData: internal.DeviceKeyData{
|
||||||
|
OTKCounts: internal.MapStringInt{
|
||||||
|
"bar": 42,
|
||||||
|
},
|
||||||
|
FallbackKeyTypes: []string{"narp"},
|
||||||
|
ChangedBits: 2,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
UserID: "@bob:localhost",
|
||||||
|
DeviceID: "BOB",
|
||||||
|
DeviceListChanges: internal.DeviceListChanges{
|
||||||
|
DeviceListChanged: []string{"@💣:localhost"},
|
||||||
|
DeviceListLeft: []string{"@bomb:localhost"},
|
||||||
|
},
|
||||||
|
DeviceKeyData: internal.DeviceKeyData{
|
||||||
|
OTKCounts: internal.MapStringInt{
|
||||||
|
"foo": 100,
|
||||||
|
},
|
||||||
|
FallbackKeyTypes: []string{"yep"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, wantNew := range wantNews {
|
||||||
|
gotNew, err := table.Select(wantNew.UserID, wantNew.DeviceID, true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
assertVal(t, "'new' data was corrupted during the migration", *gotNew, wantNew)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func assertVal(t *testing.T, msg string, got, want interface{}) {
|
||||||
|
t.Helper()
|
||||||
|
if !reflect.DeepEqual(got, want) {
|
||||||
|
t.Errorf("%s: got\n%#v\nwant\n%#v", msg, got, want)
|
||||||
|
}
|
||||||
|
}
|
@ -232,17 +232,10 @@ func (h *Handler) OnE2EEData(ctx context.Context, userID, deviceID string, otkCo
|
|||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
h.e2eeWorkerPool.Queue(func() {
|
h.e2eeWorkerPool.Queue(func() {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
// some of these fields may be set
|
err := h.Store.DeviceDataTable.Upsert(userID, deviceID, internal.DeviceKeyData{
|
||||||
partialDD := internal.DeviceData{
|
|
||||||
UserID: userID,
|
|
||||||
DeviceID: deviceID,
|
|
||||||
OTKCounts: otkCounts,
|
OTKCounts: otkCounts,
|
||||||
FallbackKeyTypes: fallbackKeyTypes,
|
FallbackKeyTypes: fallbackKeyTypes,
|
||||||
DeviceLists: internal.DeviceLists{
|
}, deviceListChanges)
|
||||||
New: deviceListChanges,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
err := h.Store.DeviceDataTable.Upsert(&partialDD)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Err(err).Str("user", userID).Msg("failed to upsert device data")
|
logger.Err(err).Str("user", userID).Msg("failed to upsert device data")
|
||||||
internal.GetSentryHubFromContextOrDefault(ctx).CaptureException(err)
|
internal.GetSentryHubFromContextOrDefault(ctx).CaptureException(err)
|
||||||
|
@ -70,11 +70,10 @@ func (r *E2EERequest) ProcessInitial(ctx context.Context, res *Response, extCtx
|
|||||||
extRes.OTKCounts = dd.OTKCounts
|
extRes.OTKCounts = dd.OTKCounts
|
||||||
hasUpdates = true
|
hasUpdates = true
|
||||||
}
|
}
|
||||||
changed, left := internal.DeviceListChangesArrays(dd.DeviceLists.Sent)
|
if len(dd.DeviceListChanged) > 0 || len(dd.DeviceListLeft) > 0 {
|
||||||
if len(changed) > 0 || len(left) > 0 {
|
|
||||||
extRes.DeviceLists = &E2EEDeviceList{
|
extRes.DeviceLists = &E2EEDeviceList{
|
||||||
Changed: changed,
|
Changed: dd.DeviceListChanged,
|
||||||
Left: left,
|
Left: dd.DeviceListLeft,
|
||||||
}
|
}
|
||||||
hasUpdates = true
|
hasUpdates = true
|
||||||
}
|
}
|
||||||
|
@ -1,11 +1,13 @@
|
|||||||
package testutils
|
package testutils
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"os/user"
|
"os/user"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
var Quiet = false
|
var Quiet = false
|
||||||
@ -64,7 +66,9 @@ func PrepareDBConnectionString() (connStr string) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
_, err = db.Exec(`DROP SCHEMA public CASCADE;
|
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
_, err = db.ExecContext(ctx, `DROP SCHEMA public CASCADE;
|
||||||
CREATE SCHEMA public;`)
|
CREATE SCHEMA public;`)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user