mirror of
https://github.com/matrix-org/sliding-sync.git
synced 2025-03-10 13:37:11 +00:00
Merge branch 'main' into kegan/poll-retry-loop-bad-create-event
This commit is contained in:
commit
94a4789287
@ -1,7 +1,6 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
@ -42,18 +41,19 @@ const (
|
||||
EnvSecret = "SYNCV3_SECRET"
|
||||
|
||||
// Optional fields
|
||||
EnvBindAddr = "SYNCV3_BINDADDR"
|
||||
EnvTLSCert = "SYNCV3_TLS_CERT"
|
||||
EnvTLSKey = "SYNCV3_TLS_KEY"
|
||||
EnvPPROF = "SYNCV3_PPROF"
|
||||
EnvPrometheus = "SYNCV3_PROM"
|
||||
EnvDebug = "SYNCV3_DEBUG"
|
||||
EnvOTLP = "SYNCV3_OTLP_URL"
|
||||
EnvOTLPUsername = "SYNCV3_OTLP_USERNAME"
|
||||
EnvOTLPPassword = "SYNCV3_OTLP_PASSWORD"
|
||||
EnvSentryDsn = "SYNCV3_SENTRY_DSN"
|
||||
EnvLogLevel = "SYNCV3_LOG_LEVEL"
|
||||
EnvMaxConns = "SYNCV3_MAX_DB_CONN"
|
||||
EnvBindAddr = "SYNCV3_BINDADDR"
|
||||
EnvTLSCert = "SYNCV3_TLS_CERT"
|
||||
EnvTLSKey = "SYNCV3_TLS_KEY"
|
||||
EnvPPROF = "SYNCV3_PPROF"
|
||||
EnvPrometheus = "SYNCV3_PROM"
|
||||
EnvDebug = "SYNCV3_DEBUG"
|
||||
EnvOTLP = "SYNCV3_OTLP_URL"
|
||||
EnvOTLPUsername = "SYNCV3_OTLP_USERNAME"
|
||||
EnvOTLPPassword = "SYNCV3_OTLP_PASSWORD"
|
||||
EnvSentryDsn = "SYNCV3_SENTRY_DSN"
|
||||
EnvLogLevel = "SYNCV3_LOG_LEVEL"
|
||||
EnvMaxConns = "SYNCV3_MAX_DB_CONN"
|
||||
EnvIdleTimeoutSecs = "SYNCV3_DB_IDLE_TIMEOUT_SECS"
|
||||
)
|
||||
|
||||
var helpMsg = fmt.Sprintf(`
|
||||
@ -72,8 +72,9 @@ Environment var
|
||||
%s Default: unset. The Sentry DSN to report events to e.g https://sliding-sync@sentry.example.com/123 - if unset does not send sentry events.
|
||||
%s Default: info. The level of verbosity for messages logged. Available values are trace, debug, info, warn, error and fatal
|
||||
%s Default: unset. Max database connections to use when communicating with postgres. Unset or 0 means no limit.
|
||||
%s Default: 3600. The maximum amount of time a database connection may be idle, in seconds. 0 means no limit.
|
||||
`, EnvServer, EnvDB, EnvSecret, EnvBindAddr, EnvTLSCert, EnvTLSKey, EnvPPROF, EnvPrometheus, EnvOTLP, EnvOTLPUsername, EnvOTLPPassword,
|
||||
EnvSentryDsn, EnvLogLevel, EnvMaxConns)
|
||||
EnvSentryDsn, EnvLogLevel, EnvMaxConns, EnvIdleTimeoutSecs)
|
||||
|
||||
func defaulting(in, dft string) string {
|
||||
if in == "" {
|
||||
@ -83,7 +84,6 @@ func defaulting(in, dft string) string {
|
||||
}
|
||||
|
||||
func main() {
|
||||
ctx := context.Background()
|
||||
fmt.Printf("Sync v3 [%s] (%s)\n", version, GitCommit)
|
||||
sync2.ProxyVersion = version
|
||||
syncv3.Version = fmt.Sprintf("%s (%s)", version, GitCommit)
|
||||
@ -94,21 +94,22 @@ func main() {
|
||||
}
|
||||
|
||||
args := map[string]string{
|
||||
EnvServer: os.Getenv(EnvServer),
|
||||
EnvDB: os.Getenv(EnvDB),
|
||||
EnvSecret: os.Getenv(EnvSecret),
|
||||
EnvBindAddr: defaulting(os.Getenv(EnvBindAddr), "0.0.0.0:8008"),
|
||||
EnvTLSCert: os.Getenv(EnvTLSCert),
|
||||
EnvTLSKey: os.Getenv(EnvTLSKey),
|
||||
EnvPPROF: os.Getenv(EnvPPROF),
|
||||
EnvPrometheus: os.Getenv(EnvPrometheus),
|
||||
EnvDebug: os.Getenv(EnvDebug),
|
||||
EnvOTLP: os.Getenv(EnvOTLP),
|
||||
EnvOTLPUsername: os.Getenv(EnvOTLPUsername),
|
||||
EnvOTLPPassword: os.Getenv(EnvOTLPPassword),
|
||||
EnvSentryDsn: os.Getenv(EnvSentryDsn),
|
||||
EnvLogLevel: os.Getenv(EnvLogLevel),
|
||||
EnvMaxConns: defaulting(os.Getenv(EnvMaxConns), "0"),
|
||||
EnvServer: os.Getenv(EnvServer),
|
||||
EnvDB: os.Getenv(EnvDB),
|
||||
EnvSecret: os.Getenv(EnvSecret),
|
||||
EnvBindAddr: defaulting(os.Getenv(EnvBindAddr), "0.0.0.0:8008"),
|
||||
EnvTLSCert: os.Getenv(EnvTLSCert),
|
||||
EnvTLSKey: os.Getenv(EnvTLSKey),
|
||||
EnvPPROF: os.Getenv(EnvPPROF),
|
||||
EnvPrometheus: os.Getenv(EnvPrometheus),
|
||||
EnvDebug: os.Getenv(EnvDebug),
|
||||
EnvOTLP: os.Getenv(EnvOTLP),
|
||||
EnvOTLPUsername: os.Getenv(EnvOTLPUsername),
|
||||
EnvOTLPPassword: os.Getenv(EnvOTLPPassword),
|
||||
EnvSentryDsn: os.Getenv(EnvSentryDsn),
|
||||
EnvLogLevel: os.Getenv(EnvLogLevel),
|
||||
EnvMaxConns: defaulting(os.Getenv(EnvMaxConns), "0"),
|
||||
EnvIdleTimeoutSecs: defaulting(os.Getenv(EnvIdleTimeoutSecs), "3600"),
|
||||
}
|
||||
requiredEnvVars := []string{EnvServer, EnvDB, EnvSecret, EnvBindAddr}
|
||||
for _, requiredEnvVar := range requiredEnvVars {
|
||||
@ -187,19 +188,18 @@ func main() {
|
||||
}
|
||||
}
|
||||
|
||||
err := sync2.MigrateDeviceIDs(ctx, args[EnvServer], args[EnvDB], args[EnvSecret], true)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
maxConnsInt, err := strconv.Atoi(args[EnvMaxConns])
|
||||
if err != nil {
|
||||
panic("invalid value for " + EnvMaxConns + ": " + args[EnvMaxConns])
|
||||
}
|
||||
idleTimeSecs, err := strconv.Atoi(args[EnvIdleTimeoutSecs])
|
||||
if err != nil {
|
||||
panic("invalid value for " + EnvIdleTimeoutSecs + ": " + args[EnvIdleTimeoutSecs])
|
||||
}
|
||||
h2, h3 := syncv3.Setup(args[EnvServer], args[EnvDB], args[EnvSecret], syncv3.Opts{
|
||||
AddPrometheusMetrics: args[EnvPrometheus] != "",
|
||||
DBMaxConns: maxConnsInt,
|
||||
DBConnMaxIdleTime: time.Hour,
|
||||
DBConnMaxIdleTime: time.Duration(idleTimeSecs) * time.Second,
|
||||
MaxTransactionIDDelay: time.Second,
|
||||
})
|
||||
|
||||
|
@ -1,32 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/matrix-org/sliding-sync/sync2"
|
||||
"os"
|
||||
)
|
||||
|
||||
const (
|
||||
// Required fields
|
||||
EnvServer = "SYNCV3_SERVER"
|
||||
EnvDB = "SYNCV3_DB"
|
||||
EnvSecret = "SYNCV3_SECRET"
|
||||
|
||||
// Migration test only
|
||||
EnvMigrationCommit = "SYNCV3_TEST_MIGRATION_COMMIT"
|
||||
)
|
||||
|
||||
func main() {
|
||||
ctx := context.Background()
|
||||
args := map[string]string{
|
||||
EnvServer: os.Getenv(EnvServer),
|
||||
EnvDB: os.Getenv(EnvDB),
|
||||
EnvSecret: os.Getenv(EnvSecret),
|
||||
EnvMigrationCommit: os.Getenv(EnvMigrationCommit),
|
||||
}
|
||||
|
||||
err := sync2.MigrateDeviceIDs(ctx, args[EnvServer], args[EnvDB], args[EnvSecret], args[EnvMigrationCommit] != "")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
@ -24,6 +24,7 @@ type V2Listener interface {
|
||||
OnReceipt(p *V2Receipt)
|
||||
OnDeviceMessages(p *V2DeviceMessages)
|
||||
OnExpiredToken(p *V2ExpiredToken)
|
||||
OnInvalidateRoom(p *V2InvalidateRoom)
|
||||
}
|
||||
|
||||
type V2Initialise struct {
|
||||
@ -129,6 +130,12 @@ type V2ExpiredToken struct {
|
||||
|
||||
func (*V2ExpiredToken) Type() string { return "V2ExpiredToken" }
|
||||
|
||||
type V2InvalidateRoom struct {
|
||||
RoomID string
|
||||
}
|
||||
|
||||
func (*V2InvalidateRoom) Type() string { return "V2InvalidateRoom" }
|
||||
|
||||
type V2Sub struct {
|
||||
listener Listener
|
||||
receiver V2Listener
|
||||
@ -173,6 +180,8 @@ func (v *V2Sub) onMessage(p Payload) {
|
||||
v.receiver.OnDeviceMessages(pl)
|
||||
case *V2ExpiredToken:
|
||||
v.receiver.OnExpiredToken(pl)
|
||||
case *V2InvalidateRoom:
|
||||
v.receiver.OnInvalidateRoom(pl)
|
||||
default:
|
||||
logger.Warn().Str("type", p.Type()).Msg("V2Sub: unhandled payload type")
|
||||
}
|
||||
|
@ -318,6 +318,19 @@ func (a *Accumulator) Initialise(roomID string, state []json.RawMessage) (Initia
|
||||
return res, err
|
||||
}
|
||||
|
||||
type AccumulateResult struct {
|
||||
// NumNew is the number of events in timeline NIDs that were not previously known
|
||||
// to the proyx.
|
||||
NumNew int
|
||||
// TimelineNIDs is the list of event nids seen in a sync v2 timeline. Some of these
|
||||
// may already be known to the proxy.
|
||||
TimelineNIDs []int64
|
||||
// RequiresReload is set to true when we have accumulated a non-incremental state
|
||||
// change (typically a redaction) that requires consumers to reload the room state
|
||||
// from the latest snapshot.
|
||||
RequiresReload bool
|
||||
}
|
||||
|
||||
// Accumulate internal state from a user's sync response. The timeline order MUST be in the order
|
||||
// received from the server. Returns the number of new events in the timeline, the new timeline event NIDs
|
||||
// or an error.
|
||||
@ -329,7 +342,7 @@ func (a *Accumulator) Initialise(roomID string, state []json.RawMessage) (Initia
|
||||
// to exist in the database, and the sync stream is already linearised for us.
|
||||
// - Else it creates a new room state snapshot if the timeline contains state events (as this now represents the current state)
|
||||
// - It adds entries to the membership log for membership events.
|
||||
func (a *Accumulator) Accumulate(txn *sqlx.Tx, userID, roomID string, prevBatch string, timeline []json.RawMessage) (numNew int, timelineNIDs []int64, err error) {
|
||||
func (a *Accumulator) Accumulate(txn *sqlx.Tx, userID, roomID string, prevBatch string, timeline []json.RawMessage) (AccumulateResult, error) {
|
||||
// The first stage of accumulating events is mostly around validation around what the upstream HS sends us. For accumulation to work correctly
|
||||
// we expect:
|
||||
// - there to be no duplicate events
|
||||
@ -338,10 +351,10 @@ func (a *Accumulator) Accumulate(txn *sqlx.Tx, userID, roomID string, prevBatch
|
||||
dedupedEvents, err := a.filterAndParseTimelineEvents(txn, roomID, timeline, prevBatch)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("filterTimelineEvents: %w", err)
|
||||
return
|
||||
return AccumulateResult{}, err
|
||||
}
|
||||
if len(dedupedEvents) == 0 {
|
||||
return 0, nil, err // nothing to do
|
||||
return AccumulateResult{}, nil // nothing to do
|
||||
}
|
||||
|
||||
// Given a timeline of [E1, E2, S3, E4, S5, S6, E7] (E=message event, S=state event)
|
||||
@ -355,7 +368,7 @@ func (a *Accumulator) Accumulate(txn *sqlx.Tx, userID, roomID string, prevBatch
|
||||
// do NOT assign the new state event in the snapshot so as to represent the state before the event.
|
||||
snapID, err := a.roomsTable.CurrentAfterSnapshotID(txn, roomID)
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
return AccumulateResult{}, err
|
||||
}
|
||||
|
||||
// The only situation where no prior snapshot should exist is if this timeline is
|
||||
@ -392,19 +405,22 @@ func (a *Accumulator) Accumulate(txn *sqlx.Tx, userID, roomID string, prevBatch
|
||||
})
|
||||
// the HS gave us bad data so there's no point retrying
|
||||
// by not returning an error, we are telling the poller it is fine to not retry this request.
|
||||
return 0, nil, nil
|
||||
return AccumulateResult{}, nil
|
||||
}
|
||||
}
|
||||
|
||||
eventIDToNID, err := a.eventsTable.Insert(txn, dedupedEvents, false)
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
return AccumulateResult{}, err
|
||||
}
|
||||
if len(eventIDToNID) == 0 {
|
||||
// nothing to do, we already know about these events
|
||||
return 0, nil, nil
|
||||
return AccumulateResult{}, nil
|
||||
}
|
||||
|
||||
result := AccumulateResult{
|
||||
NumNew: len(eventIDToNID),
|
||||
}
|
||||
numNew = len(eventIDToNID)
|
||||
|
||||
var latestNID int64
|
||||
newEvents := make([]Event, 0, len(eventIDToNID))
|
||||
@ -436,7 +452,7 @@ func (a *Accumulator) Accumulate(txn *sqlx.Tx, userID, roomID string, prevBatch
|
||||
}
|
||||
}
|
||||
newEvents = append(newEvents, ev)
|
||||
timelineNIDs = append(timelineNIDs, ev.NID)
|
||||
result.TimelineNIDs = append(result.TimelineNIDs, ev.NID)
|
||||
}
|
||||
}
|
||||
|
||||
@ -446,7 +462,7 @@ func (a *Accumulator) Accumulate(txn *sqlx.Tx, userID, roomID string, prevBatch
|
||||
if len(redactTheseEventIDs) > 0 {
|
||||
createEventJSON, err := a.eventsTable.SelectCreateEvent(txn, roomID)
|
||||
if err != nil {
|
||||
return 0, nil, fmt.Errorf("SelectCreateEvent: %w", err)
|
||||
return AccumulateResult{}, fmt.Errorf("SelectCreateEvent: %w", err)
|
||||
}
|
||||
roomVersion = gjson.GetBytes(createEventJSON, "content.room_version").Str
|
||||
if roomVersion == "" {
|
||||
@ -457,7 +473,7 @@ func (a *Accumulator) Accumulate(txn *sqlx.Tx, userID, roomID string, prevBatch
|
||||
)
|
||||
}
|
||||
if err = a.eventsTable.Redact(txn, roomVersion, redactTheseEventIDs); err != nil {
|
||||
return 0, nil, err
|
||||
return AccumulateResult{}, err
|
||||
}
|
||||
}
|
||||
|
||||
@ -473,12 +489,12 @@ func (a *Accumulator) Accumulate(txn *sqlx.Tx, userID, roomID string, prevBatch
|
||||
if snapID != 0 {
|
||||
oldStripped, err = a.strippedEventsForSnapshot(txn, snapID)
|
||||
if err != nil {
|
||||
return 0, nil, fmt.Errorf("failed to load stripped state events for snapshot %d: %s", snapID, err)
|
||||
return AccumulateResult{}, fmt.Errorf("failed to load stripped state events for snapshot %d: %s", snapID, err)
|
||||
}
|
||||
}
|
||||
newStripped, replacedNID, err := a.calculateNewSnapshot(oldStripped, ev)
|
||||
if err != nil {
|
||||
return 0, nil, fmt.Errorf("failed to calculateNewSnapshot: %s", err)
|
||||
return AccumulateResult{}, fmt.Errorf("failed to calculateNewSnapshot: %s", err)
|
||||
}
|
||||
replacesNID = replacedNID
|
||||
memNIDs, otherNIDs := newStripped.NIDs()
|
||||
@ -488,7 +504,7 @@ func (a *Accumulator) Accumulate(txn *sqlx.Tx, userID, roomID string, prevBatch
|
||||
OtherEvents: otherNIDs,
|
||||
}
|
||||
if err = a.snapshotTable.Insert(txn, newSnapshot); err != nil {
|
||||
return 0, nil, fmt.Errorf("failed to insert new snapshot: %w", err)
|
||||
return AccumulateResult{}, fmt.Errorf("failed to insert new snapshot: %w", err)
|
||||
}
|
||||
if a.snapshotMemberCountVec != nil {
|
||||
logger.Trace().Str("room_id", roomID).Int("members", len(memNIDs)).Msg("Inserted new snapshot")
|
||||
@ -497,20 +513,40 @@ func (a *Accumulator) Accumulate(txn *sqlx.Tx, userID, roomID string, prevBatch
|
||||
snapID = newSnapshot.SnapshotID
|
||||
}
|
||||
if err := a.eventsTable.UpdateBeforeSnapshotID(txn, ev.NID, beforeSnapID, replacesNID); err != nil {
|
||||
return 0, nil, err
|
||||
return AccumulateResult{}, err
|
||||
}
|
||||
}
|
||||
|
||||
if len(redactTheseEventIDs) > 0 {
|
||||
// We need to emit a cache invalidation if we have redacted some state in the
|
||||
// current snapshot ID. Note that we run this _after_ persisting any new snapshots.
|
||||
redactedEventIDs := make([]string, 0, len(redactTheseEventIDs))
|
||||
for eventID := range redactTheseEventIDs {
|
||||
redactedEventIDs = append(redactedEventIDs, eventID)
|
||||
}
|
||||
var currentStateRedactions int
|
||||
err = txn.Get(¤tStateRedactions, `
|
||||
SELECT COUNT(*)
|
||||
FROM syncv3_events
|
||||
JOIN syncv3_snapshots ON event_nid = ANY (ARRAY_CAT(events, membership_events))
|
||||
WHERE snapshot_id = $1 AND event_id = ANY($2)
|
||||
`, snapID, pq.StringArray(redactedEventIDs))
|
||||
if err != nil {
|
||||
return AccumulateResult{}, err
|
||||
}
|
||||
result.RequiresReload = currentStateRedactions > 0
|
||||
}
|
||||
|
||||
if err = a.spacesTable.HandleSpaceUpdates(txn, newEvents); err != nil {
|
||||
return 0, nil, fmt.Errorf("HandleSpaceUpdates: %s", err)
|
||||
return AccumulateResult{}, fmt.Errorf("HandleSpaceUpdates: %s", err)
|
||||
}
|
||||
|
||||
// the last fetched snapshot ID is the current one
|
||||
info := a.roomInfoDelta(roomID, newEvents)
|
||||
if err = a.roomsTable.Upsert(txn, info, snapID, latestNID); err != nil {
|
||||
return 0, nil, fmt.Errorf("failed to UpdateCurrentSnapshotID to %d: %w", snapID, err)
|
||||
return AccumulateResult{}, fmt.Errorf("failed to UpdateCurrentSnapshotID to %d: %w", snapID, err)
|
||||
}
|
||||
return numNew, timelineNIDs, nil
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// filterAndParseTimelineEvents takes a raw timeline array from sync v2 and applies sanity to it:
|
||||
|
@ -3,6 +3,7 @@ package state
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/matrix-org/sliding-sync/testutils"
|
||||
"reflect"
|
||||
"sort"
|
||||
@ -135,25 +136,24 @@ func TestAccumulatorAccumulate(t *testing.T) {
|
||||
// new state event should be added to the snapshot
|
||||
[]byte(`{"event_id":"I", "type":"m.room.history_visibility", "state_key":"", "content":{"visibility":"public"}}`),
|
||||
}
|
||||
var numNew int
|
||||
var latestNIDs []int64
|
||||
var result AccumulateResult
|
||||
err = sqlutil.WithTransaction(accumulator.db, func(txn *sqlx.Tx) error {
|
||||
numNew, latestNIDs, err = accumulator.Accumulate(txn, userID, roomID, "", newEvents)
|
||||
result, err = accumulator.Accumulate(txn, userID, roomID, "", newEvents)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to Accumulate: %s", err)
|
||||
}
|
||||
if numNew != len(newEvents) {
|
||||
t.Fatalf("got %d new events, want %d", numNew, len(newEvents))
|
||||
if result.NumNew != len(newEvents) {
|
||||
t.Fatalf("got %d new events, want %d", result.NumNew, len(newEvents))
|
||||
}
|
||||
// latest nid shoould match
|
||||
wantLatestNID, err := accumulator.eventsTable.SelectHighestNID()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to check latest NID from Accumulate: %s", err)
|
||||
}
|
||||
if latestNIDs[len(latestNIDs)-1] != wantLatestNID {
|
||||
t.Errorf("Accumulator.Accumulate returned latest nid %d, want %d", latestNIDs[len(latestNIDs)-1], wantLatestNID)
|
||||
if result.TimelineNIDs[len(result.TimelineNIDs)-1] != wantLatestNID {
|
||||
t.Errorf("Accumulator.Accumulate returned latest nid %d, want %d", result.TimelineNIDs[len(result.TimelineNIDs)-1], wantLatestNID)
|
||||
}
|
||||
|
||||
// Begin assertions
|
||||
@ -212,7 +212,7 @@ func TestAccumulatorAccumulate(t *testing.T) {
|
||||
|
||||
// subsequent calls do nothing and are not an error
|
||||
err = sqlutil.WithTransaction(accumulator.db, func(txn *sqlx.Tx) error {
|
||||
_, _, err = accumulator.Accumulate(txn, userID, roomID, "", newEvents)
|
||||
_, err = accumulator.Accumulate(txn, userID, roomID, "", newEvents)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
@ -220,6 +220,80 @@ func TestAccumulatorAccumulate(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccumulatorPromptsCacheInvalidation(t *testing.T) {
|
||||
db, close := connectToDB(t)
|
||||
defer close()
|
||||
accumulator := NewAccumulator(db)
|
||||
|
||||
t.Log("Initialise the room state, including a room name.")
|
||||
roomID := fmt.Sprintf("!%s:localhost", t.Name())
|
||||
stateBlock := []json.RawMessage{
|
||||
[]byte(`{"event_id":"$a", "type":"m.room.create", "state_key":"", "content":{"creator":"@me:localhost", "room_version": "10"}}`),
|
||||
[]byte(`{"event_id":"$b", "type":"m.room.member", "state_key":"@me:localhost", "content":{"membership":"join"}}`),
|
||||
[]byte(`{"event_id":"$c", "type":"m.room.join_rules", "state_key":"", "content":{"join_rule":"public"}}`),
|
||||
[]byte(`{"event_id":"$d", "type":"m.room.name", "state_key":"", "content":{"name":"Barry Cryer Appreciation Society"}}`),
|
||||
}
|
||||
_, err := accumulator.Initialise(roomID, stateBlock)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to Initialise accumulator: %s", err)
|
||||
}
|
||||
|
||||
t.Log("Accumulate a second room name, a message, then a third room name.")
|
||||
timeline := []json.RawMessage{
|
||||
[]byte(`{"event_id":"$e", "type":"m.room.name", "state_key":"", "content":{"name":"Jeremy Hardy Appreciation Society"}}`),
|
||||
[]byte(`{"event_id":"$f", "type":"m.room.message", "content": {"body":"Hello, world!", "msgtype":"m.text"}}`),
|
||||
[]byte(`{"event_id":"$g", "type":"m.room.name", "state_key":"", "content":{"name":"Humphrey Lyttelton Appreciation Society"}}`),
|
||||
}
|
||||
var accResult AccumulateResult
|
||||
err = sqlutil.WithTransaction(db, func(txn *sqlx.Tx) error {
|
||||
accResult, err = accumulator.Accumulate(txn, "@dummy:localhost", roomID, "prevBatch", timeline)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to Accumulate: %s", err)
|
||||
}
|
||||
|
||||
t.Log("We expect 3 new events and no reload required.")
|
||||
assertValue(t, "accResult.NumNew", accResult.NumNew, 3)
|
||||
assertValue(t, "len(accResult.TimelineNIDs)", len(accResult.TimelineNIDs), 3)
|
||||
assertValue(t, "accResult.RequiresReload", accResult.RequiresReload, false)
|
||||
|
||||
t.Log("Redact the old state event and the message.")
|
||||
timeline = []json.RawMessage{
|
||||
[]byte(`{"event_id":"$h", "type":"m.room.redaction", "content":{"redacts":"$e"}}`),
|
||||
[]byte(`{"event_id":"$i", "type":"m.room.redaction", "content":{"redacts":"$f"}}`),
|
||||
}
|
||||
err = sqlutil.WithTransaction(db, func(txn *sqlx.Tx) error {
|
||||
accResult, err = accumulator.Accumulate(txn, "@dummy:localhost", roomID, "prevBatch2", timeline)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to Accumulate: %s", err)
|
||||
}
|
||||
|
||||
t.Log("We expect 2 new events and no reload required.")
|
||||
assertValue(t, "accResult.NumNew", accResult.NumNew, 2)
|
||||
assertValue(t, "len(accResult.TimelineNIDs)", len(accResult.TimelineNIDs), 2)
|
||||
assertValue(t, "accResult.RequiresReload", accResult.RequiresReload, false)
|
||||
|
||||
t.Log("Redact the latest state event.")
|
||||
timeline = []json.RawMessage{
|
||||
[]byte(`{"event_id":"$j", "type":"m.room.redaction", "content":{"redacts":"$g"}}`),
|
||||
}
|
||||
err = sqlutil.WithTransaction(db, func(txn *sqlx.Tx) error {
|
||||
accResult, err = accumulator.Accumulate(txn, "@dummy:localhost", roomID, "prevBatch3", timeline)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to Accumulate: %s", err)
|
||||
}
|
||||
|
||||
t.Log("We expect 1 new event and a reload required.")
|
||||
assertValue(t, "accResult.NumNew", accResult.NumNew, 1)
|
||||
assertValue(t, "len(accResult.TimelineNIDs)", len(accResult.TimelineNIDs), 1)
|
||||
assertValue(t, "accResult.RequiresReload", accResult.RequiresReload, true)
|
||||
}
|
||||
|
||||
func TestAccumulatorMembershipLogs(t *testing.T) {
|
||||
roomID := "!TestAccumulatorMembershipLogs:localhost"
|
||||
db, close := connectToDB(t)
|
||||
@ -248,7 +322,7 @@ func TestAccumulatorMembershipLogs(t *testing.T) {
|
||||
[]byte(`{"event_id":"` + roomEventIDs[7] + `", "type":"m.room.member", "state_key":"@me:localhost","unsigned":{"prev_content":{"membership":"join", "displayname":"Me"}}, "content":{"membership":"leave"}}`),
|
||||
}
|
||||
err = sqlutil.WithTransaction(accumulator.db, func(txn *sqlx.Tx) error {
|
||||
_, _, err = accumulator.Accumulate(txn, userID, roomID, "", roomEvents)
|
||||
_, err = accumulator.Accumulate(txn, userID, roomID, "", roomEvents)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
@ -384,7 +458,7 @@ func TestAccumulatorDupeEvents(t *testing.T) {
|
||||
}
|
||||
|
||||
err = sqlutil.WithTransaction(accumulator.db, func(txn *sqlx.Tx) error {
|
||||
_, _, err = accumulator.Accumulate(txn, userID, roomID, "", joinRoom.Timeline.Events)
|
||||
_, err = accumulator.Accumulate(txn, userID, roomID, "", joinRoom.Timeline.Events)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
@ -584,8 +658,8 @@ func TestAccumulatorConcurrency(t *testing.T) {
|
||||
defer wg.Done()
|
||||
subset := newEvents[:(i + 1)] // i=0 => [1], i=1 => [1,2], etc
|
||||
err := sqlutil.WithTransaction(accumulator.db, func(txn *sqlx.Tx) error {
|
||||
numNew, _, err := accumulator.Accumulate(txn, userID, roomID, "", subset)
|
||||
totalNumNew += numNew
|
||||
result, err := accumulator.Accumulate(txn, userID, roomID, "", subset)
|
||||
totalNumNew += result.NumNew
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
|
@ -306,6 +306,77 @@ func (s *Storage) MetadataForAllRooms(txn *sqlx.Tx, tempTableName string, result
|
||||
return nil
|
||||
}
|
||||
|
||||
// ResetMetadataState updates the given metadata in-place to reflect the current state
|
||||
// of the room. This is only safe to call from the subscriber goroutine; it is not safe
|
||||
// to call from the connection goroutines.
|
||||
// TODO: could have this create a new RoomMetadata and get the caller to assign it.
|
||||
func (s *Storage) ResetMetadataState(metadata *internal.RoomMetadata) error {
|
||||
var events []Event
|
||||
err := s.DB.Select(&events, `
|
||||
WITH snapshot(events, membership_events) AS (
|
||||
SELECT events, membership_events
|
||||
FROM syncv3_snapshots
|
||||
JOIN syncv3_rooms ON snapshot_id = current_snapshot_id
|
||||
WHERE syncv3_rooms.room_id = $1
|
||||
)
|
||||
SELECT event_id, event_type, state_key, event, membership
|
||||
FROM syncv3_events JOIN snapshot ON (
|
||||
event_nid = ANY (ARRAY_CAT(events, membership_events))
|
||||
)
|
||||
WHERE (event_type IN ('m.room.name', 'm.room.avatar', 'm.room.canonical_alias') AND state_key = '')
|
||||
OR (event_type = 'm.room.member' AND membership IN ('join', '_join', 'invite', '_invite'))
|
||||
ORDER BY event_nid ASC
|
||||
;`, metadata.RoomID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("ResetMetadataState[%s]: %w", metadata.RoomID, err)
|
||||
}
|
||||
|
||||
heroMemberships := circularSlice[*Event]{max: 6}
|
||||
metadata.JoinCount = 0
|
||||
metadata.InviteCount = 0
|
||||
metadata.ChildSpaceRooms = make(map[string]struct{})
|
||||
|
||||
for i, ev := range events {
|
||||
switch ev.Type {
|
||||
case "m.room.name":
|
||||
metadata.NameEvent = gjson.GetBytes(ev.JSON, "content.name").Str
|
||||
case "m.room.avatar":
|
||||
metadata.AvatarEvent = gjson.GetBytes(ev.JSON, "content.avatar_url").Str
|
||||
case "m.room.canonical_alias":
|
||||
metadata.CanonicalAlias = gjson.GetBytes(ev.JSON, "content.alias").Str
|
||||
case "m.room.member":
|
||||
heroMemberships.append(&events[i])
|
||||
switch ev.Membership {
|
||||
case "join":
|
||||
fallthrough
|
||||
case "_join":
|
||||
metadata.JoinCount++
|
||||
case "invite":
|
||||
fallthrough
|
||||
case "_invite":
|
||||
metadata.InviteCount++
|
||||
}
|
||||
case "m.space.child":
|
||||
metadata.ChildSpaceRooms[ev.StateKey] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
metadata.Heroes = make([]internal.Hero, 0, len(heroMemberships.vals))
|
||||
for _, ev := range heroMemberships.vals {
|
||||
parsed := gjson.ParseBytes(ev.JSON)
|
||||
hero := internal.Hero{
|
||||
ID: ev.StateKey,
|
||||
Name: parsed.Get("content.displayname").Str,
|
||||
Avatar: parsed.Get("content.avatar_url").Str,
|
||||
}
|
||||
metadata.Heroes = append(metadata.Heroes, hero)
|
||||
}
|
||||
|
||||
// For now, don't bother reloading Encrypted, PredecessorID and UpgradedRoomID.
|
||||
// These shouldn't be changing during a room's lifetime in normal operation.
|
||||
return nil
|
||||
}
|
||||
|
||||
// Returns all current NOT MEMBERSHIP state events matching the event types given in all rooms. Returns a map of
|
||||
// room ID to events in that room.
|
||||
func (s *Storage) currentNotMembershipStateEventsInAllRooms(txn *sqlx.Tx, eventTypes []string) (map[string][]Event, error) {
|
||||
@ -336,15 +407,15 @@ func (s *Storage) currentNotMembershipStateEventsInAllRooms(txn *sqlx.Tx, eventT
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *Storage) Accumulate(userID, roomID, prevBatch string, timeline []json.RawMessage) (numNew int, timelineNIDs []int64, err error) {
|
||||
func (s *Storage) Accumulate(userID, roomID, prevBatch string, timeline []json.RawMessage) (result AccumulateResult, err error) {
|
||||
if len(timeline) == 0 {
|
||||
return 0, nil, nil
|
||||
return AccumulateResult{}, nil
|
||||
}
|
||||
err = sqlutil.WithTransaction(s.Accumulator.db, func(txn *sqlx.Tx) error {
|
||||
numNew, timelineNIDs, err = s.Accumulator.Accumulate(txn, userID, roomID, prevBatch, timeline)
|
||||
result, err = s.Accumulator.Accumulate(txn, userID, roomID, prevBatch, timeline)
|
||||
return err
|
||||
})
|
||||
return
|
||||
return result, err
|
||||
}
|
||||
|
||||
func (s *Storage) Initialise(roomID string, state []json.RawMessage) (InitialiseResult, error) {
|
||||
@ -818,7 +889,7 @@ func (s *Storage) AllJoinedMembers(txn *sqlx.Tx, tempTableName string) (joinedMe
|
||||
defer rows.Close()
|
||||
joinedMembers = make(map[string][]string)
|
||||
inviteCounts := make(map[string]int)
|
||||
heroNIDs := make(map[string]*circularSlice)
|
||||
heroNIDs := make(map[string]*circularSlice[int64])
|
||||
var stateKey string
|
||||
var membership string
|
||||
var roomID string
|
||||
@ -829,7 +900,7 @@ func (s *Storage) AllJoinedMembers(txn *sqlx.Tx, tempTableName string) (joinedMe
|
||||
}
|
||||
heroes := heroNIDs[roomID]
|
||||
if heroes == nil {
|
||||
heroes = &circularSlice{max: 6}
|
||||
heroes = &circularSlice[int64]{max: 6}
|
||||
heroNIDs[roomID] = heroes
|
||||
}
|
||||
switch membership {
|
||||
@ -982,13 +1053,13 @@ func (s *Storage) Teardown() {
|
||||
|
||||
// circularSlice is a slice which can be appended to which will wraparound at `max`.
|
||||
// Mostly useful for lazily calculating heroes. The values returned aren't sorted.
|
||||
type circularSlice struct {
|
||||
type circularSlice[T any] struct {
|
||||
i int
|
||||
vals []int64
|
||||
vals []T
|
||||
max int
|
||||
}
|
||||
|
||||
func (s *circularSlice) append(val int64) {
|
||||
func (s *circularSlice[T]) append(val T) {
|
||||
if len(s.vals) < s.max {
|
||||
// populate up to max
|
||||
s.vals = append(s.vals, val)
|
||||
|
@ -31,11 +31,11 @@ func TestStorageRoomStateBeforeAndAfterEventPosition(t *testing.T) {
|
||||
testutils.NewStateEvent(t, "m.room.join_rules", "", alice, map[string]interface{}{"join_rule": "invite"}),
|
||||
testutils.NewStateEvent(t, "m.room.member", bob, alice, map[string]interface{}{"membership": "invite"}),
|
||||
}
|
||||
_, latestNIDs, err := store.Accumulate(userID, roomID, "", events)
|
||||
accResult, err := store.Accumulate(userID, roomID, "", events)
|
||||
if err != nil {
|
||||
t.Fatalf("Accumulate returned error: %s", err)
|
||||
}
|
||||
latest := latestNIDs[len(latestNIDs)-1]
|
||||
latest := accResult.TimelineNIDs[len(accResult.TimelineNIDs)-1]
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
@ -158,14 +158,13 @@ func TestStorageJoinedRoomsAfterPosition(t *testing.T) {
|
||||
},
|
||||
}
|
||||
var latestPos int64
|
||||
var latestNIDs []int64
|
||||
var err error
|
||||
for roomID, eventMap := range roomIDToEventMap {
|
||||
_, latestNIDs, err = store.Accumulate(userID, roomID, "", eventMap)
|
||||
accResult, err := store.Accumulate(userID, roomID, "", eventMap)
|
||||
if err != nil {
|
||||
t.Fatalf("Accumulate on %s failed: %s", roomID, err)
|
||||
}
|
||||
latestPos = latestNIDs[len(latestNIDs)-1]
|
||||
latestPos = accResult.TimelineNIDs[len(accResult.TimelineNIDs)-1]
|
||||
}
|
||||
aliceJoinTimingsByRoomID, err := store.JoinedRoomsAfterPosition(alice, latestPos)
|
||||
if err != nil {
|
||||
@ -351,11 +350,11 @@ func TestVisibleEventNIDsBetween(t *testing.T) {
|
||||
},
|
||||
}
|
||||
for _, tl := range timelineInjections {
|
||||
numNew, _, err := store.Accumulate(userID, tl.RoomID, "", tl.Events)
|
||||
accResult, err := store.Accumulate(userID, tl.RoomID, "", tl.Events)
|
||||
if err != nil {
|
||||
t.Fatalf("Accumulate on %s failed: %s", tl.RoomID, err)
|
||||
}
|
||||
t.Logf("%s added %d new events", tl.RoomID, numNew)
|
||||
t.Logf("%s added %d new events", tl.RoomID, accResult.NumNew)
|
||||
}
|
||||
latestPos, err := store.LatestEventNID()
|
||||
if err != nil {
|
||||
@ -454,11 +453,11 @@ func TestVisibleEventNIDsBetween(t *testing.T) {
|
||||
t.Fatalf("LatestEventNID: %s", err)
|
||||
}
|
||||
for _, tl := range timelineInjections {
|
||||
numNew, _, err := store.Accumulate(userID, tl.RoomID, "", tl.Events)
|
||||
accResult, err := store.Accumulate(userID, tl.RoomID, "", tl.Events)
|
||||
if err != nil {
|
||||
t.Fatalf("Accumulate on %s failed: %s", tl.RoomID, err)
|
||||
}
|
||||
t.Logf("%s added %d new events", tl.RoomID, numNew)
|
||||
t.Logf("%s added %d new events", tl.RoomID, accResult.NumNew)
|
||||
}
|
||||
latestPos, err = store.LatestEventNID()
|
||||
if err != nil {
|
||||
@ -534,7 +533,7 @@ func TestStorageLatestEventsInRoomsPrevBatch(t *testing.T) {
|
||||
}
|
||||
eventIDs := []string{}
|
||||
for _, timeline := range timelines {
|
||||
_, _, err = store.Accumulate(userID, roomID, timeline.prevBatch, timeline.timeline)
|
||||
_, err := store.Accumulate(userID, roomID, timeline.prevBatch, timeline.timeline)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to accumulate: %s", err)
|
||||
}
|
||||
@ -776,7 +775,7 @@ func TestAllJoinedMembers(t *testing.T) {
|
||||
}, serialise(tc.InitMemberships)...))
|
||||
assertNoError(t, err)
|
||||
|
||||
_, _, err = store.Accumulate(userID, roomID, "foo", serialise(tc.AccumulateMemberships))
|
||||
_, err = store.Accumulate(userID, roomID, "foo", serialise(tc.AccumulateMemberships))
|
||||
assertNoError(t, err)
|
||||
testCases[i].RoomID = roomID // remember this for later
|
||||
}
|
||||
@ -855,7 +854,7 @@ func TestCircularSlice(t *testing.T) {
|
||||
},
|
||||
}
|
||||
for _, tc := range testCases {
|
||||
cs := &circularSlice{
|
||||
cs := &circularSlice[int64]{
|
||||
max: tc.max,
|
||||
}
|
||||
for _, val := range tc.appends {
|
||||
|
@ -294,19 +294,26 @@ func (h *Handler) Accumulate(ctx context.Context, userID, deviceID, roomID, prev
|
||||
}
|
||||
|
||||
// Insert new events
|
||||
numNew, latestNIDs, err := h.Store.Accumulate(userID, roomID, prevBatch, timeline)
|
||||
accResult, err := h.Store.Accumulate(userID, roomID, prevBatch, timeline)
|
||||
if err != nil {
|
||||
logger.Err(err).Int("timeline", len(timeline)).Str("room", roomID).Msg("V2: failed to accumulate room")
|
||||
internal.GetSentryHubFromContextOrDefault(ctx).CaptureException(err)
|
||||
return err
|
||||
}
|
||||
|
||||
// Consumers should reload state before processing new timeline events.
|
||||
if accResult.RequiresReload {
|
||||
h.v2Pub.Notify(pubsub.ChanV2, &pubsub.V2InvalidateRoom{
|
||||
RoomID: roomID,
|
||||
})
|
||||
}
|
||||
|
||||
// We've updated the database. Now tell any pubsub listeners what we learned.
|
||||
if numNew != 0 {
|
||||
if accResult.NumNew != 0 {
|
||||
h.v2Pub.Notify(pubsub.ChanV2, &pubsub.V2Accumulate{
|
||||
RoomID: roomID,
|
||||
PrevBatch: prevBatch,
|
||||
EventNIDs: latestNIDs,
|
||||
EventNIDs: accResult.TimelineNIDs,
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -1,455 +0,0 @@
|
||||
package sync2
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"github.com/getsentry/sentry-go"
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/matrix-org/sliding-sync/sqlutil"
|
||||
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// MigrateDeviceIDs performs a one-off DB migration from the old device ids (hash of
|
||||
// access token) to the new device ids (actual device ids from the homeserver). This is
|
||||
// not backwards compatible. If the migration has already taken place, this function is
|
||||
// a no-op.
|
||||
//
|
||||
// This code will be removed in a future version of the proxy.
|
||||
func MigrateDeviceIDs(ctx context.Context, destHomeserver, postgresURI, secret string, commit bool) error {
|
||||
whoamiClient := &HTTPClient{
|
||||
Client: &http.Client{
|
||||
Timeout: 5 * time.Minute,
|
||||
Transport: otelhttp.NewTransport(http.DefaultTransport),
|
||||
},
|
||||
DestinationServer: destHomeserver,
|
||||
}
|
||||
db, err := sqlx.Open("postgres", postgresURI)
|
||||
if err != nil {
|
||||
sentry.CaptureException(err)
|
||||
logger.Panic().Err(err).Str("uri", postgresURI).Msg("failed to open SQL DB")
|
||||
}
|
||||
|
||||
// Ensure the new table exists.
|
||||
NewTokensTable(db, secret)
|
||||
|
||||
return sqlutil.WithTransaction(db, func(txn *sqlx.Tx) (err error) {
|
||||
migrated, err := isMigrated(txn)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if migrated {
|
||||
logger.Debug().Msg("MigrateDeviceIDs: migration has already taken place")
|
||||
return nil
|
||||
}
|
||||
logger.Info().Msgf("MigrateDeviceIDs: starting (commit=%t)", commit)
|
||||
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
elapsed := time.Since(start)
|
||||
logger.Debug().Msgf("MigrateDeviceIDs: took %s", elapsed)
|
||||
}()
|
||||
err = alterTables(txn)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
err = runMigration(ctx, txn, secret, whoamiClient)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = finish(txn)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
s := NewStore(postgresURI, secret)
|
||||
tokens, err := s.TokensTable.TokenForEachDevice(txn)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
logger.Debug().Msgf("Got %d tokens after migration", len(tokens))
|
||||
|
||||
if !commit {
|
||||
err = fmt.Errorf("MigrateDeviceIDs: migration succeeded without errors, but commit is false - rolling back anyway")
|
||||
} else {
|
||||
logger.Info().Msg("MigrateDeviceIDs: migration succeeded - committing")
|
||||
}
|
||||
return
|
||||
})
|
||||
}
|
||||
|
||||
func isMigrated(txn *sqlx.Tx) (bool, error) {
|
||||
// Keep this dead simple for now. This is a one-off migration, before version 1.0.
|
||||
// In the future we'll rip this out and tell people that it's their job to ensure
|
||||
// this migration has run before they upgrade beyond the rip-out point.
|
||||
|
||||
// We're going to detect if the migration has run by testing for the existence of
|
||||
// a column added by the migration. First, check that the table exists.
|
||||
var tableExists bool
|
||||
err := txn.QueryRow(`
|
||||
SELECT EXISTS(
|
||||
SELECT 1 FROM information_schema.columns
|
||||
WHERE table_name = 'syncv3_txns'
|
||||
);
|
||||
`).Scan(&tableExists)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("isMigrated: %s", err)
|
||||
}
|
||||
if !tableExists {
|
||||
// The proxy has never been run before and its tables have never been created.
|
||||
// We do not need to run the migration.
|
||||
logger.Debug().Msg("isMigrated: no syncv3_txns table, no migration needed")
|
||||
return true, nil
|
||||
}
|
||||
|
||||
var migrated bool
|
||||
err = txn.QueryRow(`
|
||||
SELECT EXISTS(
|
||||
SELECT 1 FROM information_schema.columns
|
||||
WHERE table_name = 'syncv3_txns' AND column_name = 'device_id'
|
||||
);
|
||||
`).Scan(&migrated)
|
||||
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("isMigrated: %s", err)
|
||||
}
|
||||
return migrated, nil
|
||||
}
|
||||
|
||||
func alterTables(txn *sqlx.Tx) (err error) {
|
||||
_, err = txn.Exec(`
|
||||
ALTER TABLE syncv3_sync2_devices
|
||||
DROP CONSTRAINT syncv3_sync2_devices_pkey;
|
||||
`)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
_, err = txn.Exec(`
|
||||
ALTER TABLE syncv3_to_device_messages
|
||||
ADD COLUMN user_id TEXT;
|
||||
`)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
_, err = txn.Exec(`
|
||||
ALTER TABLE syncv3_to_device_ack_pos
|
||||
DROP CONSTRAINT syncv3_to_device_ack_pos_pkey,
|
||||
ADD COLUMN user_id TEXT;
|
||||
`)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
_, err = txn.Exec(`
|
||||
ALTER TABLE syncv3_txns
|
||||
DROP CONSTRAINT syncv3_txns_user_id_event_id_key,
|
||||
ADD COLUMN device_id TEXT;
|
||||
`)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
type oldDevice struct {
|
||||
AccessToken string // not a DB row, but it's convenient to write to here
|
||||
AccessTokenHash string `db:"device_id"`
|
||||
UserID string `db:"user_id"`
|
||||
AccessTokenEncrypted string `db:"v2_token_encrypted"`
|
||||
Since string `db:"since"`
|
||||
}
|
||||
|
||||
func runMigration(ctx context.Context, txn *sqlx.Tx, secret string, whoamiClient Client) error {
|
||||
logger.Info().Msg("Loading old-style devices into memory")
|
||||
var devices []oldDevice
|
||||
err := txn.Select(
|
||||
&devices,
|
||||
`SELECT device_id, user_id, v2_token_encrypted, since FROM syncv3_sync2_devices;`,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("runMigration: failed to select devices: %s", err)
|
||||
}
|
||||
|
||||
logger.Info().Msgf("Got %d devices to migrate", len(devices))
|
||||
|
||||
hasher := sha256.New()
|
||||
hasher.Write([]byte(secret))
|
||||
key := hasher.Sum(nil)
|
||||
|
||||
// This migration runs sequentially, one device at a time. We have found this to be
|
||||
// quick enough in practice.
|
||||
numErrors := 0
|
||||
for i, device := range devices {
|
||||
device.AccessToken, err = decrypt(device.AccessTokenEncrypted, key)
|
||||
if err != nil {
|
||||
return fmt.Errorf("runMigration: failed to decrypt device: %s", err)
|
||||
}
|
||||
userID := device.UserID
|
||||
if userID == "" {
|
||||
userID = "<unknown user>"
|
||||
}
|
||||
logger.Info().Msgf(
|
||||
"%4d/%4d migrating device %s %s",
|
||||
i+1, len(devices), userID, device.AccessTokenHash,
|
||||
)
|
||||
err = migrateDevice(ctx, txn, whoamiClient, &device)
|
||||
if err != nil {
|
||||
logger.Err(err).Msgf("runMigration: failed to migrate device %s", device.AccessTokenHash)
|
||||
numErrors++
|
||||
}
|
||||
}
|
||||
if numErrors > 0 {
|
||||
return fmt.Errorf("runMigration: there were %d failures", numErrors)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func migrateDevice(ctx context.Context, txn *sqlx.Tx, whoamiClient Client, device *oldDevice) (err error) {
|
||||
gotUserID, gotDeviceID, err := whoamiClient.WhoAmI(ctx, device.AccessToken)
|
||||
if err == HTTP401 {
|
||||
userID := device.UserID
|
||||
if userID == "" {
|
||||
userID = "<unknown user>"
|
||||
}
|
||||
logger.Warn().Msgf(
|
||||
"migrateDevice: access token for %s %s has expired. Dropping device and metadata.",
|
||||
userID, device.AccessTokenHash,
|
||||
)
|
||||
return cleanupDevice(txn, device)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Sanity check the user ID from the HS matches our records
|
||||
if gotUserID != device.UserID {
|
||||
return fmt.Errorf(
|
||||
"/whoami response was for the wrong user. Queried for %s, but got response for %s",
|
||||
device.UserID, gotUserID,
|
||||
)
|
||||
}
|
||||
|
||||
err = exec(
|
||||
txn,
|
||||
`INSERT INTO syncv3_sync2_tokens(token_hash, token_encrypted, user_id, device_id, last_seen)
|
||||
VALUES ($1, $2, $3, $4, $5)`,
|
||||
expectOneRowAffected,
|
||||
device.AccessTokenHash, device.AccessTokenEncrypted, gotUserID, gotDeviceID, time.Now(),
|
||||
)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// For these first four tables:
|
||||
// - use the actual device ID instead of the access token hash, and
|
||||
// - ensure a user ID is set.
|
||||
err = exec(
|
||||
txn,
|
||||
`UPDATE syncv3_sync2_devices SET user_id = $1, device_id = $2 WHERE device_id = $3`,
|
||||
expectOneRowAffected,
|
||||
gotUserID, gotDeviceID, device.AccessTokenHash,
|
||||
)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
err = exec(
|
||||
txn,
|
||||
`UPDATE syncv3_to_device_messages SET user_id = $1, device_id = $2 WHERE device_id = $3`,
|
||||
expectAnyNumberOfRowsAffected,
|
||||
gotUserID, gotDeviceID, device.AccessTokenHash,
|
||||
)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
err = exec(
|
||||
txn,
|
||||
`UPDATE syncv3_to_device_ack_pos SET user_id = $1, device_id = $2 WHERE device_id = $3`,
|
||||
expectAtMostOneRowAffected,
|
||||
gotUserID, gotDeviceID, device.AccessTokenHash,
|
||||
)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
err = exec(
|
||||
txn,
|
||||
`UPDATE syncv3_device_data SET user_id = $1, device_id = $2 WHERE device_id = $3`,
|
||||
expectAtMostOneRowAffected,
|
||||
gotUserID, gotDeviceID, device.AccessTokenHash,
|
||||
)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Confusingly, the txns table used to store access token hashes under the user_id
|
||||
// column. Write the actual user ID to the user_id column, and the actual device ID
|
||||
// to the device_id column.
|
||||
err = exec(
|
||||
txn,
|
||||
`UPDATE syncv3_txns SET user_id = $1, device_id = $2 WHERE user_id = $3`,
|
||||
expectAnyNumberOfRowsAffected,
|
||||
gotUserID, gotDeviceID, device.AccessTokenHash,
|
||||
)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func cleanupDevice(txn *sqlx.Tx, device *oldDevice) (err error) {
|
||||
// The homeserver does not recognise this access token. Because we have no
|
||||
// record of the device_id from the homeserver, it will never be possible to
|
||||
// spot that a future refreshed access token belongs to the device we're
|
||||
// handling here. Therefore this device is not useful to the proxy.
|
||||
//
|
||||
// If we leave this device's rows in situ, we may end up with rows in
|
||||
// syncv3_to_device_messages, syncv3_to_device_ack_pos and syncv3_txns which have
|
||||
// null values for the new fields, which will mean we fail to impose the uniqueness
|
||||
// constraints at the end of the migration. Instead, drop those rows.
|
||||
err = exec(
|
||||
txn,
|
||||
`DELETE FROM syncv3_sync2_devices WHERE device_id = $1`,
|
||||
expectOneRowAffected,
|
||||
device.AccessTokenHash,
|
||||
)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
err = exec(
|
||||
txn,
|
||||
`DELETE FROM syncv3_to_device_messages WHERE device_id = $1`,
|
||||
expectAnyNumberOfRowsAffected,
|
||||
device.AccessTokenHash,
|
||||
)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
err = exec(
|
||||
txn,
|
||||
`DELETE FROM syncv3_to_device_ack_pos WHERE device_id = $1`,
|
||||
expectAtMostOneRowAffected,
|
||||
device.AccessTokenHash,
|
||||
)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
err = exec(
|
||||
txn,
|
||||
`DELETE FROM syncv3_device_data WHERE device_id = $1`,
|
||||
expectAtMostOneRowAffected,
|
||||
device.AccessTokenHash,
|
||||
)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
err = exec(
|
||||
txn,
|
||||
`DELETE FROM syncv3_txns WHERE user_id = $1`,
|
||||
expectAnyNumberOfRowsAffected,
|
||||
device.AccessTokenHash,
|
||||
)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func exec(txn *sqlx.Tx, query string, checkRowsAffected func(ra int64) bool, args ...any) error {
|
||||
res, err := txn.Exec(query, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ra, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !checkRowsAffected(ra) {
|
||||
return fmt.Errorf("query \"%s\" unexpectedly affected %d rows", query, ra)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func expectOneRowAffected(ra int64) bool { return ra == 1 }
|
||||
func expectAnyNumberOfRowsAffected(ra int64) bool { return true }
|
||||
func logRowsAffected(msg string) func(ra int64) bool {
|
||||
return func(ra int64) bool {
|
||||
logger.Info().Msgf(msg, ra)
|
||||
return true
|
||||
}
|
||||
}
|
||||
func expectAtMostOneRowAffected(ra int64) bool { return ra == 0 || ra == 1 }
|
||||
|
||||
func finish(txn *sqlx.Tx) (err error) {
|
||||
// OnExpiredToken used to delete from the devices and to-device tables, but not from
|
||||
// the to-device ack pos or the txn tables. Fix this up by deleting orphaned rows.
|
||||
err = exec(
|
||||
txn,
|
||||
`
|
||||
DELETE FROM syncv3_to_device_ack_pos
|
||||
WHERE device_id NOT IN (SELECT device_id FROM syncv3_sync2_devices)
|
||||
;`,
|
||||
logRowsAffected("Deleted %d stale rows from syncv3_to_device_ack_pos"),
|
||||
)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
err = exec(
|
||||
txn,
|
||||
`
|
||||
DELETE FROM syncv3_txns WHERE device_id IS NULL;
|
||||
`,
|
||||
logRowsAffected("Deleted %d stale rows from syncv3_txns"),
|
||||
)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
_, err = txn.Exec(`
|
||||
ALTER TABLE syncv3_sync2_devices
|
||||
DROP COLUMN v2_token_encrypted,
|
||||
ADD PRIMARY KEY (user_id, device_id);
|
||||
`)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
_, err = txn.Exec(`
|
||||
ALTER TABLE syncv3_to_device_messages
|
||||
ALTER COLUMN user_id SET NOT NULL;
|
||||
`)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
_, err = txn.Exec(`
|
||||
ALTER TABLE syncv3_to_device_ack_pos
|
||||
ALTER COLUMN user_id SET NOT NULL,
|
||||
ADD PRIMARY KEY (user_id, device_id);
|
||||
`)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
_, err = txn.Exec(`
|
||||
ALTER TABLE syncv3_txns
|
||||
ALTER COLUMN device_id SET NOT NULL,
|
||||
ADD UNIQUE(user_id, device_id, event_id);
|
||||
`)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
@ -1,50 +0,0 @@
|
||||
package sync2
|
||||
|
||||
import (
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/matrix-org/sliding-sync/sqlutil"
|
||||
"github.com/matrix-org/sliding-sync/state"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestConsideredMigratedOnFirstStartup(t *testing.T) {
|
||||
db, close := connectToDB(t)
|
||||
defer close()
|
||||
var migrated bool
|
||||
err := sqlutil.WithTransaction(db, func(txn *sqlx.Tx) (err error) {
|
||||
// Attempt to make this test independent of others by dropping the table whose
|
||||
// columns we probe.
|
||||
_, err = txn.Exec("DROP TABLE IF EXISTS syncv3_txns;")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
migrated, err = isMigrated(txn)
|
||||
return
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Error calling isMigrated: %s", err)
|
||||
}
|
||||
if !migrated {
|
||||
t.Fatalf("Expected a non-existent DB to be considered migrated, but it was not")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewSchemaIsConsideredMigrated(t *testing.T) {
|
||||
NewStore(postgresConnectionString, "my_secret")
|
||||
state.NewStorage(postgresConnectionString)
|
||||
|
||||
db, close := connectToDB(t)
|
||||
defer close()
|
||||
var migrated bool
|
||||
err := sqlutil.WithTransaction(db, func(txn *sqlx.Tx) (err error) {
|
||||
migrated, err = isMigrated(txn)
|
||||
return
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Error calling isMigrated: %s", err)
|
||||
}
|
||||
if !migrated {
|
||||
t.Fatalf("Expected a new DB to be considered migrated, but it was not")
|
||||
}
|
||||
}
|
@ -385,3 +385,20 @@ func (c *GlobalCache) OnNewEvent(
|
||||
}
|
||||
c.roomIDToMetadata[ed.RoomID] = metadata
|
||||
}
|
||||
|
||||
func (c *GlobalCache) OnInvalidateRoom(ctx context.Context, roomID string) {
|
||||
c.roomIDToMetadataMu.Lock()
|
||||
defer c.roomIDToMetadataMu.Unlock()
|
||||
|
||||
metadata, ok := c.roomIDToMetadata[roomID]
|
||||
if !ok {
|
||||
logger.Warn().Str("room_id", roomID).Msg("OnInvalidateRoom: room not in global cache")
|
||||
return
|
||||
}
|
||||
|
||||
err := c.store.ResetMetadataState(metadata)
|
||||
if err != nil {
|
||||
internal.GetSentryHubFromContextOrDefault(ctx).CaptureException(err)
|
||||
logger.Warn().Err(err).Msg("OnInvalidateRoom: failed to reset metadata")
|
||||
}
|
||||
}
|
||||
|
@ -38,16 +38,16 @@ func TestGlobalCacheLoadState(t *testing.T) {
|
||||
testutils.NewStateEvent(t, "m.room.name", "", alice, map[string]interface{}{"name": "The Room Name"}),
|
||||
testutils.NewStateEvent(t, "m.room.name", "", alice, map[string]interface{}{"name": "The Updated Room Name"}),
|
||||
}
|
||||
_, _, err := store.Accumulate(alice, roomID2, "", eventsRoom2)
|
||||
_, err := store.Accumulate(alice, roomID2, "", eventsRoom2)
|
||||
if err != nil {
|
||||
t.Fatalf("Accumulate: %s", err)
|
||||
}
|
||||
|
||||
_, latestNIDs, err := store.Accumulate(alice, roomID, "", events)
|
||||
accResult, err := store.Accumulate(alice, roomID, "", events)
|
||||
if err != nil {
|
||||
t.Fatalf("Accumulate: %s", err)
|
||||
}
|
||||
latest := latestNIDs[len(latestNIDs)-1]
|
||||
latest := accResult.TimelineNIDs[len(accResult.TimelineNIDs)-1]
|
||||
globalCache := caches.NewGlobalCache(store)
|
||||
testCases := []struct {
|
||||
name string
|
||||
|
@ -760,3 +760,10 @@ func (u *UserCache) ShouldIgnore(userID string) bool {
|
||||
_, ignored := u.ignoredUsers[userID]
|
||||
return ignored
|
||||
}
|
||||
|
||||
func (u *UserCache) OnInvalidateRoom(ctx context.Context, roomID string) {
|
||||
// Nothing for now. In UserRoomData the fields dependant on room state are
|
||||
// IsDM, IsInvite, HasLeft, Invite, CanonicalisedName, ResolvedAvatarURL, Spaces.
|
||||
// Not clear to me if we need to reload these or if we will inherit any changes from
|
||||
// the global cache.
|
||||
}
|
||||
|
@ -24,6 +24,7 @@ type Receiver interface {
|
||||
OnNewEvent(ctx context.Context, event *caches.EventData)
|
||||
OnReceipt(ctx context.Context, receipt internal.Receipt)
|
||||
OnEphemeralEvent(ctx context.Context, roomID string, ephEvent json.RawMessage)
|
||||
OnInvalidateRoom(ctx context.Context, roomID string)
|
||||
// OnRegistered is called after a successful call to Dispatcher.Register
|
||||
OnRegistered(ctx context.Context) error
|
||||
}
|
||||
@ -285,3 +286,23 @@ func (d *Dispatcher) notifyListeners(ctx context.Context, ed *caches.EventData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (d *Dispatcher) OnInvalidateRoom(ctx context.Context, roomID string) {
|
||||
// First dispatch to the global cache.
|
||||
receiver, ok := d.userToReceiver[DispatcherAllUsers]
|
||||
if !ok {
|
||||
logger.Error().Msgf("No receiver for global cache")
|
||||
}
|
||||
receiver.OnInvalidateRoom(ctx, roomID)
|
||||
|
||||
// Then dispatch to any users who are joined to that room.
|
||||
joinedUsers, _ := d.jrt.JoinedUsersForRoom(roomID, nil)
|
||||
d.userToReceiverMu.RLock()
|
||||
defer d.userToReceiverMu.RUnlock()
|
||||
for _, userID := range joinedUsers {
|
||||
receiver = d.userToReceiver[userID]
|
||||
if receiver != nil {
|
||||
receiver.OnInvalidateRoom(ctx, roomID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -16,8 +16,8 @@ import (
|
||||
type pendingInfo struct {
|
||||
// done is set to true when the EnsurePolling request received a response.
|
||||
done bool
|
||||
// success is true when done is true and EnsurePolling succeeded; otherwise false.
|
||||
success bool
|
||||
// expired is true when the token is expired. Any 'done'ness should be ignored for expired tokens.
|
||||
expired bool
|
||||
// ch is a dummy channel which never receives any data. A call to
|
||||
// EnsurePoller.OnInitialSyncComplete will close the channel (unblocking any
|
||||
// EnsurePoller.EnsurePolling calls which are waiting on it) and then set the ch
|
||||
@ -58,19 +58,31 @@ func NewEnsurePoller(notifier pubsub.Notifier, enablePrometheus bool) *EnsurePol
|
||||
}
|
||||
|
||||
// EnsurePolling blocks until the V2InitialSyncComplete response is received for this device. It is
|
||||
// the caller's responsibility to call OnInitialSyncComplete when new events arrive. Returns the Success field from the
|
||||
// V2InitialSyncComplete response, which is true iff there is an active poller.
|
||||
// the caller's responsibility to call OnInitialSyncComplete when new events arrive. Returns whether
|
||||
// or not the token is expired
|
||||
func (p *EnsurePoller) EnsurePolling(ctx context.Context, pid sync2.PollerID, tokenHash string) bool {
|
||||
ctx, region := internal.StartSpan(ctx, "EnsurePolling")
|
||||
defer region.End()
|
||||
p.mu.Lock()
|
||||
// do we need to wait?
|
||||
if p.pendingPolls[pid].done {
|
||||
// do we need to wait? Expired devices ALWAYS need a fresh poll
|
||||
expired := p.pendingPolls[pid].expired
|
||||
if !expired && p.pendingPolls[pid].done {
|
||||
internal.Logf(ctx, "EnsurePolling", "user %s device %s already done", pid.UserID, pid.DeviceID)
|
||||
success := p.pendingPolls[pid].success
|
||||
p.mu.Unlock()
|
||||
return success
|
||||
return expired // always false
|
||||
}
|
||||
|
||||
// either we have expired or we haven't expired and are still waiting on the initial sync.
|
||||
// If we have expired, nuke the pid from the map now so we will use the same code path as a fresh device sync
|
||||
if expired {
|
||||
// close any existing channel
|
||||
if p.pendingPolls[pid].ch != nil {
|
||||
close(p.pendingPolls[pid].ch)
|
||||
}
|
||||
delete(p.pendingPolls, pid)
|
||||
// at this point, ch == nil so we will do an initial sync
|
||||
}
|
||||
|
||||
// have we called EnsurePolling for this user/device before?
|
||||
ch := p.pendingPolls[pid].ch
|
||||
if ch != nil {
|
||||
@ -84,9 +96,9 @@ func (p *EnsurePoller) EnsurePolling(ctx context.Context, pid sync2.PollerID, to
|
||||
<-ch
|
||||
r2.End()
|
||||
p.mu.Lock()
|
||||
success := p.pendingPolls[pid].success
|
||||
expired := p.pendingPolls[pid].expired
|
||||
p.mu.Unlock()
|
||||
return success
|
||||
return expired
|
||||
}
|
||||
// Make a channel to wait until we have done an initial sync
|
||||
ch = make(chan struct{})
|
||||
@ -110,9 +122,9 @@ func (p *EnsurePoller) EnsurePolling(ctx context.Context, pid sync2.PollerID, to
|
||||
r2.End()
|
||||
|
||||
p.mu.Lock()
|
||||
success := p.pendingPolls[pid].success
|
||||
expired = p.pendingPolls[pid].expired
|
||||
p.mu.Unlock()
|
||||
return success
|
||||
return expired
|
||||
}
|
||||
|
||||
func (p *EnsurePoller) OnInitialSyncComplete(payload *pubsub.V2InitialSyncComplete) {
|
||||
@ -129,7 +141,7 @@ func (p *EnsurePoller) OnInitialSyncComplete(payload *pubsub.V2InitialSyncComple
|
||||
log.Trace().Msg("OnInitialSyncComplete: we weren't waiting for this")
|
||||
p.pendingPolls[pid] = pendingInfo{
|
||||
done: true,
|
||||
success: payload.Success,
|
||||
expired: !payload.Success,
|
||||
}
|
||||
return
|
||||
}
|
||||
@ -143,7 +155,11 @@ func (p *EnsurePoller) OnInitialSyncComplete(payload *pubsub.V2InitialSyncComple
|
||||
ch := pending.ch
|
||||
pending.done = true
|
||||
pending.ch = nil
|
||||
pending.success = payload.Success
|
||||
// If for whatever reason we get OnExpiredToken prior to OnInitialSyncComplete, don't forget that
|
||||
// we expired the token i.e expiry latches true.
|
||||
if !pending.expired {
|
||||
pending.expired = !payload.Success
|
||||
}
|
||||
p.pendingPolls[pid] = pending
|
||||
p.calculateNumOutstanding() // decrement total
|
||||
log.Trace().Msg("OnInitialSyncComplete: closing channel")
|
||||
@ -159,10 +175,12 @@ func (p *EnsurePoller) OnExpiredToken(payload *pubsub.V2ExpiredToken) {
|
||||
// We weren't tracking the state of this poller, so we have nothing to clean up.
|
||||
return
|
||||
}
|
||||
if pending.ch != nil {
|
||||
close(pending.ch)
|
||||
}
|
||||
delete(p.pendingPolls, pid)
|
||||
pending.expired = true
|
||||
p.pendingPolls[pid] = pending
|
||||
|
||||
// We used to delete the entry from the map at this point to force the next
|
||||
// EnsurePolling call to do a fresh EnsurePolling request, but now we do that
|
||||
// by signalling via the expired flag.
|
||||
}
|
||||
|
||||
func (p *EnsurePoller) Teardown() {
|
||||
|
202
sync3/handler/ensure_polling_test.go
Normal file
202
sync3/handler/ensure_polling_test.go
Normal file
@ -0,0 +1,202 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"reflect"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/sliding-sync/pubsub"
|
||||
"github.com/matrix-org/sliding-sync/sync2"
|
||||
)
|
||||
|
||||
type mockNotifier struct {
|
||||
ch chan pubsub.Payload
|
||||
}
|
||||
|
||||
func (n *mockNotifier) Notify(chanName string, p pubsub.Payload) error {
|
||||
n.ch <- p
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n *mockNotifier) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n *mockNotifier) MustHaveNoSentPayloads(t *testing.T) {
|
||||
t.Helper()
|
||||
if len(n.ch) == 0 {
|
||||
return
|
||||
}
|
||||
t.Fatalf("MustHaveNoSentPayloads: %d in buffer", len(n.ch))
|
||||
}
|
||||
|
||||
func (n *mockNotifier) WaitForNextPayload(t *testing.T, timeout time.Duration) pubsub.Payload {
|
||||
t.Helper()
|
||||
select {
|
||||
case p := <-n.ch:
|
||||
return p
|
||||
case <-time.After(timeout):
|
||||
t.Fatalf("WaitForNextPayload: timed out after %v", timeout)
|
||||
}
|
||||
panic("unreachable")
|
||||
}
|
||||
|
||||
// check that the request/response works and unblocks things correctly
|
||||
func TestEnsurePollerBasicWorks(t *testing.T) {
|
||||
n := &mockNotifier{ch: make(chan pubsub.Payload, 100)}
|
||||
ctx := context.Background()
|
||||
pid := sync2.PollerID{UserID: "@alice:localhost", DeviceID: "DEVICE"}
|
||||
tokHash := "tokenHash"
|
||||
ep := NewEnsurePoller(n, false)
|
||||
|
||||
var expired atomic.Bool
|
||||
finished := make(chan bool) // dummy
|
||||
go func() {
|
||||
exp := ep.EnsurePolling(ctx, pid, tokHash)
|
||||
expired.Store(exp)
|
||||
close(finished)
|
||||
}()
|
||||
|
||||
p := n.WaitForNextPayload(t, time.Second)
|
||||
|
||||
// check it's a V3EnsurePolling payload
|
||||
pp, ok := p.(*pubsub.V3EnsurePolling)
|
||||
if !ok {
|
||||
t.Fatalf("unexpected payload: %+v", p)
|
||||
}
|
||||
assertVal(t, pp.UserID, pid.UserID)
|
||||
assertVal(t, pp.DeviceID, pid.DeviceID)
|
||||
assertVal(t, pp.AccessTokenHash, tokHash)
|
||||
|
||||
// make sure we're still waiting
|
||||
select {
|
||||
case <-finished:
|
||||
t.Fatalf("EnsurePolling unblocked before response was sent")
|
||||
default:
|
||||
}
|
||||
|
||||
// send back the response
|
||||
ep.OnInitialSyncComplete(&pubsub.V2InitialSyncComplete{
|
||||
UserID: pid.UserID,
|
||||
DeviceID: pid.DeviceID,
|
||||
Success: true,
|
||||
})
|
||||
|
||||
select {
|
||||
case <-finished:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatalf("EnsurePolling didn't unblock after response was sent")
|
||||
}
|
||||
|
||||
if expired.Load() {
|
||||
t.Fatalf("response said token was expired when it wasn't")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsurePollerCachesResponses(t *testing.T) {
|
||||
n := &mockNotifier{ch: make(chan pubsub.Payload, 100)}
|
||||
ctx := context.Background()
|
||||
pid := sync2.PollerID{UserID: "@alice:localhost", DeviceID: "DEVICE"}
|
||||
ep := NewEnsurePoller(n, false)
|
||||
|
||||
finished := make(chan bool) // dummy
|
||||
go func() {
|
||||
_ = ep.EnsurePolling(ctx, pid, "tokenHash")
|
||||
close(finished)
|
||||
}()
|
||||
|
||||
n.WaitForNextPayload(t, time.Second) // wait for V3EnsurePolling
|
||||
// send back the response
|
||||
ep.OnInitialSyncComplete(&pubsub.V2InitialSyncComplete{
|
||||
UserID: pid.UserID,
|
||||
DeviceID: pid.DeviceID,
|
||||
Success: true,
|
||||
})
|
||||
|
||||
select {
|
||||
case <-finished:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatalf("EnsurePolling didn't unblock after response was sent")
|
||||
}
|
||||
|
||||
// hitting EnsurePolling again should immediately return
|
||||
exp := ep.EnsurePolling(ctx, pid, "tokenHash")
|
||||
if exp {
|
||||
t.Fatalf("EnsurePolling said token was expired when it wasn't")
|
||||
}
|
||||
n.MustHaveNoSentPayloads(t)
|
||||
}
|
||||
|
||||
// Regression test for when we did cache failures, causing no poller to start for the device
|
||||
func TestEnsurePollerDoesntCacheFailures(t *testing.T) {
|
||||
n := &mockNotifier{ch: make(chan pubsub.Payload, 100)}
|
||||
ctx := context.Background()
|
||||
pid := sync2.PollerID{UserID: "@alice:localhost", DeviceID: "DEVICE"}
|
||||
ep := NewEnsurePoller(n, false)
|
||||
|
||||
finished := make(chan bool) // dummy
|
||||
var expired atomic.Bool
|
||||
go func() {
|
||||
exp := ep.EnsurePolling(ctx, pid, "tokenHash")
|
||||
expired.Store(exp)
|
||||
close(finished)
|
||||
}()
|
||||
|
||||
n.WaitForNextPayload(t, time.Second) // wait for V3EnsurePolling
|
||||
// send back the response, which failed
|
||||
ep.OnInitialSyncComplete(&pubsub.V2InitialSyncComplete{
|
||||
UserID: pid.UserID,
|
||||
DeviceID: pid.DeviceID,
|
||||
Success: false,
|
||||
})
|
||||
|
||||
select {
|
||||
case <-finished:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatalf("EnsurePolling didn't unblock after response was sent")
|
||||
}
|
||||
if !expired.Load() {
|
||||
t.Fatalf("EnsurePolling returned not expired, wanted expired due to Success=false")
|
||||
}
|
||||
|
||||
// hitting EnsurePolling again should do a new request (i.e not cached the failure)
|
||||
var expiredAgain atomic.Bool
|
||||
finished = make(chan bool) // dummy
|
||||
go func() {
|
||||
exp := ep.EnsurePolling(ctx, pid, "tokenHash")
|
||||
expiredAgain.Store(exp)
|
||||
close(finished)
|
||||
}()
|
||||
|
||||
p := n.WaitForNextPayload(t, time.Second) // wait for V3EnsurePolling
|
||||
// check it's a V3EnsurePolling payload
|
||||
pp, ok := p.(*pubsub.V3EnsurePolling)
|
||||
if !ok {
|
||||
t.Fatalf("unexpected payload: %+v", p)
|
||||
}
|
||||
assertVal(t, pp.UserID, pid.UserID)
|
||||
assertVal(t, pp.DeviceID, pid.DeviceID)
|
||||
assertVal(t, pp.AccessTokenHash, "tokenHash")
|
||||
|
||||
// send back the response, which succeeded this time
|
||||
ep.OnInitialSyncComplete(&pubsub.V2InitialSyncComplete{
|
||||
UserID: pid.UserID,
|
||||
DeviceID: pid.DeviceID,
|
||||
Success: true,
|
||||
})
|
||||
|
||||
select {
|
||||
case <-finished:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatalf("EnsurePolling didn't unblock after response was sent")
|
||||
}
|
||||
}
|
||||
|
||||
func assertVal(t *testing.T, got, want interface{}) {
|
||||
t.Helper()
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Fatalf("assertVal: got %v want %v", got, want)
|
||||
}
|
||||
}
|
@ -396,8 +396,8 @@ func (h *SyncLiveHandler) setupConnection(req *http.Request, syncReq *sync3.Requ
|
||||
|
||||
pid := sync2.PollerID{UserID: token.UserID, DeviceID: token.DeviceID}
|
||||
log.Trace().Any("pid", pid).Msg("checking poller exists and is running")
|
||||
success := h.EnsurePoller.EnsurePolling(req.Context(), pid, token.AccessTokenHash)
|
||||
if !success {
|
||||
expiredToken := h.EnsurePoller.EnsurePolling(req.Context(), pid, token.AccessTokenHash)
|
||||
if expiredToken {
|
||||
log.Error().Msg("EnsurePolling failed, returning 401")
|
||||
// Assumption: the only way that EnsurePolling fails is if the access token is invalid.
|
||||
return req, nil, &internal.HandlerError{
|
||||
@ -803,6 +803,13 @@ func (h *SyncLiveHandler) OnExpiredToken(p *pubsub.V2ExpiredToken) {
|
||||
h.ConnMap.CloseConnsForDevice(p.UserID, p.DeviceID)
|
||||
}
|
||||
|
||||
func (h *SyncLiveHandler) OnInvalidateRoom(p *pubsub.V2InvalidateRoom) {
|
||||
ctx, task := internal.StartTask(context.Background(), "OnInvalidateRoom")
|
||||
defer task.End()
|
||||
|
||||
h.Dispatcher.OnInvalidateRoom(ctx, p.RoomID)
|
||||
}
|
||||
|
||||
func parseIntFromQuery(u *url.URL, param string) (result int64, err *internal.HandlerError) {
|
||||
queryPos := u.Query().Get(param)
|
||||
if queryPos != "" {
|
||||
|
@ -132,6 +132,7 @@ type SyncReq struct {
|
||||
type CSAPI struct {
|
||||
UserID string
|
||||
Localpart string
|
||||
Domain string
|
||||
AccessToken string
|
||||
DeviceID string
|
||||
AvatarURL string
|
||||
@ -160,6 +161,16 @@ func (c *CSAPI) UploadContent(t *testing.T, fileBody []byte, fileName string, co
|
||||
return GetJSONFieldStr(t, body, "content_uri")
|
||||
}
|
||||
|
||||
// Use an empty string to remove a custom displayname.
|
||||
func (c *CSAPI) SetDisplayname(t *testing.T, name string) {
|
||||
t.Helper()
|
||||
reqBody := map[string]any{}
|
||||
if name != "" {
|
||||
reqBody["displayname"] = name
|
||||
}
|
||||
c.MustDoFunc(t, "PUT", []string{"_matrix", "client", "v3", "profile", c.UserID, "displayname"}, WithJSONBody(t, reqBody))
|
||||
}
|
||||
|
||||
// Use an empty string to remove your avatar.
|
||||
func (c *CSAPI) SetAvatar(t *testing.T, avatarURL string) {
|
||||
t.Helper()
|
||||
@ -184,9 +195,9 @@ func (c *CSAPI) DownloadContent(t *testing.T, mxcUri string) ([]byte, string) {
|
||||
}
|
||||
|
||||
// CreateRoom creates a room with an optional HTTP request body. Fails the test on error. Returns the room ID.
|
||||
func (c *CSAPI) CreateRoom(t *testing.T, creationContent interface{}) string {
|
||||
func (c *CSAPI) CreateRoom(t *testing.T, reqBody map[string]any) string {
|
||||
t.Helper()
|
||||
res := c.MustDo(t, "POST", []string{"_matrix", "client", "v3", "createRoom"}, creationContent)
|
||||
res := c.MustDo(t, "POST", []string{"_matrix", "client", "v3", "createRoom"}, reqBody)
|
||||
body := ParseJSON(t, res)
|
||||
return GetJSONFieldStr(t, body, "room_id")
|
||||
}
|
||||
|
@ -220,7 +220,9 @@ func registerNamedUser(t *testing.T, localpartPrefix string) *CSAPI {
|
||||
}
|
||||
|
||||
client.UserID, client.AccessToken, client.DeviceID = client.RegisterUser(t, localpart, "password")
|
||||
client.Localpart = strings.Split(client.UserID, ":")[0][1:]
|
||||
parts := strings.Split(client.UserID, ":")
|
||||
client.Localpart = parts[0][1:]
|
||||
client.Domain = strings.Split(client.UserID, ":")[1]
|
||||
return client
|
||||
}
|
||||
|
||||
|
@ -1,7 +1,9 @@
|
||||
package syncv3_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/sliding-sync/sync3"
|
||||
"github.com/matrix-org/sliding-sync/testutils/m"
|
||||
@ -75,3 +77,97 @@ func TestRedactionsAreRedactedWherePossible(t *testing.T) {
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestRedactingRoomStateIsReflectedInNextSync(t *testing.T) {
|
||||
alice := registerNamedUser(t, "alice")
|
||||
bob := registerNamedUser(t, "bob")
|
||||
|
||||
t.Log("Alice creates a room, then sets a room alias and name.")
|
||||
room := alice.CreateRoom(t, map[string]any{
|
||||
"preset": "public_chat",
|
||||
})
|
||||
|
||||
alias := fmt.Sprintf("#%s-%d:%s", t.Name(), time.Now().Unix(), alice.Domain)
|
||||
alice.MustDoFunc(t, "PUT", []string{"_matrix", "client", "v3", "directory", "room", alias},
|
||||
WithJSONBody(t, map[string]any{"room_id": room}),
|
||||
)
|
||||
aliasID := alice.SetState(t, room, "m.room.canonical_alias", "", map[string]any{
|
||||
"alias": alias,
|
||||
})
|
||||
|
||||
const naughty = "naughty room for naughty people"
|
||||
nameID := alice.SetState(t, room, "m.room.name", "", map[string]any{
|
||||
"name": naughty,
|
||||
})
|
||||
|
||||
t.Log("Alice sliding syncs, subscribing to that room explicitly.")
|
||||
res := alice.SlidingSync(t, sync3.Request{
|
||||
RoomSubscriptions: map[string]sync3.RoomSubscription{
|
||||
room: {
|
||||
TimelineLimit: 20,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
t.Log("Alice should see her room appear with its name.")
|
||||
m.MatchResponse(t, res, m.MatchRoomSubscription(room, m.MatchRoomName(naughty)))
|
||||
|
||||
t.Log("Alice redacts the room name.")
|
||||
redactionID := alice.RedactEvent(t, room, nameID)
|
||||
|
||||
t.Log("Alice syncs until she sees her redaction.")
|
||||
res = alice.SlidingSyncUntil(t, res.Pos, sync3.Request{}, m.MatchRoomSubscription(
|
||||
room,
|
||||
MatchRoomTimelineMostRecent(1, []Event{{ID: redactionID}}),
|
||||
))
|
||||
|
||||
t.Log("The room name should have been redacted, falling back to the canonical alias.")
|
||||
m.MatchResponse(t, res, m.MatchRoomSubscription(room, m.MatchRoomName(alias)))
|
||||
|
||||
t.Log("Alice sets a room avatar.")
|
||||
avatarURL := alice.UploadContent(t, smallPNG, "avatar.png", "image/png")
|
||||
avatarID := alice.SetState(t, room, "m.room.avatar", "", map[string]interface{}{
|
||||
"url": avatarURL,
|
||||
})
|
||||
|
||||
t.Log("Alice waits to see the avatar.")
|
||||
res = alice.SlidingSyncUntil(t, res.Pos, sync3.Request{}, m.MatchRoomSubscription(room, m.MatchRoomAvatar(avatarURL)))
|
||||
|
||||
t.Log("Alice redacts the avatar.")
|
||||
redactionID = alice.RedactEvent(t, room, avatarID)
|
||||
|
||||
t.Log("Alice sees the avatar revert to blank.")
|
||||
res = alice.SlidingSyncUntil(t, res.Pos, sync3.Request{}, m.MatchRoomSubscription(room, m.MatchRoomUnsetAvatar()))
|
||||
|
||||
t.Log("Bob joins the room, with a custom displayname.")
|
||||
const bobDisplayName = "bob mortimer"
|
||||
bob.SetDisplayname(t, bobDisplayName)
|
||||
bob.JoinRoom(t, room, nil)
|
||||
|
||||
t.Log("Alice sees Bob join.")
|
||||
res = alice.SlidingSyncUntil(t, res.Pos, sync3.Request{}, m.MatchRoomSubscription(room,
|
||||
MatchRoomTimelineMostRecent(1, []Event{{
|
||||
StateKey: ptr(bob.UserID),
|
||||
Type: "m.room.member",
|
||||
Content: map[string]any{
|
||||
"membership": "join",
|
||||
"displayname": bobDisplayName,
|
||||
},
|
||||
}}),
|
||||
))
|
||||
// Extract Bob's join ID because https://github.com/matrix-org/matrix-spec-proposals/pull/2943 doens't exist grrr
|
||||
timeline := res.Rooms[room].Timeline
|
||||
bobJoinID := gjson.GetBytes(timeline[len(timeline)-1], "event_id").Str
|
||||
|
||||
t.Log("Alice redacts the alias.")
|
||||
redactionID = alice.RedactEvent(t, room, aliasID)
|
||||
|
||||
t.Log("Alice sees the room name reset to Bob's display name.")
|
||||
res = alice.SlidingSyncUntil(t, res.Pos, sync3.Request{}, m.MatchRoomSubscription(room, m.MatchRoomName(bobDisplayName)))
|
||||
|
||||
t.Log("Bob redacts his membership")
|
||||
redactionID = bob.RedactEvent(t, room, bobJoinID)
|
||||
|
||||
t.Log("Alice sees the room name reset to Bob's username.")
|
||||
res = alice.SlidingSyncUntil(t, res.Pos, sync3.Request{}, m.MatchRoomSubscription(room, m.MatchRoomName(bob.UserID)))
|
||||
}
|
||||
|
@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@ -462,3 +463,75 @@ func TestPollerExpiryEnsurePollingRace(t *testing.T) {
|
||||
t.Fatalf("Should have got 401 http response; got %d\n%s", status, resBytes)
|
||||
}
|
||||
}
|
||||
|
||||
// Regression test for the bugfix for https://github.com/matrix-org/sliding-sync/issues/287#issuecomment-1706522718
|
||||
// Specifically, we could cache the failure and never tell the poller about new tokens, wedging the client(!). This
|
||||
// seems to have been due to the following:
|
||||
// - client hits sync for the first time. We /whoami and remember the token->user mapping in TokensTable.
|
||||
// - client syncing + poller syncing, everything happy.
|
||||
// - token expires. OnExpiredToken is sent to EnsurePoller which removes the entry from EnsurePoller and nukes the conns.
|
||||
// - client hits sync, gets 400 M_UNKNOWN_POS due to nuked conns.
|
||||
// - client hits a fresh /sync: for whatever reason, the token is NOT 401d there and then by the /whoami lookup failing.
|
||||
// Maybe failed to remove the token, but don't see any logs to suggest this. Seems to be an OIDC thing.
|
||||
// - EnsurePoller starts a poller, which immediately 401s as the token is expired.
|
||||
// - OnExpiredToken is sent first, which removes the entry in EnsurePoller.
|
||||
// - OnInitialSyncComplete[success=false] is sent after, which MAKES A NEW ENTRY with success=false.
|
||||
// - proxy sends back 401 M_UNKNOWN_TOKEN.
|
||||
// - At this point the proxy is wedged. Any token, no matter how valid they are, will not hit EnsurePoller because
|
||||
// we cached success=false for that (user,device).
|
||||
//
|
||||
// Traceable in the logs which show spam of this log line without "Poller: v2 poll loop started" interleaved.
|
||||
//
|
||||
// 12:45:33 ERR EnsurePolling failed, returning 401 conn=encryption device=xx user=@xx:xx.xx
|
||||
//
|
||||
// To test this failure mode we:
|
||||
// - Create Alice and sync her poller.
|
||||
// - Expire her token immediately, just like the test TestPollerExpiryEnsurePollingRace
|
||||
// - Do another request with a valid new token, this should succeed.
|
||||
func TestPollerExpiryEnsurePollingRaceDoesntWedge(t *testing.T) {
|
||||
newToken := "NEW_ALICE_TOKEN"
|
||||
pqString := testutils.PrepareDBConnectionString()
|
||||
v2 := runTestV2Server(t)
|
||||
defer v2.close()
|
||||
v3 := runTestServer(t, v2, pqString)
|
||||
defer v3.close()
|
||||
|
||||
v2.addAccount(t, alice, aliceToken)
|
||||
|
||||
// Arrange the following:
|
||||
// 1. A request arrives from an unknown token.
|
||||
// 2. The API makes a /whoami lookup for the new token. That returns without error.
|
||||
// 3. The old token expires.
|
||||
// 4. The poller tries to call /sync but finds that the token has expired.
|
||||
// NEW 5. Using a "new token" works.
|
||||
|
||||
var gotNewToken atomic.Bool
|
||||
v2.SetCheckRequest(func(token string, req *http.Request) {
|
||||
if token == newToken {
|
||||
t.Log("recv new token")
|
||||
gotNewToken.Store(true)
|
||||
return
|
||||
}
|
||||
if token != aliceToken {
|
||||
t.Fatalf("unexpected poll from %s", token)
|
||||
}
|
||||
// Expire the token before we process the request.
|
||||
t.Log("Alice's token expires.")
|
||||
v2.invalidateTokenImmediately(token)
|
||||
})
|
||||
|
||||
t.Log("Alice makes a sliding sync request with a token that's about to expire.")
|
||||
_, resBytes, status := v3.doV3Request(t, context.Background(), aliceToken, "", sync3.Request{})
|
||||
if status != http.StatusUnauthorized {
|
||||
t.Fatalf("Should have got 401 http response; got %d\n%s", status, resBytes)
|
||||
}
|
||||
// make a new token and use it
|
||||
v2.addAccount(t, alice, newToken)
|
||||
_, resBytes, status = v3.doV3Request(t, context.Background(), newToken, "", sync3.Request{})
|
||||
if status != http.StatusOK {
|
||||
t.Fatalf("Should have got 200 http response; got %d\n%s", status, resBytes)
|
||||
}
|
||||
if !gotNewToken.Load() {
|
||||
t.Fatalf("never saw a v2 poll with the new token")
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user