Cancel outstanding requests when destroying conns

This commit is contained in:
David Robertson 2023-10-26 15:58:06 +01:00
parent 72c3415561
commit f3037861a7
No known key found for this signature in database
GPG Key ID: 903ECE108A39DEDD
6 changed files with 28 additions and 9 deletions

View File

@ -35,6 +35,7 @@ type ConnHandler interface {
PublishEventsUpTo(roomID string, nid int64)
Destroy()
Alive() bool
SetCancelCallback(cancel context.CancelFunc)
}
// Conn is an abstraction of a long-poll connection. It automatically handles the position values
@ -245,3 +246,7 @@ func (c *Conn) OnIncomingRequest(ctx context.Context, req *Request, start time.T
// return the oldest value
return nextUnACKedResponse, nil
}
func (c *Conn) SetCancelCallback(cancel context.CancelFunc) {
c.handler.SetCancelCallback(cancel)
}

View File

@ -1,6 +1,7 @@
package sync3
import (
"context"
"sync"
"time"
@ -131,7 +132,7 @@ func (m *ConnMap) getConn(cid ConnID) *Conn {
}
// Atomically gets or creates a connection with this connection ID. Calls newConn if a new connection is required.
func (m *ConnMap) CreateConn(cid ConnID, newConnHandler func() ConnHandler) (*Conn, bool) {
func (m *ConnMap) CreateConn(cid ConnID, cancel context.CancelFunc, newConnHandler func() ConnHandler) (*Conn, bool) {
// atomically check if a conn exists already and nuke it if it exists
m.mu.Lock()
defer m.mu.Unlock()
@ -149,6 +150,7 @@ func (m *ConnMap) CreateConn(cid ConnID, newConnHandler func() ConnHandler) (*Co
m.closeConn(conn)
}
h := newConnHandler()
h.SetCancelCallback(cancel)
conn = NewConn(cid, h)
m.cache.Set(cid.String(), conn)
m.connIDToConn[cid.String()] = conn

View File

@ -23,8 +23,9 @@ type ConnState struct {
userID string
deviceID string
// the only thing that can touch these data structures is the conn goroutine
muxedReq *sync3.Request
lists *sync3.InternalRequestLists
muxedReq *sync3.Request
cancelLatestReq context.CancelFunc
lists *sync3.InternalRequestLists
// Confirmed room subscriptions. Entries in this list have been checked for things like
// "is the user joined to this room?" whereas subscriptions in muxedReq are untrusted.
@ -723,6 +724,8 @@ func (s *ConnState) trackProcessDuration(ctx context.Context, dur time.Duration,
// Called when the connection is torn down
func (s *ConnState) Destroy() {
s.userCache.Unsubscribe(s.userCacheID)
logger.Debug().Str("user_id", s.userID).Str("device_id", s.deviceID).Msg("cancelling any in-flight requests")
s.cancelLatestReq()
}
func (s *ConnState) Alive() bool {
@ -764,6 +767,10 @@ func (s *ConnState) PublishEventsUpTo(roomID string, nid int64) {
s.txnIDWaiter.PublishUpToNID(roomID, nid)
}
func (s *ConnState) SetCancelCallback(cancel context.CancelFunc) {
s.cancelLatestReq = cancel
}
// clampSliceRangeToListSize helps us to send client-friendly SYNC and INVALIDATE ranges.
//
// Suppose the client asks for a window on positions [10, 19]. If the list

View File

@ -73,7 +73,7 @@ func (s *connStateLive) liveUpdate(
log.Trace().Str("dur", timeLeftToWait.String()).Msg("liveUpdate: no response data yet; blocking")
select {
case <-ctx.Done(): // client has given up
log.Trace().Msg("liveUpdate: client gave up")
log.Trace().Msg("liveUpdate: client gave up, or we killed the connection")
internal.Logf(ctx, "liveUpdate", "context cancelled")
return
case <-time.After(timeLeftToWait): // we've timed out

View File

@ -249,7 +249,9 @@ func (h *SyncLiveHandler) serve(w http.ResponseWriter, req *http.Request) error
}
}
req, conn, herr := h.setupConnection(req, &requestBody, req.URL.Query().Get("pos") != "")
cancelCtx, cancel := context.WithCancel(req.Context())
req = req.WithContext(cancelCtx)
req, conn, herr := h.setupConnection(req, cancel, &requestBody, req.URL.Query().Get("pos") != "")
if herr != nil {
logErrorOrWarning("failed to get or create Conn", herr)
return herr
@ -326,7 +328,7 @@ func (h *SyncLiveHandler) serve(w http.ResponseWriter, req *http.Request) error
// setupConnection associates this request with an existing connection or makes a new connection.
// It also sets a v2 sync poll loop going if one didn't exist already for this user.
// When this function returns, the connection is alive and active.
func (h *SyncLiveHandler) setupConnection(req *http.Request, syncReq *sync3.Request, containsPos bool) (*http.Request, *sync3.Conn, *internal.HandlerError) {
func (h *SyncLiveHandler) setupConnection(req *http.Request, cancel context.CancelFunc, syncReq *sync3.Request, containsPos bool) (*http.Request, *sync3.Conn, *internal.HandlerError) {
ctx, task := internal.StartTask(req.Context(), "setupConnection")
req = req.WithContext(ctx)
defer task.End()
@ -386,6 +388,7 @@ func (h *SyncLiveHandler) setupConnection(req *http.Request, syncReq *sync3.Requ
if containsPos {
// Lookup the connection
conn = h.ConnMap.Conn(connID)
conn.SetCancelCallback(cancel)
if conn != nil {
log.Trace().Str("conn", conn.ConnID.String()).Msg("reusing conn")
return req, conn, nil
@ -434,7 +437,7 @@ func (h *SyncLiveHandler) setupConnection(req *http.Request, syncReq *sync3.Requ
// because we *either* do the existing check *or* make a new conn. It's important for CreateConn
// to check for an existing connection though, as it's possible for the client to call /sync
// twice for a new connection.
conn, created := h.ConnMap.CreateConn(connID, func() sync3.ConnHandler {
conn, created := h.ConnMap.CreateConn(connID, cancel, func() sync3.ConnHandler {
return NewConnState(token.UserID, token.DeviceID, userCache, h.GlobalCache, h.Extensions, h.Dispatcher, h.setupHistVec, h.histVec, h.maxPendingEventUpdates, h.maxTransactionIDDelay)
})
if created {

View File

@ -74,8 +74,10 @@ func TestRequestCancelledWhenItsConnIsDestroyed(t *testing.T) {
t.Log("Alice waits for her second sync to return.")
response := <-done
assertEqual(t, "status code", response.res.StatusCode, http.StatusBadRequest)
assertEqual(t, "response errcode", response.body.Get("errcode").Str, "M_UNKNOWN_POS")
// TODO: At first I expected that this the cancelled request should return 400 M_UNKNOWN_POS.
// But I think that is best handled on the next incoming request.
// assertEqual(t, "status code", response.res.StatusCode, http.StatusBadRequest)
// assertEqual(t, "response errcode", response.body.Get("errcode").Str, "M_UNKNOWN_POS")
if response.duration > cancelWithin {
t.Errorf("Waited for %s, but expected second sync to cancel after at most %s", response.duration, cancelWithin)