setupConnection: lookup tokens, not devices

This commit is contained in:
David Robertson 2023-04-28 01:35:10 +01:00
parent e71d954030
commit e670812f11
No known key found for this signature in database
GPG Key ID: 903ECE108A39DEDD
3 changed files with 88 additions and 89 deletions

View File

@ -1,24 +1,16 @@
package internal
import (
"crypto/sha256"
"encoding/hex"
"fmt"
"net/http"
"strings"
)
func HashedTokenFromRequest(req *http.Request) (hashAccessToken string, accessToken string, err error) {
// return a hash of the access token
func ExtractAccessToken(req *http.Request) (accessToken string, err error) {
ah := req.Header.Get("Authorization")
if ah == "" {
return "", "", fmt.Errorf("missing Authorization header")
return "", fmt.Errorf("missing Authorization header")
}
accessToken = strings.TrimPrefix(ah, "Bearer ")
// important that this is a cryptographically secure hash function to prevent
// preimage attacks where Eve can use a fake token to hash to an existing device ID
// on the server.
hash := sha256.New()
hash.Write([]byte(accessToken))
return hex.EncodeToString(hash.Sum(nil)), accessToken, nil
return accessToken, nil
}

View File

@ -1,24 +0,0 @@
package internal
import (
"net/http"
"testing"
)
func TestDeviceIDFromRequest(t *testing.T) {
req, _ := http.NewRequest("POST", "http://localhost:8008", nil)
req.Header.Set("Authorization", "Bearer A")
deviceIDA, _, err := HashedTokenFromRequest(req)
if err != nil {
t.Fatalf("HashedTokenFromRequest returned %s", err)
}
req.Header.Set("Authorization", "Bearer B")
deviceIDB, _, err := HashedTokenFromRequest(req)
if err != nil {
t.Fatalf("HashedTokenFromRequest returned %s", err)
}
if deviceIDA == deviceIDB {
t.Fatalf("HashedTokenFromRequest: hashed to same device ID: %s", deviceIDA)
}
}

View File

@ -3,6 +3,7 @@ package handler
import "C"
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"github.com/getsentry/sentry-go"
@ -194,6 +195,10 @@ func (h *SyncLiveHandler) serve(w http.ResponseWriter, req *http.Request) error
}
}
}
hlog.FromRequest(req).UpdateContext(func(c zerolog.Context) zerolog.Context {
c.Str("txn_id", requestBody.TxnID)
return c
})
for listKey, l := range requestBody.Lists {
if l.Ranges != nil && !l.Ranges.Valid() {
return &internal.HandlerError{
@ -276,25 +281,50 @@ func (h *SyncLiveHandler) serve(w http.ResponseWriter, req *http.Request) error
// It also sets a v2 sync poll loop going if one didn't exist already for this user.
// When this function returns, the connection is alive and active.
func (h *SyncLiveHandler) setupConnection(req *http.Request, syncReq *sync3.Request, containsPos bool) (*sync3.Conn, error) {
log := hlog.FromRequest(req)
var conn *sync3.Conn
// Identify the device
deviceID, accessToken, err := internal.HashedTokenFromRequest(req)
// Extract an access token
accessToken, err := internal.ExtractAccessToken(req)
if err != nil || accessToken == "" {
log.Warn().Err(err).Msg("failed to get device ID from request")
hlog.FromRequest(req).Warn().Err(err).Msg("failed to get access token from request")
return nil, &internal.HandlerError{
StatusCode: 400,
StatusCode: http.StatusUnauthorized,
Err: err,
}
}
// Try to lookup a record of this token
var token *sync2.Token
token, err = h.V2Store.TokensTable.Token(accessToken)
if err != nil {
if err == sql.ErrNoRows {
hlog.FromRequest(req).Info().Msg("Received connection from unknown access token, querying with homeserver")
newToken, herr := h.identifyUnknownAccessToken(accessToken)
if herr != nil {
return nil, herr
}
token = newToken
} else {
hlog.FromRequest(req).Err(err).Msg("Failed to lookup access token")
return nil, &internal.HandlerError{
StatusCode: http.StatusInternalServerError,
Err: err,
}
}
}
log := hlog.FromRequest(req).With().Str("user", token.UserID).Str("device", token.DeviceID).Logger()
// Record the fact that we've recieved a request from this token
err = h.V2Store.TokensTable.MaybeUpdateLastSeen(token, time.Now())
if err != nil {
// Not fatal---log and continue.
log.Warn().Msg("Unable to update last seen timestamp")
}
connID := connIDFromToken(token)
// client thinks they have a connection
if containsPos {
// Lookup the connection
conn = h.ConnMap.Conn(sync3.ConnID{
DeviceID: deviceID,
})
conn = h.ConnMap.Conn(connID)
if conn != nil {
log.Trace().Str("conn", conn.ConnID.String()).Msg("reusing conn")
return conn, nil
@ -303,55 +333,22 @@ func (h *SyncLiveHandler) setupConnection(req *http.Request, syncReq *sync3.Requ
return nil, internal.ExpiredSessionError()
}
// We're going to make a new connection
// Ensure we have the v2 side of things hooked up
v2device, err := h.V2Store.DevicesTable.InsertDevice(deviceID, accessToken)
if err != nil {
log.Warn().Err(err).Str("device_id", deviceID).Msg("failed to insert v2 device")
return nil, &internal.HandlerError{
StatusCode: 500,
Err: err,
}
}
if v2device.UserID == "" {
v2device.UserID, _, err = h.V2.WhoAmI(accessToken)
if err != nil {
if err == sync2.HTTP401 {
return nil, &internal.HandlerError{
StatusCode: 401,
Err: fmt.Errorf("/whoami returned HTTP 401"),
}
}
log.Warn().Err(err).Str("device_id", deviceID).Msg("failed to get user ID from device ID")
return nil, &internal.HandlerError{
StatusCode: http.StatusBadGateway,
Err: err,
}
}
if err = h.V2Store.DevicesTable.UpdateUserIDForDevice(deviceID, v2device.UserID); err != nil {
log.Warn().Err(err).Str("device_id", deviceID).Msg("failed to persist user ID -> device ID mapping")
// non-fatal, we can still work without doing this
}
}
log.Trace().Str("user", v2device.UserID).Msg("checking poller exists and is running")
h.V3Pub.EnsurePolling(v2device.UserID, v2device.DeviceID)
log.Trace().Str("user", v2device.UserID).Msg("poller exists and is running")
log.Trace().Msg("checking poller exists and is running")
h.V3Pub.EnsurePolling(token.UserID, token.DeviceID)
log.Trace().Msg("poller exists and is running")
// this may take a while so if the client has given up (e.g timed out) by this point, just stop.
// We'll be quicker next time as the poller will already exist.
if req.Context().Err() != nil {
log.Warn().Str("user_id", v2device.UserID).Msg(
"client gave up, not creating connection",
)
log.Warn().Msg("client gave up, not creating connection")
return nil, &internal.HandlerError{
StatusCode: 400,
Err: req.Context().Err(),
}
}
userCache, err := h.userCache(v2device.UserID)
userCache, err := h.userCache(token.UserID)
if err != nil {
log.Warn().Err(err).Str("user_id", v2device.UserID).Msg("failed to load user cache")
log.Warn().Err(err).Msg("failed to load user cache")
return nil, &internal.HandlerError{
StatusCode: 500,
Err: err,
@ -366,19 +363,53 @@ func (h *SyncLiveHandler) setupConnection(req *http.Request, syncReq *sync3.Requ
// because we *either* do the existing check *or* make a new conn. It's important for CreateConn
// to check for an existing connection though, as it's possible for the client to call /sync
// twice for a new connection.
conn, created := h.ConnMap.CreateConn(sync3.ConnID{
DeviceID: deviceID,
}, func() sync3.ConnHandler {
return NewConnState(v2device.UserID, v2device.DeviceID, userCache, h.GlobalCache, h.Extensions, h.Dispatcher, h.histVec, h.maxPendingEventUpdates)
conn, created := h.ConnMap.CreateConn(connID, func() sync3.ConnHandler {
return NewConnState(token.UserID, token.DeviceID, userCache, h.GlobalCache, h.Extensions, h.Dispatcher, h.histVec, h.maxPendingEventUpdates)
})
if created {
log.Info().Str("user", v2device.UserID).Str("conn_id", conn.ConnID.String()).Msg("created new connection")
log.Info().Msg("created new connection")
} else {
log.Info().Str("user", v2device.UserID).Str("conn_id", conn.ConnID.String()).Msg("using existing connection")
log.Info().Msg("using existing connection")
}
return conn, nil
}
func connIDFromToken(token *sync2.Token) sync3.ConnID {
return sync3.ConnID{
// TODO: change ConnID to be a (user, device) ID pair
DeviceID: token.DeviceID,
}
}
func (h *SyncLiveHandler) identifyUnknownAccessToken(accessToken string) (*sync2.Token, *internal.HandlerError) {
// We don't recognise the given accessToken. Ask the homeserver who owns it.
userID, deviceID, err := h.V2.WhoAmI(accessToken)
if err != nil {
if err == sync2.HTTP401 {
return nil, &internal.HandlerError{
StatusCode: 401,
Err: fmt.Errorf("/whoami returned HTTP 401"),
}
}
log.Warn().Err(err).Msg("failed to get user ID from device ID")
return nil, &internal.HandlerError{
StatusCode: http.StatusBadGateway,
Err: err,
}
}
// Create a brand-new row for this token.
token, err := h.V2Store.TokensTable.Insert(accessToken, userID, deviceID, time.Now())
if err != nil {
log.Warn().Err(err).Str("user", userID).Str("device", deviceID).Msg("failed to insert v2 token")
return nil, &internal.HandlerError{
StatusCode: 500,
Err: err,
}
}
return token, nil
}
func (h *SyncLiveHandler) CacheForUser(userID string) *caches.UserCache {
c, ok := h.userCaches.Load(userID)
if ok {