diff --git a/internal/event.go b/internal/event.go new file mode 100644 index 0000000..fdd0050 --- /dev/null +++ b/internal/event.go @@ -0,0 +1,19 @@ +package internal + +import "github.com/tidwall/gjson" + +func IsMembershipChange(eventJSON gjson.Result) bool { + // membership event possibly, make sure the membership has changed else + // things like display name changes will count as membership events :( + prevMembership := "leave" + pm := eventJSON.Get("unsigned.prev_content.membership") + if pm.Exists() && pm.Str != "" { + prevMembership = pm.Str + } + currMembership := "leave" + cm := eventJSON.Get("content.membership") + if cm.Exists() && cm.Str != "" { + currMembership = cm.Str + } + return prevMembership != currMembership // membership was changed +} diff --git a/room_names_test.go b/room_names_test.go index 5b79019..078dc03 100644 --- a/room_names_test.go +++ b/room_names_test.go @@ -42,7 +42,23 @@ func TestRoomNames(t *testing.T) { { roomID: "!TestRoomNames_empty:localhost", name: "Empty Room", - events: createRoomState(t, alice, latestTimestamp), + events: createRoomState(t, alice, latestTimestamp.Add(time.Second)), + }, + { + roomID: "!TestRoomNames_dm_name_set_after_join:localhost", + name: "Bob", + state: append(createRoomState(t, alice, latestTimestamp), []json.RawMessage{ + testutils.NewStateEvent(t, "m.room.member", bob, bob, map[string]interface{}{"membership": "join"}, testutils.WithTimestamp(latestTimestamp)), + }...), + events: []json.RawMessage{ + testutils.NewStateEvent( + t, "m.room.member", bob, bob, map[string]interface{}{"membership": "join", "displayname": "Bob"}, testutils.WithTimestamp(latestTimestamp), + testutils.WithUnsigned(map[string]interface{}{ + "prev_content": map[string]interface{}{ + "membership": "join", + }, + })), + }, }, } v2.addAccount(alice, aliceToken) @@ -53,6 +69,7 @@ func TestRoomNames(t *testing.T) { }) checkRoomNames := func(sessionID string) { + t.Helper() // do a sync, make sure room names are sensible res := v3.mustDoV3Request(t, aliceToken, sync3.Request{ Rooms: sync3.SliceRanges{ diff --git a/state/event_table.go b/state/event_table.go index 7150429..b43efde 100644 --- a/state/event_table.go +++ b/state/event_table.go @@ -6,6 +6,7 @@ import ( "math" "github.com/jmoiron/sqlx" + "github.com/matrix-org/sync-v3/internal" "github.com/matrix-org/sync-v3/sqlutil" "github.com/tidwall/gjson" ) @@ -124,7 +125,7 @@ func (t *EventTable) Insert(txn *sqlx.Tx, events []Event) (int, error) { return 0, fmt.Errorf("membership event missing membership key") } // genuine changes mark the membership event - if isMembershipChange(evJSON) { + if internal.IsMembershipChange(evJSON) { ev.Membership = membershipResult.Str } else { // profile changes have _ prefix. @@ -333,19 +334,3 @@ func (c EventChunker) Len() int { func (c EventChunker) Subslice(i, j int) sqlutil.Chunker { return c[i:j] } - -func isMembershipChange(eventJSON gjson.Result) bool { - // membership event possibly, make sure the membership has changed else - // things like display name changes will count as membership events :( - prevMembership := "leave" - pm := eventJSON.Get("unsigned.prev_content.membership") - if pm.Exists() && pm.Str != "" { - prevMembership = pm.Str - } - currMembership := "leave" - cm := eventJSON.Get("content.membership") - if cm.Exists() && cm.Str != "" { - currMembership = cm.Str - } - return prevMembership != currMembership // membership was changed -} diff --git a/state/storage.go b/state/storage.go index 9daa824..e517982 100644 --- a/state/storage.go +++ b/state/storage.go @@ -133,8 +133,12 @@ func (s *Storage) MetadataForAllRooms() (map[string]internal.RoomMetadata, error SELECT room_id, event, rank() OVER ( PARTITION BY room_id ORDER BY event_nid DESC ) FROM syncv3_events WHERE ( - membership='join' OR membership='invite' OR (membership='_join' AND before_state_snapshot_id=0) - ) AND event_type='m.room.member' + membership='join' OR membership='invite' OR membership='_join' + ) AND event_type='m.room.member' AND event_nid IN ( + SELECT unnest(events) FROM syncv3_snapshots WHERE syncv3_snapshots.snapshot_id IN ( + SELECT current_snapshot_id FROM syncv3_rooms + ) + ) ) rf WHERE rank <= 6`) if err != nil { return nil, fmt.Errorf("failed to query heroes: %s", err) diff --git a/sync3/globalcache.go b/sync3/globalcache.go index ddbcdf3..40c69ba 100644 --- a/sync3/globalcache.go +++ b/sync3/globalcache.go @@ -166,23 +166,38 @@ func (c *GlobalCache) OnNewEvent( case "m.room.member": if ed.stateKey != nil { membership := ed.content.Get("membership").Str - if membership == "invite" { - metadata.InviteCount += 1 - } else if membership == "join" { - metadata.JoinCount += 1 - } else if membership == "leave" || membership == "ban" { - metadata.JoinCount -= 1 - // remove this user as a hero - metadata.RemoveHero(*ed.stateKey) - } - if gjson.ParseBytes(ed.event).Get("unsigned.prev_content.membership").Str == "invite" { - metadata.InviteCount -= 1 + eventJSON := gjson.ParseBytes(ed.event) + if internal.IsMembershipChange(eventJSON) { + if membership == "invite" { + metadata.InviteCount += 1 + } else if membership == "join" { + metadata.JoinCount += 1 + } else if membership == "leave" || membership == "ban" { + metadata.JoinCount -= 1 + // remove this user as a hero + metadata.RemoveHero(*ed.stateKey) + } + + if eventJSON.Get("unsigned.prev_content.membership").Str == "invite" { + metadata.InviteCount -= 1 + } } if len(metadata.Heroes) < 6 && (membership == "join" || membership == "invite") { - metadata.Heroes = append(metadata.Heroes, internal.Hero{ - ID: *ed.stateKey, - Name: ed.content.Get("displayname").Str, - }) + // try to find the existing hero e.g they changed their display name + found := false + for i := range metadata.Heroes { + if metadata.Heroes[i].ID == *ed.stateKey { + metadata.Heroes[i].Name = ed.content.Get("displayname").Str + found = true + break + } + } + if !found { + metadata.Heroes = append(metadata.Heroes, internal.Hero{ + ID: *ed.stateKey, + Name: ed.content.Get("displayname").Str, + }) + } } } } diff --git a/v3_test.go b/v3_test.go index 5063e5a..82d7703 100644 --- a/v3_test.go +++ b/v3_test.go @@ -323,6 +323,7 @@ func MatchRoomNotificationCount(count int64) roomMatcher { type roomEvents struct { roomID string name string + state []json.RawMessage events []json.RawMessage } @@ -345,6 +346,11 @@ func v2JoinTimeline(joinEvents ...roomEvents) map[string]sync2.SyncV2JoinRespons data.Timeline = sync2.TimelineResponse{ Events: re.events, } + if re.state != nil { + data.State = sync2.EventsResponse{ + Events: re.state, + } + } result[re.roomID] = data } return result