mirror of
https://github.com/matrix-org/sliding-sync.git
synced 2025-03-10 13:37:11 +00:00
Make ConnID hold a UserID
This commit is contained in:
parent
adb3ba318a
commit
ca8a2d72c4
@ -78,6 +78,7 @@ type V2InitialSyncComplete struct {
|
||||
func (*V2InitialSyncComplete) Type() string { return "V2InitialSyncComplete" }
|
||||
|
||||
type V2DeviceData struct {
|
||||
UserID string
|
||||
DeviceID string
|
||||
Pos int64
|
||||
}
|
||||
@ -106,6 +107,7 @@ type V2DeviceMessages struct {
|
||||
func (*V2DeviceMessages) Type() string { return "V2DeviceMessages" }
|
||||
|
||||
type V2ExpiredToken struct {
|
||||
UserID string
|
||||
DeviceID string
|
||||
}
|
||||
|
||||
|
@ -160,6 +160,7 @@ func (h *Handler) OnExpiredToken(userID, deviceID string) {
|
||||
h.Store.DeviceDataTable.DeleteDevice(userID, deviceID)
|
||||
// also notify v3 side so it can remove the connection from ConnMap
|
||||
h.v2Pub.Notify(pubsub.ChanV2, &pubsub.V2ExpiredToken{
|
||||
UserID: userID,
|
||||
DeviceID: deviceID,
|
||||
})
|
||||
}
|
||||
@ -201,6 +202,7 @@ func (h *Handler) OnE2EEData(userID, deviceID string, otkCounts map[string]int,
|
||||
return
|
||||
}
|
||||
h.v2Pub.Notify(pubsub.ChanV2, &pubsub.V2DeviceData{
|
||||
UserID: userID,
|
||||
DeviceID: deviceID,
|
||||
Pos: nextPos,
|
||||
})
|
||||
|
@ -11,11 +11,12 @@ import (
|
||||
)
|
||||
|
||||
type ConnID struct {
|
||||
UserID string
|
||||
DeviceID string
|
||||
}
|
||||
|
||||
func (c *ConnID) String() string {
|
||||
return c.DeviceID
|
||||
return c.UserID + "|" + c.DeviceID
|
||||
}
|
||||
|
||||
type ConnHandler interface {
|
||||
@ -24,7 +25,6 @@ type ConnHandler interface {
|
||||
// status code to send back.
|
||||
OnIncomingRequest(ctx context.Context, cid ConnID, req *Request, isInitial bool) (*Response, error)
|
||||
OnUpdate(ctx context.Context, update caches.Update)
|
||||
UserID() string
|
||||
Destroy()
|
||||
Alive() bool
|
||||
}
|
||||
@ -33,7 +33,7 @@ type ConnHandler interface {
|
||||
// of the /sync request, including sending cached data in the event of retries. It does not handle
|
||||
// the contents of the data at all.
|
||||
type Conn struct {
|
||||
ConnID ConnID
|
||||
ConnID
|
||||
|
||||
handler ConnHandler
|
||||
|
||||
@ -65,10 +65,6 @@ func NewConn(connID ConnID, h ConnHandler) *Conn {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conn) UserID() string {
|
||||
return c.handler.UserID()
|
||||
}
|
||||
|
||||
func (c *Conn) Alive() bool {
|
||||
return c.handler.Alive()
|
||||
}
|
||||
@ -105,7 +101,7 @@ func (c *Conn) tryRequest(ctx context.Context, req *Request) (res *Response, err
|
||||
}
|
||||
ctx, task := internal.StartTask(ctx, taskType)
|
||||
defer task.End()
|
||||
internal.Logf(ctx, "connstate", "starting user=%v device=%v pos=%v", c.handler.UserID(), c.ConnID.DeviceID, req.pos)
|
||||
internal.Logf(ctx, "connstate", "starting user=%v device=%v pos=%v", c.UserID, c.ConnID.DeviceID, req.pos)
|
||||
return c.handler.OnIncomingRequest(ctx, c.ConnID, req, req.pos == 0)
|
||||
}
|
||||
|
||||
@ -164,7 +160,7 @@ func (c *Conn) OnIncomingRequest(ctx context.Context, req *Request) (resp *Respo
|
||||
c.serverResponses = c.serverResponses[delIndex+1:] // slice out the first delIndex+1 elements
|
||||
|
||||
defer func() {
|
||||
l := logger.Trace().Int("num_res_acks", delIndex+1).Bool("is_retransmit", isRetransmit).Bool("is_first", isFirstRequest).Bool("is_same", isSameRequest).Int64("pos", req.pos).Str("user", c.handler.UserID())
|
||||
l := logger.Trace().Int("num_res_acks", delIndex+1).Bool("is_retransmit", isRetransmit).Bool("is_first", isFirstRequest).Bool("is_same", isSameRequest).Int64("pos", req.pos).Str("user", c.UserID)
|
||||
if nextUnACKedResponse != nil {
|
||||
l.Int64("new_pos", nextUnACKedResponse.PosInt())
|
||||
}
|
||||
|
@ -71,7 +71,7 @@ func (m *ConnMap) CreateConn(cid ConnID, newConnHandler func() ConnHandler) (*Co
|
||||
conn = NewConn(cid, h)
|
||||
m.cache.Set(cid.String(), conn)
|
||||
m.connIDToConn[cid.String()] = conn
|
||||
m.userIDToConn[h.UserID()] = append(m.userIDToConn[h.UserID()], conn)
|
||||
m.userIDToConn[cid.UserID] = append(m.userIDToConn[cid.UserID], conn)
|
||||
return conn, true
|
||||
}
|
||||
|
||||
@ -94,20 +94,20 @@ func (m *ConnMap) closeConn(conn *Conn) {
|
||||
return
|
||||
}
|
||||
|
||||
connID := conn.ConnID.String()
|
||||
logger.Trace().Str("conn", connID).Msg("closing connection")
|
||||
connKey := conn.ConnID.String()
|
||||
logger.Trace().Str("conn", connKey).Msg("closing connection")
|
||||
// remove conn from all the maps
|
||||
delete(m.connIDToConn, connID)
|
||||
delete(m.connIDToConn, connKey)
|
||||
h := conn.handler
|
||||
conns := m.userIDToConn[h.UserID()]
|
||||
conns := m.userIDToConn[conn.UserID]
|
||||
for i := 0; i < len(conns); i++ {
|
||||
if conns[i].ConnID.String() == connID {
|
||||
if conns[i].DeviceID == conn.DeviceID {
|
||||
// delete without preserving order
|
||||
conns[i] = conns[len(conns)-1]
|
||||
conns = conns[:len(conns)-1]
|
||||
}
|
||||
}
|
||||
m.userIDToConn[h.UserID()] = conns
|
||||
m.userIDToConn[conn.UserID] = conns
|
||||
// remove user cache listeners etc
|
||||
h.Destroy()
|
||||
}
|
||||
|
@ -220,8 +220,8 @@ func (h *SyncLiveHandler) serve(w http.ResponseWriter, req *http.Request) error
|
||||
return herr
|
||||
}
|
||||
requestBody.SetPos(cpos)
|
||||
internal.SetRequestContextUserID(req.Context(), conn.UserID())
|
||||
log := hlog.FromRequest(req).With().Str("user", conn.UserID()).Int64("pos", cpos).Logger()
|
||||
internal.SetRequestContextUserID(req.Context(), conn.UserID)
|
||||
log := hlog.FromRequest(req).With().Str("user", conn.UserID).Int64("pos", cpos).Logger()
|
||||
|
||||
var timeout int
|
||||
if req.URL.Query().Get("timeout") == "" {
|
||||
@ -320,7 +320,10 @@ func (h *SyncLiveHandler) setupConnection(req *http.Request, syncReq *sync3.Requ
|
||||
log.Warn().Msg("Unable to update last seen timestamp")
|
||||
}
|
||||
|
||||
connID := connIDFromToken(token)
|
||||
connID := sync3.ConnID{
|
||||
UserID: token.UserID,
|
||||
DeviceID: token.DeviceID,
|
||||
}
|
||||
// client thinks they have a connection
|
||||
if containsPos {
|
||||
// Lookup the connection
|
||||
@ -375,13 +378,6 @@ func (h *SyncLiveHandler) setupConnection(req *http.Request, syncReq *sync3.Requ
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func connIDFromToken(token *sync2.Token) sync3.ConnID {
|
||||
return sync3.ConnID{
|
||||
// TODO: change ConnID to be a (user, device) ID pair
|
||||
DeviceID: token.DeviceID,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *SyncLiveHandler) identifyUnknownAccessToken(accessToken string) (*sync2.Token, *internal.HandlerError) {
|
||||
// We don't recognise the given accessToken. Ask the homeserver who owns it.
|
||||
userID, deviceID, err := h.V2.WhoAmI(accessToken)
|
||||
@ -595,6 +591,7 @@ func (h *SyncLiveHandler) OnDeviceData(p *pubsub.V2DeviceData) {
|
||||
ctx, task := internal.StartTask(context.Background(), "OnDeviceData")
|
||||
defer task.End()
|
||||
conn := h.ConnMap.Conn(sync3.ConnID{
|
||||
UserID: p.UserID,
|
||||
DeviceID: p.DeviceID,
|
||||
})
|
||||
if conn == nil {
|
||||
@ -607,6 +604,7 @@ func (h *SyncLiveHandler) OnDeviceMessages(p *pubsub.V2DeviceMessages) {
|
||||
ctx, task := internal.StartTask(context.Background(), "OnDeviceMessages")
|
||||
defer task.End()
|
||||
conn := h.ConnMap.Conn(sync3.ConnID{
|
||||
UserID: p.UserID,
|
||||
DeviceID: p.DeviceID,
|
||||
})
|
||||
if conn == nil {
|
||||
@ -703,6 +701,7 @@ func (h *SyncLiveHandler) OnAccountData(p *pubsub.V2AccountData) {
|
||||
|
||||
func (h *SyncLiveHandler) OnExpiredToken(p *pubsub.V2ExpiredToken) {
|
||||
h.ConnMap.CloseConn(sync3.ConnID{
|
||||
UserID: p.UserID,
|
||||
DeviceID: p.DeviceID,
|
||||
})
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user