mirror of
https://github.com/matrix-org/sliding-sync.git
synced 2025-03-10 13:37:11 +00:00
bugfix: expire connections when the access token gets invalidated
With regression test. The behaviour is: - Delete the connection, such that incoming requests will end up with M_UNKNOWN_POS - The next request will then return HTTP 401. This has knock-on effects: - We no longer send HTTP 502 if /whoami returns 401, instead we return 401. - When the token is expired (pollers get 401, the device is deleted from the DB).
This commit is contained in:
parent
9ba2ad7dba
commit
6bdef5feba
@ -21,6 +21,7 @@ type V2Listener interface {
|
||||
OnTyping(p *V2Typing)
|
||||
OnReceipt(p *V2Receipt)
|
||||
OnDeviceMessages(p *V2DeviceMessages)
|
||||
OnExpiredToken(p *V2ExpiredToken)
|
||||
}
|
||||
|
||||
type V2Initialise struct {
|
||||
@ -104,6 +105,12 @@ type V2DeviceMessages struct {
|
||||
|
||||
func (*V2DeviceMessages) Type() string { return "V2DeviceMessages" }
|
||||
|
||||
type V2ExpiredToken struct {
|
||||
DeviceID string
|
||||
}
|
||||
|
||||
func (*V2ExpiredToken) Type() string { return "V2ExpiredToken" }
|
||||
|
||||
type V2Sub struct {
|
||||
listener Listener
|
||||
receiver V2Listener
|
||||
@ -144,6 +151,8 @@ func (v *V2Sub) onMessage(p Payload) {
|
||||
v.receiver.OnTyping(pl)
|
||||
case *V2DeviceMessages:
|
||||
v.receiver.OnDeviceMessages(pl)
|
||||
case *V2ExpiredToken:
|
||||
v.receiver.OnExpiredToken(pl)
|
||||
default:
|
||||
logger.Warn().Str("type", p.Type()).Msg("V2Sub: unhandled payload type")
|
||||
}
|
||||
|
@ -15,6 +15,7 @@ import (
|
||||
const AccountDataGlobalRoom = ""
|
||||
|
||||
var ProxyVersion = ""
|
||||
var HTTP401 error = fmt.Errorf("HTTP 401")
|
||||
|
||||
type Client interface {
|
||||
WhoAmI(accessToken string) (string, error)
|
||||
@ -28,6 +29,7 @@ type HTTPClient struct {
|
||||
DestinationServer string
|
||||
}
|
||||
|
||||
// Return sync2.HTTP401 if this request returns 401
|
||||
func (v *HTTPClient) WhoAmI(accessToken string) (string, error) {
|
||||
req, err := http.NewRequest("GET", v.DestinationServer+"/_matrix/client/r0/account/whoami", nil)
|
||||
if err != nil {
|
||||
@ -40,6 +42,9 @@ func (v *HTTPClient) WhoAmI(accessToken string) (string, error) {
|
||||
return "", err
|
||||
}
|
||||
if res.StatusCode != 200 {
|
||||
if res.StatusCode == 401 {
|
||||
return "", HTTP401
|
||||
}
|
||||
return "", fmt.Errorf("/whoami returned HTTP %d", res.StatusCode)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
@ -147,6 +147,14 @@ func (h *Handler) OnTerminated(userID, deviceID string) {
|
||||
h.updateMetrics()
|
||||
}
|
||||
|
||||
func (h *Handler) OnExpiredToken(deviceID string) {
|
||||
h.v2Store.RemoveDevice(deviceID)
|
||||
// also notify v3 side so it can remove the connection from ConnMap
|
||||
h.v2Pub.Notify(pubsub.ChanV2, &pubsub.V2ExpiredToken{
|
||||
DeviceID: deviceID,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *Handler) addPrometheusMetrics() {
|
||||
h.numPollers = prometheus.NewGauge(prometheus.GaugeOpts{
|
||||
Namespace: "sliding_sync",
|
||||
|
@ -40,8 +40,10 @@ type V2DataReceiver interface {
|
||||
OnLeftRoom(userID, roomID string)
|
||||
// Sent when there is a _change_ in E2EE data, not all the time
|
||||
OnE2EEData(userID, deviceID string, otkCounts map[string]int, fallbackKeyTypes []string, deviceListChanges map[string]int)
|
||||
// Sent when the upstream homeserver sends back a 401 invalidating the token
|
||||
// Sent when the poll loop terminates
|
||||
OnTerminated(userID, deviceID string)
|
||||
// Sent when the token gets a 401 response
|
||||
OnExpiredToken(deviceID string)
|
||||
}
|
||||
|
||||
// PollerMap is a map of device ID to Poller
|
||||
@ -242,6 +244,10 @@ func (h *PollerMap) OnTerminated(userID, deviceID string) {
|
||||
h.callbacks.OnTerminated(userID, deviceID)
|
||||
}
|
||||
|
||||
func (h *PollerMap) OnExpiredToken(deviceID string) {
|
||||
h.callbacks.OnExpiredToken(deviceID)
|
||||
}
|
||||
|
||||
func (h *PollerMap) UpdateUnreadCounts(roomID, userID string, highlightCount, notifCount *int) {
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
@ -366,6 +372,7 @@ func (p *poller) Poll(since string) {
|
||||
continue
|
||||
} else {
|
||||
p.logger.Warn().Msg("Poller: access token has been invalidated, terminating loop")
|
||||
p.receiver.OnExpiredToken(p.deviceID)
|
||||
p.Terminate()
|
||||
break
|
||||
}
|
||||
|
@ -486,6 +486,7 @@ func (s *mockDataReceiver) OnLeftRoom(userID, roomID string)
|
||||
func (s *mockDataReceiver) OnE2EEData(userID, deviceID string, otkCounts map[string]int, fallbackKeyTypes []string, deviceListChanges map[string]int) {
|
||||
}
|
||||
func (s *mockDataReceiver) OnTerminated(userID, deviceID string) {}
|
||||
func (s *mockDataReceiver) OnExpiredToken(deviceID string) {}
|
||||
|
||||
func newMocks(doSyncV2 func(authHeader, since string) (*SyncResponse, int, error)) (*mockDataReceiver, *mockClient) {
|
||||
client := &mockClient{
|
||||
|
@ -134,6 +134,14 @@ func (s *Storage) AllDevices() (devices []Device, err error) {
|
||||
return
|
||||
}
|
||||
|
||||
func (s *Storage) RemoveDevice(deviceID string) error {
|
||||
_, err := s.db.Exec(
|
||||
`DELETE FROM syncv3_sync2_devices WHERE device_id = $1`, deviceID,
|
||||
)
|
||||
log.Info().Str("device", deviceID).Msg("Deleting device")
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Storage) InsertDevice(deviceID, accessToken string) (*Device, error) {
|
||||
var device Device
|
||||
device.AccessToken = accessToken
|
||||
|
@ -51,6 +51,7 @@ func (m *ConnMap) Conn(cid ConnID) *Conn {
|
||||
return conn
|
||||
}
|
||||
// e.g buffer exceeded, close it and remove it from the cache
|
||||
logger.Trace().Str("conn", cid.String()).Msg("closing connection due to dead connection (buffer full)")
|
||||
m.closeConn(conn)
|
||||
return nil
|
||||
}
|
||||
@ -63,6 +64,7 @@ func (m *ConnMap) CreateConn(cid ConnID, newConnHandler func() ConnHandler) (*Co
|
||||
conn := m.Conn(cid)
|
||||
if conn != nil {
|
||||
// tear down this connection and fallthrough
|
||||
logger.Trace().Str("conn", cid.String()).Msg("closing connection due to CreateConn called again")
|
||||
m.closeConn(conn)
|
||||
}
|
||||
h := newConnHandler()
|
||||
@ -74,16 +76,15 @@ func (m *ConnMap) CreateConn(cid ConnID, newConnHandler func() ConnHandler) (*Co
|
||||
}
|
||||
|
||||
func (m *ConnMap) CloseConn(connID ConnID) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
conn := m.Conn(connID)
|
||||
m.closeConn(conn)
|
||||
logger.Trace().Str("conn", connID.String()).Msg("closing connection due to CloseConn()")
|
||||
m.cache.Remove(connID.String()) // this will fire TTL callbacks which calls closeConn
|
||||
}
|
||||
|
||||
func (m *ConnMap) closeConnExpires(connID string, value interface{}) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
conn := value.(*Conn)
|
||||
logger.Trace().Str("conn", connID).Msg("closing connection due to expired TTL in cache")
|
||||
m.closeConn(conn)
|
||||
}
|
||||
|
||||
|
@ -313,6 +313,12 @@ func (h *SyncLiveHandler) setupConnection(req *http.Request, syncReq *sync3.Requ
|
||||
if v2device.UserID == "" {
|
||||
v2device.UserID, err = h.V2.WhoAmI(accessToken)
|
||||
if err != nil {
|
||||
if err == sync2.HTTP401 {
|
||||
return nil, &internal.HandlerError{
|
||||
StatusCode: 401,
|
||||
Err: fmt.Errorf("/whoami returned HTTP 401"),
|
||||
}
|
||||
}
|
||||
log.Warn().Err(err).Str("device_id", deviceID).Msg("failed to get user ID from device ID")
|
||||
return nil, &internal.HandlerError{
|
||||
StatusCode: http.StatusBadGateway,
|
||||
@ -641,6 +647,12 @@ func (h *SyncLiveHandler) OnAccountData(p *pubsub.V2AccountData) {
|
||||
userCache.(*caches.UserCache).OnAccountData(ctx, data)
|
||||
}
|
||||
|
||||
func (h *SyncLiveHandler) OnExpiredToken(p *pubsub.V2ExpiredToken) {
|
||||
h.ConnMap.CloseConn(sync3.ConnID{
|
||||
DeviceID: p.DeviceID,
|
||||
})
|
||||
}
|
||||
|
||||
func parseIntFromQuery(u *url.URL, param string) (result int64, err *internal.HandlerError) {
|
||||
queryPos := u.Query().Get(param)
|
||||
if queryPos != "" {
|
||||
|
@ -623,3 +623,34 @@ func TestSessionExpiryOnBufferFill(t *testing.T) {
|
||||
t.Errorf("got %v want errcode=M_UNKNOWN_POS", string(body))
|
||||
}
|
||||
}
|
||||
|
||||
func TestExpiredAccessToken(t *testing.T) {
|
||||
pqString := testutils.PrepareDBConnectionString()
|
||||
v2 := runTestV2Server(t)
|
||||
v2.addAccount(alice, aliceToken)
|
||||
v3 := runTestServer(t, v2, pqString)
|
||||
roomID := "!doesnt:matter"
|
||||
res := v3.mustDoV3Request(t, aliceToken, sync3.Request{
|
||||
RoomSubscriptions: map[string]sync3.RoomSubscription{
|
||||
roomID: {
|
||||
TimelineLimit: 1,
|
||||
},
|
||||
},
|
||||
})
|
||||
// now expire the token
|
||||
v2.invalidateToken(aliceToken)
|
||||
// now do another request, this should 400 as it expires the session
|
||||
req := sync3.Request{}
|
||||
req.SetTimeoutMSecs(1)
|
||||
_, body, statusCode := v3.doV3Request(t, context.Background(), aliceToken, res.Pos, req)
|
||||
if statusCode != 400 {
|
||||
t.Fatalf("got %d want 400 : %v", statusCode, string(body))
|
||||
}
|
||||
// do a fresh request, this should 401
|
||||
req = sync3.Request{}
|
||||
req.SetTimeoutMSecs(1)
|
||||
_, body, statusCode = v3.doV3Request(t, context.Background(), aliceToken, "", req)
|
||||
if statusCode != 401 {
|
||||
t.Fatalf("got %d want 401 : %v", statusCode, string(body))
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user