mirror of
https://github.com/matrix-org/sliding-sync.git
synced 2025-03-10 13:37:11 +00:00
Update txns table
This commit is contained in:
parent
e1bc972ff7
commit
b428ede1ca
@ -8,7 +8,8 @@ import (
|
||||
)
|
||||
|
||||
type txnRow struct {
|
||||
DeviceID string `db:"user_id"`
|
||||
UserID string `db:"user_id"`
|
||||
DeviceID string `db:"device_id"`
|
||||
EventID string `db:"event_id"`
|
||||
TxnID string `db:"txn_id"`
|
||||
Timestamp int64 `db:"ts"`
|
||||
@ -22,30 +23,32 @@ func NewTransactionsTable(db *sqlx.DB) *TransactionsTable {
|
||||
// make sure tables are made
|
||||
db.MustExec(`
|
||||
CREATE TABLE IF NOT EXISTS syncv3_txns (
|
||||
user_id TEXT NOT NULL, -- actually device_id
|
||||
user_id TEXT NOT NULL, -- was actually device_id before migration
|
||||
device_id TEXT NOT NULL,
|
||||
event_id TEXT NOT NULL,
|
||||
txn_id TEXT NOT NULL,
|
||||
ts BIGINT NOT NULL,
|
||||
UNIQUE(user_id, event_id)
|
||||
UNIQUE(user_id, device_id, event_id)
|
||||
);
|
||||
`)
|
||||
return &TransactionsTable{db}
|
||||
}
|
||||
|
||||
func (t *TransactionsTable) Insert(deviceID string, eventIDToTxnID map[string]string) error {
|
||||
func (t *TransactionsTable) Insert(userID, deviceID string, eventIDToTxnID map[string]string) error {
|
||||
ts := time.Now()
|
||||
rows := make([]txnRow, 0, len(eventIDToTxnID))
|
||||
for eventID, txnID := range eventIDToTxnID {
|
||||
rows = append(rows, txnRow{
|
||||
EventID: eventID,
|
||||
TxnID: txnID,
|
||||
UserID: userID,
|
||||
DeviceID: deviceID,
|
||||
Timestamp: ts.UnixMilli(),
|
||||
})
|
||||
}
|
||||
result, err := t.db.NamedQuery(`
|
||||
INSERT INTO syncv3_txns (user_id, event_id, txn_id, ts)
|
||||
VALUES (:user_id, :event_id, :txn_id, :ts)`, rows)
|
||||
INSERT INTO syncv3_txns (user_id, device_id, event_id, txn_id, ts)
|
||||
VALUES (:user_id, :device_id, :event_id, :txn_id, :ts)`, rows)
|
||||
if err == nil {
|
||||
result.Close()
|
||||
}
|
||||
@ -57,10 +60,10 @@ func (t *TransactionsTable) Clean(boundaryTime time.Time) error {
|
||||
return err
|
||||
}
|
||||
|
||||
func (t *TransactionsTable) Select(deviceID string, eventIDs []string) (map[string]string, error) {
|
||||
func (t *TransactionsTable) Select(userID, deviceID string, eventIDs []string) (map[string]string, error) {
|
||||
result := make(map[string]string, len(eventIDs))
|
||||
var rows []txnRow
|
||||
err := t.db.Select(&rows, `SELECT event_id, txn_id FROM syncv3_txns WHERE user_id=$1 and event_id=ANY($2)`, deviceID, pq.StringArray(eventIDs))
|
||||
err := t.db.Select(&rows, `SELECT event_id, txn_id FROM syncv3_txns WHERE user_id=$1 AND device_id=$2 and event_id=ANY($3)`, userID, deviceID, pq.StringArray(eventIDs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -26,33 +26,34 @@ func TestTransactionTable(t *testing.T) {
|
||||
db, close := connectToDB(t)
|
||||
defer close()
|
||||
userID := "@alice:txns"
|
||||
deviceID := "alice_phone"
|
||||
eventA := "$A"
|
||||
eventB := "$B"
|
||||
txnIDA := "txn_A"
|
||||
txnIDB := "txn_B"
|
||||
table := NewTransactionsTable(db)
|
||||
// empty table select
|
||||
gotTxns, err := table.Select(userID, []string{eventA})
|
||||
gotTxns, err := table.Select(userID, deviceID, []string{eventA})
|
||||
assertNoError(t, err)
|
||||
assertTxns(t, gotTxns, nil)
|
||||
|
||||
// basic insert and select
|
||||
err = table.Insert(userID, map[string]string{
|
||||
err = table.Insert(userID, deviceID, map[string]string{
|
||||
eventA: txnIDA,
|
||||
})
|
||||
assertNoError(t, err)
|
||||
gotTxns, err = table.Select(userID, []string{eventA})
|
||||
gotTxns, err = table.Select(userID, deviceID, []string{eventA})
|
||||
assertNoError(t, err)
|
||||
assertTxns(t, gotTxns, map[string]string{
|
||||
eventA: txnIDA,
|
||||
})
|
||||
|
||||
// multiple txns
|
||||
err = table.Insert(userID, map[string]string{
|
||||
err = table.Insert(userID, deviceID, map[string]string{
|
||||
eventB: txnIDB,
|
||||
})
|
||||
assertNoError(t, err)
|
||||
gotTxns, err = table.Select(userID, []string{eventA, eventB})
|
||||
gotTxns, err = table.Select(userID, deviceID, []string{eventA, eventB})
|
||||
assertNoError(t, err)
|
||||
assertTxns(t, gotTxns, map[string]string{
|
||||
eventA: txnIDA,
|
||||
@ -60,14 +61,14 @@ func TestTransactionTable(t *testing.T) {
|
||||
})
|
||||
|
||||
// different user select
|
||||
gotTxns, err = table.Select("@another", []string{eventA, eventB})
|
||||
gotTxns, err = table.Select("@another", "another_device", []string{eventA, eventB})
|
||||
assertNoError(t, err)
|
||||
assertTxns(t, gotTxns, nil)
|
||||
|
||||
// no-op cleanup
|
||||
err = table.Clean(time.Now().Add(-1 * time.Minute))
|
||||
assertNoError(t, err)
|
||||
gotTxns, err = table.Select(userID, []string{eventA, eventB})
|
||||
gotTxns, err = table.Select(userID, deviceID, []string{eventA, eventB})
|
||||
assertNoError(t, err)
|
||||
assertTxns(t, gotTxns, map[string]string{
|
||||
eventA: txnIDA,
|
||||
@ -77,7 +78,7 @@ func TestTransactionTable(t *testing.T) {
|
||||
// real cleanup
|
||||
err = table.Clean(time.Now())
|
||||
assertNoError(t, err)
|
||||
gotTxns, err = table.Select(userID, []string{eventA, eventB})
|
||||
gotTxns, err = table.Select(userID, deviceID, []string{eventA, eventB})
|
||||
assertNoError(t, err)
|
||||
assertTxns(t, gotTxns, nil)
|
||||
|
||||
|
@ -206,7 +206,7 @@ func (h *Handler) OnE2EEData(userID, deviceID string, otkCounts map[string]int,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *Handler) Accumulate(deviceID, roomID, prevBatch string, timeline []json.RawMessage) {
|
||||
func (h *Handler) Accumulate(userID, deviceID, roomID, prevBatch string, timeline []json.RawMessage) {
|
||||
// Remember any transaction IDs that may be unique to this user
|
||||
eventIDToTxnID := make(map[string]string, len(timeline)) // event_id -> txn_id
|
||||
for _, e := range timeline {
|
||||
@ -219,9 +219,9 @@ func (h *Handler) Accumulate(deviceID, roomID, prevBatch string, timeline []json
|
||||
}
|
||||
if len(eventIDToTxnID) > 0 {
|
||||
// persist the txn IDs
|
||||
err := h.Store.TransactionsTable.Insert(deviceID, eventIDToTxnID)
|
||||
err := h.Store.TransactionsTable.Insert(userID, deviceID, eventIDToTxnID)
|
||||
if err != nil {
|
||||
logger.Err(err).Str("device", deviceID).Int("num_txns", len(eventIDToTxnID)).Msg("failed to persist txn IDs for user")
|
||||
logger.Err(err).Str("user", userID).Str("device", deviceID).Int("num_txns", len(eventIDToTxnID)).Msg("failed to persist txn IDs for user")
|
||||
sentry.CaptureException(err)
|
||||
}
|
||||
}
|
||||
|
@ -27,7 +27,7 @@ type V2DataReceiver interface {
|
||||
// Update the since token for this device. Called AFTER all other data in this sync response has been processed.
|
||||
UpdateDeviceSince(userID, deviceID, since string)
|
||||
// Accumulate data for this room. This means the timeline section of the v2 response.
|
||||
Accumulate(deviceID, roomID, prevBatch string, timeline []json.RawMessage) // latest pos with event nids of timeline entries
|
||||
Accumulate(userID, deviceID, roomID, prevBatch string, timeline []json.RawMessage) // latest pos with event nids of timeline entries
|
||||
// Initialise the room, if it hasn't been already. This means the state section of the v2 response.
|
||||
// If given a state delta from an incremental sync, returns the slice of all state events unknown to the DB.
|
||||
Initialise(roomID string, state []json.RawMessage) []json.RawMessage // snapshot ID?
|
||||
@ -201,7 +201,7 @@ func (h *PollerMap) EnsurePolling(pid PollerID, accessToken, v2since string, isS
|
||||
if needToWait {
|
||||
poller.WaitUntilInitialSync()
|
||||
} else {
|
||||
logger.Info().Msg("a poller exists for this user; not waiting for this device to do an initial sync")
|
||||
logger.Info().Str("user", poller.userID).Msg("a poller exists for this user; not waiting for this device to do an initial sync")
|
||||
}
|
||||
}
|
||||
|
||||
@ -214,11 +214,11 @@ func (h *PollerMap) execute() {
|
||||
func (h *PollerMap) UpdateDeviceSince(userID, deviceID, since string) {
|
||||
h.callbacks.UpdateDeviceSince(userID, deviceID, since)
|
||||
}
|
||||
func (h *PollerMap) Accumulate(deviceID, roomID, prevBatch string, timeline []json.RawMessage) {
|
||||
func (h *PollerMap) Accumulate(userID, deviceID, roomID, prevBatch string, timeline []json.RawMessage) {
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
h.executor <- func() {
|
||||
h.callbacks.Accumulate(deviceID, roomID, prevBatch, timeline)
|
||||
h.callbacks.Accumulate(userID, deviceID, roomID, prevBatch, timeline)
|
||||
wg.Done()
|
||||
}
|
||||
wg.Wait()
|
||||
@ -564,7 +564,7 @@ func (p *poller) parseRoomsResponse(res *SyncResponse) {
|
||||
if len(roomData.Timeline.Events) > 0 {
|
||||
timelineCalls++
|
||||
p.trackTimelineSize(len(roomData.Timeline.Events), roomData.Timeline.Limited)
|
||||
p.receiver.Accumulate(p.deviceID, roomID, roomData.Timeline.PrevBatch, roomData.Timeline.Events)
|
||||
p.receiver.Accumulate(p.userID, p.deviceID, roomID, roomData.Timeline.PrevBatch, roomData.Timeline.Events)
|
||||
}
|
||||
|
||||
// process unread counts AFTER events so global caches have been updated by the time this metadata is added.
|
||||
@ -580,7 +580,7 @@ func (p *poller) parseRoomsResponse(res *SyncResponse) {
|
||||
// TODO: do we care about state?
|
||||
if len(roomData.Timeline.Events) > 0 {
|
||||
p.trackTimelineSize(len(roomData.Timeline.Events), roomData.Timeline.Limited)
|
||||
p.receiver.Accumulate(p.deviceID, roomID, roomData.Timeline.PrevBatch, roomData.Timeline.Events)
|
||||
p.receiver.Accumulate(p.userID, p.deviceID, roomID, roomData.Timeline.PrevBatch, roomData.Timeline.Events)
|
||||
}
|
||||
p.receiver.OnLeftRoom(p.userID, roomID)
|
||||
}
|
||||
|
@ -460,7 +460,7 @@ type mockDataReceiver struct {
|
||||
unblockProcess chan struct{}
|
||||
}
|
||||
|
||||
func (a *mockDataReceiver) Accumulate(userID, roomID, prevBatch string, timeline []json.RawMessage) {
|
||||
func (a *mockDataReceiver) Accumulate(userID, deviceID, roomID, prevBatch string, timeline []json.RawMessage) {
|
||||
a.timelines[roomID] = append(a.timelines[roomID], timeline...)
|
||||
}
|
||||
func (a *mockDataReceiver) Initialise(roomID string, state []json.RawMessage) []json.RawMessage {
|
||||
|
@ -24,7 +24,7 @@ type CacheFinder interface {
|
||||
}
|
||||
|
||||
type TransactionIDFetcher interface {
|
||||
TransactionIDForEvents(deviceID string, eventIDs []string) (eventIDToTxnID map[string]string)
|
||||
TransactionIDForEvents(userID, deviceID string, eventIDs []string) (eventIDToTxnID map[string]string)
|
||||
}
|
||||
|
||||
type UserRoomData struct {
|
||||
@ -448,7 +448,7 @@ func (c *UserCache) Invites() map[string]UserRoomData {
|
||||
// events are globally scoped, so if Alice sends a message, Bob might receive it first on his v2 loop
|
||||
// which would cause the transaction ID to be missing from the event. Instead, we always look for txn
|
||||
// IDs in the v2 poller, and then set them appropriately at request time.
|
||||
func (c *UserCache) AnnotateWithTransactionIDs(ctx context.Context, deviceID string, roomIDToEvents map[string][]json.RawMessage) map[string][]json.RawMessage {
|
||||
func (c *UserCache) AnnotateWithTransactionIDs(ctx context.Context, userID string, deviceID string, roomIDToEvents map[string][]json.RawMessage) map[string][]json.RawMessage {
|
||||
var eventIDs []string
|
||||
eventIDToEvent := make(map[string]struct {
|
||||
roomID string
|
||||
@ -467,7 +467,7 @@ func (c *UserCache) AnnotateWithTransactionIDs(ctx context.Context, deviceID str
|
||||
}
|
||||
}
|
||||
}
|
||||
eventIDToTxnID := c.txnIDs.TransactionIDForEvents(deviceID, eventIDs)
|
||||
eventIDToTxnID := c.txnIDs.TransactionIDForEvents(userID, deviceID, eventIDs)
|
||||
for eventID, txnID := range eventIDToTxnID {
|
||||
data, ok := eventIDToEvent[eventID]
|
||||
if !ok {
|
||||
|
@ -14,7 +14,7 @@ type txnIDFetcher struct {
|
||||
data map[string]string
|
||||
}
|
||||
|
||||
func (t *txnIDFetcher) TransactionIDForEvents(deviceID string, eventIDs []string) (eventIDToTxnID map[string]string) {
|
||||
func (t *txnIDFetcher) TransactionIDForEvents(userID string, deviceID string, eventIDs []string) (eventIDToTxnID map[string]string) {
|
||||
eventIDToTxnID = make(map[string]string)
|
||||
for _, eventID := range eventIDs {
|
||||
txnID, ok := t.data[eventID]
|
||||
@ -83,7 +83,7 @@ func TestAnnotateWithTransactionIDs(t *testing.T) {
|
||||
data: tc.eventIDToTxnIDs,
|
||||
}
|
||||
uc := caches.NewUserCache(userID, nil, nil, fetcher)
|
||||
got := uc.AnnotateWithTransactionIDs(context.Background(), "DEVICE", convertIDToEventStub(tc.roomIDToEvents))
|
||||
got := uc.AnnotateWithTransactionIDs(context.Background(), userID, "DEVICE", convertIDToEventStub(tc.roomIDToEvents))
|
||||
want := convertIDTxnToEventStub(tc.wantRoomIDToEvents)
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Errorf("%s : got %v want %v", tc.name, js(got), js(want))
|
||||
|
@ -458,7 +458,7 @@ func (s *ConnState) getInitialRoomData(ctx context.Context, roomSub sync3.RoomSu
|
||||
roomToUsersInTimeline[roomID] = userIDs
|
||||
roomToTimeline[roomID] = urd.Timeline
|
||||
}
|
||||
roomToTimeline = s.userCache.AnnotateWithTransactionIDs(ctx, s.deviceID, roomToTimeline)
|
||||
roomToTimeline = s.userCache.AnnotateWithTransactionIDs(ctx, s.userID, s.deviceID, roomToTimeline)
|
||||
rsm := roomSub.RequiredStateMap(s.userID)
|
||||
roomIDToState := s.globalCache.LoadRoomState(ctx, roomIDs, s.loadPosition, rsm, roomToUsersInTimeline)
|
||||
if roomIDToState == nil { // e.g no required_state
|
||||
|
@ -210,7 +210,7 @@ func (s *connStateLive) processLiveUpdate(ctx context.Context, up caches.Update,
|
||||
// - the initial:true room from BuildSubscriptions contains the latest live events in the timeline as it's pulled from the DB
|
||||
// - we then process the live events in turn which adds them again.
|
||||
if !advancedPastEvent {
|
||||
roomIDtoTimeline := s.userCache.AnnotateWithTransactionIDs(ctx, s.deviceID, map[string][]json.RawMessage{
|
||||
roomIDtoTimeline := s.userCache.AnnotateWithTransactionIDs(ctx, s.userID, s.deviceID, map[string][]json.RawMessage{
|
||||
roomEventUpdate.RoomID(): {roomEventUpdate.EventData.Event},
|
||||
})
|
||||
r.Timeline = append(r.Timeline, roomIDtoTimeline[roomEventUpdate.RoomID()]...)
|
||||
|
@ -33,7 +33,7 @@ func (t *NopJoinTracker) IsUserJoined(userID, roomID string) bool {
|
||||
|
||||
type NopTransactionFetcher struct{}
|
||||
|
||||
func (t *NopTransactionFetcher) TransactionIDForEvents(userID string, eventID []string) (eventIDToTxnID map[string]string) {
|
||||
func (t *NopTransactionFetcher) TransactionIDForEvents(userID, deviceID string, eventIDs []string) (eventIDToTxnID map[string]string) {
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -533,8 +533,8 @@ func (h *SyncLiveHandler) DeviceData(ctx context.Context, userID, deviceID strin
|
||||
}
|
||||
|
||||
// Implements TransactionIDFetcher
|
||||
func (h *SyncLiveHandler) TransactionIDForEvents(deviceID string, eventIDs []string) (eventIDToTxnID map[string]string) {
|
||||
eventIDToTxnID, err := h.Storage.TransactionsTable.Select(deviceID, eventIDs)
|
||||
func (h *SyncLiveHandler) TransactionIDForEvents(userID string, deviceID string, eventIDs []string) (eventIDToTxnID map[string]string) {
|
||||
eventIDToTxnID, err := h.Storage.TransactionsTable.Select(userID, deviceID, eventIDs)
|
||||
if err != nil {
|
||||
logger.Warn().Str("err", err.Error()).Str("device", deviceID).Msg("failed to select txn IDs for events")
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user