Make ConnID hold a UserID

This commit is contained in:
David Robertson 2023-04-28 13:43:45 +01:00
parent adb3ba318a
commit ca8a2d72c4
No known key found for this signature in database
GPG Key ID: 903ECE108A39DEDD
5 changed files with 25 additions and 26 deletions

View File

@ -78,6 +78,7 @@ type V2InitialSyncComplete struct {
func (*V2InitialSyncComplete) Type() string { return "V2InitialSyncComplete" }
type V2DeviceData struct {
UserID string
DeviceID string
Pos int64
}
@ -106,6 +107,7 @@ type V2DeviceMessages struct {
func (*V2DeviceMessages) Type() string { return "V2DeviceMessages" }
type V2ExpiredToken struct {
UserID string
DeviceID string
}

View File

@ -160,6 +160,7 @@ func (h *Handler) OnExpiredToken(userID, deviceID string) {
h.Store.DeviceDataTable.DeleteDevice(userID, deviceID)
// also notify v3 side so it can remove the connection from ConnMap
h.v2Pub.Notify(pubsub.ChanV2, &pubsub.V2ExpiredToken{
UserID: userID,
DeviceID: deviceID,
})
}
@ -201,6 +202,7 @@ func (h *Handler) OnE2EEData(userID, deviceID string, otkCounts map[string]int,
return
}
h.v2Pub.Notify(pubsub.ChanV2, &pubsub.V2DeviceData{
UserID: userID,
DeviceID: deviceID,
Pos: nextPos,
})

View File

@ -11,11 +11,12 @@ import (
)
type ConnID struct {
UserID string
DeviceID string
}
func (c *ConnID) String() string {
return c.DeviceID
return c.UserID + "|" + c.DeviceID
}
type ConnHandler interface {
@ -24,7 +25,6 @@ type ConnHandler interface {
// status code to send back.
OnIncomingRequest(ctx context.Context, cid ConnID, req *Request, isInitial bool) (*Response, error)
OnUpdate(ctx context.Context, update caches.Update)
UserID() string
Destroy()
Alive() bool
}
@ -33,7 +33,7 @@ type ConnHandler interface {
// of the /sync request, including sending cached data in the event of retries. It does not handle
// the contents of the data at all.
type Conn struct {
ConnID ConnID
ConnID
handler ConnHandler
@ -65,10 +65,6 @@ func NewConn(connID ConnID, h ConnHandler) *Conn {
}
}
func (c *Conn) UserID() string {
return c.handler.UserID()
}
func (c *Conn) Alive() bool {
return c.handler.Alive()
}
@ -105,7 +101,7 @@ func (c *Conn) tryRequest(ctx context.Context, req *Request) (res *Response, err
}
ctx, task := internal.StartTask(ctx, taskType)
defer task.End()
internal.Logf(ctx, "connstate", "starting user=%v device=%v pos=%v", c.handler.UserID(), c.ConnID.DeviceID, req.pos)
internal.Logf(ctx, "connstate", "starting user=%v device=%v pos=%v", c.UserID, c.ConnID.DeviceID, req.pos)
return c.handler.OnIncomingRequest(ctx, c.ConnID, req, req.pos == 0)
}
@ -164,7 +160,7 @@ func (c *Conn) OnIncomingRequest(ctx context.Context, req *Request) (resp *Respo
c.serverResponses = c.serverResponses[delIndex+1:] // slice out the first delIndex+1 elements
defer func() {
l := logger.Trace().Int("num_res_acks", delIndex+1).Bool("is_retransmit", isRetransmit).Bool("is_first", isFirstRequest).Bool("is_same", isSameRequest).Int64("pos", req.pos).Str("user", c.handler.UserID())
l := logger.Trace().Int("num_res_acks", delIndex+1).Bool("is_retransmit", isRetransmit).Bool("is_first", isFirstRequest).Bool("is_same", isSameRequest).Int64("pos", req.pos).Str("user", c.UserID)
if nextUnACKedResponse != nil {
l.Int64("new_pos", nextUnACKedResponse.PosInt())
}

View File

@ -71,7 +71,7 @@ func (m *ConnMap) CreateConn(cid ConnID, newConnHandler func() ConnHandler) (*Co
conn = NewConn(cid, h)
m.cache.Set(cid.String(), conn)
m.connIDToConn[cid.String()] = conn
m.userIDToConn[h.UserID()] = append(m.userIDToConn[h.UserID()], conn)
m.userIDToConn[cid.UserID] = append(m.userIDToConn[cid.UserID], conn)
return conn, true
}
@ -94,20 +94,20 @@ func (m *ConnMap) closeConn(conn *Conn) {
return
}
connID := conn.ConnID.String()
logger.Trace().Str("conn", connID).Msg("closing connection")
connKey := conn.ConnID.String()
logger.Trace().Str("conn", connKey).Msg("closing connection")
// remove conn from all the maps
delete(m.connIDToConn, connID)
delete(m.connIDToConn, connKey)
h := conn.handler
conns := m.userIDToConn[h.UserID()]
conns := m.userIDToConn[conn.UserID]
for i := 0; i < len(conns); i++ {
if conns[i].ConnID.String() == connID {
if conns[i].DeviceID == conn.DeviceID {
// delete without preserving order
conns[i] = conns[len(conns)-1]
conns = conns[:len(conns)-1]
}
}
m.userIDToConn[h.UserID()] = conns
m.userIDToConn[conn.UserID] = conns
// remove user cache listeners etc
h.Destroy()
}

View File

@ -220,8 +220,8 @@ func (h *SyncLiveHandler) serve(w http.ResponseWriter, req *http.Request) error
return herr
}
requestBody.SetPos(cpos)
internal.SetRequestContextUserID(req.Context(), conn.UserID())
log := hlog.FromRequest(req).With().Str("user", conn.UserID()).Int64("pos", cpos).Logger()
internal.SetRequestContextUserID(req.Context(), conn.UserID)
log := hlog.FromRequest(req).With().Str("user", conn.UserID).Int64("pos", cpos).Logger()
var timeout int
if req.URL.Query().Get("timeout") == "" {
@ -320,7 +320,10 @@ func (h *SyncLiveHandler) setupConnection(req *http.Request, syncReq *sync3.Requ
log.Warn().Msg("Unable to update last seen timestamp")
}
connID := connIDFromToken(token)
connID := sync3.ConnID{
UserID: token.UserID,
DeviceID: token.DeviceID,
}
// client thinks they have a connection
if containsPos {
// Lookup the connection
@ -375,13 +378,6 @@ func (h *SyncLiveHandler) setupConnection(req *http.Request, syncReq *sync3.Requ
return conn, nil
}
func connIDFromToken(token *sync2.Token) sync3.ConnID {
return sync3.ConnID{
// TODO: change ConnID to be a (user, device) ID pair
DeviceID: token.DeviceID,
}
}
func (h *SyncLiveHandler) identifyUnknownAccessToken(accessToken string) (*sync2.Token, *internal.HandlerError) {
// We don't recognise the given accessToken. Ask the homeserver who owns it.
userID, deviceID, err := h.V2.WhoAmI(accessToken)
@ -595,6 +591,7 @@ func (h *SyncLiveHandler) OnDeviceData(p *pubsub.V2DeviceData) {
ctx, task := internal.StartTask(context.Background(), "OnDeviceData")
defer task.End()
conn := h.ConnMap.Conn(sync3.ConnID{
UserID: p.UserID,
DeviceID: p.DeviceID,
})
if conn == nil {
@ -607,6 +604,7 @@ func (h *SyncLiveHandler) OnDeviceMessages(p *pubsub.V2DeviceMessages) {
ctx, task := internal.StartTask(context.Background(), "OnDeviceMessages")
defer task.End()
conn := h.ConnMap.Conn(sync3.ConnID{
UserID: p.UserID,
DeviceID: p.DeviceID,
})
if conn == nil {
@ -703,6 +701,7 @@ func (h *SyncLiveHandler) OnAccountData(p *pubsub.V2AccountData) {
func (h *SyncLiveHandler) OnExpiredToken(p *pubsub.V2ExpiredToken) {
h.ConnMap.CloseConn(sync3.ConnID{
UserID: p.UserID,
DeviceID: p.DeviceID,
})
}