mirror of
https://github.com/matrix-org/sliding-sync.git
synced 2025-03-10 13:37:11 +00:00
Merge pull request #362 from matrix-org/dmr/fetch-memberships
This commit is contained in:
commit
319da85789
@ -368,6 +368,57 @@ func (s *Storage) ResetMetadataState(metadata *internal.RoomMetadata) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// FetchMemberships looks up the latest snapshot for the given room and determines the
|
||||
// latest membership events in the room. Returns
|
||||
// - the list of joined members,
|
||||
// - the list of invited members, and then
|
||||
// - the list of all other memberships. (This is called "leaves", but includes bans. It
|
||||
// also includes knocks, but the proxy doesn't support those.)
|
||||
//
|
||||
// Each lists' members are arranged in no particular order.
|
||||
//
|
||||
// TODO: there is a very similar query in ResetMetadataState which also selects events
|
||||
// events row for memberships. It is a shame to have to do this twice---can we query
|
||||
// once and pass the data around?
|
||||
func (s *Storage) FetchMemberships(roomID string) (joins, invites, leaves []string, err error) {
|
||||
var events []Event
|
||||
err = s.DB.Select(&events, `
|
||||
WITH snapshot(membership_nids) AS (
|
||||
SELECT membership_events
|
||||
FROM syncv3_snapshots
|
||||
JOIN syncv3_rooms ON snapshot_id = current_snapshot_id
|
||||
WHERE syncv3_rooms.room_id = $1
|
||||
)
|
||||
SELECT state_key, membership
|
||||
FROM syncv3_events JOIN snapshot ON (
|
||||
event_nid = ANY( membership_nids )
|
||||
)
|
||||
`, roomID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
joins = make([]string, 0, len(events))
|
||||
invites = make([]string, 0, len(events))
|
||||
leaves = make([]string, 0, len(events))
|
||||
|
||||
for _, e := range events {
|
||||
switch e.Membership {
|
||||
case "_join":
|
||||
fallthrough
|
||||
case "join":
|
||||
joins = append(joins, e.StateKey)
|
||||
case "_invite":
|
||||
fallthrough
|
||||
case "invite":
|
||||
invites = append(invites, e.StateKey)
|
||||
default:
|
||||
leaves = append(leaves, e.StateKey)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Returns all current NOT MEMBERSHIP state events matching the event types given in all rooms. Returns a map of
|
||||
// room ID to events in that room.
|
||||
func (s *Storage) currentNotMembershipStateEventsInAllRooms(txn *sqlx.Tx, eventTypes []string) (map[string][]Event, error) {
|
||||
|
@ -875,6 +875,49 @@ func TestCircularSlice(t *testing.T) {
|
||||
|
||||
}
|
||||
|
||||
func TestStorage_FetchMemberships(t *testing.T) {
|
||||
assertNoError(t, cleanDB(t))
|
||||
store := NewStorage(postgresConnectionString)
|
||||
defer store.Teardown()
|
||||
|
||||
events := []json.RawMessage{
|
||||
testutils.NewStateEvent(t, "m.room.create", "", "@alice:test", map[string]any{}),
|
||||
testutils.NewStateEvent(t, "m.room.member", "@alice:test", "@alice:test", map[string]any{"membership": "join"}),
|
||||
testutils.NewStateEvent(t, "m.room.member", "@brian:test", "@alice:test", map[string]any{"membership": "invite"}),
|
||||
testutils.NewStateEvent(t, "m.room.member", "@chris:test", "@chris:test", map[string]any{"membership": "leave"}),
|
||||
testutils.NewStateEvent(t, "m.room.member", "@david:test", "@alice:test", map[string]any{"membership": "ban"}),
|
||||
testutils.NewStateEvent(t, "m.room.member", "@erika:test", "@erika:test", map[string]any{"membership": "join"}),
|
||||
testutils.NewStateEvent(t, "m.room.member", "@frank:test", "@erika:test", map[string]any{"membership": "invite"}),
|
||||
testutils.NewStateEvent(t, "m.room.member", "@glory:test", "@glory:test", map[string]any{"membership": "leave"}),
|
||||
testutils.NewStateEvent(t, "m.room.member", "@helen:test", "@alice:test", map[string]any{"membership": "ban"}),
|
||||
}
|
||||
|
||||
const roomID = "!unimportant"
|
||||
err := sqlutil.WithTransaction(store.DB, func(txn *sqlx.Tx) (err error) {
|
||||
_, err = store.Accumulator.Initialise(roomID, events)
|
||||
return err
|
||||
})
|
||||
assertNoError(t, err)
|
||||
|
||||
joins, invites, leaves, err := store.FetchMemberships(roomID)
|
||||
assertNoError(t, err)
|
||||
|
||||
// Do not assume an order from the DB.
|
||||
sort.Slice(joins, func(i, j int) bool {
|
||||
return joins[i] < joins[j]
|
||||
})
|
||||
sort.Slice(invites, func(i, j int) bool {
|
||||
return invites[i] < invites[j]
|
||||
})
|
||||
sort.Slice(leaves, func(i, j int) bool {
|
||||
return leaves[i] < leaves[j]
|
||||
})
|
||||
|
||||
assertValue(t, "joins", joins, []string{"@alice:test", "@erika:test"})
|
||||
assertValue(t, "invites", invites, []string{"@brian:test", "@frank:test"})
|
||||
assertValue(t, "joins", leaves, []string{"@chris:test", "@david:test", "@glory:test", "@helen:test"})
|
||||
}
|
||||
|
||||
func cleanDB(t *testing.T) error {
|
||||
// make a fresh DB which is unpolluted from other tests
|
||||
db, close := connectToDB(t)
|
||||
|
Loading…
x
Reference in New Issue
Block a user