Update txns table

This commit is contained in:
David Robertson 2023-05-02 14:56:59 +01:00
parent e1bc972ff7
commit b428ede1ca
No known key found for this signature in database
GPG Key ID: 903ECE108A39DEDD
11 changed files with 40 additions and 36 deletions

View File

@ -8,7 +8,8 @@ import (
)
type txnRow struct {
DeviceID string `db:"user_id"`
UserID string `db:"user_id"`
DeviceID string `db:"device_id"`
EventID string `db:"event_id"`
TxnID string `db:"txn_id"`
Timestamp int64 `db:"ts"`
@ -22,30 +23,32 @@ func NewTransactionsTable(db *sqlx.DB) *TransactionsTable {
// make sure tables are made
db.MustExec(`
CREATE TABLE IF NOT EXISTS syncv3_txns (
user_id TEXT NOT NULL, -- actually device_id
user_id TEXT NOT NULL, -- was actually device_id before migration
device_id TEXT NOT NULL,
event_id TEXT NOT NULL,
txn_id TEXT NOT NULL,
ts BIGINT NOT NULL,
UNIQUE(user_id, event_id)
UNIQUE(user_id, device_id, event_id)
);
`)
return &TransactionsTable{db}
}
func (t *TransactionsTable) Insert(deviceID string, eventIDToTxnID map[string]string) error {
func (t *TransactionsTable) Insert(userID, deviceID string, eventIDToTxnID map[string]string) error {
ts := time.Now()
rows := make([]txnRow, 0, len(eventIDToTxnID))
for eventID, txnID := range eventIDToTxnID {
rows = append(rows, txnRow{
EventID: eventID,
TxnID: txnID,
UserID: userID,
DeviceID: deviceID,
Timestamp: ts.UnixMilli(),
})
}
result, err := t.db.NamedQuery(`
INSERT INTO syncv3_txns (user_id, event_id, txn_id, ts)
VALUES (:user_id, :event_id, :txn_id, :ts)`, rows)
INSERT INTO syncv3_txns (user_id, device_id, event_id, txn_id, ts)
VALUES (:user_id, :device_id, :event_id, :txn_id, :ts)`, rows)
if err == nil {
result.Close()
}
@ -57,10 +60,10 @@ func (t *TransactionsTable) Clean(boundaryTime time.Time) error {
return err
}
func (t *TransactionsTable) Select(deviceID string, eventIDs []string) (map[string]string, error) {
func (t *TransactionsTable) Select(userID, deviceID string, eventIDs []string) (map[string]string, error) {
result := make(map[string]string, len(eventIDs))
var rows []txnRow
err := t.db.Select(&rows, `SELECT event_id, txn_id FROM syncv3_txns WHERE user_id=$1 and event_id=ANY($2)`, deviceID, pq.StringArray(eventIDs))
err := t.db.Select(&rows, `SELECT event_id, txn_id FROM syncv3_txns WHERE user_id=$1 AND device_id=$2 and event_id=ANY($3)`, userID, deviceID, pq.StringArray(eventIDs))
if err != nil {
return nil, err
}

View File

@ -26,33 +26,34 @@ func TestTransactionTable(t *testing.T) {
db, close := connectToDB(t)
defer close()
userID := "@alice:txns"
deviceID := "alice_phone"
eventA := "$A"
eventB := "$B"
txnIDA := "txn_A"
txnIDB := "txn_B"
table := NewTransactionsTable(db)
// empty table select
gotTxns, err := table.Select(userID, []string{eventA})
gotTxns, err := table.Select(userID, deviceID, []string{eventA})
assertNoError(t, err)
assertTxns(t, gotTxns, nil)
// basic insert and select
err = table.Insert(userID, map[string]string{
err = table.Insert(userID, deviceID, map[string]string{
eventA: txnIDA,
})
assertNoError(t, err)
gotTxns, err = table.Select(userID, []string{eventA})
gotTxns, err = table.Select(userID, deviceID, []string{eventA})
assertNoError(t, err)
assertTxns(t, gotTxns, map[string]string{
eventA: txnIDA,
})
// multiple txns
err = table.Insert(userID, map[string]string{
err = table.Insert(userID, deviceID, map[string]string{
eventB: txnIDB,
})
assertNoError(t, err)
gotTxns, err = table.Select(userID, []string{eventA, eventB})
gotTxns, err = table.Select(userID, deviceID, []string{eventA, eventB})
assertNoError(t, err)
assertTxns(t, gotTxns, map[string]string{
eventA: txnIDA,
@ -60,14 +61,14 @@ func TestTransactionTable(t *testing.T) {
})
// different user select
gotTxns, err = table.Select("@another", []string{eventA, eventB})
gotTxns, err = table.Select("@another", "another_device", []string{eventA, eventB})
assertNoError(t, err)
assertTxns(t, gotTxns, nil)
// no-op cleanup
err = table.Clean(time.Now().Add(-1 * time.Minute))
assertNoError(t, err)
gotTxns, err = table.Select(userID, []string{eventA, eventB})
gotTxns, err = table.Select(userID, deviceID, []string{eventA, eventB})
assertNoError(t, err)
assertTxns(t, gotTxns, map[string]string{
eventA: txnIDA,
@ -77,7 +78,7 @@ func TestTransactionTable(t *testing.T) {
// real cleanup
err = table.Clean(time.Now())
assertNoError(t, err)
gotTxns, err = table.Select(userID, []string{eventA, eventB})
gotTxns, err = table.Select(userID, deviceID, []string{eventA, eventB})
assertNoError(t, err)
assertTxns(t, gotTxns, nil)

View File

@ -206,7 +206,7 @@ func (h *Handler) OnE2EEData(userID, deviceID string, otkCounts map[string]int,
})
}
func (h *Handler) Accumulate(deviceID, roomID, prevBatch string, timeline []json.RawMessage) {
func (h *Handler) Accumulate(userID, deviceID, roomID, prevBatch string, timeline []json.RawMessage) {
// Remember any transaction IDs that may be unique to this user
eventIDToTxnID := make(map[string]string, len(timeline)) // event_id -> txn_id
for _, e := range timeline {
@ -219,9 +219,9 @@ func (h *Handler) Accumulate(deviceID, roomID, prevBatch string, timeline []json
}
if len(eventIDToTxnID) > 0 {
// persist the txn IDs
err := h.Store.TransactionsTable.Insert(deviceID, eventIDToTxnID)
err := h.Store.TransactionsTable.Insert(userID, deviceID, eventIDToTxnID)
if err != nil {
logger.Err(err).Str("device", deviceID).Int("num_txns", len(eventIDToTxnID)).Msg("failed to persist txn IDs for user")
logger.Err(err).Str("user", userID).Str("device", deviceID).Int("num_txns", len(eventIDToTxnID)).Msg("failed to persist txn IDs for user")
sentry.CaptureException(err)
}
}

View File

@ -27,7 +27,7 @@ type V2DataReceiver interface {
// Update the since token for this device. Called AFTER all other data in this sync response has been processed.
UpdateDeviceSince(userID, deviceID, since string)
// Accumulate data for this room. This means the timeline section of the v2 response.
Accumulate(deviceID, roomID, prevBatch string, timeline []json.RawMessage) // latest pos with event nids of timeline entries
Accumulate(userID, deviceID, roomID, prevBatch string, timeline []json.RawMessage) // latest pos with event nids of timeline entries
// Initialise the room, if it hasn't been already. This means the state section of the v2 response.
// If given a state delta from an incremental sync, returns the slice of all state events unknown to the DB.
Initialise(roomID string, state []json.RawMessage) []json.RawMessage // snapshot ID?
@ -201,7 +201,7 @@ func (h *PollerMap) EnsurePolling(pid PollerID, accessToken, v2since string, isS
if needToWait {
poller.WaitUntilInitialSync()
} else {
logger.Info().Msg("a poller exists for this user; not waiting for this device to do an initial sync")
logger.Info().Str("user", poller.userID).Msg("a poller exists for this user; not waiting for this device to do an initial sync")
}
}
@ -214,11 +214,11 @@ func (h *PollerMap) execute() {
func (h *PollerMap) UpdateDeviceSince(userID, deviceID, since string) {
h.callbacks.UpdateDeviceSince(userID, deviceID, since)
}
func (h *PollerMap) Accumulate(deviceID, roomID, prevBatch string, timeline []json.RawMessage) {
func (h *PollerMap) Accumulate(userID, deviceID, roomID, prevBatch string, timeline []json.RawMessage) {
var wg sync.WaitGroup
wg.Add(1)
h.executor <- func() {
h.callbacks.Accumulate(deviceID, roomID, prevBatch, timeline)
h.callbacks.Accumulate(userID, deviceID, roomID, prevBatch, timeline)
wg.Done()
}
wg.Wait()
@ -564,7 +564,7 @@ func (p *poller) parseRoomsResponse(res *SyncResponse) {
if len(roomData.Timeline.Events) > 0 {
timelineCalls++
p.trackTimelineSize(len(roomData.Timeline.Events), roomData.Timeline.Limited)
p.receiver.Accumulate(p.deviceID, roomID, roomData.Timeline.PrevBatch, roomData.Timeline.Events)
p.receiver.Accumulate(p.userID, p.deviceID, roomID, roomData.Timeline.PrevBatch, roomData.Timeline.Events)
}
// process unread counts AFTER events so global caches have been updated by the time this metadata is added.
@ -580,7 +580,7 @@ func (p *poller) parseRoomsResponse(res *SyncResponse) {
// TODO: do we care about state?
if len(roomData.Timeline.Events) > 0 {
p.trackTimelineSize(len(roomData.Timeline.Events), roomData.Timeline.Limited)
p.receiver.Accumulate(p.deviceID, roomID, roomData.Timeline.PrevBatch, roomData.Timeline.Events)
p.receiver.Accumulate(p.userID, p.deviceID, roomID, roomData.Timeline.PrevBatch, roomData.Timeline.Events)
}
p.receiver.OnLeftRoom(p.userID, roomID)
}

View File

@ -460,7 +460,7 @@ type mockDataReceiver struct {
unblockProcess chan struct{}
}
func (a *mockDataReceiver) Accumulate(userID, roomID, prevBatch string, timeline []json.RawMessage) {
func (a *mockDataReceiver) Accumulate(userID, deviceID, roomID, prevBatch string, timeline []json.RawMessage) {
a.timelines[roomID] = append(a.timelines[roomID], timeline...)
}
func (a *mockDataReceiver) Initialise(roomID string, state []json.RawMessage) []json.RawMessage {

View File

@ -24,7 +24,7 @@ type CacheFinder interface {
}
type TransactionIDFetcher interface {
TransactionIDForEvents(deviceID string, eventIDs []string) (eventIDToTxnID map[string]string)
TransactionIDForEvents(userID, deviceID string, eventIDs []string) (eventIDToTxnID map[string]string)
}
type UserRoomData struct {
@ -448,7 +448,7 @@ func (c *UserCache) Invites() map[string]UserRoomData {
// events are globally scoped, so if Alice sends a message, Bob might receive it first on his v2 loop
// which would cause the transaction ID to be missing from the event. Instead, we always look for txn
// IDs in the v2 poller, and then set them appropriately at request time.
func (c *UserCache) AnnotateWithTransactionIDs(ctx context.Context, deviceID string, roomIDToEvents map[string][]json.RawMessage) map[string][]json.RawMessage {
func (c *UserCache) AnnotateWithTransactionIDs(ctx context.Context, userID string, deviceID string, roomIDToEvents map[string][]json.RawMessage) map[string][]json.RawMessage {
var eventIDs []string
eventIDToEvent := make(map[string]struct {
roomID string
@ -467,7 +467,7 @@ func (c *UserCache) AnnotateWithTransactionIDs(ctx context.Context, deviceID str
}
}
}
eventIDToTxnID := c.txnIDs.TransactionIDForEvents(deviceID, eventIDs)
eventIDToTxnID := c.txnIDs.TransactionIDForEvents(userID, deviceID, eventIDs)
for eventID, txnID := range eventIDToTxnID {
data, ok := eventIDToEvent[eventID]
if !ok {

View File

@ -14,7 +14,7 @@ type txnIDFetcher struct {
data map[string]string
}
func (t *txnIDFetcher) TransactionIDForEvents(deviceID string, eventIDs []string) (eventIDToTxnID map[string]string) {
func (t *txnIDFetcher) TransactionIDForEvents(userID string, deviceID string, eventIDs []string) (eventIDToTxnID map[string]string) {
eventIDToTxnID = make(map[string]string)
for _, eventID := range eventIDs {
txnID, ok := t.data[eventID]
@ -83,7 +83,7 @@ func TestAnnotateWithTransactionIDs(t *testing.T) {
data: tc.eventIDToTxnIDs,
}
uc := caches.NewUserCache(userID, nil, nil, fetcher)
got := uc.AnnotateWithTransactionIDs(context.Background(), "DEVICE", convertIDToEventStub(tc.roomIDToEvents))
got := uc.AnnotateWithTransactionIDs(context.Background(), userID, "DEVICE", convertIDToEventStub(tc.roomIDToEvents))
want := convertIDTxnToEventStub(tc.wantRoomIDToEvents)
if !reflect.DeepEqual(got, want) {
t.Errorf("%s : got %v want %v", tc.name, js(got), js(want))

View File

@ -458,7 +458,7 @@ func (s *ConnState) getInitialRoomData(ctx context.Context, roomSub sync3.RoomSu
roomToUsersInTimeline[roomID] = userIDs
roomToTimeline[roomID] = urd.Timeline
}
roomToTimeline = s.userCache.AnnotateWithTransactionIDs(ctx, s.deviceID, roomToTimeline)
roomToTimeline = s.userCache.AnnotateWithTransactionIDs(ctx, s.userID, s.deviceID, roomToTimeline)
rsm := roomSub.RequiredStateMap(s.userID)
roomIDToState := s.globalCache.LoadRoomState(ctx, roomIDs, s.loadPosition, rsm, roomToUsersInTimeline)
if roomIDToState == nil { // e.g no required_state

View File

@ -210,7 +210,7 @@ func (s *connStateLive) processLiveUpdate(ctx context.Context, up caches.Update,
// - the initial:true room from BuildSubscriptions contains the latest live events in the timeline as it's pulled from the DB
// - we then process the live events in turn which adds them again.
if !advancedPastEvent {
roomIDtoTimeline := s.userCache.AnnotateWithTransactionIDs(ctx, s.deviceID, map[string][]json.RawMessage{
roomIDtoTimeline := s.userCache.AnnotateWithTransactionIDs(ctx, s.userID, s.deviceID, map[string][]json.RawMessage{
roomEventUpdate.RoomID(): {roomEventUpdate.EventData.Event},
})
r.Timeline = append(r.Timeline, roomIDtoTimeline[roomEventUpdate.RoomID()]...)

View File

@ -33,7 +33,7 @@ func (t *NopJoinTracker) IsUserJoined(userID, roomID string) bool {
type NopTransactionFetcher struct{}
func (t *NopTransactionFetcher) TransactionIDForEvents(userID string, eventID []string) (eventIDToTxnID map[string]string) {
func (t *NopTransactionFetcher) TransactionIDForEvents(userID, deviceID string, eventIDs []string) (eventIDToTxnID map[string]string) {
return
}

View File

@ -533,8 +533,8 @@ func (h *SyncLiveHandler) DeviceData(ctx context.Context, userID, deviceID strin
}
// Implements TransactionIDFetcher
func (h *SyncLiveHandler) TransactionIDForEvents(deviceID string, eventIDs []string) (eventIDToTxnID map[string]string) {
eventIDToTxnID, err := h.Storage.TransactionsTable.Select(deviceID, eventIDs)
func (h *SyncLiveHandler) TransactionIDForEvents(userID string, deviceID string, eventIDs []string) (eventIDToTxnID map[string]string) {
eventIDToTxnID, err := h.Storage.TransactionsTable.Select(userID, deviceID, eventIDs)
if err != nil {
logger.Warn().Str("err", err.Error()).Str("device", deviceID).Msg("failed to select txn IDs for events")
}