diff --git a/state/device_data_table.go b/state/device_data_table.go index 86388fe..2c53295 100644 --- a/state/device_data_table.go +++ b/state/device_data_table.go @@ -107,15 +107,15 @@ func (t *DeviceDataTable) Select(userID, deviceID string, swap bool) (result *in } // Upsert combines what is in the database for this user|device with the partial entry `dd` -func (t *DeviceDataTable) Upsert(dd *internal.DeviceData, deviceListChanges map[string]int) (err error) { +func (t *DeviceDataTable) Upsert(userID, deviceID string, keys internal.DeviceKeyData, deviceListChanges map[string]int) (err error) { err = sqlutil.WithTransaction(t.db, func(txn *sqlx.Tx) error { // Update device lists - if err = t.deviceListTable.UpsertTx(txn, dd.UserID, dd.DeviceID, deviceListChanges); err != nil { + if err = t.deviceListTable.UpsertTx(txn, userID, deviceID, deviceListChanges); err != nil { return err } // select what already exists var row DeviceDataRow - err = txn.Get(&row, `SELECT data FROM syncv3_device_data WHERE user_id=$1 AND device_id=$2 FOR UPDATE`, dd.UserID, dd.DeviceID) + err = txn.Get(&row, `SELECT data FROM syncv3_device_data WHERE user_id=$1 AND device_id=$2 FOR UPDATE`, userID, deviceID) if err != nil && err != sql.ErrNoRows { return err } @@ -126,12 +126,12 @@ func (t *DeviceDataTable) Upsert(dd *internal.DeviceData, deviceListChanges map[ return err } } - if dd.FallbackKeyTypes != nil { - keyData.FallbackKeyTypes = dd.FallbackKeyTypes + if keys.FallbackKeyTypes != nil { + keyData.FallbackKeyTypes = keys.FallbackKeyTypes keyData.SetFallbackKeysChanged() } - if dd.OTKCounts != nil { - keyData.OTKCounts = dd.OTKCounts + if keys.OTKCounts != nil { + keyData.OTKCounts = keys.OTKCounts keyData.SetOTKCountChanged() } @@ -142,7 +142,7 @@ func (t *DeviceDataTable) Upsert(dd *internal.DeviceData, deviceListChanges map[ _, err = txn.Exec( `INSERT INTO syncv3_device_data(user_id, device_id, data) VALUES($1,$2,$3) ON CONFLICT (user_id, device_id) DO UPDATE SET data=$3`, - dd.UserID, dd.DeviceID, data, + userID, deviceID, data, ) return err }) diff --git a/state/device_data_table_test.go b/state/device_data_table_test.go index eac7309..39e2d05 100644 --- a/state/device_data_table_test.go +++ b/state/device_data_table_test.go @@ -68,7 +68,7 @@ func TestDeviceDataTableOTKCountAndFallbackKeyTypes(t *testing.T) { // apply them for _, dd := range deltas { - err := table.Upsert(&dd, nil) + err := table.Upsert(dd.UserID, dd.DeviceID, dd.DeviceKeyData, nil) assertNoError(t, err) } @@ -161,7 +161,7 @@ func TestDeviceDataTableBitset(t *testing.T) { }, } - err := table.Upsert(&otkUpdate, nil) + err := table.Upsert(otkUpdate.UserID, otkUpdate.DeviceID, otkUpdate.DeviceKeyData, nil) assertNoError(t, err) got, err := table.Select(userID, deviceID, true) assertNoError(t, err) @@ -173,7 +173,7 @@ func TestDeviceDataTableBitset(t *testing.T) { otkUpdate.ChangedBits = 0 assertDeviceData(t, *got, otkUpdate) // now same for fallback keys, but we won't swap them so it should return those diffs - err = table.Upsert(&fallbakKeyUpdate, nil) + err = table.Upsert(fallbakKeyUpdate.UserID, fallbakKeyUpdate.DeviceID, fallbakKeyUpdate.DeviceKeyData, nil) assertNoError(t, err) fallbakKeyUpdate.OTKCounts = otkUpdate.OTKCounts got, err = table.Select(userID, deviceID, false) @@ -185,7 +185,7 @@ func TestDeviceDataTableBitset(t *testing.T) { fallbakKeyUpdate.SetFallbackKeysChanged() assertDeviceData(t, *got, fallbakKeyUpdate) // updating both works - err = table.Upsert(&bothUpdate, nil) + err = table.Upsert(bothUpdate.UserID, bothUpdate.DeviceID, bothUpdate.DeviceKeyData, nil) assertNoError(t, err) got, err = table.Select(userID, deviceID, true) assertNoError(t, err) diff --git a/state/device_list_table.go b/state/device_list_table.go index 889081e..350a606 100644 --- a/state/device_list_table.go +++ b/state/device_list_table.go @@ -109,15 +109,30 @@ func (t *DeviceListTable) SelectTx(txn *sqlx.Tx, userID, deviceID string, swap b if err != nil { return nil, err } - // grab any 'new' updates - result, err = t.selectDeviceListChangesInBucket(txn, userID, deviceID, BucketNew) + // grab any 'new' updates and atomically mark these as 'sent'. + // NB: we must not SELECT then UPDATE, because a 'new' row could be inserted after the SELECT and before the UPDATE, which + // would then be incorrectly moved to 'sent' without being returned to the client, dropping the data. This happens because + // the default transaction level is 'read committed', which /allows/ nonrepeatable reads which is: + // > A transaction re-reads data it has previously read and finds that data has been modified by another transaction (that committed since the initial read). + // We could change the isolation level but this incurs extra performance costs in addition to serialisation errors which + // need to be handled. It's easier to just use UPDATE .. RETURNING. Note that we don't require UPDATE .. RETURNING to be + // atomic in any way, it's just that we need to guarantee each things SELECTed is also UPDATEd (so in the scenario above, + // we don't care if the SELECT includes or excludes the 'new' row, but if it is SELECTed it MUST be UPDATEd). + rows, err := txn.Query(`UPDATE syncv3_device_list_updates SET bucket=$1 WHERE user_id=$2 AND device_id=$3 AND bucket=$4 RETURNING target_user_id, target_state`, BucketSent, userID, deviceID, BucketNew) if err != nil { return nil, err } - - // mark these 'new' updates as 'sent' - _, err = txn.Exec(`UPDATE syncv3_device_list_updates SET bucket=$1 WHERE user_id=$2 AND device_id=$3 AND bucket=$4`, BucketSent, userID, deviceID, BucketNew) - return result, err + defer rows.Close() + result = make(internal.MapStringInt) + var targetUserID string + var targetState int + for rows.Next() { + if err := rows.Scan(&targetUserID, &targetState); err != nil { + return nil, err + } + result[targetUserID] = targetState + } + return result, rows.Err() } func (t *DeviceListTable) selectDeviceListChangesInBucket(txn *sqlx.Tx, userID, deviceID string, bucket int) (result internal.MapStringInt, err error) { diff --git a/sync2/handler2/handler.go b/sync2/handler2/handler.go index b8f2975..7b6ecf5 100644 --- a/sync2/handler2/handler.go +++ b/sync2/handler2/handler.go @@ -232,16 +232,10 @@ func (h *Handler) OnE2EEData(ctx context.Context, userID, deviceID string, otkCo wg.Add(1) h.e2eeWorkerPool.Queue(func() { defer wg.Done() - // some of these fields may be set - partialDD := internal.DeviceData{ - UserID: userID, - DeviceID: deviceID, - DeviceKeyData: internal.DeviceKeyData{ - OTKCounts: otkCounts, - FallbackKeyTypes: fallbackKeyTypes, - }, - } - err := h.Store.DeviceDataTable.Upsert(&partialDD, deviceListChanges) + err := h.Store.DeviceDataTable.Upsert(userID, deviceID, internal.DeviceKeyData{ + OTKCounts: otkCounts, + FallbackKeyTypes: fallbackKeyTypes, + }, deviceListChanges) if err != nil { logger.Err(err).Str("user", userID).Msg("failed to upsert device data") internal.GetSentryHubFromContextOrDefault(ctx).CaptureException(err)