bugfix: expire connections when the access token gets invalidated

With regression test. The behaviour is:
 - Delete the connection, such that incoming requests will end up with M_UNKNOWN_POS
 - The next request will then return HTTP 401.

This has knock-on effects:
 - We no longer send HTTP 502 if /whoami returns 401, instead we return 401.
 - When the token is expired (pollers get 401, the device is deleted from the DB).
This commit is contained in:
Kegan Dougal 2023-03-01 16:40:15 +00:00
parent 9ba2ad7dba
commit 6bdef5feba
9 changed files with 87 additions and 5 deletions

View File

@ -21,6 +21,7 @@ type V2Listener interface {
OnTyping(p *V2Typing)
OnReceipt(p *V2Receipt)
OnDeviceMessages(p *V2DeviceMessages)
OnExpiredToken(p *V2ExpiredToken)
}
type V2Initialise struct {
@ -104,6 +105,12 @@ type V2DeviceMessages struct {
func (*V2DeviceMessages) Type() string { return "V2DeviceMessages" }
type V2ExpiredToken struct {
DeviceID string
}
func (*V2ExpiredToken) Type() string { return "V2ExpiredToken" }
type V2Sub struct {
listener Listener
receiver V2Listener
@ -144,6 +151,8 @@ func (v *V2Sub) onMessage(p Payload) {
v.receiver.OnTyping(pl)
case *V2DeviceMessages:
v.receiver.OnDeviceMessages(pl)
case *V2ExpiredToken:
v.receiver.OnExpiredToken(pl)
default:
logger.Warn().Str("type", p.Type()).Msg("V2Sub: unhandled payload type")
}

View File

@ -15,6 +15,7 @@ import (
const AccountDataGlobalRoom = ""
var ProxyVersion = ""
var HTTP401 error = fmt.Errorf("HTTP 401")
type Client interface {
WhoAmI(accessToken string) (string, error)
@ -28,6 +29,7 @@ type HTTPClient struct {
DestinationServer string
}
// Return sync2.HTTP401 if this request returns 401
func (v *HTTPClient) WhoAmI(accessToken string) (string, error) {
req, err := http.NewRequest("GET", v.DestinationServer+"/_matrix/client/r0/account/whoami", nil)
if err != nil {
@ -40,6 +42,9 @@ func (v *HTTPClient) WhoAmI(accessToken string) (string, error) {
return "", err
}
if res.StatusCode != 200 {
if res.StatusCode == 401 {
return "", HTTP401
}
return "", fmt.Errorf("/whoami returned HTTP %d", res.StatusCode)
}
defer res.Body.Close()

View File

@ -147,6 +147,14 @@ func (h *Handler) OnTerminated(userID, deviceID string) {
h.updateMetrics()
}
func (h *Handler) OnExpiredToken(deviceID string) {
h.v2Store.RemoveDevice(deviceID)
// also notify v3 side so it can remove the connection from ConnMap
h.v2Pub.Notify(pubsub.ChanV2, &pubsub.V2ExpiredToken{
DeviceID: deviceID,
})
}
func (h *Handler) addPrometheusMetrics() {
h.numPollers = prometheus.NewGauge(prometheus.GaugeOpts{
Namespace: "sliding_sync",

View File

@ -40,8 +40,10 @@ type V2DataReceiver interface {
OnLeftRoom(userID, roomID string)
// Sent when there is a _change_ in E2EE data, not all the time
OnE2EEData(userID, deviceID string, otkCounts map[string]int, fallbackKeyTypes []string, deviceListChanges map[string]int)
// Sent when the upstream homeserver sends back a 401 invalidating the token
// Sent when the poll loop terminates
OnTerminated(userID, deviceID string)
// Sent when the token gets a 401 response
OnExpiredToken(deviceID string)
}
// PollerMap is a map of device ID to Poller
@ -242,6 +244,10 @@ func (h *PollerMap) OnTerminated(userID, deviceID string) {
h.callbacks.OnTerminated(userID, deviceID)
}
func (h *PollerMap) OnExpiredToken(deviceID string) {
h.callbacks.OnExpiredToken(deviceID)
}
func (h *PollerMap) UpdateUnreadCounts(roomID, userID string, highlightCount, notifCount *int) {
var wg sync.WaitGroup
wg.Add(1)
@ -366,6 +372,7 @@ func (p *poller) Poll(since string) {
continue
} else {
p.logger.Warn().Msg("Poller: access token has been invalidated, terminating loop")
p.receiver.OnExpiredToken(p.deviceID)
p.Terminate()
break
}

View File

@ -486,6 +486,7 @@ func (s *mockDataReceiver) OnLeftRoom(userID, roomID string)
func (s *mockDataReceiver) OnE2EEData(userID, deviceID string, otkCounts map[string]int, fallbackKeyTypes []string, deviceListChanges map[string]int) {
}
func (s *mockDataReceiver) OnTerminated(userID, deviceID string) {}
func (s *mockDataReceiver) OnExpiredToken(deviceID string) {}
func newMocks(doSyncV2 func(authHeader, since string) (*SyncResponse, int, error)) (*mockDataReceiver, *mockClient) {
client := &mockClient{

View File

@ -134,6 +134,14 @@ func (s *Storage) AllDevices() (devices []Device, err error) {
return
}
func (s *Storage) RemoveDevice(deviceID string) error {
_, err := s.db.Exec(
`DELETE FROM syncv3_sync2_devices WHERE device_id = $1`, deviceID,
)
log.Info().Str("device", deviceID).Msg("Deleting device")
return err
}
func (s *Storage) InsertDevice(deviceID, accessToken string) (*Device, error) {
var device Device
device.AccessToken = accessToken

View File

@ -51,6 +51,7 @@ func (m *ConnMap) Conn(cid ConnID) *Conn {
return conn
}
// e.g buffer exceeded, close it and remove it from the cache
logger.Trace().Str("conn", cid.String()).Msg("closing connection due to dead connection (buffer full)")
m.closeConn(conn)
return nil
}
@ -63,6 +64,7 @@ func (m *ConnMap) CreateConn(cid ConnID, newConnHandler func() ConnHandler) (*Co
conn := m.Conn(cid)
if conn != nil {
// tear down this connection and fallthrough
logger.Trace().Str("conn", cid.String()).Msg("closing connection due to CreateConn called again")
m.closeConn(conn)
}
h := newConnHandler()
@ -74,16 +76,15 @@ func (m *ConnMap) CreateConn(cid ConnID, newConnHandler func() ConnHandler) (*Co
}
func (m *ConnMap) CloseConn(connID ConnID) {
m.mu.Lock()
defer m.mu.Unlock()
conn := m.Conn(connID)
m.closeConn(conn)
logger.Trace().Str("conn", connID.String()).Msg("closing connection due to CloseConn()")
m.cache.Remove(connID.String()) // this will fire TTL callbacks which calls closeConn
}
func (m *ConnMap) closeConnExpires(connID string, value interface{}) {
m.mu.Lock()
defer m.mu.Unlock()
conn := value.(*Conn)
logger.Trace().Str("conn", connID).Msg("closing connection due to expired TTL in cache")
m.closeConn(conn)
}

View File

@ -313,6 +313,12 @@ func (h *SyncLiveHandler) setupConnection(req *http.Request, syncReq *sync3.Requ
if v2device.UserID == "" {
v2device.UserID, err = h.V2.WhoAmI(accessToken)
if err != nil {
if err == sync2.HTTP401 {
return nil, &internal.HandlerError{
StatusCode: 401,
Err: fmt.Errorf("/whoami returned HTTP 401"),
}
}
log.Warn().Err(err).Str("device_id", deviceID).Msg("failed to get user ID from device ID")
return nil, &internal.HandlerError{
StatusCode: http.StatusBadGateway,
@ -641,6 +647,12 @@ func (h *SyncLiveHandler) OnAccountData(p *pubsub.V2AccountData) {
userCache.(*caches.UserCache).OnAccountData(ctx, data)
}
func (h *SyncLiveHandler) OnExpiredToken(p *pubsub.V2ExpiredToken) {
h.ConnMap.CloseConn(sync3.ConnID{
DeviceID: p.DeviceID,
})
}
func parseIntFromQuery(u *url.URL, param string) (result int64, err *internal.HandlerError) {
queryPos := u.Query().Get(param)
if queryPos != "" {

View File

@ -623,3 +623,34 @@ func TestSessionExpiryOnBufferFill(t *testing.T) {
t.Errorf("got %v want errcode=M_UNKNOWN_POS", string(body))
}
}
func TestExpiredAccessToken(t *testing.T) {
pqString := testutils.PrepareDBConnectionString()
v2 := runTestV2Server(t)
v2.addAccount(alice, aliceToken)
v3 := runTestServer(t, v2, pqString)
roomID := "!doesnt:matter"
res := v3.mustDoV3Request(t, aliceToken, sync3.Request{
RoomSubscriptions: map[string]sync3.RoomSubscription{
roomID: {
TimelineLimit: 1,
},
},
})
// now expire the token
v2.invalidateToken(aliceToken)
// now do another request, this should 400 as it expires the session
req := sync3.Request{}
req.SetTimeoutMSecs(1)
_, body, statusCode := v3.doV3Request(t, context.Background(), aliceToken, res.Pos, req)
if statusCode != 400 {
t.Fatalf("got %d want 400 : %v", statusCode, string(body))
}
// do a fresh request, this should 401
req = sync3.Request{}
req.SetTimeoutMSecs(1)
_, body, statusCode = v3.doV3Request(t, context.Background(), aliceToken, "", req)
if statusCode != 401 {
t.Fatalf("got %d want 401 : %v", statusCode, string(body))
}
}