diff --git a/state/unread_table.go b/state/unread_table.go index a389184..8dcd2d1 100644 --- a/state/unread_table.go +++ b/state/unread_table.go @@ -23,9 +23,10 @@ func NewUnreadTable(db *sqlx.DB) *UnreadTable { return &UnreadTable{db} } -func (t *UnreadTable) SelectAllNonZeroCounts(callback func(roomID, userID string, highlightCount, notificationCount int)) error { +func (t *UnreadTable) SelectAllNonZeroCountsForUser(userID string, callback func(roomID string, highlightCount, notificationCount int)) error { rows, err := t.db.Query( - `SELECT user_id, room_id, notification_count, highlight_count FROM syncv3_unread WHERE notification_count > 0 OR highlight_count > 0`, + `SELECT room_id, notification_count, highlight_count FROM syncv3_unread WHERE user_id=$1 AND (notification_count > 0 OR highlight_count > 0)`, + userID, ) if err != nil { return err @@ -33,13 +34,12 @@ func (t *UnreadTable) SelectAllNonZeroCounts(callback func(roomID, userID string defer rows.Close() for rows.Next() { var roomID string - var userID string var highlightCount int var notifCount int - if err := rows.Scan(&userID, &roomID, ¬ifCount, &highlightCount); err != nil { + if err := rows.Scan(&roomID, ¬ifCount, &highlightCount); err != nil { return err } - callback(roomID, userID, highlightCount, notifCount) + callback(roomID, highlightCount, notifCount) } return nil } diff --git a/state/unread_table_test.go b/state/unread_table_test.go index a11118c..ff3a068 100644 --- a/state/unread_table_test.go +++ b/state/unread_table_test.go @@ -44,26 +44,23 @@ func TestUnreadTable(t *testing.T) { roomA: 1, roomB: 2, } - assertNoError(t, table.SelectAllNonZeroCounts(func(gotRoomID string, gotUserID string, gotHighlight int, gotNotif int) { - if userID != gotUserID { - t.Errorf("SelectAllNonZeroCounts: got user %v want %v", gotUserID, userID) - } + assertNoError(t, table.SelectAllNonZeroCountsForUser(userID, func(gotRoomID string, gotHighlight int, gotNotif int) { wantHighlight := wantHighlights[gotRoomID] if wantHighlight != gotHighlight { - t.Errorf("SelectAllNonZeroCounts for %v got %d highlights, want %d", gotRoomID, gotHighlight, wantHighlight) + t.Errorf("SelectAllNonZeroCountsForUser for %v got %d highlights, want %d", gotRoomID, gotHighlight, wantHighlight) } wantNotif := wantNotifs[gotRoomID] if wantNotif != gotNotif { - t.Errorf("SelectAllNonZeroCounts for %v got %d notifs, want %d", gotRoomID, gotNotif, wantNotif) + t.Errorf("SelectAllNonZeroCountsForUser for %v got %d notifs, want %d", gotRoomID, gotNotif, wantNotif) } delete(wantHighlights, gotRoomID) delete(wantNotifs, gotRoomID) })) if len(wantHighlights) != 0 { - t.Errorf("SelectAllNonZeroCounts missed highlight rooms: %+v", wantHighlights) + t.Errorf("SelectAllNonZeroCountsForUser missed highlight rooms: %+v", wantHighlights) } if len(wantNotifs) != 0 { - t.Errorf("SelectAllNonZeroCounts missed notif rooms: %+v", wantNotifs) + t.Errorf("SelectAllNonZeroCountsForUser missed notif rooms: %+v", wantNotifs) } } diff --git a/sync3/conn.go b/sync3/conn.go index d5793cf..5752e80 100644 --- a/sync3/conn.go +++ b/sync3/conn.go @@ -55,14 +55,6 @@ func (c *Conn) PushNewEvent(eventData *EventData) { c.connState.PushNewEvent(eventData) } -func (c *Conn) PushUserRoomData(userID, roomID string, data userRoomData, timestamp int64) { - c.connState.PushNewEvent(&EventData{ - roomID: roomID, - userRoomData: &data, - timestamp: timestamp, - }) -} - // OnIncomingRequest advances the clients position in the stream, returning the response position and data. func (c *Conn) OnIncomingRequest(ctx context.Context, req *Request) (resp *Response, herr *internal.HandlerError) { c.mu.Lock() diff --git a/sync3/connmap.go b/sync3/connmap.go index f49b70f..39b7989 100644 --- a/sync3/connmap.go +++ b/sync3/connmap.go @@ -18,11 +18,12 @@ type EventData struct { stateKey *string content gjson.Result timestamp int64 + + // TODO: remove or factor out + userRoomData *UserRoomData // the absolute latest position for this event data. The NID for this event is guaranteed to // be <= this value. latestPos int64 - - userRoomData *userRoomData } // ConnMap stores a collection of Conns along with other global server-wide state e.g the in-memory @@ -46,24 +47,18 @@ type ConnMap struct { globalRoomInfo map[string]*SortableRoom mu *sync.Mutex - // inserts are done by v2 poll loops, selects are done by v3 request threads - // but the v3 requests touch non-overlapping keys, which is a good use case for sync.Map - // > (2) when multiple goroutines read, write, and overwrite entries for disjoint sets of keys. - perUserPerRoomData *sync.Map // map[string]userRoomData - store *state.Storage } func NewConnMap(store *state.Storage) *ConnMap { cm := &ConnMap{ - userIDToConn: make(map[string][]*Conn), - connIDToConn: make(map[string]*Conn), - cache: ttlcache.NewCache(), - mu: &sync.Mutex{}, - jrt: NewJoinedRoomsTracker(), - store: store, - globalRoomInfo: make(map[string]*SortableRoom), - perUserPerRoomData: &sync.Map{}, + userIDToConn: make(map[string][]*Conn), + connIDToConn: make(map[string]*Conn), + cache: ttlcache.NewCache(), + mu: &sync.Mutex{}, + jrt: NewJoinedRoomsTracker(), + store: store, + globalRoomInfo: make(map[string]*SortableRoom), } cm.cache.SetTTL(30 * time.Minute) // TODO: customisable cm.cache.SetExpirationCallback(cm.closeConn) @@ -80,7 +75,7 @@ func (m *ConnMap) Conn(cid ConnID) *Conn { } // Atomically gets or creates a connection with this connection ID. -func (m *ConnMap) GetOrCreateConn(cid ConnID, userID string) (*Conn, bool) { +func (m *ConnMap) GetOrCreateConn(cid ConnID, userID string, userCache *UserCache) (*Conn, bool) { // atomically check if a conn exists already and return that if so m.mu.Lock() defer m.mu.Unlock() @@ -88,7 +83,7 @@ func (m *ConnMap) GetOrCreateConn(cid ConnID, userID string) (*Conn, bool) { if conn != nil { return conn, false } - state := NewConnState(userID, m) + state := NewConnState(userID, userCache, m) conn = NewConn(cid, state, state.HandleIncomingRequest) m.cache.Set(cid.String(), conn) m.connIDToConn[cid.String()] = conn @@ -147,13 +142,6 @@ func (m *ConnMap) LoadBaseline(roomIDToUserIDs map[string][]string) error { m.jrt.UserJoinedRoom(userID, roomID) } } - // select all non-zero highlight or notif counts and set them, as this is less costly than looping every room/user pair - err = m.store.UnreadTable.SelectAllNonZeroCounts(func(roomID, userID string, highlightCount, notificationCount int) { - m.OnUnreadCounts(roomID, userID, &highlightCount, ¬ificationCount) - }) - if err != nil { - return fmt.Errorf("failed to load unread counts: %s", err) - } return nil } @@ -237,48 +225,6 @@ func (m *ConnMap) closeConn(connID string, value interface{}) { } } -func (m *ConnMap) LoadUserRoomData(roomID, userID string) userRoomData { - key := userID + " " + roomID - data, ok := m.perUserPerRoomData.Load(key) - if !ok { - return userRoomData{} - } - return data.(userRoomData) -} - -// TODO: Move to cache struct -func (m *ConnMap) OnUnreadCounts(roomID, userID string, highlightCount, notifCount *int) { - data := m.LoadUserRoomData(roomID, userID) - hasCountDecreased := false - if highlightCount != nil { - hasCountDecreased = *highlightCount < data.highlightCount - data.highlightCount = *highlightCount - } - if notifCount != nil { - if !hasCountDecreased { - hasCountDecreased = *notifCount < data.notificationCount - } - data.notificationCount = *notifCount - } - key := userID + " " + roomID - m.perUserPerRoomData.Store(key, data) - if hasCountDecreased { - // we will notify the connection for count decreases so the client can update their badge counter. - // we don't do this on increases as this should always be associated with an actual event which - // we will notify the connection for (and unread counts are processed prior to this). By doing - // this we ensure atomic style updates of badge counts and events, rather than getting the badge - // count update without a message. - m.mu.Lock() - conns := m.userIDToConn[userID] - m.mu.Unlock() - room := m.LoadRoom(roomID) - // TODO: don't indirect via conn :S this is dumb - for _, conn := range conns { - conn.PushUserRoomData(userID, roomID, data, room.LastMessageTimestamp) - } - } -} - // TODO: Move to cache struct // Call this when there is a new event received on a v2 stream. // This event must be globally unique, i.e indicated so by the state store. diff --git a/sync3/connstate.go b/sync3/connstate.go index 6aff0d0..3e912e4 100644 --- a/sync3/connstate.go +++ b/sync3/connstate.go @@ -17,7 +17,6 @@ var ( type ConnStateStore interface { LoadRoom(roomID string) *SortableRoom - LoadUserRoomData(roomID, userID string) userRoomData LoadState(roomID string, loadPosition int64, requiredState [][2]string) []json.RawMessage Load(userID string) (joinedRoomIDs []string, initialLoadPosition int64, err error) } @@ -36,11 +35,14 @@ type ConnState struct { // Consumed when the conn is read. There is a limit to how many updates we will store before // saying the client is ded and cleaning up the conn. updateEvents chan *EventData + + userCache *UserCache } -func NewConnState(userID string, store ConnStateStore) *ConnState { +func NewConnState(userID string, userCache *UserCache, store ConnStateStore) *ConnState { return &ConnState{ store: store, + userCache: userCache, userID: userID, roomSubscriptions: make(map[string]RoomSubscription), sortedJoinedRoomsPositions: make(map[string]int), @@ -60,6 +62,7 @@ func NewConnState(userID string, store ConnStateStore) *ConnState { // - load() bases its current state based on the latest position, which includes processing of these N events. // - post load() we read N events, processing them a 2nd time. func (s *ConnState) load(req *Request) error { + s.userCache.Subsribe(s) joinedRoomIDs, initialLoadPosition, err := s.store.Load(s.userID) if err != nil { return err @@ -311,11 +314,11 @@ func (s *ConnState) updateRoomSubscriptions(subs, unsubs []string) map[string]Ro } func (s *ConnState) getDeltaRoomData(updateEvent *EventData) *Room { - userRoomData := s.store.LoadUserRoomData(updateEvent.roomID, s.userID) + userRoomData := s.userCache.loadRoomData(updateEvent.roomID) room := &Room{ RoomID: updateEvent.roomID, - NotificationCount: int64(userRoomData.notificationCount), - HighlightCount: int64(userRoomData.highlightCount), + NotificationCount: int64(userRoomData.NotificationCount), + HighlightCount: int64(userRoomData.HighlightCount), } if updateEvent.event != nil { room.Timeline = []json.RawMessage{ @@ -327,12 +330,12 @@ func (s *ConnState) getDeltaRoomData(updateEvent *EventData) *Room { func (s *ConnState) getInitialRoomData(roomID string) *Room { r := s.store.LoadRoom(roomID) - userRoomData := s.store.LoadUserRoomData(roomID, s.userID) + userRoomData := s.userCache.loadRoomData(roomID) return &Room{ RoomID: roomID, Name: r.Name, - NotificationCount: int64(userRoomData.notificationCount), - HighlightCount: int64(userRoomData.highlightCount), + NotificationCount: int64(userRoomData.NotificationCount), + HighlightCount: int64(userRoomData.HighlightCount), // TODO: timeline limits Timeline: []json.RawMessage{ r.LastEventJSON, @@ -395,3 +398,15 @@ func (s *ConnState) moveRoom(updateEvent *EventData, fromIndex, toIndex int, ran } } + +func (s *ConnState) OnUnreadCountsChanged(userID, roomID string, urd UserRoomData, hasCountDecreased bool) { + if !hasCountDecreased { + return + } + room := s.store.LoadRoom(roomID) + s.PushNewEvent(&EventData{ + roomID: roomID, + userRoomData: &urd, + timestamp: room.LastMessageTimestamp, + }) +} diff --git a/sync3/connstate_test.go b/sync3/connstate_test.go index 15a240e..baa34c1 100644 --- a/sync3/connstate_test.go +++ b/sync3/connstate_test.go @@ -42,9 +42,6 @@ func (s *connStateStoreMock) Load(userID string) (joinedRoomIDs []string, initia func (s *connStateStoreMock) LoadState(roomID string, loadPosition int64, requiredState [][2]string) []json.RawMessage { return nil } -func (s *connStateStoreMock) LoadUserRoomData(roomID, userID string) userRoomData { - return userRoomData{} -} func (s *connStateStoreMock) PushNewEvent(cs *ConnState, ed *EventData) { room := s.roomIDToRoom[ed.roomID] room.LastEventJSON = ed.event @@ -80,7 +77,7 @@ func TestConnStateInitial(t *testing.T) { roomC.RoomID: roomC, }, } - cs := NewConnState(userID, csm) + cs := NewConnState(userID, NewUserCache(userID), csm) if userID != cs.UserID() { t.Fatalf("UserID returned wrong value, got %v want %v", cs.UserID(), userID) } @@ -217,7 +214,7 @@ func TestConnStateMultipleRanges(t *testing.T) { }, roomIDToRoom: roomIDToRoom, } - cs := NewConnState(userID, csm) + cs := NewConnState(userID, NewUserCache(userID), csm) // request first page res, err := cs.HandleIncomingRequest(context.Background(), connID, &Request{ @@ -384,7 +381,7 @@ func TestBumpToOutsideRange(t *testing.T) { roomD.RoomID: roomD, }, } - cs := NewConnState(userID, csm) + cs := NewConnState(userID, NewUserCache(userID), csm) // Ask for A,B res, err := cs.HandleIncomingRequest(context.Background(), connID, &Request{ Sort: []string{SortByRecency}, @@ -464,7 +461,7 @@ func TestConnStateRoomSubscriptions(t *testing.T) { roomD.RoomID: roomD, }, } - cs := NewConnState(userID, csm) + cs := NewConnState(userID, NewUserCache(userID), csm) // subscribe to room D res, err := cs.HandleIncomingRequest(context.Background(), connID, &Request{ Sort: []string{SortByRecency}, diff --git a/sync3/handler.go b/sync3/handler.go index bfc59e4..2bd58a5 100644 --- a/sync3/handler.go +++ b/sync3/handler.go @@ -6,6 +6,7 @@ import ( "net/http" "os" "strconv" + "sync" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/sync-v3/internal" @@ -31,13 +32,19 @@ type SyncLiveHandler struct { V2Store *sync2.Storage PollerMap *sync2.PollerMap ConnMap *ConnMap + + // inserts are done by v2 poll loops, selects are done by v3 request threads + // but the v3 requests touch non-overlapping keys, which is a good use case for sync.Map + // > (2) when multiple goroutines read, write, and overwrite entries for disjoint sets of keys. + userCaches *sync.Map // map[user_id]*UserCache } func NewSync3Handler(v2Client sync2.Client, postgresDBURI string) (*SyncLiveHandler, error) { sh := &SyncLiveHandler{ - V2: v2Client, - Storage: state.NewStorage(postgresDBURI), - V2Store: sync2.NewStore(postgresDBURI), + V2: v2Client, + Storage: state.NewStorage(postgresDBURI), + V2Store: sync2.NewStore(postgresDBURI), + userCaches: &sync.Map{}, } sh.PollerMap = sync2.NewPollerMap(v2Client, sh) sh.ConnMap = NewConnMap(sh.Storage) @@ -198,6 +205,15 @@ func (h *SyncLiveHandler) setupConnection(req *http.Request, syncReq *Request, c hlog.FromRequest(req).With().Str("user_id", v2device.UserID).Logger(), ) + userCache, err := h.userCache(v2device.UserID) + if err != nil { + log.Warn().Err(err).Str("user_id", v2device.UserID).Msg("failed to load user cache") + return nil, &internal.HandlerError{ + StatusCode: 500, + Err: err, + } + } + // Now the v2 side of things are running, we can make a v3 live sync conn // NB: this isn't inherently racey (we did the check for an existing conn before EnsurePolling) // because we *either* do the existing check *or* make a new conn. It's important for CreateConn @@ -206,7 +222,7 @@ func (h *SyncLiveHandler) setupConnection(req *http.Request, syncReq *Request, c conn, created := h.ConnMap.GetOrCreateConn(ConnID{ SessionID: syncReq.SessionID, DeviceID: deviceID, - }, v2device.UserID) + }, v2device.UserID, userCache) if created { log.Info().Str("conn_id", conn.ConnID.String()).Msg("created new connection") } else { @@ -215,6 +231,23 @@ func (h *SyncLiveHandler) setupConnection(req *http.Request, syncReq *Request, c return conn, nil } +func (h *SyncLiveHandler) userCache(userID string) (*UserCache, error) { + c, ok := h.userCaches.Load(userID) + if ok { + return c.(*UserCache), nil + } + uc := NewUserCache(userID) + // select all non-zero highlight or notif counts and set them, as this is less costly than looping every room/user pair + err := h.Storage.UnreadTable.SelectAllNonZeroCountsForUser(userID, func(roomID string, highlightCount, notificationCount int) { + uc.OnUnreadCounts(roomID, &highlightCount, ¬ificationCount) + }) + if err != nil { + return nil, fmt.Errorf("failed to load unread counts: %s", err) + } + h.userCaches.Store(userID, uc) + return uc, nil +} + // Called from the v2 poller, implements V2DataReceiver func (h *SyncLiveHandler) UpdateDeviceSince(deviceID, since string) error { return h.V2Store.UpdateDeviceSince(deviceID, since) @@ -270,5 +303,9 @@ func (h *SyncLiveHandler) UpdateUnreadCounts(roomID, userID string, highlightCou if err != nil { logger.Err(err).Str("user", userID).Str("room", roomID).Msg("failed to update unread counters") } - h.ConnMap.OnUnreadCounts(roomID, userID, highlightCount, notifCount) + userCache, ok := h.userCaches.Load(userID) + if !ok { + return + } + userCache.(*UserCache).OnUnreadCounts(roomID, highlightCount, notifCount) } diff --git a/sync3/room.go b/sync3/room.go index 7d4e6e9..acd51fe 100644 --- a/sync3/room.go +++ b/sync3/room.go @@ -30,8 +30,3 @@ func (s SortableRooms) Len() int64 { func (s SortableRooms) Subslice(i, j int64) Subslicer { return s[i:j] } - -type userRoomData struct { - notificationCount int - highlightCount int -} diff --git a/sync3/usercache.go b/sync3/usercache.go new file mode 100644 index 0000000..cd2aa61 --- /dev/null +++ b/sync3/usercache.go @@ -0,0 +1,83 @@ +package sync3 + +import ( + "sync" +) + +type UserRoomData struct { + NotificationCount int + HighlightCount int +} + +type UserCacheListener interface { + OnUnreadCountsChanged(userID, roomID string, urd UserRoomData, hasCountDecreased bool) +} + +type UserCache struct { + userID string + roomToData map[string]UserRoomData + roomToDataMu *sync.RWMutex + listeners map[int]UserCacheListener + listenersMu *sync.Mutex + id int +} + +func NewUserCache(userID string) *UserCache { + return &UserCache{ + userID: userID, + roomToDataMu: &sync.RWMutex{}, + roomToData: make(map[string]UserRoomData), + listeners: make(map[int]UserCacheListener), + listenersMu: &sync.Mutex{}, + } +} + +func (c *UserCache) Subsribe(ucl UserCacheListener) (id int) { + c.listenersMu.Lock() + defer c.listenersMu.Unlock() + id = c.id + c.id += 1 + c.listeners[id] = ucl + return +} + +func (c *UserCache) Unsubscribe(id int) { + c.listenersMu.Lock() + defer c.listenersMu.Unlock() + delete(c.listeners, id) +} + +func (c *UserCache) loadRoomData(roomID string) UserRoomData { + c.roomToDataMu.RLock() + defer c.roomToDataMu.RUnlock() + data, ok := c.roomToData[roomID] + if !ok { + return UserRoomData{} + } + return data +} + +// ================================================= +// Listener functions called by v2 pollers are below +// ================================================= + +func (c *UserCache) OnUnreadCounts(roomID string, highlightCount, notifCount *int) { + data := c.loadRoomData(roomID) + hasCountDecreased := false + if highlightCount != nil { + hasCountDecreased = *highlightCount < data.HighlightCount + data.HighlightCount = *highlightCount + } + if notifCount != nil { + if !hasCountDecreased { + hasCountDecreased = *notifCount < data.NotificationCount + } + data.NotificationCount = *notifCount + } + c.roomToDataMu.Lock() + c.roomToData[roomID] = data + c.roomToDataMu.Unlock() + for _, l := range c.listeners { + l.OnUnreadCountsChanged(c.userID, roomID, data, hasCountDecreased) + } +}