Merge pull request #296 from matrix-org/dmr/cache-invalidation

This commit is contained in:
David Robertson 2023-09-14 11:08:44 +01:00 committed by GitHub
commit 5b32cc44c2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 417 additions and 60 deletions

View File

@ -24,6 +24,7 @@ type V2Listener interface {
OnReceipt(p *V2Receipt)
OnDeviceMessages(p *V2DeviceMessages)
OnExpiredToken(p *V2ExpiredToken)
OnInvalidateRoom(p *V2InvalidateRoom)
}
type V2Initialise struct {
@ -129,6 +130,12 @@ type V2ExpiredToken struct {
func (*V2ExpiredToken) Type() string { return "V2ExpiredToken" }
type V2InvalidateRoom struct {
RoomID string
}
func (*V2InvalidateRoom) Type() string { return "V2InvalidateRoom" }
type V2Sub struct {
listener Listener
receiver V2Listener
@ -173,6 +180,8 @@ func (v *V2Sub) onMessage(p Payload) {
v.receiver.OnDeviceMessages(pl)
case *V2ExpiredToken:
v.receiver.OnExpiredToken(pl)
case *V2InvalidateRoom:
v.receiver.OnInvalidateRoom(pl)
default:
logger.Warn().Str("type", p.Type()).Msg("V2Sub: unhandled payload type")
}

View File

@ -317,6 +317,19 @@ func (a *Accumulator) Initialise(roomID string, state []json.RawMessage) (Initia
return res, err
}
type AccumulateResult struct {
// NumNew is the number of events in timeline NIDs that were not previously known
// to the proyx.
NumNew int
// TimelineNIDs is the list of event nids seen in a sync v2 timeline. Some of these
// may already be known to the proxy.
TimelineNIDs []int64
// RequiresReload is set to true when we have accumulated a non-incremental state
// change (typically a redaction) that requires consumers to reload the room state
// from the latest snapshot.
RequiresReload bool
}
// Accumulate internal state from a user's sync response. The timeline order MUST be in the order
// received from the server. Returns the number of new events in the timeline, the new timeline event NIDs
// or an error.
@ -328,7 +341,7 @@ func (a *Accumulator) Initialise(roomID string, state []json.RawMessage) (Initia
// to exist in the database, and the sync stream is already linearised for us.
// - Else it creates a new room state snapshot if the timeline contains state events (as this now represents the current state)
// - It adds entries to the membership log for membership events.
func (a *Accumulator) Accumulate(txn *sqlx.Tx, userID, roomID string, prevBatch string, timeline []json.RawMessage) (numNew int, timelineNIDs []int64, err error) {
func (a *Accumulator) Accumulate(txn *sqlx.Tx, userID, roomID string, prevBatch string, timeline []json.RawMessage) (AccumulateResult, error) {
// The first stage of accumulating events is mostly around validation around what the upstream HS sends us. For accumulation to work correctly
// we expect:
// - there to be no duplicate events
@ -337,10 +350,10 @@ func (a *Accumulator) Accumulate(txn *sqlx.Tx, userID, roomID string, prevBatch
dedupedEvents, err := a.filterAndParseTimelineEvents(txn, roomID, timeline, prevBatch)
if err != nil {
err = fmt.Errorf("filterTimelineEvents: %w", err)
return
return AccumulateResult{}, err
}
if len(dedupedEvents) == 0 {
return 0, nil, err // nothing to do
return AccumulateResult{}, nil // nothing to do
}
// Given a timeline of [E1, E2, S3, E4, S5, S6, E7] (E=message event, S=state event)
@ -354,7 +367,7 @@ func (a *Accumulator) Accumulate(txn *sqlx.Tx, userID, roomID string, prevBatch
// do NOT assign the new state event in the snapshot so as to represent the state before the event.
snapID, err := a.roomsTable.CurrentAfterSnapshotID(txn, roomID)
if err != nil {
return 0, nil, err
return AccumulateResult{}, err
}
// The only situation where no prior snapshot should exist is if this timeline is
@ -389,19 +402,22 @@ func (a *Accumulator) Accumulate(txn *sqlx.Tx, userID, roomID string, prevBatch
})
sentry.CaptureMessage(msg)
})
return 0, nil, nil
return AccumulateResult{}, nil
}
}
eventIDToNID, err := a.eventsTable.Insert(txn, dedupedEvents, false)
if err != nil {
return 0, nil, err
return AccumulateResult{}, err
}
if len(eventIDToNID) == 0 {
// nothing to do, we already know about these events
return 0, nil, nil
return AccumulateResult{}, nil
}
result := AccumulateResult{
NumNew: len(eventIDToNID),
}
numNew = len(eventIDToNID)
var latestNID int64
newEvents := make([]Event, 0, len(eventIDToNID))
@ -433,7 +449,7 @@ func (a *Accumulator) Accumulate(txn *sqlx.Tx, userID, roomID string, prevBatch
}
}
newEvents = append(newEvents, ev)
timelineNIDs = append(timelineNIDs, ev.NID)
result.TimelineNIDs = append(result.TimelineNIDs, ev.NID)
}
}
@ -443,7 +459,7 @@ func (a *Accumulator) Accumulate(txn *sqlx.Tx, userID, roomID string, prevBatch
if len(redactTheseEventIDs) > 0 {
createEventJSON, err := a.eventsTable.SelectCreateEvent(txn, roomID)
if err != nil {
return 0, nil, fmt.Errorf("SelectCreateEvent: %w", err)
return AccumulateResult{}, fmt.Errorf("SelectCreateEvent: %w", err)
}
roomVersion = gjson.GetBytes(createEventJSON, "content.room_version").Str
if roomVersion == "" {
@ -454,7 +470,7 @@ func (a *Accumulator) Accumulate(txn *sqlx.Tx, userID, roomID string, prevBatch
)
}
if err = a.eventsTable.Redact(txn, roomVersion, redactTheseEventIDs); err != nil {
return 0, nil, err
return AccumulateResult{}, err
}
}
@ -470,12 +486,12 @@ func (a *Accumulator) Accumulate(txn *sqlx.Tx, userID, roomID string, prevBatch
if snapID != 0 {
oldStripped, err = a.strippedEventsForSnapshot(txn, snapID)
if err != nil {
return 0, nil, fmt.Errorf("failed to load stripped state events for snapshot %d: %s", snapID, err)
return AccumulateResult{}, fmt.Errorf("failed to load stripped state events for snapshot %d: %s", snapID, err)
}
}
newStripped, replacedNID, err := a.calculateNewSnapshot(oldStripped, ev)
if err != nil {
return 0, nil, fmt.Errorf("failed to calculateNewSnapshot: %s", err)
return AccumulateResult{}, fmt.Errorf("failed to calculateNewSnapshot: %s", err)
}
replacesNID = replacedNID
memNIDs, otherNIDs := newStripped.NIDs()
@ -485,7 +501,7 @@ func (a *Accumulator) Accumulate(txn *sqlx.Tx, userID, roomID string, prevBatch
OtherEvents: otherNIDs,
}
if err = a.snapshotTable.Insert(txn, newSnapshot); err != nil {
return 0, nil, fmt.Errorf("failed to insert new snapshot: %w", err)
return AccumulateResult{}, fmt.Errorf("failed to insert new snapshot: %w", err)
}
if a.snapshotMemberCountVec != nil {
logger.Trace().Str("room_id", roomID).Int("members", len(memNIDs)).Msg("Inserted new snapshot")
@ -494,20 +510,40 @@ func (a *Accumulator) Accumulate(txn *sqlx.Tx, userID, roomID string, prevBatch
snapID = newSnapshot.SnapshotID
}
if err := a.eventsTable.UpdateBeforeSnapshotID(txn, ev.NID, beforeSnapID, replacesNID); err != nil {
return 0, nil, err
return AccumulateResult{}, err
}
}
if len(redactTheseEventIDs) > 0 {
// We need to emit a cache invalidation if we have redacted some state in the
// current snapshot ID. Note that we run this _after_ persisting any new snapshots.
redactedEventIDs := make([]string, 0, len(redactTheseEventIDs))
for eventID := range redactTheseEventIDs {
redactedEventIDs = append(redactedEventIDs, eventID)
}
var currentStateRedactions int
err = txn.Get(&currentStateRedactions, `
SELECT COUNT(*)
FROM syncv3_events
JOIN syncv3_snapshots ON event_nid = ANY (ARRAY_CAT(events, membership_events))
WHERE snapshot_id = $1 AND event_id = ANY($2)
`, snapID, pq.StringArray(redactedEventIDs))
if err != nil {
return AccumulateResult{}, err
}
result.RequiresReload = currentStateRedactions > 0
}
if err = a.spacesTable.HandleSpaceUpdates(txn, newEvents); err != nil {
return 0, nil, fmt.Errorf("HandleSpaceUpdates: %s", err)
return AccumulateResult{}, fmt.Errorf("HandleSpaceUpdates: %s", err)
}
// the last fetched snapshot ID is the current one
info := a.roomInfoDelta(roomID, newEvents)
if err = a.roomsTable.Upsert(txn, info, snapID, latestNID); err != nil {
return 0, nil, fmt.Errorf("failed to UpdateCurrentSnapshotID to %d: %w", snapID, err)
return AccumulateResult{}, fmt.Errorf("failed to UpdateCurrentSnapshotID to %d: %w", snapID, err)
}
return numNew, timelineNIDs, nil
return result, nil
}
// filterAndParseTimelineEvents takes a raw timeline array from sync v2 and applies sanity to it:

View File

@ -3,6 +3,7 @@ package state
import (
"context"
"encoding/json"
"fmt"
"github.com/matrix-org/sliding-sync/testutils"
"reflect"
"sort"
@ -135,25 +136,24 @@ func TestAccumulatorAccumulate(t *testing.T) {
// new state event should be added to the snapshot
[]byte(`{"event_id":"I", "type":"m.room.history_visibility", "state_key":"", "content":{"visibility":"public"}}`),
}
var numNew int
var latestNIDs []int64
var result AccumulateResult
err = sqlutil.WithTransaction(accumulator.db, func(txn *sqlx.Tx) error {
numNew, latestNIDs, err = accumulator.Accumulate(txn, userID, roomID, "", newEvents)
result, err = accumulator.Accumulate(txn, userID, roomID, "", newEvents)
return err
})
if err != nil {
t.Fatalf("failed to Accumulate: %s", err)
}
if numNew != len(newEvents) {
t.Fatalf("got %d new events, want %d", numNew, len(newEvents))
if result.NumNew != len(newEvents) {
t.Fatalf("got %d new events, want %d", result.NumNew, len(newEvents))
}
// latest nid shoould match
wantLatestNID, err := accumulator.eventsTable.SelectHighestNID()
if err != nil {
t.Fatalf("failed to check latest NID from Accumulate: %s", err)
}
if latestNIDs[len(latestNIDs)-1] != wantLatestNID {
t.Errorf("Accumulator.Accumulate returned latest nid %d, want %d", latestNIDs[len(latestNIDs)-1], wantLatestNID)
if result.TimelineNIDs[len(result.TimelineNIDs)-1] != wantLatestNID {
t.Errorf("Accumulator.Accumulate returned latest nid %d, want %d", result.TimelineNIDs[len(result.TimelineNIDs)-1], wantLatestNID)
}
// Begin assertions
@ -212,7 +212,7 @@ func TestAccumulatorAccumulate(t *testing.T) {
// subsequent calls do nothing and are not an error
err = sqlutil.WithTransaction(accumulator.db, func(txn *sqlx.Tx) error {
_, _, err = accumulator.Accumulate(txn, userID, roomID, "", newEvents)
_, err = accumulator.Accumulate(txn, userID, roomID, "", newEvents)
return err
})
if err != nil {
@ -220,6 +220,80 @@ func TestAccumulatorAccumulate(t *testing.T) {
}
}
func TestAccumulatorPromptsCacheInvalidation(t *testing.T) {
db, close := connectToDB(t)
defer close()
accumulator := NewAccumulator(db)
t.Log("Initialise the room state, including a room name.")
roomID := fmt.Sprintf("!%s:localhost", t.Name())
stateBlock := []json.RawMessage{
[]byte(`{"event_id":"$a", "type":"m.room.create", "state_key":"", "content":{"creator":"@me:localhost", "room_version": "10"}}`),
[]byte(`{"event_id":"$b", "type":"m.room.member", "state_key":"@me:localhost", "content":{"membership":"join"}}`),
[]byte(`{"event_id":"$c", "type":"m.room.join_rules", "state_key":"", "content":{"join_rule":"public"}}`),
[]byte(`{"event_id":"$d", "type":"m.room.name", "state_key":"", "content":{"name":"Barry Cryer Appreciation Society"}}`),
}
_, err := accumulator.Initialise(roomID, stateBlock)
if err != nil {
t.Fatalf("failed to Initialise accumulator: %s", err)
}
t.Log("Accumulate a second room name, a message, then a third room name.")
timeline := []json.RawMessage{
[]byte(`{"event_id":"$e", "type":"m.room.name", "state_key":"", "content":{"name":"Jeremy Hardy Appreciation Society"}}`),
[]byte(`{"event_id":"$f", "type":"m.room.message", "content": {"body":"Hello, world!", "msgtype":"m.text"}}`),
[]byte(`{"event_id":"$g", "type":"m.room.name", "state_key":"", "content":{"name":"Humphrey Lyttelton Appreciation Society"}}`),
}
var accResult AccumulateResult
err = sqlutil.WithTransaction(db, func(txn *sqlx.Tx) error {
accResult, err = accumulator.Accumulate(txn, "@dummy:localhost", roomID, "prevBatch", timeline)
return err
})
if err != nil {
t.Fatalf("Failed to Accumulate: %s", err)
}
t.Log("We expect 3 new events and no reload required.")
assertValue(t, "accResult.NumNew", accResult.NumNew, 3)
assertValue(t, "len(accResult.TimelineNIDs)", len(accResult.TimelineNIDs), 3)
assertValue(t, "accResult.RequiresReload", accResult.RequiresReload, false)
t.Log("Redact the old state event and the message.")
timeline = []json.RawMessage{
[]byte(`{"event_id":"$h", "type":"m.room.redaction", "content":{"redacts":"$e"}}`),
[]byte(`{"event_id":"$i", "type":"m.room.redaction", "content":{"redacts":"$f"}}`),
}
err = sqlutil.WithTransaction(db, func(txn *sqlx.Tx) error {
accResult, err = accumulator.Accumulate(txn, "@dummy:localhost", roomID, "prevBatch2", timeline)
return err
})
if err != nil {
t.Fatalf("Failed to Accumulate: %s", err)
}
t.Log("We expect 2 new events and no reload required.")
assertValue(t, "accResult.NumNew", accResult.NumNew, 2)
assertValue(t, "len(accResult.TimelineNIDs)", len(accResult.TimelineNIDs), 2)
assertValue(t, "accResult.RequiresReload", accResult.RequiresReload, false)
t.Log("Redact the latest state event.")
timeline = []json.RawMessage{
[]byte(`{"event_id":"$j", "type":"m.room.redaction", "content":{"redacts":"$g"}}`),
}
err = sqlutil.WithTransaction(db, func(txn *sqlx.Tx) error {
accResult, err = accumulator.Accumulate(txn, "@dummy:localhost", roomID, "prevBatch3", timeline)
return err
})
if err != nil {
t.Fatalf("Failed to Accumulate: %s", err)
}
t.Log("We expect 1 new event and a reload required.")
assertValue(t, "accResult.NumNew", accResult.NumNew, 1)
assertValue(t, "len(accResult.TimelineNIDs)", len(accResult.TimelineNIDs), 1)
assertValue(t, "accResult.RequiresReload", accResult.RequiresReload, true)
}
func TestAccumulatorMembershipLogs(t *testing.T) {
roomID := "!TestAccumulatorMembershipLogs:localhost"
db, close := connectToDB(t)
@ -248,7 +322,7 @@ func TestAccumulatorMembershipLogs(t *testing.T) {
[]byte(`{"event_id":"` + roomEventIDs[7] + `", "type":"m.room.member", "state_key":"@me:localhost","unsigned":{"prev_content":{"membership":"join", "displayname":"Me"}}, "content":{"membership":"leave"}}`),
}
err = sqlutil.WithTransaction(accumulator.db, func(txn *sqlx.Tx) error {
_, _, err = accumulator.Accumulate(txn, userID, roomID, "", roomEvents)
_, err = accumulator.Accumulate(txn, userID, roomID, "", roomEvents)
return err
})
if err != nil {
@ -384,7 +458,7 @@ func TestAccumulatorDupeEvents(t *testing.T) {
}
err = sqlutil.WithTransaction(accumulator.db, func(txn *sqlx.Tx) error {
_, _, err = accumulator.Accumulate(txn, userID, roomID, "", joinRoom.Timeline.Events)
_, err = accumulator.Accumulate(txn, userID, roomID, "", joinRoom.Timeline.Events)
return err
})
if err != nil {
@ -584,8 +658,8 @@ func TestAccumulatorConcurrency(t *testing.T) {
defer wg.Done()
subset := newEvents[:(i + 1)] // i=0 => [1], i=1 => [1,2], etc
err := sqlutil.WithTransaction(accumulator.db, func(txn *sqlx.Tx) error {
numNew, _, err := accumulator.Accumulate(txn, userID, roomID, "", subset)
totalNumNew += numNew
result, err := accumulator.Accumulate(txn, userID, roomID, "", subset)
totalNumNew += result.NumNew
return err
})
if err != nil {

View File

@ -306,6 +306,77 @@ func (s *Storage) MetadataForAllRooms(txn *sqlx.Tx, tempTableName string, result
return nil
}
// ResetMetadataState updates the given metadata in-place to reflect the current state
// of the room. This is only safe to call from the subscriber goroutine; it is not safe
// to call from the connection goroutines.
// TODO: could have this create a new RoomMetadata and get the caller to assign it.
func (s *Storage) ResetMetadataState(metadata *internal.RoomMetadata) error {
var events []Event
err := s.DB.Select(&events, `
WITH snapshot(events, membership_events) AS (
SELECT events, membership_events
FROM syncv3_snapshots
JOIN syncv3_rooms ON snapshot_id = current_snapshot_id
WHERE syncv3_rooms.room_id = $1
)
SELECT event_id, event_type, state_key, event, membership
FROM syncv3_events JOIN snapshot ON (
event_nid = ANY (ARRAY_CAT(events, membership_events))
)
WHERE (event_type IN ('m.room.name', 'm.room.avatar', 'm.room.canonical_alias') AND state_key = '')
OR (event_type = 'm.room.member' AND membership IN ('join', '_join', 'invite', '_invite'))
ORDER BY event_nid ASC
;`, metadata.RoomID)
if err != nil {
return fmt.Errorf("ResetMetadataState[%s]: %w", metadata.RoomID, err)
}
heroMemberships := circularSlice[*Event]{max: 6}
metadata.JoinCount = 0
metadata.InviteCount = 0
metadata.ChildSpaceRooms = make(map[string]struct{})
for i, ev := range events {
switch ev.Type {
case "m.room.name":
metadata.NameEvent = gjson.GetBytes(ev.JSON, "content.name").Str
case "m.room.avatar":
metadata.AvatarEvent = gjson.GetBytes(ev.JSON, "content.avatar_url").Str
case "m.room.canonical_alias":
metadata.CanonicalAlias = gjson.GetBytes(ev.JSON, "content.alias").Str
case "m.room.member":
heroMemberships.append(&events[i])
switch ev.Membership {
case "join":
fallthrough
case "_join":
metadata.JoinCount++
case "invite":
fallthrough
case "_invite":
metadata.InviteCount++
}
case "m.space.child":
metadata.ChildSpaceRooms[ev.StateKey] = struct{}{}
}
}
metadata.Heroes = make([]internal.Hero, 0, len(heroMemberships.vals))
for _, ev := range heroMemberships.vals {
parsed := gjson.ParseBytes(ev.JSON)
hero := internal.Hero{
ID: ev.StateKey,
Name: parsed.Get("content.displayname").Str,
Avatar: parsed.Get("content.avatar_url").Str,
}
metadata.Heroes = append(metadata.Heroes, hero)
}
// For now, don't bother reloading Encrypted, PredecessorID and UpgradedRoomID.
// These shouldn't be changing during a room's lifetime in normal operation.
return nil
}
// Returns all current NOT MEMBERSHIP state events matching the event types given in all rooms. Returns a map of
// room ID to events in that room.
func (s *Storage) currentNotMembershipStateEventsInAllRooms(txn *sqlx.Tx, eventTypes []string) (map[string][]Event, error) {
@ -336,15 +407,15 @@ func (s *Storage) currentNotMembershipStateEventsInAllRooms(txn *sqlx.Tx, eventT
return result, nil
}
func (s *Storage) Accumulate(userID, roomID, prevBatch string, timeline []json.RawMessage) (numNew int, timelineNIDs []int64, err error) {
func (s *Storage) Accumulate(userID, roomID, prevBatch string, timeline []json.RawMessage) (result AccumulateResult, err error) {
if len(timeline) == 0 {
return 0, nil, nil
return AccumulateResult{}, nil
}
err = sqlutil.WithTransaction(s.Accumulator.db, func(txn *sqlx.Tx) error {
numNew, timelineNIDs, err = s.Accumulator.Accumulate(txn, userID, roomID, prevBatch, timeline)
result, err = s.Accumulator.Accumulate(txn, userID, roomID, prevBatch, timeline)
return err
})
return
return result, err
}
func (s *Storage) Initialise(roomID string, state []json.RawMessage) (InitialiseResult, error) {
@ -818,7 +889,7 @@ func (s *Storage) AllJoinedMembers(txn *sqlx.Tx, tempTableName string) (joinedMe
defer rows.Close()
joinedMembers = make(map[string][]string)
inviteCounts := make(map[string]int)
heroNIDs := make(map[string]*circularSlice)
heroNIDs := make(map[string]*circularSlice[int64])
var stateKey string
var membership string
var roomID string
@ -829,7 +900,7 @@ func (s *Storage) AllJoinedMembers(txn *sqlx.Tx, tempTableName string) (joinedMe
}
heroes := heroNIDs[roomID]
if heroes == nil {
heroes = &circularSlice{max: 6}
heroes = &circularSlice[int64]{max: 6}
heroNIDs[roomID] = heroes
}
switch membership {
@ -982,13 +1053,13 @@ func (s *Storage) Teardown() {
// circularSlice is a slice which can be appended to which will wraparound at `max`.
// Mostly useful for lazily calculating heroes. The values returned aren't sorted.
type circularSlice struct {
type circularSlice[T any] struct {
i int
vals []int64
vals []T
max int
}
func (s *circularSlice) append(val int64) {
func (s *circularSlice[T]) append(val T) {
if len(s.vals) < s.max {
// populate up to max
s.vals = append(s.vals, val)

View File

@ -31,11 +31,11 @@ func TestStorageRoomStateBeforeAndAfterEventPosition(t *testing.T) {
testutils.NewStateEvent(t, "m.room.join_rules", "", alice, map[string]interface{}{"join_rule": "invite"}),
testutils.NewStateEvent(t, "m.room.member", bob, alice, map[string]interface{}{"membership": "invite"}),
}
_, latestNIDs, err := store.Accumulate(userID, roomID, "", events)
accResult, err := store.Accumulate(userID, roomID, "", events)
if err != nil {
t.Fatalf("Accumulate returned error: %s", err)
}
latest := latestNIDs[len(latestNIDs)-1]
latest := accResult.TimelineNIDs[len(accResult.TimelineNIDs)-1]
testCases := []struct {
name string
@ -158,14 +158,13 @@ func TestStorageJoinedRoomsAfterPosition(t *testing.T) {
},
}
var latestPos int64
var latestNIDs []int64
var err error
for roomID, eventMap := range roomIDToEventMap {
_, latestNIDs, err = store.Accumulate(userID, roomID, "", eventMap)
accResult, err := store.Accumulate(userID, roomID, "", eventMap)
if err != nil {
t.Fatalf("Accumulate on %s failed: %s", roomID, err)
}
latestPos = latestNIDs[len(latestNIDs)-1]
latestPos = accResult.TimelineNIDs[len(accResult.TimelineNIDs)-1]
}
aliceJoinTimingsByRoomID, err := store.JoinedRoomsAfterPosition(alice, latestPos)
if err != nil {
@ -351,11 +350,11 @@ func TestVisibleEventNIDsBetween(t *testing.T) {
},
}
for _, tl := range timelineInjections {
numNew, _, err := store.Accumulate(userID, tl.RoomID, "", tl.Events)
accResult, err := store.Accumulate(userID, tl.RoomID, "", tl.Events)
if err != nil {
t.Fatalf("Accumulate on %s failed: %s", tl.RoomID, err)
}
t.Logf("%s added %d new events", tl.RoomID, numNew)
t.Logf("%s added %d new events", tl.RoomID, accResult.NumNew)
}
latestPos, err := store.LatestEventNID()
if err != nil {
@ -454,11 +453,11 @@ func TestVisibleEventNIDsBetween(t *testing.T) {
t.Fatalf("LatestEventNID: %s", err)
}
for _, tl := range timelineInjections {
numNew, _, err := store.Accumulate(userID, tl.RoomID, "", tl.Events)
accResult, err := store.Accumulate(userID, tl.RoomID, "", tl.Events)
if err != nil {
t.Fatalf("Accumulate on %s failed: %s", tl.RoomID, err)
}
t.Logf("%s added %d new events", tl.RoomID, numNew)
t.Logf("%s added %d new events", tl.RoomID, accResult.NumNew)
}
latestPos, err = store.LatestEventNID()
if err != nil {
@ -534,7 +533,7 @@ func TestStorageLatestEventsInRoomsPrevBatch(t *testing.T) {
}
eventIDs := []string{}
for _, timeline := range timelines {
_, _, err = store.Accumulate(userID, roomID, timeline.prevBatch, timeline.timeline)
_, err := store.Accumulate(userID, roomID, timeline.prevBatch, timeline.timeline)
if err != nil {
t.Fatalf("failed to accumulate: %s", err)
}
@ -776,7 +775,7 @@ func TestAllJoinedMembers(t *testing.T) {
}, serialise(tc.InitMemberships)...))
assertNoError(t, err)
_, _, err = store.Accumulate(userID, roomID, "foo", serialise(tc.AccumulateMemberships))
_, err = store.Accumulate(userID, roomID, "foo", serialise(tc.AccumulateMemberships))
assertNoError(t, err)
testCases[i].RoomID = roomID // remember this for later
}
@ -855,7 +854,7 @@ func TestCircularSlice(t *testing.T) {
},
}
for _, tc := range testCases {
cs := &circularSlice{
cs := &circularSlice[int64]{
max: tc.max,
}
for _, val := range tc.appends {

View File

@ -294,19 +294,26 @@ func (h *Handler) Accumulate(ctx context.Context, userID, deviceID, roomID, prev
}
// Insert new events
numNew, latestNIDs, err := h.Store.Accumulate(userID, roomID, prevBatch, timeline)
accResult, err := h.Store.Accumulate(userID, roomID, prevBatch, timeline)
if err != nil {
logger.Err(err).Int("timeline", len(timeline)).Str("room", roomID).Msg("V2: failed to accumulate room")
internal.GetSentryHubFromContextOrDefault(ctx).CaptureException(err)
return err
}
// Consumers should reload state before processing new timeline events.
if accResult.RequiresReload {
h.v2Pub.Notify(pubsub.ChanV2, &pubsub.V2InvalidateRoom{
RoomID: roomID,
})
}
// We've updated the database. Now tell any pubsub listeners what we learned.
if numNew != 0 {
if accResult.NumNew != 0 {
h.v2Pub.Notify(pubsub.ChanV2, &pubsub.V2Accumulate{
RoomID: roomID,
PrevBatch: prevBatch,
EventNIDs: latestNIDs,
EventNIDs: accResult.TimelineNIDs,
})
}

View File

@ -385,3 +385,20 @@ func (c *GlobalCache) OnNewEvent(
}
c.roomIDToMetadata[ed.RoomID] = metadata
}
func (c *GlobalCache) OnInvalidateRoom(ctx context.Context, roomID string) {
c.roomIDToMetadataMu.Lock()
defer c.roomIDToMetadataMu.Unlock()
metadata, ok := c.roomIDToMetadata[roomID]
if !ok {
logger.Warn().Str("room_id", roomID).Msg("OnInvalidateRoom: room not in global cache")
return
}
err := c.store.ResetMetadataState(metadata)
if err != nil {
internal.GetSentryHubFromContextOrDefault(ctx).CaptureException(err)
logger.Warn().Err(err).Msg("OnInvalidateRoom: failed to reset metadata")
}
}

View File

@ -38,16 +38,16 @@ func TestGlobalCacheLoadState(t *testing.T) {
testutils.NewStateEvent(t, "m.room.name", "", alice, map[string]interface{}{"name": "The Room Name"}),
testutils.NewStateEvent(t, "m.room.name", "", alice, map[string]interface{}{"name": "The Updated Room Name"}),
}
_, _, err := store.Accumulate(alice, roomID2, "", eventsRoom2)
_, err := store.Accumulate(alice, roomID2, "", eventsRoom2)
if err != nil {
t.Fatalf("Accumulate: %s", err)
}
_, latestNIDs, err := store.Accumulate(alice, roomID, "", events)
accResult, err := store.Accumulate(alice, roomID, "", events)
if err != nil {
t.Fatalf("Accumulate: %s", err)
}
latest := latestNIDs[len(latestNIDs)-1]
latest := accResult.TimelineNIDs[len(accResult.TimelineNIDs)-1]
globalCache := caches.NewGlobalCache(store)
testCases := []struct {
name string

View File

@ -760,3 +760,10 @@ func (u *UserCache) ShouldIgnore(userID string) bool {
_, ignored := u.ignoredUsers[userID]
return ignored
}
func (u *UserCache) OnInvalidateRoom(ctx context.Context, roomID string) {
// Nothing for now. In UserRoomData the fields dependant on room state are
// IsDM, IsInvite, HasLeft, Invite, CanonicalisedName, ResolvedAvatarURL, Spaces.
// Not clear to me if we need to reload these or if we will inherit any changes from
// the global cache.
}

View File

@ -24,6 +24,7 @@ type Receiver interface {
OnNewEvent(ctx context.Context, event *caches.EventData)
OnReceipt(ctx context.Context, receipt internal.Receipt)
OnEphemeralEvent(ctx context.Context, roomID string, ephEvent json.RawMessage)
OnInvalidateRoom(ctx context.Context, roomID string)
// OnRegistered is called after a successful call to Dispatcher.Register
OnRegistered(ctx context.Context) error
}
@ -285,3 +286,23 @@ func (d *Dispatcher) notifyListeners(ctx context.Context, ed *caches.EventData,
}
}
}
func (d *Dispatcher) OnInvalidateRoom(ctx context.Context, roomID string) {
// First dispatch to the global cache.
receiver, ok := d.userToReceiver[DispatcherAllUsers]
if !ok {
logger.Error().Msgf("No receiver for global cache")
}
receiver.OnInvalidateRoom(ctx, roomID)
// Then dispatch to any users who are joined to that room.
joinedUsers, _ := d.jrt.JoinedUsersForRoom(roomID, nil)
d.userToReceiverMu.RLock()
defer d.userToReceiverMu.RUnlock()
for _, userID := range joinedUsers {
receiver = d.userToReceiver[userID]
if receiver != nil {
receiver.OnInvalidateRoom(ctx, roomID)
}
}
}

View File

@ -803,6 +803,13 @@ func (h *SyncLiveHandler) OnExpiredToken(p *pubsub.V2ExpiredToken) {
h.ConnMap.CloseConnsForDevice(p.UserID, p.DeviceID)
}
func (h *SyncLiveHandler) OnInvalidateRoom(p *pubsub.V2InvalidateRoom) {
ctx, task := internal.StartTask(context.Background(), "OnInvalidateRoom")
defer task.End()
h.Dispatcher.OnInvalidateRoom(ctx, p.RoomID)
}
func parseIntFromQuery(u *url.URL, param string) (result int64, err *internal.HandlerError) {
queryPos := u.Query().Get(param)
if queryPos != "" {

View File

@ -132,6 +132,7 @@ type SyncReq struct {
type CSAPI struct {
UserID string
Localpart string
Domain string
AccessToken string
DeviceID string
AvatarURL string
@ -160,6 +161,16 @@ func (c *CSAPI) UploadContent(t *testing.T, fileBody []byte, fileName string, co
return GetJSONFieldStr(t, body, "content_uri")
}
// Use an empty string to remove a custom displayname.
func (c *CSAPI) SetDisplayname(t *testing.T, name string) {
t.Helper()
reqBody := map[string]any{}
if name != "" {
reqBody["displayname"] = name
}
c.MustDoFunc(t, "PUT", []string{"_matrix", "client", "v3", "profile", c.UserID, "displayname"}, WithJSONBody(t, reqBody))
}
// Use an empty string to remove your avatar.
func (c *CSAPI) SetAvatar(t *testing.T, avatarURL string) {
t.Helper()
@ -184,9 +195,9 @@ func (c *CSAPI) DownloadContent(t *testing.T, mxcUri string) ([]byte, string) {
}
// CreateRoom creates a room with an optional HTTP request body. Fails the test on error. Returns the room ID.
func (c *CSAPI) CreateRoom(t *testing.T, creationContent interface{}) string {
func (c *CSAPI) CreateRoom(t *testing.T, reqBody map[string]any) string {
t.Helper()
res := c.MustDo(t, "POST", []string{"_matrix", "client", "v3", "createRoom"}, creationContent)
res := c.MustDo(t, "POST", []string{"_matrix", "client", "v3", "createRoom"}, reqBody)
body := ParseJSON(t, res)
return GetJSONFieldStr(t, body, "room_id")
}

View File

@ -220,7 +220,9 @@ func registerNamedUser(t *testing.T, localpartPrefix string) *CSAPI {
}
client.UserID, client.AccessToken, client.DeviceID = client.RegisterUser(t, localpart, "password")
client.Localpart = strings.Split(client.UserID, ":")[0][1:]
parts := strings.Split(client.UserID, ":")
client.Localpart = parts[0][1:]
client.Domain = strings.Split(client.UserID, ":")[1]
return client
}

View File

@ -1,7 +1,9 @@
package syncv3_test
import (
"fmt"
"testing"
"time"
"github.com/matrix-org/sliding-sync/sync3"
"github.com/matrix-org/sliding-sync/testutils/m"
@ -75,3 +77,97 @@ func TestRedactionsAreRedactedWherePossible(t *testing.T) {
}
}
func TestRedactingRoomStateIsReflectedInNextSync(t *testing.T) {
alice := registerNamedUser(t, "alice")
bob := registerNamedUser(t, "bob")
t.Log("Alice creates a room, then sets a room alias and name.")
room := alice.CreateRoom(t, map[string]any{
"preset": "public_chat",
})
alias := fmt.Sprintf("#%s-%d:%s", t.Name(), time.Now().Unix(), alice.Domain)
alice.MustDoFunc(t, "PUT", []string{"_matrix", "client", "v3", "directory", "room", alias},
WithJSONBody(t, map[string]any{"room_id": room}),
)
aliasID := alice.SetState(t, room, "m.room.canonical_alias", "", map[string]any{
"alias": alias,
})
const naughty = "naughty room for naughty people"
nameID := alice.SetState(t, room, "m.room.name", "", map[string]any{
"name": naughty,
})
t.Log("Alice sliding syncs, subscribing to that room explicitly.")
res := alice.SlidingSync(t, sync3.Request{
RoomSubscriptions: map[string]sync3.RoomSubscription{
room: {
TimelineLimit: 20,
},
},
})
t.Log("Alice should see her room appear with its name.")
m.MatchResponse(t, res, m.MatchRoomSubscription(room, m.MatchRoomName(naughty)))
t.Log("Alice redacts the room name.")
redactionID := alice.RedactEvent(t, room, nameID)
t.Log("Alice syncs until she sees her redaction.")
res = alice.SlidingSyncUntil(t, res.Pos, sync3.Request{}, m.MatchRoomSubscription(
room,
MatchRoomTimelineMostRecent(1, []Event{{ID: redactionID}}),
))
t.Log("The room name should have been redacted, falling back to the canonical alias.")
m.MatchResponse(t, res, m.MatchRoomSubscription(room, m.MatchRoomName(alias)))
t.Log("Alice sets a room avatar.")
avatarURL := alice.UploadContent(t, smallPNG, "avatar.png", "image/png")
avatarID := alice.SetState(t, room, "m.room.avatar", "", map[string]interface{}{
"url": avatarURL,
})
t.Log("Alice waits to see the avatar.")
res = alice.SlidingSyncUntil(t, res.Pos, sync3.Request{}, m.MatchRoomSubscription(room, m.MatchRoomAvatar(avatarURL)))
t.Log("Alice redacts the avatar.")
redactionID = alice.RedactEvent(t, room, avatarID)
t.Log("Alice sees the avatar revert to blank.")
res = alice.SlidingSyncUntil(t, res.Pos, sync3.Request{}, m.MatchRoomSubscription(room, m.MatchRoomUnsetAvatar()))
t.Log("Bob joins the room, with a custom displayname.")
const bobDisplayName = "bob mortimer"
bob.SetDisplayname(t, bobDisplayName)
bob.JoinRoom(t, room, nil)
t.Log("Alice sees Bob join.")
res = alice.SlidingSyncUntil(t, res.Pos, sync3.Request{}, m.MatchRoomSubscription(room,
MatchRoomTimelineMostRecent(1, []Event{{
StateKey: ptr(bob.UserID),
Type: "m.room.member",
Content: map[string]any{
"membership": "join",
"displayname": bobDisplayName,
},
}}),
))
// Extract Bob's join ID because https://github.com/matrix-org/matrix-spec-proposals/pull/2943 doens't exist grrr
timeline := res.Rooms[room].Timeline
bobJoinID := gjson.GetBytes(timeline[len(timeline)-1], "event_id").Str
t.Log("Alice redacts the alias.")
redactionID = alice.RedactEvent(t, room, aliasID)
t.Log("Alice sees the room name reset to Bob's display name.")
res = alice.SlidingSyncUntil(t, res.Pos, sync3.Request{}, m.MatchRoomSubscription(room, m.MatchRoomName(bobDisplayName)))
t.Log("Bob redacts his membership")
redactionID = bob.RedactEvent(t, room, bobJoinID)
t.Log("Alice sees the room name reset to Bob's username.")
res = alice.SlidingSyncUntil(t, res.Pos, sync3.Request{}, m.MatchRoomSubscription(room, m.MatchRoomName(bob.UserID)))
}