mirror of
https://github.com/matrix-org/sliding-sync.git
synced 2025-03-10 13:37:11 +00:00
Some review comments; swap to UPDATE..RETURNING
This commit is contained in:
parent
35c9fd4d95
commit
fdbebaea68
@ -107,15 +107,15 @@ func (t *DeviceDataTable) Select(userID, deviceID string, swap bool) (result *in
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Upsert combines what is in the database for this user|device with the partial entry `dd`
|
// Upsert combines what is in the database for this user|device with the partial entry `dd`
|
||||||
func (t *DeviceDataTable) Upsert(dd *internal.DeviceData, deviceListChanges map[string]int) (err error) {
|
func (t *DeviceDataTable) Upsert(userID, deviceID string, keys internal.DeviceKeyData, deviceListChanges map[string]int) (err error) {
|
||||||
err = sqlutil.WithTransaction(t.db, func(txn *sqlx.Tx) error {
|
err = sqlutil.WithTransaction(t.db, func(txn *sqlx.Tx) error {
|
||||||
// Update device lists
|
// Update device lists
|
||||||
if err = t.deviceListTable.UpsertTx(txn, dd.UserID, dd.DeviceID, deviceListChanges); err != nil {
|
if err = t.deviceListTable.UpsertTx(txn, userID, deviceID, deviceListChanges); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
// select what already exists
|
// select what already exists
|
||||||
var row DeviceDataRow
|
var row DeviceDataRow
|
||||||
err = txn.Get(&row, `SELECT data FROM syncv3_device_data WHERE user_id=$1 AND device_id=$2 FOR UPDATE`, dd.UserID, dd.DeviceID)
|
err = txn.Get(&row, `SELECT data FROM syncv3_device_data WHERE user_id=$1 AND device_id=$2 FOR UPDATE`, userID, deviceID)
|
||||||
if err != nil && err != sql.ErrNoRows {
|
if err != nil && err != sql.ErrNoRows {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -126,12 +126,12 @@ func (t *DeviceDataTable) Upsert(dd *internal.DeviceData, deviceListChanges map[
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if dd.FallbackKeyTypes != nil {
|
if keys.FallbackKeyTypes != nil {
|
||||||
keyData.FallbackKeyTypes = dd.FallbackKeyTypes
|
keyData.FallbackKeyTypes = keys.FallbackKeyTypes
|
||||||
keyData.SetFallbackKeysChanged()
|
keyData.SetFallbackKeysChanged()
|
||||||
}
|
}
|
||||||
if dd.OTKCounts != nil {
|
if keys.OTKCounts != nil {
|
||||||
keyData.OTKCounts = dd.OTKCounts
|
keyData.OTKCounts = keys.OTKCounts
|
||||||
keyData.SetOTKCountChanged()
|
keyData.SetOTKCountChanged()
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -142,7 +142,7 @@ func (t *DeviceDataTable) Upsert(dd *internal.DeviceData, deviceListChanges map[
|
|||||||
_, err = txn.Exec(
|
_, err = txn.Exec(
|
||||||
`INSERT INTO syncv3_device_data(user_id, device_id, data) VALUES($1,$2,$3)
|
`INSERT INTO syncv3_device_data(user_id, device_id, data) VALUES($1,$2,$3)
|
||||||
ON CONFLICT (user_id, device_id) DO UPDATE SET data=$3`,
|
ON CONFLICT (user_id, device_id) DO UPDATE SET data=$3`,
|
||||||
dd.UserID, dd.DeviceID, data,
|
userID, deviceID, data,
|
||||||
)
|
)
|
||||||
return err
|
return err
|
||||||
})
|
})
|
||||||
|
@ -68,7 +68,7 @@ func TestDeviceDataTableOTKCountAndFallbackKeyTypes(t *testing.T) {
|
|||||||
|
|
||||||
// apply them
|
// apply them
|
||||||
for _, dd := range deltas {
|
for _, dd := range deltas {
|
||||||
err := table.Upsert(&dd, nil)
|
err := table.Upsert(dd.UserID, dd.DeviceID, dd.DeviceKeyData, nil)
|
||||||
assertNoError(t, err)
|
assertNoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -161,7 +161,7 @@ func TestDeviceDataTableBitset(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
err := table.Upsert(&otkUpdate, nil)
|
err := table.Upsert(otkUpdate.UserID, otkUpdate.DeviceID, otkUpdate.DeviceKeyData, nil)
|
||||||
assertNoError(t, err)
|
assertNoError(t, err)
|
||||||
got, err := table.Select(userID, deviceID, true)
|
got, err := table.Select(userID, deviceID, true)
|
||||||
assertNoError(t, err)
|
assertNoError(t, err)
|
||||||
@ -173,7 +173,7 @@ func TestDeviceDataTableBitset(t *testing.T) {
|
|||||||
otkUpdate.ChangedBits = 0
|
otkUpdate.ChangedBits = 0
|
||||||
assertDeviceData(t, *got, otkUpdate)
|
assertDeviceData(t, *got, otkUpdate)
|
||||||
// now same for fallback keys, but we won't swap them so it should return those diffs
|
// now same for fallback keys, but we won't swap them so it should return those diffs
|
||||||
err = table.Upsert(&fallbakKeyUpdate, nil)
|
err = table.Upsert(fallbakKeyUpdate.UserID, fallbakKeyUpdate.DeviceID, fallbakKeyUpdate.DeviceKeyData, nil)
|
||||||
assertNoError(t, err)
|
assertNoError(t, err)
|
||||||
fallbakKeyUpdate.OTKCounts = otkUpdate.OTKCounts
|
fallbakKeyUpdate.OTKCounts = otkUpdate.OTKCounts
|
||||||
got, err = table.Select(userID, deviceID, false)
|
got, err = table.Select(userID, deviceID, false)
|
||||||
@ -185,7 +185,7 @@ func TestDeviceDataTableBitset(t *testing.T) {
|
|||||||
fallbakKeyUpdate.SetFallbackKeysChanged()
|
fallbakKeyUpdate.SetFallbackKeysChanged()
|
||||||
assertDeviceData(t, *got, fallbakKeyUpdate)
|
assertDeviceData(t, *got, fallbakKeyUpdate)
|
||||||
// updating both works
|
// updating both works
|
||||||
err = table.Upsert(&bothUpdate, nil)
|
err = table.Upsert(bothUpdate.UserID, bothUpdate.DeviceID, bothUpdate.DeviceKeyData, nil)
|
||||||
assertNoError(t, err)
|
assertNoError(t, err)
|
||||||
got, err = table.Select(userID, deviceID, true)
|
got, err = table.Select(userID, deviceID, true)
|
||||||
assertNoError(t, err)
|
assertNoError(t, err)
|
||||||
|
@ -109,15 +109,30 @@ func (t *DeviceListTable) SelectTx(txn *sqlx.Tx, userID, deviceID string, swap b
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
// grab any 'new' updates
|
// grab any 'new' updates and atomically mark these as 'sent'.
|
||||||
result, err = t.selectDeviceListChangesInBucket(txn, userID, deviceID, BucketNew)
|
// NB: we must not SELECT then UPDATE, because a 'new' row could be inserted after the SELECT and before the UPDATE, which
|
||||||
|
// would then be incorrectly moved to 'sent' without being returned to the client, dropping the data. This happens because
|
||||||
|
// the default transaction level is 'read committed', which /allows/ nonrepeatable reads which is:
|
||||||
|
// > A transaction re-reads data it has previously read and finds that data has been modified by another transaction (that committed since the initial read).
|
||||||
|
// We could change the isolation level but this incurs extra performance costs in addition to serialisation errors which
|
||||||
|
// need to be handled. It's easier to just use UPDATE .. RETURNING. Note that we don't require UPDATE .. RETURNING to be
|
||||||
|
// atomic in any way, it's just that we need to guarantee each things SELECTed is also UPDATEd (so in the scenario above,
|
||||||
|
// we don't care if the SELECT includes or excludes the 'new' row, but if it is SELECTed it MUST be UPDATEd).
|
||||||
|
rows, err := txn.Query(`UPDATE syncv3_device_list_updates SET bucket=$1 WHERE user_id=$2 AND device_id=$3 AND bucket=$4 RETURNING target_user_id, target_state`, BucketSent, userID, deviceID, BucketNew)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
defer rows.Close()
|
||||||
// mark these 'new' updates as 'sent'
|
result = make(internal.MapStringInt)
|
||||||
_, err = txn.Exec(`UPDATE syncv3_device_list_updates SET bucket=$1 WHERE user_id=$2 AND device_id=$3 AND bucket=$4`, BucketSent, userID, deviceID, BucketNew)
|
var targetUserID string
|
||||||
return result, err
|
var targetState int
|
||||||
|
for rows.Next() {
|
||||||
|
if err := rows.Scan(&targetUserID, &targetState); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
result[targetUserID] = targetState
|
||||||
|
}
|
||||||
|
return result, rows.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *DeviceListTable) selectDeviceListChangesInBucket(txn *sqlx.Tx, userID, deviceID string, bucket int) (result internal.MapStringInt, err error) {
|
func (t *DeviceListTable) selectDeviceListChangesInBucket(txn *sqlx.Tx, userID, deviceID string, bucket int) (result internal.MapStringInt, err error) {
|
||||||
|
@ -232,16 +232,10 @@ func (h *Handler) OnE2EEData(ctx context.Context, userID, deviceID string, otkCo
|
|||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
h.e2eeWorkerPool.Queue(func() {
|
h.e2eeWorkerPool.Queue(func() {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
// some of these fields may be set
|
err := h.Store.DeviceDataTable.Upsert(userID, deviceID, internal.DeviceKeyData{
|
||||||
partialDD := internal.DeviceData{
|
OTKCounts: otkCounts,
|
||||||
UserID: userID,
|
FallbackKeyTypes: fallbackKeyTypes,
|
||||||
DeviceID: deviceID,
|
}, deviceListChanges)
|
||||||
DeviceKeyData: internal.DeviceKeyData{
|
|
||||||
OTKCounts: otkCounts,
|
|
||||||
FallbackKeyTypes: fallbackKeyTypes,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
err := h.Store.DeviceDataTable.Upsert(&partialDD, deviceListChanges)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Err(err).Str("user", userID).Msg("failed to upsert device data")
|
logger.Err(err).Str("user", userID).Msg("failed to upsert device data")
|
||||||
internal.GetSentryHubFromContextOrDefault(ctx).CaptureException(err)
|
internal.GetSentryHubFromContextOrDefault(ctx).CaptureException(err)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user