Add UserCache and move unread count tracking to it

Keep it pure (not dependent on `state.Storage`) to make testing
easier. The responsibility for fanning out user cache updates
is with the Handler as it generally deals with glue code.
This commit is contained in:
Kegan Dougal 2021-10-11 16:22:41 +01:00
parent ab359c7ff3
commit 48613956d1
9 changed files with 174 additions and 112 deletions

View File

@ -23,9 +23,10 @@ func NewUnreadTable(db *sqlx.DB) *UnreadTable {
return &UnreadTable{db}
}
func (t *UnreadTable) SelectAllNonZeroCounts(callback func(roomID, userID string, highlightCount, notificationCount int)) error {
func (t *UnreadTable) SelectAllNonZeroCountsForUser(userID string, callback func(roomID string, highlightCount, notificationCount int)) error {
rows, err := t.db.Query(
`SELECT user_id, room_id, notification_count, highlight_count FROM syncv3_unread WHERE notification_count > 0 OR highlight_count > 0`,
`SELECT room_id, notification_count, highlight_count FROM syncv3_unread WHERE user_id=$1 AND (notification_count > 0 OR highlight_count > 0)`,
userID,
)
if err != nil {
return err
@ -33,13 +34,12 @@ func (t *UnreadTable) SelectAllNonZeroCounts(callback func(roomID, userID string
defer rows.Close()
for rows.Next() {
var roomID string
var userID string
var highlightCount int
var notifCount int
if err := rows.Scan(&userID, &roomID, &notifCount, &highlightCount); err != nil {
if err := rows.Scan(&roomID, &notifCount, &highlightCount); err != nil {
return err
}
callback(roomID, userID, highlightCount, notifCount)
callback(roomID, highlightCount, notifCount)
}
return nil
}

View File

@ -44,26 +44,23 @@ func TestUnreadTable(t *testing.T) {
roomA: 1,
roomB: 2,
}
assertNoError(t, table.SelectAllNonZeroCounts(func(gotRoomID string, gotUserID string, gotHighlight int, gotNotif int) {
if userID != gotUserID {
t.Errorf("SelectAllNonZeroCounts: got user %v want %v", gotUserID, userID)
}
assertNoError(t, table.SelectAllNonZeroCountsForUser(userID, func(gotRoomID string, gotHighlight int, gotNotif int) {
wantHighlight := wantHighlights[gotRoomID]
if wantHighlight != gotHighlight {
t.Errorf("SelectAllNonZeroCounts for %v got %d highlights, want %d", gotRoomID, gotHighlight, wantHighlight)
t.Errorf("SelectAllNonZeroCountsForUser for %v got %d highlights, want %d", gotRoomID, gotHighlight, wantHighlight)
}
wantNotif := wantNotifs[gotRoomID]
if wantNotif != gotNotif {
t.Errorf("SelectAllNonZeroCounts for %v got %d notifs, want %d", gotRoomID, gotNotif, wantNotif)
t.Errorf("SelectAllNonZeroCountsForUser for %v got %d notifs, want %d", gotRoomID, gotNotif, wantNotif)
}
delete(wantHighlights, gotRoomID)
delete(wantNotifs, gotRoomID)
}))
if len(wantHighlights) != 0 {
t.Errorf("SelectAllNonZeroCounts missed highlight rooms: %+v", wantHighlights)
t.Errorf("SelectAllNonZeroCountsForUser missed highlight rooms: %+v", wantHighlights)
}
if len(wantNotifs) != 0 {
t.Errorf("SelectAllNonZeroCounts missed notif rooms: %+v", wantNotifs)
t.Errorf("SelectAllNonZeroCountsForUser missed notif rooms: %+v", wantNotifs)
}
}

View File

@ -55,14 +55,6 @@ func (c *Conn) PushNewEvent(eventData *EventData) {
c.connState.PushNewEvent(eventData)
}
func (c *Conn) PushUserRoomData(userID, roomID string, data userRoomData, timestamp int64) {
c.connState.PushNewEvent(&EventData{
roomID: roomID,
userRoomData: &data,
timestamp: timestamp,
})
}
// OnIncomingRequest advances the clients position in the stream, returning the response position and data.
func (c *Conn) OnIncomingRequest(ctx context.Context, req *Request) (resp *Response, herr *internal.HandlerError) {
c.mu.Lock()

View File

@ -18,11 +18,12 @@ type EventData struct {
stateKey *string
content gjson.Result
timestamp int64
// TODO: remove or factor out
userRoomData *UserRoomData
// the absolute latest position for this event data. The NID for this event is guaranteed to
// be <= this value.
latestPos int64
userRoomData *userRoomData
}
// ConnMap stores a collection of Conns along with other global server-wide state e.g the in-memory
@ -46,24 +47,18 @@ type ConnMap struct {
globalRoomInfo map[string]*SortableRoom
mu *sync.Mutex
// inserts are done by v2 poll loops, selects are done by v3 request threads
// but the v3 requests touch non-overlapping keys, which is a good use case for sync.Map
// > (2) when multiple goroutines read, write, and overwrite entries for disjoint sets of keys.
perUserPerRoomData *sync.Map // map[string]userRoomData
store *state.Storage
}
func NewConnMap(store *state.Storage) *ConnMap {
cm := &ConnMap{
userIDToConn: make(map[string][]*Conn),
connIDToConn: make(map[string]*Conn),
cache: ttlcache.NewCache(),
mu: &sync.Mutex{},
jrt: NewJoinedRoomsTracker(),
store: store,
globalRoomInfo: make(map[string]*SortableRoom),
perUserPerRoomData: &sync.Map{},
userIDToConn: make(map[string][]*Conn),
connIDToConn: make(map[string]*Conn),
cache: ttlcache.NewCache(),
mu: &sync.Mutex{},
jrt: NewJoinedRoomsTracker(),
store: store,
globalRoomInfo: make(map[string]*SortableRoom),
}
cm.cache.SetTTL(30 * time.Minute) // TODO: customisable
cm.cache.SetExpirationCallback(cm.closeConn)
@ -80,7 +75,7 @@ func (m *ConnMap) Conn(cid ConnID) *Conn {
}
// Atomically gets or creates a connection with this connection ID.
func (m *ConnMap) GetOrCreateConn(cid ConnID, userID string) (*Conn, bool) {
func (m *ConnMap) GetOrCreateConn(cid ConnID, userID string, userCache *UserCache) (*Conn, bool) {
// atomically check if a conn exists already and return that if so
m.mu.Lock()
defer m.mu.Unlock()
@ -88,7 +83,7 @@ func (m *ConnMap) GetOrCreateConn(cid ConnID, userID string) (*Conn, bool) {
if conn != nil {
return conn, false
}
state := NewConnState(userID, m)
state := NewConnState(userID, userCache, m)
conn = NewConn(cid, state, state.HandleIncomingRequest)
m.cache.Set(cid.String(), conn)
m.connIDToConn[cid.String()] = conn
@ -147,13 +142,6 @@ func (m *ConnMap) LoadBaseline(roomIDToUserIDs map[string][]string) error {
m.jrt.UserJoinedRoom(userID, roomID)
}
}
// select all non-zero highlight or notif counts and set them, as this is less costly than looping every room/user pair
err = m.store.UnreadTable.SelectAllNonZeroCounts(func(roomID, userID string, highlightCount, notificationCount int) {
m.OnUnreadCounts(roomID, userID, &highlightCount, &notificationCount)
})
if err != nil {
return fmt.Errorf("failed to load unread counts: %s", err)
}
return nil
}
@ -237,48 +225,6 @@ func (m *ConnMap) closeConn(connID string, value interface{}) {
}
}
func (m *ConnMap) LoadUserRoomData(roomID, userID string) userRoomData {
key := userID + " " + roomID
data, ok := m.perUserPerRoomData.Load(key)
if !ok {
return userRoomData{}
}
return data.(userRoomData)
}
// TODO: Move to cache struct
func (m *ConnMap) OnUnreadCounts(roomID, userID string, highlightCount, notifCount *int) {
data := m.LoadUserRoomData(roomID, userID)
hasCountDecreased := false
if highlightCount != nil {
hasCountDecreased = *highlightCount < data.highlightCount
data.highlightCount = *highlightCount
}
if notifCount != nil {
if !hasCountDecreased {
hasCountDecreased = *notifCount < data.notificationCount
}
data.notificationCount = *notifCount
}
key := userID + " " + roomID
m.perUserPerRoomData.Store(key, data)
if hasCountDecreased {
// we will notify the connection for count decreases so the client can update their badge counter.
// we don't do this on increases as this should always be associated with an actual event which
// we will notify the connection for (and unread counts are processed prior to this). By doing
// this we ensure atomic style updates of badge counts and events, rather than getting the badge
// count update without a message.
m.mu.Lock()
conns := m.userIDToConn[userID]
m.mu.Unlock()
room := m.LoadRoom(roomID)
// TODO: don't indirect via conn :S this is dumb
for _, conn := range conns {
conn.PushUserRoomData(userID, roomID, data, room.LastMessageTimestamp)
}
}
}
// TODO: Move to cache struct
// Call this when there is a new event received on a v2 stream.
// This event must be globally unique, i.e indicated so by the state store.

View File

@ -17,7 +17,6 @@ var (
type ConnStateStore interface {
LoadRoom(roomID string) *SortableRoom
LoadUserRoomData(roomID, userID string) userRoomData
LoadState(roomID string, loadPosition int64, requiredState [][2]string) []json.RawMessage
Load(userID string) (joinedRoomIDs []string, initialLoadPosition int64, err error)
}
@ -36,11 +35,14 @@ type ConnState struct {
// Consumed when the conn is read. There is a limit to how many updates we will store before
// saying the client is ded and cleaning up the conn.
updateEvents chan *EventData
userCache *UserCache
}
func NewConnState(userID string, store ConnStateStore) *ConnState {
func NewConnState(userID string, userCache *UserCache, store ConnStateStore) *ConnState {
return &ConnState{
store: store,
userCache: userCache,
userID: userID,
roomSubscriptions: make(map[string]RoomSubscription),
sortedJoinedRoomsPositions: make(map[string]int),
@ -60,6 +62,7 @@ func NewConnState(userID string, store ConnStateStore) *ConnState {
// - load() bases its current state based on the latest position, which includes processing of these N events.
// - post load() we read N events, processing them a 2nd time.
func (s *ConnState) load(req *Request) error {
s.userCache.Subsribe(s)
joinedRoomIDs, initialLoadPosition, err := s.store.Load(s.userID)
if err != nil {
return err
@ -311,11 +314,11 @@ func (s *ConnState) updateRoomSubscriptions(subs, unsubs []string) map[string]Ro
}
func (s *ConnState) getDeltaRoomData(updateEvent *EventData) *Room {
userRoomData := s.store.LoadUserRoomData(updateEvent.roomID, s.userID)
userRoomData := s.userCache.loadRoomData(updateEvent.roomID)
room := &Room{
RoomID: updateEvent.roomID,
NotificationCount: int64(userRoomData.notificationCount),
HighlightCount: int64(userRoomData.highlightCount),
NotificationCount: int64(userRoomData.NotificationCount),
HighlightCount: int64(userRoomData.HighlightCount),
}
if updateEvent.event != nil {
room.Timeline = []json.RawMessage{
@ -327,12 +330,12 @@ func (s *ConnState) getDeltaRoomData(updateEvent *EventData) *Room {
func (s *ConnState) getInitialRoomData(roomID string) *Room {
r := s.store.LoadRoom(roomID)
userRoomData := s.store.LoadUserRoomData(roomID, s.userID)
userRoomData := s.userCache.loadRoomData(roomID)
return &Room{
RoomID: roomID,
Name: r.Name,
NotificationCount: int64(userRoomData.notificationCount),
HighlightCount: int64(userRoomData.highlightCount),
NotificationCount: int64(userRoomData.NotificationCount),
HighlightCount: int64(userRoomData.HighlightCount),
// TODO: timeline limits
Timeline: []json.RawMessage{
r.LastEventJSON,
@ -395,3 +398,15 @@ func (s *ConnState) moveRoom(updateEvent *EventData, fromIndex, toIndex int, ran
}
}
func (s *ConnState) OnUnreadCountsChanged(userID, roomID string, urd UserRoomData, hasCountDecreased bool) {
if !hasCountDecreased {
return
}
room := s.store.LoadRoom(roomID)
s.PushNewEvent(&EventData{
roomID: roomID,
userRoomData: &urd,
timestamp: room.LastMessageTimestamp,
})
}

View File

@ -42,9 +42,6 @@ func (s *connStateStoreMock) Load(userID string) (joinedRoomIDs []string, initia
func (s *connStateStoreMock) LoadState(roomID string, loadPosition int64, requiredState [][2]string) []json.RawMessage {
return nil
}
func (s *connStateStoreMock) LoadUserRoomData(roomID, userID string) userRoomData {
return userRoomData{}
}
func (s *connStateStoreMock) PushNewEvent(cs *ConnState, ed *EventData) {
room := s.roomIDToRoom[ed.roomID]
room.LastEventJSON = ed.event
@ -80,7 +77,7 @@ func TestConnStateInitial(t *testing.T) {
roomC.RoomID: roomC,
},
}
cs := NewConnState(userID, csm)
cs := NewConnState(userID, NewUserCache(userID), csm)
if userID != cs.UserID() {
t.Fatalf("UserID returned wrong value, got %v want %v", cs.UserID(), userID)
}
@ -217,7 +214,7 @@ func TestConnStateMultipleRanges(t *testing.T) {
},
roomIDToRoom: roomIDToRoom,
}
cs := NewConnState(userID, csm)
cs := NewConnState(userID, NewUserCache(userID), csm)
// request first page
res, err := cs.HandleIncomingRequest(context.Background(), connID, &Request{
@ -384,7 +381,7 @@ func TestBumpToOutsideRange(t *testing.T) {
roomD.RoomID: roomD,
},
}
cs := NewConnState(userID, csm)
cs := NewConnState(userID, NewUserCache(userID), csm)
// Ask for A,B
res, err := cs.HandleIncomingRequest(context.Background(), connID, &Request{
Sort: []string{SortByRecency},
@ -464,7 +461,7 @@ func TestConnStateRoomSubscriptions(t *testing.T) {
roomD.RoomID: roomD,
},
}
cs := NewConnState(userID, csm)
cs := NewConnState(userID, NewUserCache(userID), csm)
// subscribe to room D
res, err := cs.HandleIncomingRequest(context.Background(), connID, &Request{
Sort: []string{SortByRecency},

View File

@ -6,6 +6,7 @@ import (
"net/http"
"os"
"strconv"
"sync"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/sync-v3/internal"
@ -31,13 +32,19 @@ type SyncLiveHandler struct {
V2Store *sync2.Storage
PollerMap *sync2.PollerMap
ConnMap *ConnMap
// inserts are done by v2 poll loops, selects are done by v3 request threads
// but the v3 requests touch non-overlapping keys, which is a good use case for sync.Map
// > (2) when multiple goroutines read, write, and overwrite entries for disjoint sets of keys.
userCaches *sync.Map // map[user_id]*UserCache
}
func NewSync3Handler(v2Client sync2.Client, postgresDBURI string) (*SyncLiveHandler, error) {
sh := &SyncLiveHandler{
V2: v2Client,
Storage: state.NewStorage(postgresDBURI),
V2Store: sync2.NewStore(postgresDBURI),
V2: v2Client,
Storage: state.NewStorage(postgresDBURI),
V2Store: sync2.NewStore(postgresDBURI),
userCaches: &sync.Map{},
}
sh.PollerMap = sync2.NewPollerMap(v2Client, sh)
sh.ConnMap = NewConnMap(sh.Storage)
@ -198,6 +205,15 @@ func (h *SyncLiveHandler) setupConnection(req *http.Request, syncReq *Request, c
hlog.FromRequest(req).With().Str("user_id", v2device.UserID).Logger(),
)
userCache, err := h.userCache(v2device.UserID)
if err != nil {
log.Warn().Err(err).Str("user_id", v2device.UserID).Msg("failed to load user cache")
return nil, &internal.HandlerError{
StatusCode: 500,
Err: err,
}
}
// Now the v2 side of things are running, we can make a v3 live sync conn
// NB: this isn't inherently racey (we did the check for an existing conn before EnsurePolling)
// because we *either* do the existing check *or* make a new conn. It's important for CreateConn
@ -206,7 +222,7 @@ func (h *SyncLiveHandler) setupConnection(req *http.Request, syncReq *Request, c
conn, created := h.ConnMap.GetOrCreateConn(ConnID{
SessionID: syncReq.SessionID,
DeviceID: deviceID,
}, v2device.UserID)
}, v2device.UserID, userCache)
if created {
log.Info().Str("conn_id", conn.ConnID.String()).Msg("created new connection")
} else {
@ -215,6 +231,23 @@ func (h *SyncLiveHandler) setupConnection(req *http.Request, syncReq *Request, c
return conn, nil
}
func (h *SyncLiveHandler) userCache(userID string) (*UserCache, error) {
c, ok := h.userCaches.Load(userID)
if ok {
return c.(*UserCache), nil
}
uc := NewUserCache(userID)
// select all non-zero highlight or notif counts and set them, as this is less costly than looping every room/user pair
err := h.Storage.UnreadTable.SelectAllNonZeroCountsForUser(userID, func(roomID string, highlightCount, notificationCount int) {
uc.OnUnreadCounts(roomID, &highlightCount, &notificationCount)
})
if err != nil {
return nil, fmt.Errorf("failed to load unread counts: %s", err)
}
h.userCaches.Store(userID, uc)
return uc, nil
}
// Called from the v2 poller, implements V2DataReceiver
func (h *SyncLiveHandler) UpdateDeviceSince(deviceID, since string) error {
return h.V2Store.UpdateDeviceSince(deviceID, since)
@ -270,5 +303,9 @@ func (h *SyncLiveHandler) UpdateUnreadCounts(roomID, userID string, highlightCou
if err != nil {
logger.Err(err).Str("user", userID).Str("room", roomID).Msg("failed to update unread counters")
}
h.ConnMap.OnUnreadCounts(roomID, userID, highlightCount, notifCount)
userCache, ok := h.userCaches.Load(userID)
if !ok {
return
}
userCache.(*UserCache).OnUnreadCounts(roomID, highlightCount, notifCount)
}

View File

@ -30,8 +30,3 @@ func (s SortableRooms) Len() int64 {
func (s SortableRooms) Subslice(i, j int64) Subslicer {
return s[i:j]
}
type userRoomData struct {
notificationCount int
highlightCount int
}

83
sync3/usercache.go Normal file
View File

@ -0,0 +1,83 @@
package sync3
import (
"sync"
)
type UserRoomData struct {
NotificationCount int
HighlightCount int
}
type UserCacheListener interface {
OnUnreadCountsChanged(userID, roomID string, urd UserRoomData, hasCountDecreased bool)
}
type UserCache struct {
userID string
roomToData map[string]UserRoomData
roomToDataMu *sync.RWMutex
listeners map[int]UserCacheListener
listenersMu *sync.Mutex
id int
}
func NewUserCache(userID string) *UserCache {
return &UserCache{
userID: userID,
roomToDataMu: &sync.RWMutex{},
roomToData: make(map[string]UserRoomData),
listeners: make(map[int]UserCacheListener),
listenersMu: &sync.Mutex{},
}
}
func (c *UserCache) Subsribe(ucl UserCacheListener) (id int) {
c.listenersMu.Lock()
defer c.listenersMu.Unlock()
id = c.id
c.id += 1
c.listeners[id] = ucl
return
}
func (c *UserCache) Unsubscribe(id int) {
c.listenersMu.Lock()
defer c.listenersMu.Unlock()
delete(c.listeners, id)
}
func (c *UserCache) loadRoomData(roomID string) UserRoomData {
c.roomToDataMu.RLock()
defer c.roomToDataMu.RUnlock()
data, ok := c.roomToData[roomID]
if !ok {
return UserRoomData{}
}
return data
}
// =================================================
// Listener functions called by v2 pollers are below
// =================================================
func (c *UserCache) OnUnreadCounts(roomID string, highlightCount, notifCount *int) {
data := c.loadRoomData(roomID)
hasCountDecreased := false
if highlightCount != nil {
hasCountDecreased = *highlightCount < data.HighlightCount
data.HighlightCount = *highlightCount
}
if notifCount != nil {
if !hasCountDecreased {
hasCountDecreased = *notifCount < data.NotificationCount
}
data.NotificationCount = *notifCount
}
c.roomToDataMu.Lock()
c.roomToData[roomID] = data
c.roomToDataMu.Unlock()
for _, l := range c.listeners {
l.OnUnreadCountsChanged(c.userID, roomID, data, hasCountDecreased)
}
}