From b428ede1ca36e1ea61f188499e693f646346072a Mon Sep 17 00:00:00 2001 From: David Robertson Date: Tue, 2 May 2023 14:56:59 +0100 Subject: [PATCH] Update txns table --- state/txn_table.go | 19 +++++++++++-------- state/txn_table_test.go | 17 +++++++++-------- sync2/handler2/handler.go | 6 +++--- sync2/poller.go | 12 ++++++------ sync2/poller_test.go | 2 +- sync3/caches/user.go | 6 +++--- sync3/caches/user_test.go | 4 ++-- sync3/handler/connstate.go | 2 +- sync3/handler/connstate_live.go | 2 +- sync3/handler/connstate_test.go | 2 +- sync3/handler/handler.go | 4 ++-- 11 files changed, 40 insertions(+), 36 deletions(-) diff --git a/state/txn_table.go b/state/txn_table.go index 77c1241..452b264 100644 --- a/state/txn_table.go +++ b/state/txn_table.go @@ -8,7 +8,8 @@ import ( ) type txnRow struct { - DeviceID string `db:"user_id"` + UserID string `db:"user_id"` + DeviceID string `db:"device_id"` EventID string `db:"event_id"` TxnID string `db:"txn_id"` Timestamp int64 `db:"ts"` @@ -22,30 +23,32 @@ func NewTransactionsTable(db *sqlx.DB) *TransactionsTable { // make sure tables are made db.MustExec(` CREATE TABLE IF NOT EXISTS syncv3_txns ( - user_id TEXT NOT NULL, -- actually device_id + user_id TEXT NOT NULL, -- was actually device_id before migration + device_id TEXT NOT NULL, event_id TEXT NOT NULL, txn_id TEXT NOT NULL, ts BIGINT NOT NULL, - UNIQUE(user_id, event_id) + UNIQUE(user_id, device_id, event_id) ); `) return &TransactionsTable{db} } -func (t *TransactionsTable) Insert(deviceID string, eventIDToTxnID map[string]string) error { +func (t *TransactionsTable) Insert(userID, deviceID string, eventIDToTxnID map[string]string) error { ts := time.Now() rows := make([]txnRow, 0, len(eventIDToTxnID)) for eventID, txnID := range eventIDToTxnID { rows = append(rows, txnRow{ EventID: eventID, TxnID: txnID, + UserID: userID, DeviceID: deviceID, Timestamp: ts.UnixMilli(), }) } result, err := t.db.NamedQuery(` - INSERT INTO syncv3_txns (user_id, event_id, txn_id, ts) - VALUES (:user_id, :event_id, :txn_id, :ts)`, rows) + INSERT INTO syncv3_txns (user_id, device_id, event_id, txn_id, ts) + VALUES (:user_id, :device_id, :event_id, :txn_id, :ts)`, rows) if err == nil { result.Close() } @@ -57,10 +60,10 @@ func (t *TransactionsTable) Clean(boundaryTime time.Time) error { return err } -func (t *TransactionsTable) Select(deviceID string, eventIDs []string) (map[string]string, error) { +func (t *TransactionsTable) Select(userID, deviceID string, eventIDs []string) (map[string]string, error) { result := make(map[string]string, len(eventIDs)) var rows []txnRow - err := t.db.Select(&rows, `SELECT event_id, txn_id FROM syncv3_txns WHERE user_id=$1 and event_id=ANY($2)`, deviceID, pq.StringArray(eventIDs)) + err := t.db.Select(&rows, `SELECT event_id, txn_id FROM syncv3_txns WHERE user_id=$1 AND device_id=$2 and event_id=ANY($3)`, userID, deviceID, pq.StringArray(eventIDs)) if err != nil { return nil, err } diff --git a/state/txn_table_test.go b/state/txn_table_test.go index 42e22e0..85c0f25 100644 --- a/state/txn_table_test.go +++ b/state/txn_table_test.go @@ -26,33 +26,34 @@ func TestTransactionTable(t *testing.T) { db, close := connectToDB(t) defer close() userID := "@alice:txns" + deviceID := "alice_phone" eventA := "$A" eventB := "$B" txnIDA := "txn_A" txnIDB := "txn_B" table := NewTransactionsTable(db) // empty table select - gotTxns, err := table.Select(userID, []string{eventA}) + gotTxns, err := table.Select(userID, deviceID, []string{eventA}) assertNoError(t, err) assertTxns(t, gotTxns, nil) // basic insert and select - err = table.Insert(userID, map[string]string{ + err = table.Insert(userID, deviceID, map[string]string{ eventA: txnIDA, }) assertNoError(t, err) - gotTxns, err = table.Select(userID, []string{eventA}) + gotTxns, err = table.Select(userID, deviceID, []string{eventA}) assertNoError(t, err) assertTxns(t, gotTxns, map[string]string{ eventA: txnIDA, }) // multiple txns - err = table.Insert(userID, map[string]string{ + err = table.Insert(userID, deviceID, map[string]string{ eventB: txnIDB, }) assertNoError(t, err) - gotTxns, err = table.Select(userID, []string{eventA, eventB}) + gotTxns, err = table.Select(userID, deviceID, []string{eventA, eventB}) assertNoError(t, err) assertTxns(t, gotTxns, map[string]string{ eventA: txnIDA, @@ -60,14 +61,14 @@ func TestTransactionTable(t *testing.T) { }) // different user select - gotTxns, err = table.Select("@another", []string{eventA, eventB}) + gotTxns, err = table.Select("@another", "another_device", []string{eventA, eventB}) assertNoError(t, err) assertTxns(t, gotTxns, nil) // no-op cleanup err = table.Clean(time.Now().Add(-1 * time.Minute)) assertNoError(t, err) - gotTxns, err = table.Select(userID, []string{eventA, eventB}) + gotTxns, err = table.Select(userID, deviceID, []string{eventA, eventB}) assertNoError(t, err) assertTxns(t, gotTxns, map[string]string{ eventA: txnIDA, @@ -77,7 +78,7 @@ func TestTransactionTable(t *testing.T) { // real cleanup err = table.Clean(time.Now()) assertNoError(t, err) - gotTxns, err = table.Select(userID, []string{eventA, eventB}) + gotTxns, err = table.Select(userID, deviceID, []string{eventA, eventB}) assertNoError(t, err) assertTxns(t, gotTxns, nil) diff --git a/sync2/handler2/handler.go b/sync2/handler2/handler.go index 9b42306..1e57461 100644 --- a/sync2/handler2/handler.go +++ b/sync2/handler2/handler.go @@ -206,7 +206,7 @@ func (h *Handler) OnE2EEData(userID, deviceID string, otkCounts map[string]int, }) } -func (h *Handler) Accumulate(deviceID, roomID, prevBatch string, timeline []json.RawMessage) { +func (h *Handler) Accumulate(userID, deviceID, roomID, prevBatch string, timeline []json.RawMessage) { // Remember any transaction IDs that may be unique to this user eventIDToTxnID := make(map[string]string, len(timeline)) // event_id -> txn_id for _, e := range timeline { @@ -219,9 +219,9 @@ func (h *Handler) Accumulate(deviceID, roomID, prevBatch string, timeline []json } if len(eventIDToTxnID) > 0 { // persist the txn IDs - err := h.Store.TransactionsTable.Insert(deviceID, eventIDToTxnID) + err := h.Store.TransactionsTable.Insert(userID, deviceID, eventIDToTxnID) if err != nil { - logger.Err(err).Str("device", deviceID).Int("num_txns", len(eventIDToTxnID)).Msg("failed to persist txn IDs for user") + logger.Err(err).Str("user", userID).Str("device", deviceID).Int("num_txns", len(eventIDToTxnID)).Msg("failed to persist txn IDs for user") sentry.CaptureException(err) } } diff --git a/sync2/poller.go b/sync2/poller.go index c2c9b67..a3298f1 100644 --- a/sync2/poller.go +++ b/sync2/poller.go @@ -27,7 +27,7 @@ type V2DataReceiver interface { // Update the since token for this device. Called AFTER all other data in this sync response has been processed. UpdateDeviceSince(userID, deviceID, since string) // Accumulate data for this room. This means the timeline section of the v2 response. - Accumulate(deviceID, roomID, prevBatch string, timeline []json.RawMessage) // latest pos with event nids of timeline entries + Accumulate(userID, deviceID, roomID, prevBatch string, timeline []json.RawMessage) // latest pos with event nids of timeline entries // Initialise the room, if it hasn't been already. This means the state section of the v2 response. // If given a state delta from an incremental sync, returns the slice of all state events unknown to the DB. Initialise(roomID string, state []json.RawMessage) []json.RawMessage // snapshot ID? @@ -201,7 +201,7 @@ func (h *PollerMap) EnsurePolling(pid PollerID, accessToken, v2since string, isS if needToWait { poller.WaitUntilInitialSync() } else { - logger.Info().Msg("a poller exists for this user; not waiting for this device to do an initial sync") + logger.Info().Str("user", poller.userID).Msg("a poller exists for this user; not waiting for this device to do an initial sync") } } @@ -214,11 +214,11 @@ func (h *PollerMap) execute() { func (h *PollerMap) UpdateDeviceSince(userID, deviceID, since string) { h.callbacks.UpdateDeviceSince(userID, deviceID, since) } -func (h *PollerMap) Accumulate(deviceID, roomID, prevBatch string, timeline []json.RawMessage) { +func (h *PollerMap) Accumulate(userID, deviceID, roomID, prevBatch string, timeline []json.RawMessage) { var wg sync.WaitGroup wg.Add(1) h.executor <- func() { - h.callbacks.Accumulate(deviceID, roomID, prevBatch, timeline) + h.callbacks.Accumulate(userID, deviceID, roomID, prevBatch, timeline) wg.Done() } wg.Wait() @@ -564,7 +564,7 @@ func (p *poller) parseRoomsResponse(res *SyncResponse) { if len(roomData.Timeline.Events) > 0 { timelineCalls++ p.trackTimelineSize(len(roomData.Timeline.Events), roomData.Timeline.Limited) - p.receiver.Accumulate(p.deviceID, roomID, roomData.Timeline.PrevBatch, roomData.Timeline.Events) + p.receiver.Accumulate(p.userID, p.deviceID, roomID, roomData.Timeline.PrevBatch, roomData.Timeline.Events) } // process unread counts AFTER events so global caches have been updated by the time this metadata is added. @@ -580,7 +580,7 @@ func (p *poller) parseRoomsResponse(res *SyncResponse) { // TODO: do we care about state? if len(roomData.Timeline.Events) > 0 { p.trackTimelineSize(len(roomData.Timeline.Events), roomData.Timeline.Limited) - p.receiver.Accumulate(p.deviceID, roomID, roomData.Timeline.PrevBatch, roomData.Timeline.Events) + p.receiver.Accumulate(p.userID, p.deviceID, roomID, roomData.Timeline.PrevBatch, roomData.Timeline.Events) } p.receiver.OnLeftRoom(p.userID, roomID) } diff --git a/sync2/poller_test.go b/sync2/poller_test.go index 73f89ee..043d077 100644 --- a/sync2/poller_test.go +++ b/sync2/poller_test.go @@ -460,7 +460,7 @@ type mockDataReceiver struct { unblockProcess chan struct{} } -func (a *mockDataReceiver) Accumulate(userID, roomID, prevBatch string, timeline []json.RawMessage) { +func (a *mockDataReceiver) Accumulate(userID, deviceID, roomID, prevBatch string, timeline []json.RawMessage) { a.timelines[roomID] = append(a.timelines[roomID], timeline...) } func (a *mockDataReceiver) Initialise(roomID string, state []json.RawMessage) []json.RawMessage { diff --git a/sync3/caches/user.go b/sync3/caches/user.go index 552900a..0b6ecc3 100644 --- a/sync3/caches/user.go +++ b/sync3/caches/user.go @@ -24,7 +24,7 @@ type CacheFinder interface { } type TransactionIDFetcher interface { - TransactionIDForEvents(deviceID string, eventIDs []string) (eventIDToTxnID map[string]string) + TransactionIDForEvents(userID, deviceID string, eventIDs []string) (eventIDToTxnID map[string]string) } type UserRoomData struct { @@ -448,7 +448,7 @@ func (c *UserCache) Invites() map[string]UserRoomData { // events are globally scoped, so if Alice sends a message, Bob might receive it first on his v2 loop // which would cause the transaction ID to be missing from the event. Instead, we always look for txn // IDs in the v2 poller, and then set them appropriately at request time. -func (c *UserCache) AnnotateWithTransactionIDs(ctx context.Context, deviceID string, roomIDToEvents map[string][]json.RawMessage) map[string][]json.RawMessage { +func (c *UserCache) AnnotateWithTransactionIDs(ctx context.Context, userID string, deviceID string, roomIDToEvents map[string][]json.RawMessage) map[string][]json.RawMessage { var eventIDs []string eventIDToEvent := make(map[string]struct { roomID string @@ -467,7 +467,7 @@ func (c *UserCache) AnnotateWithTransactionIDs(ctx context.Context, deviceID str } } } - eventIDToTxnID := c.txnIDs.TransactionIDForEvents(deviceID, eventIDs) + eventIDToTxnID := c.txnIDs.TransactionIDForEvents(userID, deviceID, eventIDs) for eventID, txnID := range eventIDToTxnID { data, ok := eventIDToEvent[eventID] if !ok { diff --git a/sync3/caches/user_test.go b/sync3/caches/user_test.go index 0c9f459..a1e5ec8 100644 --- a/sync3/caches/user_test.go +++ b/sync3/caches/user_test.go @@ -14,7 +14,7 @@ type txnIDFetcher struct { data map[string]string } -func (t *txnIDFetcher) TransactionIDForEvents(deviceID string, eventIDs []string) (eventIDToTxnID map[string]string) { +func (t *txnIDFetcher) TransactionIDForEvents(userID string, deviceID string, eventIDs []string) (eventIDToTxnID map[string]string) { eventIDToTxnID = make(map[string]string) for _, eventID := range eventIDs { txnID, ok := t.data[eventID] @@ -83,7 +83,7 @@ func TestAnnotateWithTransactionIDs(t *testing.T) { data: tc.eventIDToTxnIDs, } uc := caches.NewUserCache(userID, nil, nil, fetcher) - got := uc.AnnotateWithTransactionIDs(context.Background(), "DEVICE", convertIDToEventStub(tc.roomIDToEvents)) + got := uc.AnnotateWithTransactionIDs(context.Background(), userID, "DEVICE", convertIDToEventStub(tc.roomIDToEvents)) want := convertIDTxnToEventStub(tc.wantRoomIDToEvents) if !reflect.DeepEqual(got, want) { t.Errorf("%s : got %v want %v", tc.name, js(got), js(want)) diff --git a/sync3/handler/connstate.go b/sync3/handler/connstate.go index d2921c3..3d514ed 100644 --- a/sync3/handler/connstate.go +++ b/sync3/handler/connstate.go @@ -458,7 +458,7 @@ func (s *ConnState) getInitialRoomData(ctx context.Context, roomSub sync3.RoomSu roomToUsersInTimeline[roomID] = userIDs roomToTimeline[roomID] = urd.Timeline } - roomToTimeline = s.userCache.AnnotateWithTransactionIDs(ctx, s.deviceID, roomToTimeline) + roomToTimeline = s.userCache.AnnotateWithTransactionIDs(ctx, s.userID, s.deviceID, roomToTimeline) rsm := roomSub.RequiredStateMap(s.userID) roomIDToState := s.globalCache.LoadRoomState(ctx, roomIDs, s.loadPosition, rsm, roomToUsersInTimeline) if roomIDToState == nil { // e.g no required_state diff --git a/sync3/handler/connstate_live.go b/sync3/handler/connstate_live.go index 76e0d5a..f4dcde2 100644 --- a/sync3/handler/connstate_live.go +++ b/sync3/handler/connstate_live.go @@ -210,7 +210,7 @@ func (s *connStateLive) processLiveUpdate(ctx context.Context, up caches.Update, // - the initial:true room from BuildSubscriptions contains the latest live events in the timeline as it's pulled from the DB // - we then process the live events in turn which adds them again. if !advancedPastEvent { - roomIDtoTimeline := s.userCache.AnnotateWithTransactionIDs(ctx, s.deviceID, map[string][]json.RawMessage{ + roomIDtoTimeline := s.userCache.AnnotateWithTransactionIDs(ctx, s.userID, s.deviceID, map[string][]json.RawMessage{ roomEventUpdate.RoomID(): {roomEventUpdate.EventData.Event}, }) r.Timeline = append(r.Timeline, roomIDtoTimeline[roomEventUpdate.RoomID()]...) diff --git a/sync3/handler/connstate_test.go b/sync3/handler/connstate_test.go index 530d3cd..07caa4c 100644 --- a/sync3/handler/connstate_test.go +++ b/sync3/handler/connstate_test.go @@ -33,7 +33,7 @@ func (t *NopJoinTracker) IsUserJoined(userID, roomID string) bool { type NopTransactionFetcher struct{} -func (t *NopTransactionFetcher) TransactionIDForEvents(userID string, eventID []string) (eventIDToTxnID map[string]string) { +func (t *NopTransactionFetcher) TransactionIDForEvents(userID, deviceID string, eventIDs []string) (eventIDToTxnID map[string]string) { return } diff --git a/sync3/handler/handler.go b/sync3/handler/handler.go index f1e8097..7c1091f 100644 --- a/sync3/handler/handler.go +++ b/sync3/handler/handler.go @@ -533,8 +533,8 @@ func (h *SyncLiveHandler) DeviceData(ctx context.Context, userID, deviceID strin } // Implements TransactionIDFetcher -func (h *SyncLiveHandler) TransactionIDForEvents(deviceID string, eventIDs []string) (eventIDToTxnID map[string]string) { - eventIDToTxnID, err := h.Storage.TransactionsTable.Select(deviceID, eventIDs) +func (h *SyncLiveHandler) TransactionIDForEvents(userID string, deviceID string, eventIDs []string) (eventIDToTxnID map[string]string) { + eventIDToTxnID, err := h.Storage.TransactionsTable.Select(userID, deviceID, eventIDs) if err != nil { logger.Warn().Str("err", err.Error()).Str("device", deviceID).Msg("failed to select txn IDs for events") }