mirror of
https://github.com/matrix-org/sliding-sync.git
synced 2025-03-10 13:37:11 +00:00
Add UserCache and move unread count tracking to it
Keep it pure (not dependent on `state.Storage`) to make testing easier. The responsibility for fanning out user cache updates is with the Handler as it generally deals with glue code.
This commit is contained in:
parent
ab359c7ff3
commit
48613956d1
@ -23,9 +23,10 @@ func NewUnreadTable(db *sqlx.DB) *UnreadTable {
|
|||||||
return &UnreadTable{db}
|
return &UnreadTable{db}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *UnreadTable) SelectAllNonZeroCounts(callback func(roomID, userID string, highlightCount, notificationCount int)) error {
|
func (t *UnreadTable) SelectAllNonZeroCountsForUser(userID string, callback func(roomID string, highlightCount, notificationCount int)) error {
|
||||||
rows, err := t.db.Query(
|
rows, err := t.db.Query(
|
||||||
`SELECT user_id, room_id, notification_count, highlight_count FROM syncv3_unread WHERE notification_count > 0 OR highlight_count > 0`,
|
`SELECT room_id, notification_count, highlight_count FROM syncv3_unread WHERE user_id=$1 AND (notification_count > 0 OR highlight_count > 0)`,
|
||||||
|
userID,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -33,13 +34,12 @@ func (t *UnreadTable) SelectAllNonZeroCounts(callback func(roomID, userID string
|
|||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var roomID string
|
var roomID string
|
||||||
var userID string
|
|
||||||
var highlightCount int
|
var highlightCount int
|
||||||
var notifCount int
|
var notifCount int
|
||||||
if err := rows.Scan(&userID, &roomID, ¬ifCount, &highlightCount); err != nil {
|
if err := rows.Scan(&roomID, ¬ifCount, &highlightCount); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
callback(roomID, userID, highlightCount, notifCount)
|
callback(roomID, highlightCount, notifCount)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -44,26 +44,23 @@ func TestUnreadTable(t *testing.T) {
|
|||||||
roomA: 1,
|
roomA: 1,
|
||||||
roomB: 2,
|
roomB: 2,
|
||||||
}
|
}
|
||||||
assertNoError(t, table.SelectAllNonZeroCounts(func(gotRoomID string, gotUserID string, gotHighlight int, gotNotif int) {
|
assertNoError(t, table.SelectAllNonZeroCountsForUser(userID, func(gotRoomID string, gotHighlight int, gotNotif int) {
|
||||||
if userID != gotUserID {
|
|
||||||
t.Errorf("SelectAllNonZeroCounts: got user %v want %v", gotUserID, userID)
|
|
||||||
}
|
|
||||||
wantHighlight := wantHighlights[gotRoomID]
|
wantHighlight := wantHighlights[gotRoomID]
|
||||||
if wantHighlight != gotHighlight {
|
if wantHighlight != gotHighlight {
|
||||||
t.Errorf("SelectAllNonZeroCounts for %v got %d highlights, want %d", gotRoomID, gotHighlight, wantHighlight)
|
t.Errorf("SelectAllNonZeroCountsForUser for %v got %d highlights, want %d", gotRoomID, gotHighlight, wantHighlight)
|
||||||
}
|
}
|
||||||
wantNotif := wantNotifs[gotRoomID]
|
wantNotif := wantNotifs[gotRoomID]
|
||||||
if wantNotif != gotNotif {
|
if wantNotif != gotNotif {
|
||||||
t.Errorf("SelectAllNonZeroCounts for %v got %d notifs, want %d", gotRoomID, gotNotif, wantNotif)
|
t.Errorf("SelectAllNonZeroCountsForUser for %v got %d notifs, want %d", gotRoomID, gotNotif, wantNotif)
|
||||||
}
|
}
|
||||||
delete(wantHighlights, gotRoomID)
|
delete(wantHighlights, gotRoomID)
|
||||||
delete(wantNotifs, gotRoomID)
|
delete(wantNotifs, gotRoomID)
|
||||||
}))
|
}))
|
||||||
if len(wantHighlights) != 0 {
|
if len(wantHighlights) != 0 {
|
||||||
t.Errorf("SelectAllNonZeroCounts missed highlight rooms: %+v", wantHighlights)
|
t.Errorf("SelectAllNonZeroCountsForUser missed highlight rooms: %+v", wantHighlights)
|
||||||
}
|
}
|
||||||
if len(wantNotifs) != 0 {
|
if len(wantNotifs) != 0 {
|
||||||
t.Errorf("SelectAllNonZeroCounts missed notif rooms: %+v", wantNotifs)
|
t.Errorf("SelectAllNonZeroCountsForUser missed notif rooms: %+v", wantNotifs)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -55,14 +55,6 @@ func (c *Conn) PushNewEvent(eventData *EventData) {
|
|||||||
c.connState.PushNewEvent(eventData)
|
c.connState.PushNewEvent(eventData)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) PushUserRoomData(userID, roomID string, data userRoomData, timestamp int64) {
|
|
||||||
c.connState.PushNewEvent(&EventData{
|
|
||||||
roomID: roomID,
|
|
||||||
userRoomData: &data,
|
|
||||||
timestamp: timestamp,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// OnIncomingRequest advances the clients position in the stream, returning the response position and data.
|
// OnIncomingRequest advances the clients position in the stream, returning the response position and data.
|
||||||
func (c *Conn) OnIncomingRequest(ctx context.Context, req *Request) (resp *Response, herr *internal.HandlerError) {
|
func (c *Conn) OnIncomingRequest(ctx context.Context, req *Request) (resp *Response, herr *internal.HandlerError) {
|
||||||
c.mu.Lock()
|
c.mu.Lock()
|
||||||
|
@ -18,11 +18,12 @@ type EventData struct {
|
|||||||
stateKey *string
|
stateKey *string
|
||||||
content gjson.Result
|
content gjson.Result
|
||||||
timestamp int64
|
timestamp int64
|
||||||
|
|
||||||
|
// TODO: remove or factor out
|
||||||
|
userRoomData *UserRoomData
|
||||||
// the absolute latest position for this event data. The NID for this event is guaranteed to
|
// the absolute latest position for this event data. The NID for this event is guaranteed to
|
||||||
// be <= this value.
|
// be <= this value.
|
||||||
latestPos int64
|
latestPos int64
|
||||||
|
|
||||||
userRoomData *userRoomData
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ConnMap stores a collection of Conns along with other global server-wide state e.g the in-memory
|
// ConnMap stores a collection of Conns along with other global server-wide state e.g the in-memory
|
||||||
@ -46,24 +47,18 @@ type ConnMap struct {
|
|||||||
globalRoomInfo map[string]*SortableRoom
|
globalRoomInfo map[string]*SortableRoom
|
||||||
mu *sync.Mutex
|
mu *sync.Mutex
|
||||||
|
|
||||||
// inserts are done by v2 poll loops, selects are done by v3 request threads
|
|
||||||
// but the v3 requests touch non-overlapping keys, which is a good use case for sync.Map
|
|
||||||
// > (2) when multiple goroutines read, write, and overwrite entries for disjoint sets of keys.
|
|
||||||
perUserPerRoomData *sync.Map // map[string]userRoomData
|
|
||||||
|
|
||||||
store *state.Storage
|
store *state.Storage
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewConnMap(store *state.Storage) *ConnMap {
|
func NewConnMap(store *state.Storage) *ConnMap {
|
||||||
cm := &ConnMap{
|
cm := &ConnMap{
|
||||||
userIDToConn: make(map[string][]*Conn),
|
userIDToConn: make(map[string][]*Conn),
|
||||||
connIDToConn: make(map[string]*Conn),
|
connIDToConn: make(map[string]*Conn),
|
||||||
cache: ttlcache.NewCache(),
|
cache: ttlcache.NewCache(),
|
||||||
mu: &sync.Mutex{},
|
mu: &sync.Mutex{},
|
||||||
jrt: NewJoinedRoomsTracker(),
|
jrt: NewJoinedRoomsTracker(),
|
||||||
store: store,
|
store: store,
|
||||||
globalRoomInfo: make(map[string]*SortableRoom),
|
globalRoomInfo: make(map[string]*SortableRoom),
|
||||||
perUserPerRoomData: &sync.Map{},
|
|
||||||
}
|
}
|
||||||
cm.cache.SetTTL(30 * time.Minute) // TODO: customisable
|
cm.cache.SetTTL(30 * time.Minute) // TODO: customisable
|
||||||
cm.cache.SetExpirationCallback(cm.closeConn)
|
cm.cache.SetExpirationCallback(cm.closeConn)
|
||||||
@ -80,7 +75,7 @@ func (m *ConnMap) Conn(cid ConnID) *Conn {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Atomically gets or creates a connection with this connection ID.
|
// Atomically gets or creates a connection with this connection ID.
|
||||||
func (m *ConnMap) GetOrCreateConn(cid ConnID, userID string) (*Conn, bool) {
|
func (m *ConnMap) GetOrCreateConn(cid ConnID, userID string, userCache *UserCache) (*Conn, bool) {
|
||||||
// atomically check if a conn exists already and return that if so
|
// atomically check if a conn exists already and return that if so
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
defer m.mu.Unlock()
|
defer m.mu.Unlock()
|
||||||
@ -88,7 +83,7 @@ func (m *ConnMap) GetOrCreateConn(cid ConnID, userID string) (*Conn, bool) {
|
|||||||
if conn != nil {
|
if conn != nil {
|
||||||
return conn, false
|
return conn, false
|
||||||
}
|
}
|
||||||
state := NewConnState(userID, m)
|
state := NewConnState(userID, userCache, m)
|
||||||
conn = NewConn(cid, state, state.HandleIncomingRequest)
|
conn = NewConn(cid, state, state.HandleIncomingRequest)
|
||||||
m.cache.Set(cid.String(), conn)
|
m.cache.Set(cid.String(), conn)
|
||||||
m.connIDToConn[cid.String()] = conn
|
m.connIDToConn[cid.String()] = conn
|
||||||
@ -147,13 +142,6 @@ func (m *ConnMap) LoadBaseline(roomIDToUserIDs map[string][]string) error {
|
|||||||
m.jrt.UserJoinedRoom(userID, roomID)
|
m.jrt.UserJoinedRoom(userID, roomID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// select all non-zero highlight or notif counts and set them, as this is less costly than looping every room/user pair
|
|
||||||
err = m.store.UnreadTable.SelectAllNonZeroCounts(func(roomID, userID string, highlightCount, notificationCount int) {
|
|
||||||
m.OnUnreadCounts(roomID, userID, &highlightCount, ¬ificationCount)
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to load unread counts: %s", err)
|
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -237,48 +225,6 @@ func (m *ConnMap) closeConn(connID string, value interface{}) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *ConnMap) LoadUserRoomData(roomID, userID string) userRoomData {
|
|
||||||
key := userID + " " + roomID
|
|
||||||
data, ok := m.perUserPerRoomData.Load(key)
|
|
||||||
if !ok {
|
|
||||||
return userRoomData{}
|
|
||||||
}
|
|
||||||
return data.(userRoomData)
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: Move to cache struct
|
|
||||||
func (m *ConnMap) OnUnreadCounts(roomID, userID string, highlightCount, notifCount *int) {
|
|
||||||
data := m.LoadUserRoomData(roomID, userID)
|
|
||||||
hasCountDecreased := false
|
|
||||||
if highlightCount != nil {
|
|
||||||
hasCountDecreased = *highlightCount < data.highlightCount
|
|
||||||
data.highlightCount = *highlightCount
|
|
||||||
}
|
|
||||||
if notifCount != nil {
|
|
||||||
if !hasCountDecreased {
|
|
||||||
hasCountDecreased = *notifCount < data.notificationCount
|
|
||||||
}
|
|
||||||
data.notificationCount = *notifCount
|
|
||||||
}
|
|
||||||
key := userID + " " + roomID
|
|
||||||
m.perUserPerRoomData.Store(key, data)
|
|
||||||
if hasCountDecreased {
|
|
||||||
// we will notify the connection for count decreases so the client can update their badge counter.
|
|
||||||
// we don't do this on increases as this should always be associated with an actual event which
|
|
||||||
// we will notify the connection for (and unread counts are processed prior to this). By doing
|
|
||||||
// this we ensure atomic style updates of badge counts and events, rather than getting the badge
|
|
||||||
// count update without a message.
|
|
||||||
m.mu.Lock()
|
|
||||||
conns := m.userIDToConn[userID]
|
|
||||||
m.mu.Unlock()
|
|
||||||
room := m.LoadRoom(roomID)
|
|
||||||
// TODO: don't indirect via conn :S this is dumb
|
|
||||||
for _, conn := range conns {
|
|
||||||
conn.PushUserRoomData(userID, roomID, data, room.LastMessageTimestamp)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: Move to cache struct
|
// TODO: Move to cache struct
|
||||||
// Call this when there is a new event received on a v2 stream.
|
// Call this when there is a new event received on a v2 stream.
|
||||||
// This event must be globally unique, i.e indicated so by the state store.
|
// This event must be globally unique, i.e indicated so by the state store.
|
||||||
|
@ -17,7 +17,6 @@ var (
|
|||||||
|
|
||||||
type ConnStateStore interface {
|
type ConnStateStore interface {
|
||||||
LoadRoom(roomID string) *SortableRoom
|
LoadRoom(roomID string) *SortableRoom
|
||||||
LoadUserRoomData(roomID, userID string) userRoomData
|
|
||||||
LoadState(roomID string, loadPosition int64, requiredState [][2]string) []json.RawMessage
|
LoadState(roomID string, loadPosition int64, requiredState [][2]string) []json.RawMessage
|
||||||
Load(userID string) (joinedRoomIDs []string, initialLoadPosition int64, err error)
|
Load(userID string) (joinedRoomIDs []string, initialLoadPosition int64, err error)
|
||||||
}
|
}
|
||||||
@ -36,11 +35,14 @@ type ConnState struct {
|
|||||||
// Consumed when the conn is read. There is a limit to how many updates we will store before
|
// Consumed when the conn is read. There is a limit to how many updates we will store before
|
||||||
// saying the client is ded and cleaning up the conn.
|
// saying the client is ded and cleaning up the conn.
|
||||||
updateEvents chan *EventData
|
updateEvents chan *EventData
|
||||||
|
|
||||||
|
userCache *UserCache
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewConnState(userID string, store ConnStateStore) *ConnState {
|
func NewConnState(userID string, userCache *UserCache, store ConnStateStore) *ConnState {
|
||||||
return &ConnState{
|
return &ConnState{
|
||||||
store: store,
|
store: store,
|
||||||
|
userCache: userCache,
|
||||||
userID: userID,
|
userID: userID,
|
||||||
roomSubscriptions: make(map[string]RoomSubscription),
|
roomSubscriptions: make(map[string]RoomSubscription),
|
||||||
sortedJoinedRoomsPositions: make(map[string]int),
|
sortedJoinedRoomsPositions: make(map[string]int),
|
||||||
@ -60,6 +62,7 @@ func NewConnState(userID string, store ConnStateStore) *ConnState {
|
|||||||
// - load() bases its current state based on the latest position, which includes processing of these N events.
|
// - load() bases its current state based on the latest position, which includes processing of these N events.
|
||||||
// - post load() we read N events, processing them a 2nd time.
|
// - post load() we read N events, processing them a 2nd time.
|
||||||
func (s *ConnState) load(req *Request) error {
|
func (s *ConnState) load(req *Request) error {
|
||||||
|
s.userCache.Subsribe(s)
|
||||||
joinedRoomIDs, initialLoadPosition, err := s.store.Load(s.userID)
|
joinedRoomIDs, initialLoadPosition, err := s.store.Load(s.userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -311,11 +314,11 @@ func (s *ConnState) updateRoomSubscriptions(subs, unsubs []string) map[string]Ro
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *ConnState) getDeltaRoomData(updateEvent *EventData) *Room {
|
func (s *ConnState) getDeltaRoomData(updateEvent *EventData) *Room {
|
||||||
userRoomData := s.store.LoadUserRoomData(updateEvent.roomID, s.userID)
|
userRoomData := s.userCache.loadRoomData(updateEvent.roomID)
|
||||||
room := &Room{
|
room := &Room{
|
||||||
RoomID: updateEvent.roomID,
|
RoomID: updateEvent.roomID,
|
||||||
NotificationCount: int64(userRoomData.notificationCount),
|
NotificationCount: int64(userRoomData.NotificationCount),
|
||||||
HighlightCount: int64(userRoomData.highlightCount),
|
HighlightCount: int64(userRoomData.HighlightCount),
|
||||||
}
|
}
|
||||||
if updateEvent.event != nil {
|
if updateEvent.event != nil {
|
||||||
room.Timeline = []json.RawMessage{
|
room.Timeline = []json.RawMessage{
|
||||||
@ -327,12 +330,12 @@ func (s *ConnState) getDeltaRoomData(updateEvent *EventData) *Room {
|
|||||||
|
|
||||||
func (s *ConnState) getInitialRoomData(roomID string) *Room {
|
func (s *ConnState) getInitialRoomData(roomID string) *Room {
|
||||||
r := s.store.LoadRoom(roomID)
|
r := s.store.LoadRoom(roomID)
|
||||||
userRoomData := s.store.LoadUserRoomData(roomID, s.userID)
|
userRoomData := s.userCache.loadRoomData(roomID)
|
||||||
return &Room{
|
return &Room{
|
||||||
RoomID: roomID,
|
RoomID: roomID,
|
||||||
Name: r.Name,
|
Name: r.Name,
|
||||||
NotificationCount: int64(userRoomData.notificationCount),
|
NotificationCount: int64(userRoomData.NotificationCount),
|
||||||
HighlightCount: int64(userRoomData.highlightCount),
|
HighlightCount: int64(userRoomData.HighlightCount),
|
||||||
// TODO: timeline limits
|
// TODO: timeline limits
|
||||||
Timeline: []json.RawMessage{
|
Timeline: []json.RawMessage{
|
||||||
r.LastEventJSON,
|
r.LastEventJSON,
|
||||||
@ -395,3 +398,15 @@ func (s *ConnState) moveRoom(updateEvent *EventData, fromIndex, toIndex int, ran
|
|||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *ConnState) OnUnreadCountsChanged(userID, roomID string, urd UserRoomData, hasCountDecreased bool) {
|
||||||
|
if !hasCountDecreased {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
room := s.store.LoadRoom(roomID)
|
||||||
|
s.PushNewEvent(&EventData{
|
||||||
|
roomID: roomID,
|
||||||
|
userRoomData: &urd,
|
||||||
|
timestamp: room.LastMessageTimestamp,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
@ -42,9 +42,6 @@ func (s *connStateStoreMock) Load(userID string) (joinedRoomIDs []string, initia
|
|||||||
func (s *connStateStoreMock) LoadState(roomID string, loadPosition int64, requiredState [][2]string) []json.RawMessage {
|
func (s *connStateStoreMock) LoadState(roomID string, loadPosition int64, requiredState [][2]string) []json.RawMessage {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
func (s *connStateStoreMock) LoadUserRoomData(roomID, userID string) userRoomData {
|
|
||||||
return userRoomData{}
|
|
||||||
}
|
|
||||||
func (s *connStateStoreMock) PushNewEvent(cs *ConnState, ed *EventData) {
|
func (s *connStateStoreMock) PushNewEvent(cs *ConnState, ed *EventData) {
|
||||||
room := s.roomIDToRoom[ed.roomID]
|
room := s.roomIDToRoom[ed.roomID]
|
||||||
room.LastEventJSON = ed.event
|
room.LastEventJSON = ed.event
|
||||||
@ -80,7 +77,7 @@ func TestConnStateInitial(t *testing.T) {
|
|||||||
roomC.RoomID: roomC,
|
roomC.RoomID: roomC,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
cs := NewConnState(userID, csm)
|
cs := NewConnState(userID, NewUserCache(userID), csm)
|
||||||
if userID != cs.UserID() {
|
if userID != cs.UserID() {
|
||||||
t.Fatalf("UserID returned wrong value, got %v want %v", cs.UserID(), userID)
|
t.Fatalf("UserID returned wrong value, got %v want %v", cs.UserID(), userID)
|
||||||
}
|
}
|
||||||
@ -217,7 +214,7 @@ func TestConnStateMultipleRanges(t *testing.T) {
|
|||||||
},
|
},
|
||||||
roomIDToRoom: roomIDToRoom,
|
roomIDToRoom: roomIDToRoom,
|
||||||
}
|
}
|
||||||
cs := NewConnState(userID, csm)
|
cs := NewConnState(userID, NewUserCache(userID), csm)
|
||||||
|
|
||||||
// request first page
|
// request first page
|
||||||
res, err := cs.HandleIncomingRequest(context.Background(), connID, &Request{
|
res, err := cs.HandleIncomingRequest(context.Background(), connID, &Request{
|
||||||
@ -384,7 +381,7 @@ func TestBumpToOutsideRange(t *testing.T) {
|
|||||||
roomD.RoomID: roomD,
|
roomD.RoomID: roomD,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
cs := NewConnState(userID, csm)
|
cs := NewConnState(userID, NewUserCache(userID), csm)
|
||||||
// Ask for A,B
|
// Ask for A,B
|
||||||
res, err := cs.HandleIncomingRequest(context.Background(), connID, &Request{
|
res, err := cs.HandleIncomingRequest(context.Background(), connID, &Request{
|
||||||
Sort: []string{SortByRecency},
|
Sort: []string{SortByRecency},
|
||||||
@ -464,7 +461,7 @@ func TestConnStateRoomSubscriptions(t *testing.T) {
|
|||||||
roomD.RoomID: roomD,
|
roomD.RoomID: roomD,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
cs := NewConnState(userID, csm)
|
cs := NewConnState(userID, NewUserCache(userID), csm)
|
||||||
// subscribe to room D
|
// subscribe to room D
|
||||||
res, err := cs.HandleIncomingRequest(context.Background(), connID, &Request{
|
res, err := cs.HandleIncomingRequest(context.Background(), connID, &Request{
|
||||||
Sort: []string{SortByRecency},
|
Sort: []string{SortByRecency},
|
||||||
|
@ -6,6 +6,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"sync"
|
||||||
|
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
"github.com/matrix-org/sync-v3/internal"
|
"github.com/matrix-org/sync-v3/internal"
|
||||||
@ -31,13 +32,19 @@ type SyncLiveHandler struct {
|
|||||||
V2Store *sync2.Storage
|
V2Store *sync2.Storage
|
||||||
PollerMap *sync2.PollerMap
|
PollerMap *sync2.PollerMap
|
||||||
ConnMap *ConnMap
|
ConnMap *ConnMap
|
||||||
|
|
||||||
|
// inserts are done by v2 poll loops, selects are done by v3 request threads
|
||||||
|
// but the v3 requests touch non-overlapping keys, which is a good use case for sync.Map
|
||||||
|
// > (2) when multiple goroutines read, write, and overwrite entries for disjoint sets of keys.
|
||||||
|
userCaches *sync.Map // map[user_id]*UserCache
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSync3Handler(v2Client sync2.Client, postgresDBURI string) (*SyncLiveHandler, error) {
|
func NewSync3Handler(v2Client sync2.Client, postgresDBURI string) (*SyncLiveHandler, error) {
|
||||||
sh := &SyncLiveHandler{
|
sh := &SyncLiveHandler{
|
||||||
V2: v2Client,
|
V2: v2Client,
|
||||||
Storage: state.NewStorage(postgresDBURI),
|
Storage: state.NewStorage(postgresDBURI),
|
||||||
V2Store: sync2.NewStore(postgresDBURI),
|
V2Store: sync2.NewStore(postgresDBURI),
|
||||||
|
userCaches: &sync.Map{},
|
||||||
}
|
}
|
||||||
sh.PollerMap = sync2.NewPollerMap(v2Client, sh)
|
sh.PollerMap = sync2.NewPollerMap(v2Client, sh)
|
||||||
sh.ConnMap = NewConnMap(sh.Storage)
|
sh.ConnMap = NewConnMap(sh.Storage)
|
||||||
@ -198,6 +205,15 @@ func (h *SyncLiveHandler) setupConnection(req *http.Request, syncReq *Request, c
|
|||||||
hlog.FromRequest(req).With().Str("user_id", v2device.UserID).Logger(),
|
hlog.FromRequest(req).With().Str("user_id", v2device.UserID).Logger(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
userCache, err := h.userCache(v2device.UserID)
|
||||||
|
if err != nil {
|
||||||
|
log.Warn().Err(err).Str("user_id", v2device.UserID).Msg("failed to load user cache")
|
||||||
|
return nil, &internal.HandlerError{
|
||||||
|
StatusCode: 500,
|
||||||
|
Err: err,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Now the v2 side of things are running, we can make a v3 live sync conn
|
// Now the v2 side of things are running, we can make a v3 live sync conn
|
||||||
// NB: this isn't inherently racey (we did the check for an existing conn before EnsurePolling)
|
// NB: this isn't inherently racey (we did the check for an existing conn before EnsurePolling)
|
||||||
// because we *either* do the existing check *or* make a new conn. It's important for CreateConn
|
// because we *either* do the existing check *or* make a new conn. It's important for CreateConn
|
||||||
@ -206,7 +222,7 @@ func (h *SyncLiveHandler) setupConnection(req *http.Request, syncReq *Request, c
|
|||||||
conn, created := h.ConnMap.GetOrCreateConn(ConnID{
|
conn, created := h.ConnMap.GetOrCreateConn(ConnID{
|
||||||
SessionID: syncReq.SessionID,
|
SessionID: syncReq.SessionID,
|
||||||
DeviceID: deviceID,
|
DeviceID: deviceID,
|
||||||
}, v2device.UserID)
|
}, v2device.UserID, userCache)
|
||||||
if created {
|
if created {
|
||||||
log.Info().Str("conn_id", conn.ConnID.String()).Msg("created new connection")
|
log.Info().Str("conn_id", conn.ConnID.String()).Msg("created new connection")
|
||||||
} else {
|
} else {
|
||||||
@ -215,6 +231,23 @@ func (h *SyncLiveHandler) setupConnection(req *http.Request, syncReq *Request, c
|
|||||||
return conn, nil
|
return conn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *SyncLiveHandler) userCache(userID string) (*UserCache, error) {
|
||||||
|
c, ok := h.userCaches.Load(userID)
|
||||||
|
if ok {
|
||||||
|
return c.(*UserCache), nil
|
||||||
|
}
|
||||||
|
uc := NewUserCache(userID)
|
||||||
|
// select all non-zero highlight or notif counts and set them, as this is less costly than looping every room/user pair
|
||||||
|
err := h.Storage.UnreadTable.SelectAllNonZeroCountsForUser(userID, func(roomID string, highlightCount, notificationCount int) {
|
||||||
|
uc.OnUnreadCounts(roomID, &highlightCount, ¬ificationCount)
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to load unread counts: %s", err)
|
||||||
|
}
|
||||||
|
h.userCaches.Store(userID, uc)
|
||||||
|
return uc, nil
|
||||||
|
}
|
||||||
|
|
||||||
// Called from the v2 poller, implements V2DataReceiver
|
// Called from the v2 poller, implements V2DataReceiver
|
||||||
func (h *SyncLiveHandler) UpdateDeviceSince(deviceID, since string) error {
|
func (h *SyncLiveHandler) UpdateDeviceSince(deviceID, since string) error {
|
||||||
return h.V2Store.UpdateDeviceSince(deviceID, since)
|
return h.V2Store.UpdateDeviceSince(deviceID, since)
|
||||||
@ -270,5 +303,9 @@ func (h *SyncLiveHandler) UpdateUnreadCounts(roomID, userID string, highlightCou
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Err(err).Str("user", userID).Str("room", roomID).Msg("failed to update unread counters")
|
logger.Err(err).Str("user", userID).Str("room", roomID).Msg("failed to update unread counters")
|
||||||
}
|
}
|
||||||
h.ConnMap.OnUnreadCounts(roomID, userID, highlightCount, notifCount)
|
userCache, ok := h.userCaches.Load(userID)
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
userCache.(*UserCache).OnUnreadCounts(roomID, highlightCount, notifCount)
|
||||||
}
|
}
|
||||||
|
@ -30,8 +30,3 @@ func (s SortableRooms) Len() int64 {
|
|||||||
func (s SortableRooms) Subslice(i, j int64) Subslicer {
|
func (s SortableRooms) Subslice(i, j int64) Subslicer {
|
||||||
return s[i:j]
|
return s[i:j]
|
||||||
}
|
}
|
||||||
|
|
||||||
type userRoomData struct {
|
|
||||||
notificationCount int
|
|
||||||
highlightCount int
|
|
||||||
}
|
|
||||||
|
83
sync3/usercache.go
Normal file
83
sync3/usercache.go
Normal file
@ -0,0 +1,83 @@
|
|||||||
|
package sync3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
type UserRoomData struct {
|
||||||
|
NotificationCount int
|
||||||
|
HighlightCount int
|
||||||
|
}
|
||||||
|
|
||||||
|
type UserCacheListener interface {
|
||||||
|
OnUnreadCountsChanged(userID, roomID string, urd UserRoomData, hasCountDecreased bool)
|
||||||
|
}
|
||||||
|
|
||||||
|
type UserCache struct {
|
||||||
|
userID string
|
||||||
|
roomToData map[string]UserRoomData
|
||||||
|
roomToDataMu *sync.RWMutex
|
||||||
|
listeners map[int]UserCacheListener
|
||||||
|
listenersMu *sync.Mutex
|
||||||
|
id int
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewUserCache(userID string) *UserCache {
|
||||||
|
return &UserCache{
|
||||||
|
userID: userID,
|
||||||
|
roomToDataMu: &sync.RWMutex{},
|
||||||
|
roomToData: make(map[string]UserRoomData),
|
||||||
|
listeners: make(map[int]UserCacheListener),
|
||||||
|
listenersMu: &sync.Mutex{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *UserCache) Subsribe(ucl UserCacheListener) (id int) {
|
||||||
|
c.listenersMu.Lock()
|
||||||
|
defer c.listenersMu.Unlock()
|
||||||
|
id = c.id
|
||||||
|
c.id += 1
|
||||||
|
c.listeners[id] = ucl
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *UserCache) Unsubscribe(id int) {
|
||||||
|
c.listenersMu.Lock()
|
||||||
|
defer c.listenersMu.Unlock()
|
||||||
|
delete(c.listeners, id)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *UserCache) loadRoomData(roomID string) UserRoomData {
|
||||||
|
c.roomToDataMu.RLock()
|
||||||
|
defer c.roomToDataMu.RUnlock()
|
||||||
|
data, ok := c.roomToData[roomID]
|
||||||
|
if !ok {
|
||||||
|
return UserRoomData{}
|
||||||
|
}
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
|
// =================================================
|
||||||
|
// Listener functions called by v2 pollers are below
|
||||||
|
// =================================================
|
||||||
|
|
||||||
|
func (c *UserCache) OnUnreadCounts(roomID string, highlightCount, notifCount *int) {
|
||||||
|
data := c.loadRoomData(roomID)
|
||||||
|
hasCountDecreased := false
|
||||||
|
if highlightCount != nil {
|
||||||
|
hasCountDecreased = *highlightCount < data.HighlightCount
|
||||||
|
data.HighlightCount = *highlightCount
|
||||||
|
}
|
||||||
|
if notifCount != nil {
|
||||||
|
if !hasCountDecreased {
|
||||||
|
hasCountDecreased = *notifCount < data.NotificationCount
|
||||||
|
}
|
||||||
|
data.NotificationCount = *notifCount
|
||||||
|
}
|
||||||
|
c.roomToDataMu.Lock()
|
||||||
|
c.roomToData[roomID] = data
|
||||||
|
c.roomToDataMu.Unlock()
|
||||||
|
for _, l := range c.listeners {
|
||||||
|
l.OnUnreadCountsChanged(c.userID, roomID, data, hasCountDecreased)
|
||||||
|
}
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user