Merge branch 'main' into kegan/poll-retry-loop-bad-create-event

This commit is contained in:
kegsay 2023-09-19 13:48:40 +01:00 committed by GitHub
commit 94a4789287
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 766 additions and 653 deletions

View File

@ -1,7 +1,6 @@
package main
import (
"context"
"flag"
"fmt"
"log"
@ -42,18 +41,19 @@ const (
EnvSecret = "SYNCV3_SECRET"
// Optional fields
EnvBindAddr = "SYNCV3_BINDADDR"
EnvTLSCert = "SYNCV3_TLS_CERT"
EnvTLSKey = "SYNCV3_TLS_KEY"
EnvPPROF = "SYNCV3_PPROF"
EnvPrometheus = "SYNCV3_PROM"
EnvDebug = "SYNCV3_DEBUG"
EnvOTLP = "SYNCV3_OTLP_URL"
EnvOTLPUsername = "SYNCV3_OTLP_USERNAME"
EnvOTLPPassword = "SYNCV3_OTLP_PASSWORD"
EnvSentryDsn = "SYNCV3_SENTRY_DSN"
EnvLogLevel = "SYNCV3_LOG_LEVEL"
EnvMaxConns = "SYNCV3_MAX_DB_CONN"
EnvBindAddr = "SYNCV3_BINDADDR"
EnvTLSCert = "SYNCV3_TLS_CERT"
EnvTLSKey = "SYNCV3_TLS_KEY"
EnvPPROF = "SYNCV3_PPROF"
EnvPrometheus = "SYNCV3_PROM"
EnvDebug = "SYNCV3_DEBUG"
EnvOTLP = "SYNCV3_OTLP_URL"
EnvOTLPUsername = "SYNCV3_OTLP_USERNAME"
EnvOTLPPassword = "SYNCV3_OTLP_PASSWORD"
EnvSentryDsn = "SYNCV3_SENTRY_DSN"
EnvLogLevel = "SYNCV3_LOG_LEVEL"
EnvMaxConns = "SYNCV3_MAX_DB_CONN"
EnvIdleTimeoutSecs = "SYNCV3_DB_IDLE_TIMEOUT_SECS"
)
var helpMsg = fmt.Sprintf(`
@ -72,8 +72,9 @@ Environment var
%s Default: unset. The Sentry DSN to report events to e.g https://sliding-sync@sentry.example.com/123 - if unset does not send sentry events.
%s Default: info. The level of verbosity for messages logged. Available values are trace, debug, info, warn, error and fatal
%s Default: unset. Max database connections to use when communicating with postgres. Unset or 0 means no limit.
%s Default: 3600. The maximum amount of time a database connection may be idle, in seconds. 0 means no limit.
`, EnvServer, EnvDB, EnvSecret, EnvBindAddr, EnvTLSCert, EnvTLSKey, EnvPPROF, EnvPrometheus, EnvOTLP, EnvOTLPUsername, EnvOTLPPassword,
EnvSentryDsn, EnvLogLevel, EnvMaxConns)
EnvSentryDsn, EnvLogLevel, EnvMaxConns, EnvIdleTimeoutSecs)
func defaulting(in, dft string) string {
if in == "" {
@ -83,7 +84,6 @@ func defaulting(in, dft string) string {
}
func main() {
ctx := context.Background()
fmt.Printf("Sync v3 [%s] (%s)\n", version, GitCommit)
sync2.ProxyVersion = version
syncv3.Version = fmt.Sprintf("%s (%s)", version, GitCommit)
@ -94,21 +94,22 @@ func main() {
}
args := map[string]string{
EnvServer: os.Getenv(EnvServer),
EnvDB: os.Getenv(EnvDB),
EnvSecret: os.Getenv(EnvSecret),
EnvBindAddr: defaulting(os.Getenv(EnvBindAddr), "0.0.0.0:8008"),
EnvTLSCert: os.Getenv(EnvTLSCert),
EnvTLSKey: os.Getenv(EnvTLSKey),
EnvPPROF: os.Getenv(EnvPPROF),
EnvPrometheus: os.Getenv(EnvPrometheus),
EnvDebug: os.Getenv(EnvDebug),
EnvOTLP: os.Getenv(EnvOTLP),
EnvOTLPUsername: os.Getenv(EnvOTLPUsername),
EnvOTLPPassword: os.Getenv(EnvOTLPPassword),
EnvSentryDsn: os.Getenv(EnvSentryDsn),
EnvLogLevel: os.Getenv(EnvLogLevel),
EnvMaxConns: defaulting(os.Getenv(EnvMaxConns), "0"),
EnvServer: os.Getenv(EnvServer),
EnvDB: os.Getenv(EnvDB),
EnvSecret: os.Getenv(EnvSecret),
EnvBindAddr: defaulting(os.Getenv(EnvBindAddr), "0.0.0.0:8008"),
EnvTLSCert: os.Getenv(EnvTLSCert),
EnvTLSKey: os.Getenv(EnvTLSKey),
EnvPPROF: os.Getenv(EnvPPROF),
EnvPrometheus: os.Getenv(EnvPrometheus),
EnvDebug: os.Getenv(EnvDebug),
EnvOTLP: os.Getenv(EnvOTLP),
EnvOTLPUsername: os.Getenv(EnvOTLPUsername),
EnvOTLPPassword: os.Getenv(EnvOTLPPassword),
EnvSentryDsn: os.Getenv(EnvSentryDsn),
EnvLogLevel: os.Getenv(EnvLogLevel),
EnvMaxConns: defaulting(os.Getenv(EnvMaxConns), "0"),
EnvIdleTimeoutSecs: defaulting(os.Getenv(EnvIdleTimeoutSecs), "3600"),
}
requiredEnvVars := []string{EnvServer, EnvDB, EnvSecret, EnvBindAddr}
for _, requiredEnvVar := range requiredEnvVars {
@ -187,19 +188,18 @@ func main() {
}
}
err := sync2.MigrateDeviceIDs(ctx, args[EnvServer], args[EnvDB], args[EnvSecret], true)
if err != nil {
panic(err)
}
maxConnsInt, err := strconv.Atoi(args[EnvMaxConns])
if err != nil {
panic("invalid value for " + EnvMaxConns + ": " + args[EnvMaxConns])
}
idleTimeSecs, err := strconv.Atoi(args[EnvIdleTimeoutSecs])
if err != nil {
panic("invalid value for " + EnvIdleTimeoutSecs + ": " + args[EnvIdleTimeoutSecs])
}
h2, h3 := syncv3.Setup(args[EnvServer], args[EnvDB], args[EnvSecret], syncv3.Opts{
AddPrometheusMetrics: args[EnvPrometheus] != "",
DBMaxConns: maxConnsInt,
DBConnMaxIdleTime: time.Hour,
DBConnMaxIdleTime: time.Duration(idleTimeSecs) * time.Second,
MaxTransactionIDDelay: time.Second,
})

View File

@ -1,32 +0,0 @@
package main
import (
"context"
"github.com/matrix-org/sliding-sync/sync2"
"os"
)
const (
// Required fields
EnvServer = "SYNCV3_SERVER"
EnvDB = "SYNCV3_DB"
EnvSecret = "SYNCV3_SECRET"
// Migration test only
EnvMigrationCommit = "SYNCV3_TEST_MIGRATION_COMMIT"
)
func main() {
ctx := context.Background()
args := map[string]string{
EnvServer: os.Getenv(EnvServer),
EnvDB: os.Getenv(EnvDB),
EnvSecret: os.Getenv(EnvSecret),
EnvMigrationCommit: os.Getenv(EnvMigrationCommit),
}
err := sync2.MigrateDeviceIDs(ctx, args[EnvServer], args[EnvDB], args[EnvSecret], args[EnvMigrationCommit] != "")
if err != nil {
panic(err)
}
}

View File

@ -24,6 +24,7 @@ type V2Listener interface {
OnReceipt(p *V2Receipt)
OnDeviceMessages(p *V2DeviceMessages)
OnExpiredToken(p *V2ExpiredToken)
OnInvalidateRoom(p *V2InvalidateRoom)
}
type V2Initialise struct {
@ -129,6 +130,12 @@ type V2ExpiredToken struct {
func (*V2ExpiredToken) Type() string { return "V2ExpiredToken" }
type V2InvalidateRoom struct {
RoomID string
}
func (*V2InvalidateRoom) Type() string { return "V2InvalidateRoom" }
type V2Sub struct {
listener Listener
receiver V2Listener
@ -173,6 +180,8 @@ func (v *V2Sub) onMessage(p Payload) {
v.receiver.OnDeviceMessages(pl)
case *V2ExpiredToken:
v.receiver.OnExpiredToken(pl)
case *V2InvalidateRoom:
v.receiver.OnInvalidateRoom(pl)
default:
logger.Warn().Str("type", p.Type()).Msg("V2Sub: unhandled payload type")
}

View File

@ -318,6 +318,19 @@ func (a *Accumulator) Initialise(roomID string, state []json.RawMessage) (Initia
return res, err
}
type AccumulateResult struct {
// NumNew is the number of events in timeline NIDs that were not previously known
// to the proyx.
NumNew int
// TimelineNIDs is the list of event nids seen in a sync v2 timeline. Some of these
// may already be known to the proxy.
TimelineNIDs []int64
// RequiresReload is set to true when we have accumulated a non-incremental state
// change (typically a redaction) that requires consumers to reload the room state
// from the latest snapshot.
RequiresReload bool
}
// Accumulate internal state from a user's sync response. The timeline order MUST be in the order
// received from the server. Returns the number of new events in the timeline, the new timeline event NIDs
// or an error.
@ -329,7 +342,7 @@ func (a *Accumulator) Initialise(roomID string, state []json.RawMessage) (Initia
// to exist in the database, and the sync stream is already linearised for us.
// - Else it creates a new room state snapshot if the timeline contains state events (as this now represents the current state)
// - It adds entries to the membership log for membership events.
func (a *Accumulator) Accumulate(txn *sqlx.Tx, userID, roomID string, prevBatch string, timeline []json.RawMessage) (numNew int, timelineNIDs []int64, err error) {
func (a *Accumulator) Accumulate(txn *sqlx.Tx, userID, roomID string, prevBatch string, timeline []json.RawMessage) (AccumulateResult, error) {
// The first stage of accumulating events is mostly around validation around what the upstream HS sends us. For accumulation to work correctly
// we expect:
// - there to be no duplicate events
@ -338,10 +351,10 @@ func (a *Accumulator) Accumulate(txn *sqlx.Tx, userID, roomID string, prevBatch
dedupedEvents, err := a.filterAndParseTimelineEvents(txn, roomID, timeline, prevBatch)
if err != nil {
err = fmt.Errorf("filterTimelineEvents: %w", err)
return
return AccumulateResult{}, err
}
if len(dedupedEvents) == 0 {
return 0, nil, err // nothing to do
return AccumulateResult{}, nil // nothing to do
}
// Given a timeline of [E1, E2, S3, E4, S5, S6, E7] (E=message event, S=state event)
@ -355,7 +368,7 @@ func (a *Accumulator) Accumulate(txn *sqlx.Tx, userID, roomID string, prevBatch
// do NOT assign the new state event in the snapshot so as to represent the state before the event.
snapID, err := a.roomsTable.CurrentAfterSnapshotID(txn, roomID)
if err != nil {
return 0, nil, err
return AccumulateResult{}, err
}
// The only situation where no prior snapshot should exist is if this timeline is
@ -392,19 +405,22 @@ func (a *Accumulator) Accumulate(txn *sqlx.Tx, userID, roomID string, prevBatch
})
// the HS gave us bad data so there's no point retrying
// by not returning an error, we are telling the poller it is fine to not retry this request.
return 0, nil, nil
return AccumulateResult{}, nil
}
}
eventIDToNID, err := a.eventsTable.Insert(txn, dedupedEvents, false)
if err != nil {
return 0, nil, err
return AccumulateResult{}, err
}
if len(eventIDToNID) == 0 {
// nothing to do, we already know about these events
return 0, nil, nil
return AccumulateResult{}, nil
}
result := AccumulateResult{
NumNew: len(eventIDToNID),
}
numNew = len(eventIDToNID)
var latestNID int64
newEvents := make([]Event, 0, len(eventIDToNID))
@ -436,7 +452,7 @@ func (a *Accumulator) Accumulate(txn *sqlx.Tx, userID, roomID string, prevBatch
}
}
newEvents = append(newEvents, ev)
timelineNIDs = append(timelineNIDs, ev.NID)
result.TimelineNIDs = append(result.TimelineNIDs, ev.NID)
}
}
@ -446,7 +462,7 @@ func (a *Accumulator) Accumulate(txn *sqlx.Tx, userID, roomID string, prevBatch
if len(redactTheseEventIDs) > 0 {
createEventJSON, err := a.eventsTable.SelectCreateEvent(txn, roomID)
if err != nil {
return 0, nil, fmt.Errorf("SelectCreateEvent: %w", err)
return AccumulateResult{}, fmt.Errorf("SelectCreateEvent: %w", err)
}
roomVersion = gjson.GetBytes(createEventJSON, "content.room_version").Str
if roomVersion == "" {
@ -457,7 +473,7 @@ func (a *Accumulator) Accumulate(txn *sqlx.Tx, userID, roomID string, prevBatch
)
}
if err = a.eventsTable.Redact(txn, roomVersion, redactTheseEventIDs); err != nil {
return 0, nil, err
return AccumulateResult{}, err
}
}
@ -473,12 +489,12 @@ func (a *Accumulator) Accumulate(txn *sqlx.Tx, userID, roomID string, prevBatch
if snapID != 0 {
oldStripped, err = a.strippedEventsForSnapshot(txn, snapID)
if err != nil {
return 0, nil, fmt.Errorf("failed to load stripped state events for snapshot %d: %s", snapID, err)
return AccumulateResult{}, fmt.Errorf("failed to load stripped state events for snapshot %d: %s", snapID, err)
}
}
newStripped, replacedNID, err := a.calculateNewSnapshot(oldStripped, ev)
if err != nil {
return 0, nil, fmt.Errorf("failed to calculateNewSnapshot: %s", err)
return AccumulateResult{}, fmt.Errorf("failed to calculateNewSnapshot: %s", err)
}
replacesNID = replacedNID
memNIDs, otherNIDs := newStripped.NIDs()
@ -488,7 +504,7 @@ func (a *Accumulator) Accumulate(txn *sqlx.Tx, userID, roomID string, prevBatch
OtherEvents: otherNIDs,
}
if err = a.snapshotTable.Insert(txn, newSnapshot); err != nil {
return 0, nil, fmt.Errorf("failed to insert new snapshot: %w", err)
return AccumulateResult{}, fmt.Errorf("failed to insert new snapshot: %w", err)
}
if a.snapshotMemberCountVec != nil {
logger.Trace().Str("room_id", roomID).Int("members", len(memNIDs)).Msg("Inserted new snapshot")
@ -497,20 +513,40 @@ func (a *Accumulator) Accumulate(txn *sqlx.Tx, userID, roomID string, prevBatch
snapID = newSnapshot.SnapshotID
}
if err := a.eventsTable.UpdateBeforeSnapshotID(txn, ev.NID, beforeSnapID, replacesNID); err != nil {
return 0, nil, err
return AccumulateResult{}, err
}
}
if len(redactTheseEventIDs) > 0 {
// We need to emit a cache invalidation if we have redacted some state in the
// current snapshot ID. Note that we run this _after_ persisting any new snapshots.
redactedEventIDs := make([]string, 0, len(redactTheseEventIDs))
for eventID := range redactTheseEventIDs {
redactedEventIDs = append(redactedEventIDs, eventID)
}
var currentStateRedactions int
err = txn.Get(&currentStateRedactions, `
SELECT COUNT(*)
FROM syncv3_events
JOIN syncv3_snapshots ON event_nid = ANY (ARRAY_CAT(events, membership_events))
WHERE snapshot_id = $1 AND event_id = ANY($2)
`, snapID, pq.StringArray(redactedEventIDs))
if err != nil {
return AccumulateResult{}, err
}
result.RequiresReload = currentStateRedactions > 0
}
if err = a.spacesTable.HandleSpaceUpdates(txn, newEvents); err != nil {
return 0, nil, fmt.Errorf("HandleSpaceUpdates: %s", err)
return AccumulateResult{}, fmt.Errorf("HandleSpaceUpdates: %s", err)
}
// the last fetched snapshot ID is the current one
info := a.roomInfoDelta(roomID, newEvents)
if err = a.roomsTable.Upsert(txn, info, snapID, latestNID); err != nil {
return 0, nil, fmt.Errorf("failed to UpdateCurrentSnapshotID to %d: %w", snapID, err)
return AccumulateResult{}, fmt.Errorf("failed to UpdateCurrentSnapshotID to %d: %w", snapID, err)
}
return numNew, timelineNIDs, nil
return result, nil
}
// filterAndParseTimelineEvents takes a raw timeline array from sync v2 and applies sanity to it:

View File

@ -3,6 +3,7 @@ package state
import (
"context"
"encoding/json"
"fmt"
"github.com/matrix-org/sliding-sync/testutils"
"reflect"
"sort"
@ -135,25 +136,24 @@ func TestAccumulatorAccumulate(t *testing.T) {
// new state event should be added to the snapshot
[]byte(`{"event_id":"I", "type":"m.room.history_visibility", "state_key":"", "content":{"visibility":"public"}}`),
}
var numNew int
var latestNIDs []int64
var result AccumulateResult
err = sqlutil.WithTransaction(accumulator.db, func(txn *sqlx.Tx) error {
numNew, latestNIDs, err = accumulator.Accumulate(txn, userID, roomID, "", newEvents)
result, err = accumulator.Accumulate(txn, userID, roomID, "", newEvents)
return err
})
if err != nil {
t.Fatalf("failed to Accumulate: %s", err)
}
if numNew != len(newEvents) {
t.Fatalf("got %d new events, want %d", numNew, len(newEvents))
if result.NumNew != len(newEvents) {
t.Fatalf("got %d new events, want %d", result.NumNew, len(newEvents))
}
// latest nid shoould match
wantLatestNID, err := accumulator.eventsTable.SelectHighestNID()
if err != nil {
t.Fatalf("failed to check latest NID from Accumulate: %s", err)
}
if latestNIDs[len(latestNIDs)-1] != wantLatestNID {
t.Errorf("Accumulator.Accumulate returned latest nid %d, want %d", latestNIDs[len(latestNIDs)-1], wantLatestNID)
if result.TimelineNIDs[len(result.TimelineNIDs)-1] != wantLatestNID {
t.Errorf("Accumulator.Accumulate returned latest nid %d, want %d", result.TimelineNIDs[len(result.TimelineNIDs)-1], wantLatestNID)
}
// Begin assertions
@ -212,7 +212,7 @@ func TestAccumulatorAccumulate(t *testing.T) {
// subsequent calls do nothing and are not an error
err = sqlutil.WithTransaction(accumulator.db, func(txn *sqlx.Tx) error {
_, _, err = accumulator.Accumulate(txn, userID, roomID, "", newEvents)
_, err = accumulator.Accumulate(txn, userID, roomID, "", newEvents)
return err
})
if err != nil {
@ -220,6 +220,80 @@ func TestAccumulatorAccumulate(t *testing.T) {
}
}
func TestAccumulatorPromptsCacheInvalidation(t *testing.T) {
db, close := connectToDB(t)
defer close()
accumulator := NewAccumulator(db)
t.Log("Initialise the room state, including a room name.")
roomID := fmt.Sprintf("!%s:localhost", t.Name())
stateBlock := []json.RawMessage{
[]byte(`{"event_id":"$a", "type":"m.room.create", "state_key":"", "content":{"creator":"@me:localhost", "room_version": "10"}}`),
[]byte(`{"event_id":"$b", "type":"m.room.member", "state_key":"@me:localhost", "content":{"membership":"join"}}`),
[]byte(`{"event_id":"$c", "type":"m.room.join_rules", "state_key":"", "content":{"join_rule":"public"}}`),
[]byte(`{"event_id":"$d", "type":"m.room.name", "state_key":"", "content":{"name":"Barry Cryer Appreciation Society"}}`),
}
_, err := accumulator.Initialise(roomID, stateBlock)
if err != nil {
t.Fatalf("failed to Initialise accumulator: %s", err)
}
t.Log("Accumulate a second room name, a message, then a third room name.")
timeline := []json.RawMessage{
[]byte(`{"event_id":"$e", "type":"m.room.name", "state_key":"", "content":{"name":"Jeremy Hardy Appreciation Society"}}`),
[]byte(`{"event_id":"$f", "type":"m.room.message", "content": {"body":"Hello, world!", "msgtype":"m.text"}}`),
[]byte(`{"event_id":"$g", "type":"m.room.name", "state_key":"", "content":{"name":"Humphrey Lyttelton Appreciation Society"}}`),
}
var accResult AccumulateResult
err = sqlutil.WithTransaction(db, func(txn *sqlx.Tx) error {
accResult, err = accumulator.Accumulate(txn, "@dummy:localhost", roomID, "prevBatch", timeline)
return err
})
if err != nil {
t.Fatalf("Failed to Accumulate: %s", err)
}
t.Log("We expect 3 new events and no reload required.")
assertValue(t, "accResult.NumNew", accResult.NumNew, 3)
assertValue(t, "len(accResult.TimelineNIDs)", len(accResult.TimelineNIDs), 3)
assertValue(t, "accResult.RequiresReload", accResult.RequiresReload, false)
t.Log("Redact the old state event and the message.")
timeline = []json.RawMessage{
[]byte(`{"event_id":"$h", "type":"m.room.redaction", "content":{"redacts":"$e"}}`),
[]byte(`{"event_id":"$i", "type":"m.room.redaction", "content":{"redacts":"$f"}}`),
}
err = sqlutil.WithTransaction(db, func(txn *sqlx.Tx) error {
accResult, err = accumulator.Accumulate(txn, "@dummy:localhost", roomID, "prevBatch2", timeline)
return err
})
if err != nil {
t.Fatalf("Failed to Accumulate: %s", err)
}
t.Log("We expect 2 new events and no reload required.")
assertValue(t, "accResult.NumNew", accResult.NumNew, 2)
assertValue(t, "len(accResult.TimelineNIDs)", len(accResult.TimelineNIDs), 2)
assertValue(t, "accResult.RequiresReload", accResult.RequiresReload, false)
t.Log("Redact the latest state event.")
timeline = []json.RawMessage{
[]byte(`{"event_id":"$j", "type":"m.room.redaction", "content":{"redacts":"$g"}}`),
}
err = sqlutil.WithTransaction(db, func(txn *sqlx.Tx) error {
accResult, err = accumulator.Accumulate(txn, "@dummy:localhost", roomID, "prevBatch3", timeline)
return err
})
if err != nil {
t.Fatalf("Failed to Accumulate: %s", err)
}
t.Log("We expect 1 new event and a reload required.")
assertValue(t, "accResult.NumNew", accResult.NumNew, 1)
assertValue(t, "len(accResult.TimelineNIDs)", len(accResult.TimelineNIDs), 1)
assertValue(t, "accResult.RequiresReload", accResult.RequiresReload, true)
}
func TestAccumulatorMembershipLogs(t *testing.T) {
roomID := "!TestAccumulatorMembershipLogs:localhost"
db, close := connectToDB(t)
@ -248,7 +322,7 @@ func TestAccumulatorMembershipLogs(t *testing.T) {
[]byte(`{"event_id":"` + roomEventIDs[7] + `", "type":"m.room.member", "state_key":"@me:localhost","unsigned":{"prev_content":{"membership":"join", "displayname":"Me"}}, "content":{"membership":"leave"}}`),
}
err = sqlutil.WithTransaction(accumulator.db, func(txn *sqlx.Tx) error {
_, _, err = accumulator.Accumulate(txn, userID, roomID, "", roomEvents)
_, err = accumulator.Accumulate(txn, userID, roomID, "", roomEvents)
return err
})
if err != nil {
@ -384,7 +458,7 @@ func TestAccumulatorDupeEvents(t *testing.T) {
}
err = sqlutil.WithTransaction(accumulator.db, func(txn *sqlx.Tx) error {
_, _, err = accumulator.Accumulate(txn, userID, roomID, "", joinRoom.Timeline.Events)
_, err = accumulator.Accumulate(txn, userID, roomID, "", joinRoom.Timeline.Events)
return err
})
if err != nil {
@ -584,8 +658,8 @@ func TestAccumulatorConcurrency(t *testing.T) {
defer wg.Done()
subset := newEvents[:(i + 1)] // i=0 => [1], i=1 => [1,2], etc
err := sqlutil.WithTransaction(accumulator.db, func(txn *sqlx.Tx) error {
numNew, _, err := accumulator.Accumulate(txn, userID, roomID, "", subset)
totalNumNew += numNew
result, err := accumulator.Accumulate(txn, userID, roomID, "", subset)
totalNumNew += result.NumNew
return err
})
if err != nil {

View File

@ -306,6 +306,77 @@ func (s *Storage) MetadataForAllRooms(txn *sqlx.Tx, tempTableName string, result
return nil
}
// ResetMetadataState updates the given metadata in-place to reflect the current state
// of the room. This is only safe to call from the subscriber goroutine; it is not safe
// to call from the connection goroutines.
// TODO: could have this create a new RoomMetadata and get the caller to assign it.
func (s *Storage) ResetMetadataState(metadata *internal.RoomMetadata) error {
var events []Event
err := s.DB.Select(&events, `
WITH snapshot(events, membership_events) AS (
SELECT events, membership_events
FROM syncv3_snapshots
JOIN syncv3_rooms ON snapshot_id = current_snapshot_id
WHERE syncv3_rooms.room_id = $1
)
SELECT event_id, event_type, state_key, event, membership
FROM syncv3_events JOIN snapshot ON (
event_nid = ANY (ARRAY_CAT(events, membership_events))
)
WHERE (event_type IN ('m.room.name', 'm.room.avatar', 'm.room.canonical_alias') AND state_key = '')
OR (event_type = 'm.room.member' AND membership IN ('join', '_join', 'invite', '_invite'))
ORDER BY event_nid ASC
;`, metadata.RoomID)
if err != nil {
return fmt.Errorf("ResetMetadataState[%s]: %w", metadata.RoomID, err)
}
heroMemberships := circularSlice[*Event]{max: 6}
metadata.JoinCount = 0
metadata.InviteCount = 0
metadata.ChildSpaceRooms = make(map[string]struct{})
for i, ev := range events {
switch ev.Type {
case "m.room.name":
metadata.NameEvent = gjson.GetBytes(ev.JSON, "content.name").Str
case "m.room.avatar":
metadata.AvatarEvent = gjson.GetBytes(ev.JSON, "content.avatar_url").Str
case "m.room.canonical_alias":
metadata.CanonicalAlias = gjson.GetBytes(ev.JSON, "content.alias").Str
case "m.room.member":
heroMemberships.append(&events[i])
switch ev.Membership {
case "join":
fallthrough
case "_join":
metadata.JoinCount++
case "invite":
fallthrough
case "_invite":
metadata.InviteCount++
}
case "m.space.child":
metadata.ChildSpaceRooms[ev.StateKey] = struct{}{}
}
}
metadata.Heroes = make([]internal.Hero, 0, len(heroMemberships.vals))
for _, ev := range heroMemberships.vals {
parsed := gjson.ParseBytes(ev.JSON)
hero := internal.Hero{
ID: ev.StateKey,
Name: parsed.Get("content.displayname").Str,
Avatar: parsed.Get("content.avatar_url").Str,
}
metadata.Heroes = append(metadata.Heroes, hero)
}
// For now, don't bother reloading Encrypted, PredecessorID and UpgradedRoomID.
// These shouldn't be changing during a room's lifetime in normal operation.
return nil
}
// Returns all current NOT MEMBERSHIP state events matching the event types given in all rooms. Returns a map of
// room ID to events in that room.
func (s *Storage) currentNotMembershipStateEventsInAllRooms(txn *sqlx.Tx, eventTypes []string) (map[string][]Event, error) {
@ -336,15 +407,15 @@ func (s *Storage) currentNotMembershipStateEventsInAllRooms(txn *sqlx.Tx, eventT
return result, nil
}
func (s *Storage) Accumulate(userID, roomID, prevBatch string, timeline []json.RawMessage) (numNew int, timelineNIDs []int64, err error) {
func (s *Storage) Accumulate(userID, roomID, prevBatch string, timeline []json.RawMessage) (result AccumulateResult, err error) {
if len(timeline) == 0 {
return 0, nil, nil
return AccumulateResult{}, nil
}
err = sqlutil.WithTransaction(s.Accumulator.db, func(txn *sqlx.Tx) error {
numNew, timelineNIDs, err = s.Accumulator.Accumulate(txn, userID, roomID, prevBatch, timeline)
result, err = s.Accumulator.Accumulate(txn, userID, roomID, prevBatch, timeline)
return err
})
return
return result, err
}
func (s *Storage) Initialise(roomID string, state []json.RawMessage) (InitialiseResult, error) {
@ -818,7 +889,7 @@ func (s *Storage) AllJoinedMembers(txn *sqlx.Tx, tempTableName string) (joinedMe
defer rows.Close()
joinedMembers = make(map[string][]string)
inviteCounts := make(map[string]int)
heroNIDs := make(map[string]*circularSlice)
heroNIDs := make(map[string]*circularSlice[int64])
var stateKey string
var membership string
var roomID string
@ -829,7 +900,7 @@ func (s *Storage) AllJoinedMembers(txn *sqlx.Tx, tempTableName string) (joinedMe
}
heroes := heroNIDs[roomID]
if heroes == nil {
heroes = &circularSlice{max: 6}
heroes = &circularSlice[int64]{max: 6}
heroNIDs[roomID] = heroes
}
switch membership {
@ -982,13 +1053,13 @@ func (s *Storage) Teardown() {
// circularSlice is a slice which can be appended to which will wraparound at `max`.
// Mostly useful for lazily calculating heroes. The values returned aren't sorted.
type circularSlice struct {
type circularSlice[T any] struct {
i int
vals []int64
vals []T
max int
}
func (s *circularSlice) append(val int64) {
func (s *circularSlice[T]) append(val T) {
if len(s.vals) < s.max {
// populate up to max
s.vals = append(s.vals, val)

View File

@ -31,11 +31,11 @@ func TestStorageRoomStateBeforeAndAfterEventPosition(t *testing.T) {
testutils.NewStateEvent(t, "m.room.join_rules", "", alice, map[string]interface{}{"join_rule": "invite"}),
testutils.NewStateEvent(t, "m.room.member", bob, alice, map[string]interface{}{"membership": "invite"}),
}
_, latestNIDs, err := store.Accumulate(userID, roomID, "", events)
accResult, err := store.Accumulate(userID, roomID, "", events)
if err != nil {
t.Fatalf("Accumulate returned error: %s", err)
}
latest := latestNIDs[len(latestNIDs)-1]
latest := accResult.TimelineNIDs[len(accResult.TimelineNIDs)-1]
testCases := []struct {
name string
@ -158,14 +158,13 @@ func TestStorageJoinedRoomsAfterPosition(t *testing.T) {
},
}
var latestPos int64
var latestNIDs []int64
var err error
for roomID, eventMap := range roomIDToEventMap {
_, latestNIDs, err = store.Accumulate(userID, roomID, "", eventMap)
accResult, err := store.Accumulate(userID, roomID, "", eventMap)
if err != nil {
t.Fatalf("Accumulate on %s failed: %s", roomID, err)
}
latestPos = latestNIDs[len(latestNIDs)-1]
latestPos = accResult.TimelineNIDs[len(accResult.TimelineNIDs)-1]
}
aliceJoinTimingsByRoomID, err := store.JoinedRoomsAfterPosition(alice, latestPos)
if err != nil {
@ -351,11 +350,11 @@ func TestVisibleEventNIDsBetween(t *testing.T) {
},
}
for _, tl := range timelineInjections {
numNew, _, err := store.Accumulate(userID, tl.RoomID, "", tl.Events)
accResult, err := store.Accumulate(userID, tl.RoomID, "", tl.Events)
if err != nil {
t.Fatalf("Accumulate on %s failed: %s", tl.RoomID, err)
}
t.Logf("%s added %d new events", tl.RoomID, numNew)
t.Logf("%s added %d new events", tl.RoomID, accResult.NumNew)
}
latestPos, err := store.LatestEventNID()
if err != nil {
@ -454,11 +453,11 @@ func TestVisibleEventNIDsBetween(t *testing.T) {
t.Fatalf("LatestEventNID: %s", err)
}
for _, tl := range timelineInjections {
numNew, _, err := store.Accumulate(userID, tl.RoomID, "", tl.Events)
accResult, err := store.Accumulate(userID, tl.RoomID, "", tl.Events)
if err != nil {
t.Fatalf("Accumulate on %s failed: %s", tl.RoomID, err)
}
t.Logf("%s added %d new events", tl.RoomID, numNew)
t.Logf("%s added %d new events", tl.RoomID, accResult.NumNew)
}
latestPos, err = store.LatestEventNID()
if err != nil {
@ -534,7 +533,7 @@ func TestStorageLatestEventsInRoomsPrevBatch(t *testing.T) {
}
eventIDs := []string{}
for _, timeline := range timelines {
_, _, err = store.Accumulate(userID, roomID, timeline.prevBatch, timeline.timeline)
_, err := store.Accumulate(userID, roomID, timeline.prevBatch, timeline.timeline)
if err != nil {
t.Fatalf("failed to accumulate: %s", err)
}
@ -776,7 +775,7 @@ func TestAllJoinedMembers(t *testing.T) {
}, serialise(tc.InitMemberships)...))
assertNoError(t, err)
_, _, err = store.Accumulate(userID, roomID, "foo", serialise(tc.AccumulateMemberships))
_, err = store.Accumulate(userID, roomID, "foo", serialise(tc.AccumulateMemberships))
assertNoError(t, err)
testCases[i].RoomID = roomID // remember this for later
}
@ -855,7 +854,7 @@ func TestCircularSlice(t *testing.T) {
},
}
for _, tc := range testCases {
cs := &circularSlice{
cs := &circularSlice[int64]{
max: tc.max,
}
for _, val := range tc.appends {

View File

@ -294,19 +294,26 @@ func (h *Handler) Accumulate(ctx context.Context, userID, deviceID, roomID, prev
}
// Insert new events
numNew, latestNIDs, err := h.Store.Accumulate(userID, roomID, prevBatch, timeline)
accResult, err := h.Store.Accumulate(userID, roomID, prevBatch, timeline)
if err != nil {
logger.Err(err).Int("timeline", len(timeline)).Str("room", roomID).Msg("V2: failed to accumulate room")
internal.GetSentryHubFromContextOrDefault(ctx).CaptureException(err)
return err
}
// Consumers should reload state before processing new timeline events.
if accResult.RequiresReload {
h.v2Pub.Notify(pubsub.ChanV2, &pubsub.V2InvalidateRoom{
RoomID: roomID,
})
}
// We've updated the database. Now tell any pubsub listeners what we learned.
if numNew != 0 {
if accResult.NumNew != 0 {
h.v2Pub.Notify(pubsub.ChanV2, &pubsub.V2Accumulate{
RoomID: roomID,
PrevBatch: prevBatch,
EventNIDs: latestNIDs,
EventNIDs: accResult.TimelineNIDs,
})
}

View File

@ -1,455 +0,0 @@
package sync2
import (
"context"
"crypto/sha256"
"fmt"
"github.com/getsentry/sentry-go"
"github.com/jmoiron/sqlx"
"github.com/matrix-org/sliding-sync/sqlutil"
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
"net/http"
"time"
)
// MigrateDeviceIDs performs a one-off DB migration from the old device ids (hash of
// access token) to the new device ids (actual device ids from the homeserver). This is
// not backwards compatible. If the migration has already taken place, this function is
// a no-op.
//
// This code will be removed in a future version of the proxy.
func MigrateDeviceIDs(ctx context.Context, destHomeserver, postgresURI, secret string, commit bool) error {
whoamiClient := &HTTPClient{
Client: &http.Client{
Timeout: 5 * time.Minute,
Transport: otelhttp.NewTransport(http.DefaultTransport),
},
DestinationServer: destHomeserver,
}
db, err := sqlx.Open("postgres", postgresURI)
if err != nil {
sentry.CaptureException(err)
logger.Panic().Err(err).Str("uri", postgresURI).Msg("failed to open SQL DB")
}
// Ensure the new table exists.
NewTokensTable(db, secret)
return sqlutil.WithTransaction(db, func(txn *sqlx.Tx) (err error) {
migrated, err := isMigrated(txn)
if err != nil {
return
}
if migrated {
logger.Debug().Msg("MigrateDeviceIDs: migration has already taken place")
return nil
}
logger.Info().Msgf("MigrateDeviceIDs: starting (commit=%t)", commit)
start := time.Now()
defer func() {
elapsed := time.Since(start)
logger.Debug().Msgf("MigrateDeviceIDs: took %s", elapsed)
}()
err = alterTables(txn)
if err != nil {
return
}
err = runMigration(ctx, txn, secret, whoamiClient)
if err != nil {
return
}
err = finish(txn)
if err != nil {
return
}
s := NewStore(postgresURI, secret)
tokens, err := s.TokensTable.TokenForEachDevice(txn)
if err != nil {
return
}
logger.Debug().Msgf("Got %d tokens after migration", len(tokens))
if !commit {
err = fmt.Errorf("MigrateDeviceIDs: migration succeeded without errors, but commit is false - rolling back anyway")
} else {
logger.Info().Msg("MigrateDeviceIDs: migration succeeded - committing")
}
return
})
}
func isMigrated(txn *sqlx.Tx) (bool, error) {
// Keep this dead simple for now. This is a one-off migration, before version 1.0.
// In the future we'll rip this out and tell people that it's their job to ensure
// this migration has run before they upgrade beyond the rip-out point.
// We're going to detect if the migration has run by testing for the existence of
// a column added by the migration. First, check that the table exists.
var tableExists bool
err := txn.QueryRow(`
SELECT EXISTS(
SELECT 1 FROM information_schema.columns
WHERE table_name = 'syncv3_txns'
);
`).Scan(&tableExists)
if err != nil {
return false, fmt.Errorf("isMigrated: %s", err)
}
if !tableExists {
// The proxy has never been run before and its tables have never been created.
// We do not need to run the migration.
logger.Debug().Msg("isMigrated: no syncv3_txns table, no migration needed")
return true, nil
}
var migrated bool
err = txn.QueryRow(`
SELECT EXISTS(
SELECT 1 FROM information_schema.columns
WHERE table_name = 'syncv3_txns' AND column_name = 'device_id'
);
`).Scan(&migrated)
if err != nil {
return false, fmt.Errorf("isMigrated: %s", err)
}
return migrated, nil
}
func alterTables(txn *sqlx.Tx) (err error) {
_, err = txn.Exec(`
ALTER TABLE syncv3_sync2_devices
DROP CONSTRAINT syncv3_sync2_devices_pkey;
`)
if err != nil {
return
}
_, err = txn.Exec(`
ALTER TABLE syncv3_to_device_messages
ADD COLUMN user_id TEXT;
`)
if err != nil {
return
}
_, err = txn.Exec(`
ALTER TABLE syncv3_to_device_ack_pos
DROP CONSTRAINT syncv3_to_device_ack_pos_pkey,
ADD COLUMN user_id TEXT;
`)
if err != nil {
return
}
_, err = txn.Exec(`
ALTER TABLE syncv3_txns
DROP CONSTRAINT syncv3_txns_user_id_event_id_key,
ADD COLUMN device_id TEXT;
`)
return
}
type oldDevice struct {
AccessToken string // not a DB row, but it's convenient to write to here
AccessTokenHash string `db:"device_id"`
UserID string `db:"user_id"`
AccessTokenEncrypted string `db:"v2_token_encrypted"`
Since string `db:"since"`
}
func runMigration(ctx context.Context, txn *sqlx.Tx, secret string, whoamiClient Client) error {
logger.Info().Msg("Loading old-style devices into memory")
var devices []oldDevice
err := txn.Select(
&devices,
`SELECT device_id, user_id, v2_token_encrypted, since FROM syncv3_sync2_devices;`,
)
if err != nil {
return fmt.Errorf("runMigration: failed to select devices: %s", err)
}
logger.Info().Msgf("Got %d devices to migrate", len(devices))
hasher := sha256.New()
hasher.Write([]byte(secret))
key := hasher.Sum(nil)
// This migration runs sequentially, one device at a time. We have found this to be
// quick enough in practice.
numErrors := 0
for i, device := range devices {
device.AccessToken, err = decrypt(device.AccessTokenEncrypted, key)
if err != nil {
return fmt.Errorf("runMigration: failed to decrypt device: %s", err)
}
userID := device.UserID
if userID == "" {
userID = "<unknown user>"
}
logger.Info().Msgf(
"%4d/%4d migrating device %s %s",
i+1, len(devices), userID, device.AccessTokenHash,
)
err = migrateDevice(ctx, txn, whoamiClient, &device)
if err != nil {
logger.Err(err).Msgf("runMigration: failed to migrate device %s", device.AccessTokenHash)
numErrors++
}
}
if numErrors > 0 {
return fmt.Errorf("runMigration: there were %d failures", numErrors)
}
return nil
}
func migrateDevice(ctx context.Context, txn *sqlx.Tx, whoamiClient Client, device *oldDevice) (err error) {
gotUserID, gotDeviceID, err := whoamiClient.WhoAmI(ctx, device.AccessToken)
if err == HTTP401 {
userID := device.UserID
if userID == "" {
userID = "<unknown user>"
}
logger.Warn().Msgf(
"migrateDevice: access token for %s %s has expired. Dropping device and metadata.",
userID, device.AccessTokenHash,
)
return cleanupDevice(txn, device)
}
if err != nil {
return err
}
// Sanity check the user ID from the HS matches our records
if gotUserID != device.UserID {
return fmt.Errorf(
"/whoami response was for the wrong user. Queried for %s, but got response for %s",
device.UserID, gotUserID,
)
}
err = exec(
txn,
`INSERT INTO syncv3_sync2_tokens(token_hash, token_encrypted, user_id, device_id, last_seen)
VALUES ($1, $2, $3, $4, $5)`,
expectOneRowAffected,
device.AccessTokenHash, device.AccessTokenEncrypted, gotUserID, gotDeviceID, time.Now(),
)
if err != nil {
return
}
// For these first four tables:
// - use the actual device ID instead of the access token hash, and
// - ensure a user ID is set.
err = exec(
txn,
`UPDATE syncv3_sync2_devices SET user_id = $1, device_id = $2 WHERE device_id = $3`,
expectOneRowAffected,
gotUserID, gotDeviceID, device.AccessTokenHash,
)
if err != nil {
return
}
err = exec(
txn,
`UPDATE syncv3_to_device_messages SET user_id = $1, device_id = $2 WHERE device_id = $3`,
expectAnyNumberOfRowsAffected,
gotUserID, gotDeviceID, device.AccessTokenHash,
)
if err != nil {
return
}
err = exec(
txn,
`UPDATE syncv3_to_device_ack_pos SET user_id = $1, device_id = $2 WHERE device_id = $3`,
expectAtMostOneRowAffected,
gotUserID, gotDeviceID, device.AccessTokenHash,
)
if err != nil {
return
}
err = exec(
txn,
`UPDATE syncv3_device_data SET user_id = $1, device_id = $2 WHERE device_id = $3`,
expectAtMostOneRowAffected,
gotUserID, gotDeviceID, device.AccessTokenHash,
)
if err != nil {
return
}
// Confusingly, the txns table used to store access token hashes under the user_id
// column. Write the actual user ID to the user_id column, and the actual device ID
// to the device_id column.
err = exec(
txn,
`UPDATE syncv3_txns SET user_id = $1, device_id = $2 WHERE user_id = $3`,
expectAnyNumberOfRowsAffected,
gotUserID, gotDeviceID, device.AccessTokenHash,
)
if err != nil {
return
}
return
}
func cleanupDevice(txn *sqlx.Tx, device *oldDevice) (err error) {
// The homeserver does not recognise this access token. Because we have no
// record of the device_id from the homeserver, it will never be possible to
// spot that a future refreshed access token belongs to the device we're
// handling here. Therefore this device is not useful to the proxy.
//
// If we leave this device's rows in situ, we may end up with rows in
// syncv3_to_device_messages, syncv3_to_device_ack_pos and syncv3_txns which have
// null values for the new fields, which will mean we fail to impose the uniqueness
// constraints at the end of the migration. Instead, drop those rows.
err = exec(
txn,
`DELETE FROM syncv3_sync2_devices WHERE device_id = $1`,
expectOneRowAffected,
device.AccessTokenHash,
)
if err != nil {
return
}
err = exec(
txn,
`DELETE FROM syncv3_to_device_messages WHERE device_id = $1`,
expectAnyNumberOfRowsAffected,
device.AccessTokenHash,
)
if err != nil {
return
}
err = exec(
txn,
`DELETE FROM syncv3_to_device_ack_pos WHERE device_id = $1`,
expectAtMostOneRowAffected,
device.AccessTokenHash,
)
if err != nil {
return
}
err = exec(
txn,
`DELETE FROM syncv3_device_data WHERE device_id = $1`,
expectAtMostOneRowAffected,
device.AccessTokenHash,
)
if err != nil {
return
}
err = exec(
txn,
`DELETE FROM syncv3_txns WHERE user_id = $1`,
expectAnyNumberOfRowsAffected,
device.AccessTokenHash,
)
if err != nil {
return
}
return
}
func exec(txn *sqlx.Tx, query string, checkRowsAffected func(ra int64) bool, args ...any) error {
res, err := txn.Exec(query, args...)
if err != nil {
return err
}
ra, err := res.RowsAffected()
if err != nil {
return err
}
if !checkRowsAffected(ra) {
return fmt.Errorf("query \"%s\" unexpectedly affected %d rows", query, ra)
}
return nil
}
func expectOneRowAffected(ra int64) bool { return ra == 1 }
func expectAnyNumberOfRowsAffected(ra int64) bool { return true }
func logRowsAffected(msg string) func(ra int64) bool {
return func(ra int64) bool {
logger.Info().Msgf(msg, ra)
return true
}
}
func expectAtMostOneRowAffected(ra int64) bool { return ra == 0 || ra == 1 }
func finish(txn *sqlx.Tx) (err error) {
// OnExpiredToken used to delete from the devices and to-device tables, but not from
// the to-device ack pos or the txn tables. Fix this up by deleting orphaned rows.
err = exec(
txn,
`
DELETE FROM syncv3_to_device_ack_pos
WHERE device_id NOT IN (SELECT device_id FROM syncv3_sync2_devices)
;`,
logRowsAffected("Deleted %d stale rows from syncv3_to_device_ack_pos"),
)
if err != nil {
return
}
err = exec(
txn,
`
DELETE FROM syncv3_txns WHERE device_id IS NULL;
`,
logRowsAffected("Deleted %d stale rows from syncv3_txns"),
)
if err != nil {
return
}
_, err = txn.Exec(`
ALTER TABLE syncv3_sync2_devices
DROP COLUMN v2_token_encrypted,
ADD PRIMARY KEY (user_id, device_id);
`)
if err != nil {
return
}
_, err = txn.Exec(`
ALTER TABLE syncv3_to_device_messages
ALTER COLUMN user_id SET NOT NULL;
`)
if err != nil {
return
}
_, err = txn.Exec(`
ALTER TABLE syncv3_to_device_ack_pos
ALTER COLUMN user_id SET NOT NULL,
ADD PRIMARY KEY (user_id, device_id);
`)
if err != nil {
return
}
_, err = txn.Exec(`
ALTER TABLE syncv3_txns
ALTER COLUMN device_id SET NOT NULL,
ADD UNIQUE(user_id, device_id, event_id);
`)
if err != nil {
return
}
return
}

View File

@ -1,50 +0,0 @@
package sync2
import (
"github.com/jmoiron/sqlx"
"github.com/matrix-org/sliding-sync/sqlutil"
"github.com/matrix-org/sliding-sync/state"
"testing"
)
func TestConsideredMigratedOnFirstStartup(t *testing.T) {
db, close := connectToDB(t)
defer close()
var migrated bool
err := sqlutil.WithTransaction(db, func(txn *sqlx.Tx) (err error) {
// Attempt to make this test independent of others by dropping the table whose
// columns we probe.
_, err = txn.Exec("DROP TABLE IF EXISTS syncv3_txns;")
if err != nil {
return err
}
migrated, err = isMigrated(txn)
return
})
if err != nil {
t.Fatalf("Error calling isMigrated: %s", err)
}
if !migrated {
t.Fatalf("Expected a non-existent DB to be considered migrated, but it was not")
}
}
func TestNewSchemaIsConsideredMigrated(t *testing.T) {
NewStore(postgresConnectionString, "my_secret")
state.NewStorage(postgresConnectionString)
db, close := connectToDB(t)
defer close()
var migrated bool
err := sqlutil.WithTransaction(db, func(txn *sqlx.Tx) (err error) {
migrated, err = isMigrated(txn)
return
})
if err != nil {
t.Fatalf("Error calling isMigrated: %s", err)
}
if !migrated {
t.Fatalf("Expected a new DB to be considered migrated, but it was not")
}
}

View File

@ -385,3 +385,20 @@ func (c *GlobalCache) OnNewEvent(
}
c.roomIDToMetadata[ed.RoomID] = metadata
}
func (c *GlobalCache) OnInvalidateRoom(ctx context.Context, roomID string) {
c.roomIDToMetadataMu.Lock()
defer c.roomIDToMetadataMu.Unlock()
metadata, ok := c.roomIDToMetadata[roomID]
if !ok {
logger.Warn().Str("room_id", roomID).Msg("OnInvalidateRoom: room not in global cache")
return
}
err := c.store.ResetMetadataState(metadata)
if err != nil {
internal.GetSentryHubFromContextOrDefault(ctx).CaptureException(err)
logger.Warn().Err(err).Msg("OnInvalidateRoom: failed to reset metadata")
}
}

View File

@ -38,16 +38,16 @@ func TestGlobalCacheLoadState(t *testing.T) {
testutils.NewStateEvent(t, "m.room.name", "", alice, map[string]interface{}{"name": "The Room Name"}),
testutils.NewStateEvent(t, "m.room.name", "", alice, map[string]interface{}{"name": "The Updated Room Name"}),
}
_, _, err := store.Accumulate(alice, roomID2, "", eventsRoom2)
_, err := store.Accumulate(alice, roomID2, "", eventsRoom2)
if err != nil {
t.Fatalf("Accumulate: %s", err)
}
_, latestNIDs, err := store.Accumulate(alice, roomID, "", events)
accResult, err := store.Accumulate(alice, roomID, "", events)
if err != nil {
t.Fatalf("Accumulate: %s", err)
}
latest := latestNIDs[len(latestNIDs)-1]
latest := accResult.TimelineNIDs[len(accResult.TimelineNIDs)-1]
globalCache := caches.NewGlobalCache(store)
testCases := []struct {
name string

View File

@ -760,3 +760,10 @@ func (u *UserCache) ShouldIgnore(userID string) bool {
_, ignored := u.ignoredUsers[userID]
return ignored
}
func (u *UserCache) OnInvalidateRoom(ctx context.Context, roomID string) {
// Nothing for now. In UserRoomData the fields dependant on room state are
// IsDM, IsInvite, HasLeft, Invite, CanonicalisedName, ResolvedAvatarURL, Spaces.
// Not clear to me if we need to reload these or if we will inherit any changes from
// the global cache.
}

View File

@ -24,6 +24,7 @@ type Receiver interface {
OnNewEvent(ctx context.Context, event *caches.EventData)
OnReceipt(ctx context.Context, receipt internal.Receipt)
OnEphemeralEvent(ctx context.Context, roomID string, ephEvent json.RawMessage)
OnInvalidateRoom(ctx context.Context, roomID string)
// OnRegistered is called after a successful call to Dispatcher.Register
OnRegistered(ctx context.Context) error
}
@ -285,3 +286,23 @@ func (d *Dispatcher) notifyListeners(ctx context.Context, ed *caches.EventData,
}
}
}
func (d *Dispatcher) OnInvalidateRoom(ctx context.Context, roomID string) {
// First dispatch to the global cache.
receiver, ok := d.userToReceiver[DispatcherAllUsers]
if !ok {
logger.Error().Msgf("No receiver for global cache")
}
receiver.OnInvalidateRoom(ctx, roomID)
// Then dispatch to any users who are joined to that room.
joinedUsers, _ := d.jrt.JoinedUsersForRoom(roomID, nil)
d.userToReceiverMu.RLock()
defer d.userToReceiverMu.RUnlock()
for _, userID := range joinedUsers {
receiver = d.userToReceiver[userID]
if receiver != nil {
receiver.OnInvalidateRoom(ctx, roomID)
}
}
}

View File

@ -16,8 +16,8 @@ import (
type pendingInfo struct {
// done is set to true when the EnsurePolling request received a response.
done bool
// success is true when done is true and EnsurePolling succeeded; otherwise false.
success bool
// expired is true when the token is expired. Any 'done'ness should be ignored for expired tokens.
expired bool
// ch is a dummy channel which never receives any data. A call to
// EnsurePoller.OnInitialSyncComplete will close the channel (unblocking any
// EnsurePoller.EnsurePolling calls which are waiting on it) and then set the ch
@ -58,19 +58,31 @@ func NewEnsurePoller(notifier pubsub.Notifier, enablePrometheus bool) *EnsurePol
}
// EnsurePolling blocks until the V2InitialSyncComplete response is received for this device. It is
// the caller's responsibility to call OnInitialSyncComplete when new events arrive. Returns the Success field from the
// V2InitialSyncComplete response, which is true iff there is an active poller.
// the caller's responsibility to call OnInitialSyncComplete when new events arrive. Returns whether
// or not the token is expired
func (p *EnsurePoller) EnsurePolling(ctx context.Context, pid sync2.PollerID, tokenHash string) bool {
ctx, region := internal.StartSpan(ctx, "EnsurePolling")
defer region.End()
p.mu.Lock()
// do we need to wait?
if p.pendingPolls[pid].done {
// do we need to wait? Expired devices ALWAYS need a fresh poll
expired := p.pendingPolls[pid].expired
if !expired && p.pendingPolls[pid].done {
internal.Logf(ctx, "EnsurePolling", "user %s device %s already done", pid.UserID, pid.DeviceID)
success := p.pendingPolls[pid].success
p.mu.Unlock()
return success
return expired // always false
}
// either we have expired or we haven't expired and are still waiting on the initial sync.
// If we have expired, nuke the pid from the map now so we will use the same code path as a fresh device sync
if expired {
// close any existing channel
if p.pendingPolls[pid].ch != nil {
close(p.pendingPolls[pid].ch)
}
delete(p.pendingPolls, pid)
// at this point, ch == nil so we will do an initial sync
}
// have we called EnsurePolling for this user/device before?
ch := p.pendingPolls[pid].ch
if ch != nil {
@ -84,9 +96,9 @@ func (p *EnsurePoller) EnsurePolling(ctx context.Context, pid sync2.PollerID, to
<-ch
r2.End()
p.mu.Lock()
success := p.pendingPolls[pid].success
expired := p.pendingPolls[pid].expired
p.mu.Unlock()
return success
return expired
}
// Make a channel to wait until we have done an initial sync
ch = make(chan struct{})
@ -110,9 +122,9 @@ func (p *EnsurePoller) EnsurePolling(ctx context.Context, pid sync2.PollerID, to
r2.End()
p.mu.Lock()
success := p.pendingPolls[pid].success
expired = p.pendingPolls[pid].expired
p.mu.Unlock()
return success
return expired
}
func (p *EnsurePoller) OnInitialSyncComplete(payload *pubsub.V2InitialSyncComplete) {
@ -129,7 +141,7 @@ func (p *EnsurePoller) OnInitialSyncComplete(payload *pubsub.V2InitialSyncComple
log.Trace().Msg("OnInitialSyncComplete: we weren't waiting for this")
p.pendingPolls[pid] = pendingInfo{
done: true,
success: payload.Success,
expired: !payload.Success,
}
return
}
@ -143,7 +155,11 @@ func (p *EnsurePoller) OnInitialSyncComplete(payload *pubsub.V2InitialSyncComple
ch := pending.ch
pending.done = true
pending.ch = nil
pending.success = payload.Success
// If for whatever reason we get OnExpiredToken prior to OnInitialSyncComplete, don't forget that
// we expired the token i.e expiry latches true.
if !pending.expired {
pending.expired = !payload.Success
}
p.pendingPolls[pid] = pending
p.calculateNumOutstanding() // decrement total
log.Trace().Msg("OnInitialSyncComplete: closing channel")
@ -159,10 +175,12 @@ func (p *EnsurePoller) OnExpiredToken(payload *pubsub.V2ExpiredToken) {
// We weren't tracking the state of this poller, so we have nothing to clean up.
return
}
if pending.ch != nil {
close(pending.ch)
}
delete(p.pendingPolls, pid)
pending.expired = true
p.pendingPolls[pid] = pending
// We used to delete the entry from the map at this point to force the next
// EnsurePolling call to do a fresh EnsurePolling request, but now we do that
// by signalling via the expired flag.
}
func (p *EnsurePoller) Teardown() {

View File

@ -0,0 +1,202 @@
package handler
import (
"context"
"reflect"
"sync/atomic"
"testing"
"time"
"github.com/matrix-org/sliding-sync/pubsub"
"github.com/matrix-org/sliding-sync/sync2"
)
type mockNotifier struct {
ch chan pubsub.Payload
}
func (n *mockNotifier) Notify(chanName string, p pubsub.Payload) error {
n.ch <- p
return nil
}
func (n *mockNotifier) Close() error {
return nil
}
func (n *mockNotifier) MustHaveNoSentPayloads(t *testing.T) {
t.Helper()
if len(n.ch) == 0 {
return
}
t.Fatalf("MustHaveNoSentPayloads: %d in buffer", len(n.ch))
}
func (n *mockNotifier) WaitForNextPayload(t *testing.T, timeout time.Duration) pubsub.Payload {
t.Helper()
select {
case p := <-n.ch:
return p
case <-time.After(timeout):
t.Fatalf("WaitForNextPayload: timed out after %v", timeout)
}
panic("unreachable")
}
// check that the request/response works and unblocks things correctly
func TestEnsurePollerBasicWorks(t *testing.T) {
n := &mockNotifier{ch: make(chan pubsub.Payload, 100)}
ctx := context.Background()
pid := sync2.PollerID{UserID: "@alice:localhost", DeviceID: "DEVICE"}
tokHash := "tokenHash"
ep := NewEnsurePoller(n, false)
var expired atomic.Bool
finished := make(chan bool) // dummy
go func() {
exp := ep.EnsurePolling(ctx, pid, tokHash)
expired.Store(exp)
close(finished)
}()
p := n.WaitForNextPayload(t, time.Second)
// check it's a V3EnsurePolling payload
pp, ok := p.(*pubsub.V3EnsurePolling)
if !ok {
t.Fatalf("unexpected payload: %+v", p)
}
assertVal(t, pp.UserID, pid.UserID)
assertVal(t, pp.DeviceID, pid.DeviceID)
assertVal(t, pp.AccessTokenHash, tokHash)
// make sure we're still waiting
select {
case <-finished:
t.Fatalf("EnsurePolling unblocked before response was sent")
default:
}
// send back the response
ep.OnInitialSyncComplete(&pubsub.V2InitialSyncComplete{
UserID: pid.UserID,
DeviceID: pid.DeviceID,
Success: true,
})
select {
case <-finished:
case <-time.After(time.Second):
t.Fatalf("EnsurePolling didn't unblock after response was sent")
}
if expired.Load() {
t.Fatalf("response said token was expired when it wasn't")
}
}
func TestEnsurePollerCachesResponses(t *testing.T) {
n := &mockNotifier{ch: make(chan pubsub.Payload, 100)}
ctx := context.Background()
pid := sync2.PollerID{UserID: "@alice:localhost", DeviceID: "DEVICE"}
ep := NewEnsurePoller(n, false)
finished := make(chan bool) // dummy
go func() {
_ = ep.EnsurePolling(ctx, pid, "tokenHash")
close(finished)
}()
n.WaitForNextPayload(t, time.Second) // wait for V3EnsurePolling
// send back the response
ep.OnInitialSyncComplete(&pubsub.V2InitialSyncComplete{
UserID: pid.UserID,
DeviceID: pid.DeviceID,
Success: true,
})
select {
case <-finished:
case <-time.After(time.Second):
t.Fatalf("EnsurePolling didn't unblock after response was sent")
}
// hitting EnsurePolling again should immediately return
exp := ep.EnsurePolling(ctx, pid, "tokenHash")
if exp {
t.Fatalf("EnsurePolling said token was expired when it wasn't")
}
n.MustHaveNoSentPayloads(t)
}
// Regression test for when we did cache failures, causing no poller to start for the device
func TestEnsurePollerDoesntCacheFailures(t *testing.T) {
n := &mockNotifier{ch: make(chan pubsub.Payload, 100)}
ctx := context.Background()
pid := sync2.PollerID{UserID: "@alice:localhost", DeviceID: "DEVICE"}
ep := NewEnsurePoller(n, false)
finished := make(chan bool) // dummy
var expired atomic.Bool
go func() {
exp := ep.EnsurePolling(ctx, pid, "tokenHash")
expired.Store(exp)
close(finished)
}()
n.WaitForNextPayload(t, time.Second) // wait for V3EnsurePolling
// send back the response, which failed
ep.OnInitialSyncComplete(&pubsub.V2InitialSyncComplete{
UserID: pid.UserID,
DeviceID: pid.DeviceID,
Success: false,
})
select {
case <-finished:
case <-time.After(time.Second):
t.Fatalf("EnsurePolling didn't unblock after response was sent")
}
if !expired.Load() {
t.Fatalf("EnsurePolling returned not expired, wanted expired due to Success=false")
}
// hitting EnsurePolling again should do a new request (i.e not cached the failure)
var expiredAgain atomic.Bool
finished = make(chan bool) // dummy
go func() {
exp := ep.EnsurePolling(ctx, pid, "tokenHash")
expiredAgain.Store(exp)
close(finished)
}()
p := n.WaitForNextPayload(t, time.Second) // wait for V3EnsurePolling
// check it's a V3EnsurePolling payload
pp, ok := p.(*pubsub.V3EnsurePolling)
if !ok {
t.Fatalf("unexpected payload: %+v", p)
}
assertVal(t, pp.UserID, pid.UserID)
assertVal(t, pp.DeviceID, pid.DeviceID)
assertVal(t, pp.AccessTokenHash, "tokenHash")
// send back the response, which succeeded this time
ep.OnInitialSyncComplete(&pubsub.V2InitialSyncComplete{
UserID: pid.UserID,
DeviceID: pid.DeviceID,
Success: true,
})
select {
case <-finished:
case <-time.After(time.Second):
t.Fatalf("EnsurePolling didn't unblock after response was sent")
}
}
func assertVal(t *testing.T, got, want interface{}) {
t.Helper()
if !reflect.DeepEqual(got, want) {
t.Fatalf("assertVal: got %v want %v", got, want)
}
}

View File

@ -396,8 +396,8 @@ func (h *SyncLiveHandler) setupConnection(req *http.Request, syncReq *sync3.Requ
pid := sync2.PollerID{UserID: token.UserID, DeviceID: token.DeviceID}
log.Trace().Any("pid", pid).Msg("checking poller exists and is running")
success := h.EnsurePoller.EnsurePolling(req.Context(), pid, token.AccessTokenHash)
if !success {
expiredToken := h.EnsurePoller.EnsurePolling(req.Context(), pid, token.AccessTokenHash)
if expiredToken {
log.Error().Msg("EnsurePolling failed, returning 401")
// Assumption: the only way that EnsurePolling fails is if the access token is invalid.
return req, nil, &internal.HandlerError{
@ -803,6 +803,13 @@ func (h *SyncLiveHandler) OnExpiredToken(p *pubsub.V2ExpiredToken) {
h.ConnMap.CloseConnsForDevice(p.UserID, p.DeviceID)
}
func (h *SyncLiveHandler) OnInvalidateRoom(p *pubsub.V2InvalidateRoom) {
ctx, task := internal.StartTask(context.Background(), "OnInvalidateRoom")
defer task.End()
h.Dispatcher.OnInvalidateRoom(ctx, p.RoomID)
}
func parseIntFromQuery(u *url.URL, param string) (result int64, err *internal.HandlerError) {
queryPos := u.Query().Get(param)
if queryPos != "" {

View File

@ -132,6 +132,7 @@ type SyncReq struct {
type CSAPI struct {
UserID string
Localpart string
Domain string
AccessToken string
DeviceID string
AvatarURL string
@ -160,6 +161,16 @@ func (c *CSAPI) UploadContent(t *testing.T, fileBody []byte, fileName string, co
return GetJSONFieldStr(t, body, "content_uri")
}
// Use an empty string to remove a custom displayname.
func (c *CSAPI) SetDisplayname(t *testing.T, name string) {
t.Helper()
reqBody := map[string]any{}
if name != "" {
reqBody["displayname"] = name
}
c.MustDoFunc(t, "PUT", []string{"_matrix", "client", "v3", "profile", c.UserID, "displayname"}, WithJSONBody(t, reqBody))
}
// Use an empty string to remove your avatar.
func (c *CSAPI) SetAvatar(t *testing.T, avatarURL string) {
t.Helper()
@ -184,9 +195,9 @@ func (c *CSAPI) DownloadContent(t *testing.T, mxcUri string) ([]byte, string) {
}
// CreateRoom creates a room with an optional HTTP request body. Fails the test on error. Returns the room ID.
func (c *CSAPI) CreateRoom(t *testing.T, creationContent interface{}) string {
func (c *CSAPI) CreateRoom(t *testing.T, reqBody map[string]any) string {
t.Helper()
res := c.MustDo(t, "POST", []string{"_matrix", "client", "v3", "createRoom"}, creationContent)
res := c.MustDo(t, "POST", []string{"_matrix", "client", "v3", "createRoom"}, reqBody)
body := ParseJSON(t, res)
return GetJSONFieldStr(t, body, "room_id")
}

View File

@ -220,7 +220,9 @@ func registerNamedUser(t *testing.T, localpartPrefix string) *CSAPI {
}
client.UserID, client.AccessToken, client.DeviceID = client.RegisterUser(t, localpart, "password")
client.Localpart = strings.Split(client.UserID, ":")[0][1:]
parts := strings.Split(client.UserID, ":")
client.Localpart = parts[0][1:]
client.Domain = strings.Split(client.UserID, ":")[1]
return client
}

View File

@ -1,7 +1,9 @@
package syncv3_test
import (
"fmt"
"testing"
"time"
"github.com/matrix-org/sliding-sync/sync3"
"github.com/matrix-org/sliding-sync/testutils/m"
@ -75,3 +77,97 @@ func TestRedactionsAreRedactedWherePossible(t *testing.T) {
}
}
func TestRedactingRoomStateIsReflectedInNextSync(t *testing.T) {
alice := registerNamedUser(t, "alice")
bob := registerNamedUser(t, "bob")
t.Log("Alice creates a room, then sets a room alias and name.")
room := alice.CreateRoom(t, map[string]any{
"preset": "public_chat",
})
alias := fmt.Sprintf("#%s-%d:%s", t.Name(), time.Now().Unix(), alice.Domain)
alice.MustDoFunc(t, "PUT", []string{"_matrix", "client", "v3", "directory", "room", alias},
WithJSONBody(t, map[string]any{"room_id": room}),
)
aliasID := alice.SetState(t, room, "m.room.canonical_alias", "", map[string]any{
"alias": alias,
})
const naughty = "naughty room for naughty people"
nameID := alice.SetState(t, room, "m.room.name", "", map[string]any{
"name": naughty,
})
t.Log("Alice sliding syncs, subscribing to that room explicitly.")
res := alice.SlidingSync(t, sync3.Request{
RoomSubscriptions: map[string]sync3.RoomSubscription{
room: {
TimelineLimit: 20,
},
},
})
t.Log("Alice should see her room appear with its name.")
m.MatchResponse(t, res, m.MatchRoomSubscription(room, m.MatchRoomName(naughty)))
t.Log("Alice redacts the room name.")
redactionID := alice.RedactEvent(t, room, nameID)
t.Log("Alice syncs until she sees her redaction.")
res = alice.SlidingSyncUntil(t, res.Pos, sync3.Request{}, m.MatchRoomSubscription(
room,
MatchRoomTimelineMostRecent(1, []Event{{ID: redactionID}}),
))
t.Log("The room name should have been redacted, falling back to the canonical alias.")
m.MatchResponse(t, res, m.MatchRoomSubscription(room, m.MatchRoomName(alias)))
t.Log("Alice sets a room avatar.")
avatarURL := alice.UploadContent(t, smallPNG, "avatar.png", "image/png")
avatarID := alice.SetState(t, room, "m.room.avatar", "", map[string]interface{}{
"url": avatarURL,
})
t.Log("Alice waits to see the avatar.")
res = alice.SlidingSyncUntil(t, res.Pos, sync3.Request{}, m.MatchRoomSubscription(room, m.MatchRoomAvatar(avatarURL)))
t.Log("Alice redacts the avatar.")
redactionID = alice.RedactEvent(t, room, avatarID)
t.Log("Alice sees the avatar revert to blank.")
res = alice.SlidingSyncUntil(t, res.Pos, sync3.Request{}, m.MatchRoomSubscription(room, m.MatchRoomUnsetAvatar()))
t.Log("Bob joins the room, with a custom displayname.")
const bobDisplayName = "bob mortimer"
bob.SetDisplayname(t, bobDisplayName)
bob.JoinRoom(t, room, nil)
t.Log("Alice sees Bob join.")
res = alice.SlidingSyncUntil(t, res.Pos, sync3.Request{}, m.MatchRoomSubscription(room,
MatchRoomTimelineMostRecent(1, []Event{{
StateKey: ptr(bob.UserID),
Type: "m.room.member",
Content: map[string]any{
"membership": "join",
"displayname": bobDisplayName,
},
}}),
))
// Extract Bob's join ID because https://github.com/matrix-org/matrix-spec-proposals/pull/2943 doens't exist grrr
timeline := res.Rooms[room].Timeline
bobJoinID := gjson.GetBytes(timeline[len(timeline)-1], "event_id").Str
t.Log("Alice redacts the alias.")
redactionID = alice.RedactEvent(t, room, aliasID)
t.Log("Alice sees the room name reset to Bob's display name.")
res = alice.SlidingSyncUntil(t, res.Pos, sync3.Request{}, m.MatchRoomSubscription(room, m.MatchRoomName(bobDisplayName)))
t.Log("Bob redacts his membership")
redactionID = bob.RedactEvent(t, room, bobJoinID)
t.Log("Alice sees the room name reset to Bob's username.")
res = alice.SlidingSyncUntil(t, res.Pos, sync3.Request{}, m.MatchRoomSubscription(room, m.MatchRoomName(bob.UserID)))
}

View File

@ -6,6 +6,7 @@ import (
"fmt"
"net/http"
"os"
"sync/atomic"
"testing"
"time"
@ -462,3 +463,75 @@ func TestPollerExpiryEnsurePollingRace(t *testing.T) {
t.Fatalf("Should have got 401 http response; got %d\n%s", status, resBytes)
}
}
// Regression test for the bugfix for https://github.com/matrix-org/sliding-sync/issues/287#issuecomment-1706522718
// Specifically, we could cache the failure and never tell the poller about new tokens, wedging the client(!). This
// seems to have been due to the following:
// - client hits sync for the first time. We /whoami and remember the token->user mapping in TokensTable.
// - client syncing + poller syncing, everything happy.
// - token expires. OnExpiredToken is sent to EnsurePoller which removes the entry from EnsurePoller and nukes the conns.
// - client hits sync, gets 400 M_UNKNOWN_POS due to nuked conns.
// - client hits a fresh /sync: for whatever reason, the token is NOT 401d there and then by the /whoami lookup failing.
// Maybe failed to remove the token, but don't see any logs to suggest this. Seems to be an OIDC thing.
// - EnsurePoller starts a poller, which immediately 401s as the token is expired.
// - OnExpiredToken is sent first, which removes the entry in EnsurePoller.
// - OnInitialSyncComplete[success=false] is sent after, which MAKES A NEW ENTRY with success=false.
// - proxy sends back 401 M_UNKNOWN_TOKEN.
// - At this point the proxy is wedged. Any token, no matter how valid they are, will not hit EnsurePoller because
// we cached success=false for that (user,device).
//
// Traceable in the logs which show spam of this log line without "Poller: v2 poll loop started" interleaved.
//
// 12:45:33 ERR EnsurePolling failed, returning 401 conn=encryption device=xx user=@xx:xx.xx
//
// To test this failure mode we:
// - Create Alice and sync her poller.
// - Expire her token immediately, just like the test TestPollerExpiryEnsurePollingRace
// - Do another request with a valid new token, this should succeed.
func TestPollerExpiryEnsurePollingRaceDoesntWedge(t *testing.T) {
newToken := "NEW_ALICE_TOKEN"
pqString := testutils.PrepareDBConnectionString()
v2 := runTestV2Server(t)
defer v2.close()
v3 := runTestServer(t, v2, pqString)
defer v3.close()
v2.addAccount(t, alice, aliceToken)
// Arrange the following:
// 1. A request arrives from an unknown token.
// 2. The API makes a /whoami lookup for the new token. That returns without error.
// 3. The old token expires.
// 4. The poller tries to call /sync but finds that the token has expired.
// NEW 5. Using a "new token" works.
var gotNewToken atomic.Bool
v2.SetCheckRequest(func(token string, req *http.Request) {
if token == newToken {
t.Log("recv new token")
gotNewToken.Store(true)
return
}
if token != aliceToken {
t.Fatalf("unexpected poll from %s", token)
}
// Expire the token before we process the request.
t.Log("Alice's token expires.")
v2.invalidateTokenImmediately(token)
})
t.Log("Alice makes a sliding sync request with a token that's about to expire.")
_, resBytes, status := v3.doV3Request(t, context.Background(), aliceToken, "", sync3.Request{})
if status != http.StatusUnauthorized {
t.Fatalf("Should have got 401 http response; got %d\n%s", status, resBytes)
}
// make a new token and use it
v2.addAccount(t, alice, newToken)
_, resBytes, status = v3.doV3Request(t, context.Background(), newToken, "", sync3.Request{})
if status != http.StatusOK {
t.Fatalf("Should have got 200 http response; got %d\n%s", status, resBytes)
}
if !gotNewToken.Load() {
t.Fatalf("never saw a v2 poll with the new token")
}
}