sliding-sync/sync3/tracker.go
David Robertson 4897acf3d6
Add the ability to reset the JoinedRoomsTracker
Not used yet; pulled out of #329.
2023-11-02 15:14:42 +00:00

233 lines
7.0 KiB
Go

package sync3
import (
"sync"
)
type set map[string]struct{}
// Tracks who is joined to which rooms. This is critical from a security perspective in order to
// ensure that only the users joined to the room receive events in that room. Consider the situation
// where Alice and Bob are joined to room X. If Alice gets kicked from X, the proxy server will still
// receive messages for room X due to Bob being joined to the room. We therefore need to decide which
// active connections should be pushed events, which is what this tracker does.
type JoinedRoomsTracker struct {
// map of room_id to joined user IDs.
roomIDToJoinedUsers map[string]set
userIDToJoinedRooms map[string]set
// not for security, just to track invite counts correctly as Synapse can send dupe invite->join events
// so increment +-1 counts don't work.
roomIDToInvitedUsers map[string]set
mu *sync.RWMutex
}
func NewJoinedRoomsTracker() *JoinedRoomsTracker {
return &JoinedRoomsTracker{
roomIDToJoinedUsers: make(map[string]set),
userIDToJoinedRooms: make(map[string]set),
roomIDToInvitedUsers: make(map[string]set),
mu: &sync.RWMutex{},
}
}
// Startup efficiently sets up the joined rooms tracker, but isn't safe to call with live traffic,
// as it replaces all known in-memory state. Panics if called on a non-empty tracker.
func (t *JoinedRoomsTracker) Startup(roomToJoinedUsers map[string][]string) {
t.mu.Lock()
defer t.mu.Unlock()
if len(t.roomIDToJoinedUsers) > 0 || len(t.userIDToJoinedRooms) > 0 {
panic("programming error: cannot call JoinedRoomsTracker.Startup with existing data already set!")
}
for roomID, userIDs := range roomToJoinedUsers {
userSet := make(set)
for _, u := range userIDs {
userSet[u] = struct{}{}
rooms := t.userIDToJoinedRooms[u]
if rooms == nil {
rooms = make(set)
}
rooms[roomID] = struct{}{}
t.userIDToJoinedRooms[u] = rooms
}
t.roomIDToJoinedUsers[roomID] = userSet
}
}
func (t *JoinedRoomsTracker) IsUserJoined(userID, roomID string) bool {
t.mu.RLock()
defer t.mu.RUnlock()
users := t.roomIDToJoinedUsers[roomID]
for u := range users {
if u == userID {
return true
}
}
return false
}
// UserJoinedRoom marks the given user as having joined the given room. Returns true
// if the user was not joined to the room prior to this call, and false otherwise.
func (t *JoinedRoomsTracker) UserJoinedRoom(userID, roomID string) bool {
u := make([]string, 1, 1)
u[0] = userID
return t.UsersJoinedRoom(u, roomID)
}
// UsersJoinedRoom marks the given slice of users as having joined the given room.
// Returns true if at least one of the users was not joined to the room prior to the
// call, and false otherwise.
func (t *JoinedRoomsTracker) UsersJoinedRoom(userIDs []string, roomID string) bool {
t.mu.Lock()
defer t.mu.Unlock()
wasJoined := true
users := t.roomIDToJoinedUsers[roomID]
for _, newlyJoinedUser := range userIDs {
_, exists := users[newlyJoinedUser]
if !exists {
wasJoined = false
break
}
}
// pull out room specific structs
joinedUsers := t.roomIDToJoinedUsers[roomID]
if joinedUsers == nil {
joinedUsers = make(set)
}
invitedUsers := t.roomIDToInvitedUsers[roomID]
// loop user specific structs
for _, newlyJoinedUser := range userIDs {
joinedRooms := t.userIDToJoinedRooms[newlyJoinedUser]
if joinedRooms == nil {
joinedRooms = make(set)
}
delete(invitedUsers, newlyJoinedUser)
joinedRooms[roomID] = struct{}{}
joinedUsers[newlyJoinedUser] = struct{}{}
t.userIDToJoinedRooms[newlyJoinedUser] = joinedRooms
}
t.roomIDToJoinedUsers[roomID] = joinedUsers
t.roomIDToInvitedUsers[roomID] = invitedUsers
return !wasJoined
}
// UserLeftRoom marks the given user as having left the given room.
// Returns true if this user _was_ joined or invited to the room before this call,
// and false otherwise.
func (t *JoinedRoomsTracker) UserLeftRoom(userID, roomID string) bool {
t.mu.Lock()
defer t.mu.Unlock()
joinedRooms := t.userIDToJoinedRooms[userID]
joinedUsers := t.roomIDToJoinedUsers[roomID]
invitedUsers := t.roomIDToInvitedUsers[roomID]
_, wasJoined := joinedUsers[userID]
_, wasInvited := invitedUsers[userID]
delete(joinedRooms, roomID)
delete(joinedUsers, userID)
delete(invitedUsers, userID)
t.userIDToJoinedRooms[userID] = joinedRooms
t.roomIDToJoinedUsers[roomID] = joinedUsers
t.roomIDToInvitedUsers[roomID] = invitedUsers
return wasJoined || wasInvited
}
func (t *JoinedRoomsTracker) JoinedRoomsForUser(userID string) []string {
t.mu.RLock()
defer t.mu.RUnlock()
rooms := t.userIDToJoinedRooms[userID]
if rooms == nil || len(rooms) == 0 {
return nil
}
n := len(rooms)
i := 0
result := make([]string, n)
for roomID := range rooms {
result[i] = roomID
i++
}
return result
}
// JoinedUsersForRoom returns the joined users in the given room, filtered by the filter function if provided. If one is not
// provided, all joined users are returned. Returns the join count at the time this function was called.
func (t *JoinedRoomsTracker) JoinedUsersForRoom(roomID string, filter func(userID string) bool) (matchedUserIDs []string, joinCount int) {
t.mu.RLock()
defer t.mu.RUnlock()
users := t.roomIDToJoinedUsers[roomID]
if users == nil || len(users) == 0 {
return nil, 0
}
n := len(users)
if filter == nil {
filter = func(userID string) bool { return true }
}
for userID := range users {
if filter(userID) {
matchedUserIDs = append(matchedUserIDs, userID)
}
}
return matchedUserIDs, n
}
func (t *JoinedRoomsTracker) UsersInvitedToRoom(userIDs []string, roomID string) {
t.mu.Lock()
defer t.mu.Unlock()
users := t.roomIDToInvitedUsers[roomID]
if users == nil {
users = make(set)
}
for _, userID := range userIDs {
users[userID] = struct{}{}
}
t.roomIDToInvitedUsers[roomID] = users
}
func (t *JoinedRoomsTracker) NumInvitedUsersForRoom(roomID string) int {
t.mu.RLock()
defer t.mu.RUnlock()
return len(t.roomIDToInvitedUsers[roomID])
}
// ReloadMembershipsForRoom overwrites the JoinedRoomsTracker state for one room to the
// given list of joined and invited users.
func (t *JoinedRoomsTracker) ReloadMembershipsForRoom(roomID string, joined, invited []string) {
newJoined := make(set, len(joined))
newInvited := make(set, len(invited))
for _, member := range joined {
newJoined[member] = struct{}{}
}
for _, member := range invited {
newInvited[member] = struct{}{}
}
t.mu.Lock()
defer t.mu.Unlock()
// 1. Overwrite the room's memberships with the given arguments.
oldJoined := t.roomIDToJoinedUsers[roomID]
t.roomIDToJoinedUsers[roomID] = newJoined
t.roomIDToInvitedUsers[roomID] = newInvited
// 2. Mark the joined users as being joined to this room.
for userID := range newJoined {
if t.userIDToJoinedRooms[userID] == nil {
t.userIDToJoinedRooms[userID] = make(set)
}
t.userIDToJoinedRooms[userID][roomID] = struct{}{}
}
// 3. Scan the old joined list for users who are no longer joined, and mark them as such.
for userID := range oldJoined {
_, stillJoined := newJoined[userID]
if !stillJoined {
delete(t.userIDToJoinedRooms[userID], roomID)
}
}
}