mirror of
https://github.com/matrix-org/sliding-sync.git
synced 2025-03-10 13:37:11 +00:00
Cancel outstanding requests when destroying conns
This commit is contained in:
parent
72c3415561
commit
f3037861a7
@ -35,6 +35,7 @@ type ConnHandler interface {
|
||||
PublishEventsUpTo(roomID string, nid int64)
|
||||
Destroy()
|
||||
Alive() bool
|
||||
SetCancelCallback(cancel context.CancelFunc)
|
||||
}
|
||||
|
||||
// Conn is an abstraction of a long-poll connection. It automatically handles the position values
|
||||
@ -245,3 +246,7 @@ func (c *Conn) OnIncomingRequest(ctx context.Context, req *Request, start time.T
|
||||
// return the oldest value
|
||||
return nextUnACKedResponse, nil
|
||||
}
|
||||
|
||||
func (c *Conn) SetCancelCallback(cancel context.CancelFunc) {
|
||||
c.handler.SetCancelCallback(cancel)
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
package sync3
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@ -131,7 +132,7 @@ func (m *ConnMap) getConn(cid ConnID) *Conn {
|
||||
}
|
||||
|
||||
// Atomically gets or creates a connection with this connection ID. Calls newConn if a new connection is required.
|
||||
func (m *ConnMap) CreateConn(cid ConnID, newConnHandler func() ConnHandler) (*Conn, bool) {
|
||||
func (m *ConnMap) CreateConn(cid ConnID, cancel context.CancelFunc, newConnHandler func() ConnHandler) (*Conn, bool) {
|
||||
// atomically check if a conn exists already and nuke it if it exists
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
@ -149,6 +150,7 @@ func (m *ConnMap) CreateConn(cid ConnID, newConnHandler func() ConnHandler) (*Co
|
||||
m.closeConn(conn)
|
||||
}
|
||||
h := newConnHandler()
|
||||
h.SetCancelCallback(cancel)
|
||||
conn = NewConn(cid, h)
|
||||
m.cache.Set(cid.String(), conn)
|
||||
m.connIDToConn[cid.String()] = conn
|
||||
|
@ -23,8 +23,9 @@ type ConnState struct {
|
||||
userID string
|
||||
deviceID string
|
||||
// the only thing that can touch these data structures is the conn goroutine
|
||||
muxedReq *sync3.Request
|
||||
lists *sync3.InternalRequestLists
|
||||
muxedReq *sync3.Request
|
||||
cancelLatestReq context.CancelFunc
|
||||
lists *sync3.InternalRequestLists
|
||||
|
||||
// Confirmed room subscriptions. Entries in this list have been checked for things like
|
||||
// "is the user joined to this room?" whereas subscriptions in muxedReq are untrusted.
|
||||
@ -723,6 +724,8 @@ func (s *ConnState) trackProcessDuration(ctx context.Context, dur time.Duration,
|
||||
// Called when the connection is torn down
|
||||
func (s *ConnState) Destroy() {
|
||||
s.userCache.Unsubscribe(s.userCacheID)
|
||||
logger.Debug().Str("user_id", s.userID).Str("device_id", s.deviceID).Msg("cancelling any in-flight requests")
|
||||
s.cancelLatestReq()
|
||||
}
|
||||
|
||||
func (s *ConnState) Alive() bool {
|
||||
@ -764,6 +767,10 @@ func (s *ConnState) PublishEventsUpTo(roomID string, nid int64) {
|
||||
s.txnIDWaiter.PublishUpToNID(roomID, nid)
|
||||
}
|
||||
|
||||
func (s *ConnState) SetCancelCallback(cancel context.CancelFunc) {
|
||||
s.cancelLatestReq = cancel
|
||||
}
|
||||
|
||||
// clampSliceRangeToListSize helps us to send client-friendly SYNC and INVALIDATE ranges.
|
||||
//
|
||||
// Suppose the client asks for a window on positions [10, 19]. If the list
|
||||
|
@ -73,7 +73,7 @@ func (s *connStateLive) liveUpdate(
|
||||
log.Trace().Str("dur", timeLeftToWait.String()).Msg("liveUpdate: no response data yet; blocking")
|
||||
select {
|
||||
case <-ctx.Done(): // client has given up
|
||||
log.Trace().Msg("liveUpdate: client gave up")
|
||||
log.Trace().Msg("liveUpdate: client gave up, or we killed the connection")
|
||||
internal.Logf(ctx, "liveUpdate", "context cancelled")
|
||||
return
|
||||
case <-time.After(timeLeftToWait): // we've timed out
|
||||
|
@ -249,7 +249,9 @@ func (h *SyncLiveHandler) serve(w http.ResponseWriter, req *http.Request) error
|
||||
}
|
||||
}
|
||||
|
||||
req, conn, herr := h.setupConnection(req, &requestBody, req.URL.Query().Get("pos") != "")
|
||||
cancelCtx, cancel := context.WithCancel(req.Context())
|
||||
req = req.WithContext(cancelCtx)
|
||||
req, conn, herr := h.setupConnection(req, cancel, &requestBody, req.URL.Query().Get("pos") != "")
|
||||
if herr != nil {
|
||||
logErrorOrWarning("failed to get or create Conn", herr)
|
||||
return herr
|
||||
@ -326,7 +328,7 @@ func (h *SyncLiveHandler) serve(w http.ResponseWriter, req *http.Request) error
|
||||
// setupConnection associates this request with an existing connection or makes a new connection.
|
||||
// It also sets a v2 sync poll loop going if one didn't exist already for this user.
|
||||
// When this function returns, the connection is alive and active.
|
||||
func (h *SyncLiveHandler) setupConnection(req *http.Request, syncReq *sync3.Request, containsPos bool) (*http.Request, *sync3.Conn, *internal.HandlerError) {
|
||||
func (h *SyncLiveHandler) setupConnection(req *http.Request, cancel context.CancelFunc, syncReq *sync3.Request, containsPos bool) (*http.Request, *sync3.Conn, *internal.HandlerError) {
|
||||
ctx, task := internal.StartTask(req.Context(), "setupConnection")
|
||||
req = req.WithContext(ctx)
|
||||
defer task.End()
|
||||
@ -386,6 +388,7 @@ func (h *SyncLiveHandler) setupConnection(req *http.Request, syncReq *sync3.Requ
|
||||
if containsPos {
|
||||
// Lookup the connection
|
||||
conn = h.ConnMap.Conn(connID)
|
||||
conn.SetCancelCallback(cancel)
|
||||
if conn != nil {
|
||||
log.Trace().Str("conn", conn.ConnID.String()).Msg("reusing conn")
|
||||
return req, conn, nil
|
||||
@ -434,7 +437,7 @@ func (h *SyncLiveHandler) setupConnection(req *http.Request, syncReq *sync3.Requ
|
||||
// because we *either* do the existing check *or* make a new conn. It's important for CreateConn
|
||||
// to check for an existing connection though, as it's possible for the client to call /sync
|
||||
// twice for a new connection.
|
||||
conn, created := h.ConnMap.CreateConn(connID, func() sync3.ConnHandler {
|
||||
conn, created := h.ConnMap.CreateConn(connID, cancel, func() sync3.ConnHandler {
|
||||
return NewConnState(token.UserID, token.DeviceID, userCache, h.GlobalCache, h.Extensions, h.Dispatcher, h.setupHistVec, h.histVec, h.maxPendingEventUpdates, h.maxTransactionIDDelay)
|
||||
})
|
||||
if created {
|
||||
|
@ -74,8 +74,10 @@ func TestRequestCancelledWhenItsConnIsDestroyed(t *testing.T) {
|
||||
t.Log("Alice waits for her second sync to return.")
|
||||
response := <-done
|
||||
|
||||
assertEqual(t, "status code", response.res.StatusCode, http.StatusBadRequest)
|
||||
assertEqual(t, "response errcode", response.body.Get("errcode").Str, "M_UNKNOWN_POS")
|
||||
// TODO: At first I expected that this the cancelled request should return 400 M_UNKNOWN_POS.
|
||||
// But I think that is best handled on the next incoming request.
|
||||
// assertEqual(t, "status code", response.res.StatusCode, http.StatusBadRequest)
|
||||
// assertEqual(t, "response errcode", response.body.Get("errcode").Str, "M_UNKNOWN_POS")
|
||||
|
||||
if response.duration > cancelWithin {
|
||||
t.Errorf("Waited for %s, but expected second sync to cancel after at most %s", response.duration, cancelWithin)
|
||||
|
Loading…
x
Reference in New Issue
Block a user