Some review comments; swap to UPDATE..RETURNING

This commit is contained in:
Kegan Dougal 2024-05-20 08:22:48 +01:00
parent 35c9fd4d95
commit fdbebaea68
4 changed files with 37 additions and 28 deletions

View File

@ -107,15 +107,15 @@ func (t *DeviceDataTable) Select(userID, deviceID string, swap bool) (result *in
}
// Upsert combines what is in the database for this user|device with the partial entry `dd`
func (t *DeviceDataTable) Upsert(dd *internal.DeviceData, deviceListChanges map[string]int) (err error) {
func (t *DeviceDataTable) Upsert(userID, deviceID string, keys internal.DeviceKeyData, deviceListChanges map[string]int) (err error) {
err = sqlutil.WithTransaction(t.db, func(txn *sqlx.Tx) error {
// Update device lists
if err = t.deviceListTable.UpsertTx(txn, dd.UserID, dd.DeviceID, deviceListChanges); err != nil {
if err = t.deviceListTable.UpsertTx(txn, userID, deviceID, deviceListChanges); err != nil {
return err
}
// select what already exists
var row DeviceDataRow
err = txn.Get(&row, `SELECT data FROM syncv3_device_data WHERE user_id=$1 AND device_id=$2 FOR UPDATE`, dd.UserID, dd.DeviceID)
err = txn.Get(&row, `SELECT data FROM syncv3_device_data WHERE user_id=$1 AND device_id=$2 FOR UPDATE`, userID, deviceID)
if err != nil && err != sql.ErrNoRows {
return err
}
@ -126,12 +126,12 @@ func (t *DeviceDataTable) Upsert(dd *internal.DeviceData, deviceListChanges map[
return err
}
}
if dd.FallbackKeyTypes != nil {
keyData.FallbackKeyTypes = dd.FallbackKeyTypes
if keys.FallbackKeyTypes != nil {
keyData.FallbackKeyTypes = keys.FallbackKeyTypes
keyData.SetFallbackKeysChanged()
}
if dd.OTKCounts != nil {
keyData.OTKCounts = dd.OTKCounts
if keys.OTKCounts != nil {
keyData.OTKCounts = keys.OTKCounts
keyData.SetOTKCountChanged()
}
@ -142,7 +142,7 @@ func (t *DeviceDataTable) Upsert(dd *internal.DeviceData, deviceListChanges map[
_, err = txn.Exec(
`INSERT INTO syncv3_device_data(user_id, device_id, data) VALUES($1,$2,$3)
ON CONFLICT (user_id, device_id) DO UPDATE SET data=$3`,
dd.UserID, dd.DeviceID, data,
userID, deviceID, data,
)
return err
})

View File

@ -68,7 +68,7 @@ func TestDeviceDataTableOTKCountAndFallbackKeyTypes(t *testing.T) {
// apply them
for _, dd := range deltas {
err := table.Upsert(&dd, nil)
err := table.Upsert(dd.UserID, dd.DeviceID, dd.DeviceKeyData, nil)
assertNoError(t, err)
}
@ -161,7 +161,7 @@ func TestDeviceDataTableBitset(t *testing.T) {
},
}
err := table.Upsert(&otkUpdate, nil)
err := table.Upsert(otkUpdate.UserID, otkUpdate.DeviceID, otkUpdate.DeviceKeyData, nil)
assertNoError(t, err)
got, err := table.Select(userID, deviceID, true)
assertNoError(t, err)
@ -173,7 +173,7 @@ func TestDeviceDataTableBitset(t *testing.T) {
otkUpdate.ChangedBits = 0
assertDeviceData(t, *got, otkUpdate)
// now same for fallback keys, but we won't swap them so it should return those diffs
err = table.Upsert(&fallbakKeyUpdate, nil)
err = table.Upsert(fallbakKeyUpdate.UserID, fallbakKeyUpdate.DeviceID, fallbakKeyUpdate.DeviceKeyData, nil)
assertNoError(t, err)
fallbakKeyUpdate.OTKCounts = otkUpdate.OTKCounts
got, err = table.Select(userID, deviceID, false)
@ -185,7 +185,7 @@ func TestDeviceDataTableBitset(t *testing.T) {
fallbakKeyUpdate.SetFallbackKeysChanged()
assertDeviceData(t, *got, fallbakKeyUpdate)
// updating both works
err = table.Upsert(&bothUpdate, nil)
err = table.Upsert(bothUpdate.UserID, bothUpdate.DeviceID, bothUpdate.DeviceKeyData, nil)
assertNoError(t, err)
got, err = table.Select(userID, deviceID, true)
assertNoError(t, err)

View File

@ -109,15 +109,30 @@ func (t *DeviceListTable) SelectTx(txn *sqlx.Tx, userID, deviceID string, swap b
if err != nil {
return nil, err
}
// grab any 'new' updates
result, err = t.selectDeviceListChangesInBucket(txn, userID, deviceID, BucketNew)
// grab any 'new' updates and atomically mark these as 'sent'.
// NB: we must not SELECT then UPDATE, because a 'new' row could be inserted after the SELECT and before the UPDATE, which
// would then be incorrectly moved to 'sent' without being returned to the client, dropping the data. This happens because
// the default transaction level is 'read committed', which /allows/ nonrepeatable reads which is:
// > A transaction re-reads data it has previously read and finds that data has been modified by another transaction (that committed since the initial read).
// We could change the isolation level but this incurs extra performance costs in addition to serialisation errors which
// need to be handled. It's easier to just use UPDATE .. RETURNING. Note that we don't require UPDATE .. RETURNING to be
// atomic in any way, it's just that we need to guarantee each things SELECTed is also UPDATEd (so in the scenario above,
// we don't care if the SELECT includes or excludes the 'new' row, but if it is SELECTed it MUST be UPDATEd).
rows, err := txn.Query(`UPDATE syncv3_device_list_updates SET bucket=$1 WHERE user_id=$2 AND device_id=$3 AND bucket=$4 RETURNING target_user_id, target_state`, BucketSent, userID, deviceID, BucketNew)
if err != nil {
return nil, err
}
// mark these 'new' updates as 'sent'
_, err = txn.Exec(`UPDATE syncv3_device_list_updates SET bucket=$1 WHERE user_id=$2 AND device_id=$3 AND bucket=$4`, BucketSent, userID, deviceID, BucketNew)
return result, err
defer rows.Close()
result = make(internal.MapStringInt)
var targetUserID string
var targetState int
for rows.Next() {
if err := rows.Scan(&targetUserID, &targetState); err != nil {
return nil, err
}
result[targetUserID] = targetState
}
return result, rows.Err()
}
func (t *DeviceListTable) selectDeviceListChangesInBucket(txn *sqlx.Tx, userID, deviceID string, bucket int) (result internal.MapStringInt, err error) {

View File

@ -232,16 +232,10 @@ func (h *Handler) OnE2EEData(ctx context.Context, userID, deviceID string, otkCo
wg.Add(1)
h.e2eeWorkerPool.Queue(func() {
defer wg.Done()
// some of these fields may be set
partialDD := internal.DeviceData{
UserID: userID,
DeviceID: deviceID,
DeviceKeyData: internal.DeviceKeyData{
err := h.Store.DeviceDataTable.Upsert(userID, deviceID, internal.DeviceKeyData{
OTKCounts: otkCounts,
FallbackKeyTypes: fallbackKeyTypes,
},
}
err := h.Store.DeviceDataTable.Upsert(&partialDD, deviceListChanges)
}, deviceListChanges)
if err != nil {
logger.Err(err).Str("user", userID).Msg("failed to upsert device data")
internal.GetSentryHubFromContextOrDefault(ctx).CaptureException(err)