mirror of
https://github.com/matrix-org/sliding-sync.git
synced 2025-03-10 13:37:11 +00:00
setupConnection: lookup tokens, not devices
This commit is contained in:
parent
e71d954030
commit
e670812f11
@ -1,24 +1,16 @@
|
|||||||
package internal
|
package internal
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/sha256"
|
|
||||||
"encoding/hex"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
func HashedTokenFromRequest(req *http.Request) (hashAccessToken string, accessToken string, err error) {
|
func ExtractAccessToken(req *http.Request) (accessToken string, err error) {
|
||||||
// return a hash of the access token
|
|
||||||
ah := req.Header.Get("Authorization")
|
ah := req.Header.Get("Authorization")
|
||||||
if ah == "" {
|
if ah == "" {
|
||||||
return "", "", fmt.Errorf("missing Authorization header")
|
return "", fmt.Errorf("missing Authorization header")
|
||||||
}
|
}
|
||||||
accessToken = strings.TrimPrefix(ah, "Bearer ")
|
accessToken = strings.TrimPrefix(ah, "Bearer ")
|
||||||
// important that this is a cryptographically secure hash function to prevent
|
return accessToken, nil
|
||||||
// preimage attacks where Eve can use a fake token to hash to an existing device ID
|
|
||||||
// on the server.
|
|
||||||
hash := sha256.New()
|
|
||||||
hash.Write([]byte(accessToken))
|
|
||||||
return hex.EncodeToString(hash.Sum(nil)), accessToken, nil
|
|
||||||
}
|
}
|
||||||
|
@ -1,24 +0,0 @@
|
|||||||
package internal
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/http"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestDeviceIDFromRequest(t *testing.T) {
|
|
||||||
req, _ := http.NewRequest("POST", "http://localhost:8008", nil)
|
|
||||||
req.Header.Set("Authorization", "Bearer A")
|
|
||||||
deviceIDA, _, err := HashedTokenFromRequest(req)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("HashedTokenFromRequest returned %s", err)
|
|
||||||
}
|
|
||||||
req.Header.Set("Authorization", "Bearer B")
|
|
||||||
deviceIDB, _, err := HashedTokenFromRequest(req)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("HashedTokenFromRequest returned %s", err)
|
|
||||||
}
|
|
||||||
if deviceIDA == deviceIDB {
|
|
||||||
t.Fatalf("HashedTokenFromRequest: hashed to same device ID: %s", deviceIDA)
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
@ -3,6 +3,7 @@ package handler
|
|||||||
import "C"
|
import "C"
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"database/sql"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/getsentry/sentry-go"
|
"github.com/getsentry/sentry-go"
|
||||||
@ -194,6 +195,10 @@ func (h *SyncLiveHandler) serve(w http.ResponseWriter, req *http.Request) error
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
hlog.FromRequest(req).UpdateContext(func(c zerolog.Context) zerolog.Context {
|
||||||
|
c.Str("txn_id", requestBody.TxnID)
|
||||||
|
return c
|
||||||
|
})
|
||||||
for listKey, l := range requestBody.Lists {
|
for listKey, l := range requestBody.Lists {
|
||||||
if l.Ranges != nil && !l.Ranges.Valid() {
|
if l.Ranges != nil && !l.Ranges.Valid() {
|
||||||
return &internal.HandlerError{
|
return &internal.HandlerError{
|
||||||
@ -276,25 +281,50 @@ func (h *SyncLiveHandler) serve(w http.ResponseWriter, req *http.Request) error
|
|||||||
// It also sets a v2 sync poll loop going if one didn't exist already for this user.
|
// It also sets a v2 sync poll loop going if one didn't exist already for this user.
|
||||||
// When this function returns, the connection is alive and active.
|
// When this function returns, the connection is alive and active.
|
||||||
func (h *SyncLiveHandler) setupConnection(req *http.Request, syncReq *sync3.Request, containsPos bool) (*sync3.Conn, error) {
|
func (h *SyncLiveHandler) setupConnection(req *http.Request, syncReq *sync3.Request, containsPos bool) (*sync3.Conn, error) {
|
||||||
log := hlog.FromRequest(req)
|
|
||||||
var conn *sync3.Conn
|
var conn *sync3.Conn
|
||||||
|
// Extract an access token
|
||||||
// Identify the device
|
accessToken, err := internal.ExtractAccessToken(req)
|
||||||
deviceID, accessToken, err := internal.HashedTokenFromRequest(req)
|
|
||||||
if err != nil || accessToken == "" {
|
if err != nil || accessToken == "" {
|
||||||
log.Warn().Err(err).Msg("failed to get device ID from request")
|
hlog.FromRequest(req).Warn().Err(err).Msg("failed to get access token from request")
|
||||||
return nil, &internal.HandlerError{
|
return nil, &internal.HandlerError{
|
||||||
StatusCode: 400,
|
StatusCode: http.StatusUnauthorized,
|
||||||
Err: err,
|
Err: err,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Try to lookup a record of this token
|
||||||
|
var token *sync2.Token
|
||||||
|
token, err = h.V2Store.TokensTable.Token(accessToken)
|
||||||
|
if err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
hlog.FromRequest(req).Info().Msg("Received connection from unknown access token, querying with homeserver")
|
||||||
|
newToken, herr := h.identifyUnknownAccessToken(accessToken)
|
||||||
|
if herr != nil {
|
||||||
|
return nil, herr
|
||||||
|
}
|
||||||
|
token = newToken
|
||||||
|
} else {
|
||||||
|
hlog.FromRequest(req).Err(err).Msg("Failed to lookup access token")
|
||||||
|
return nil, &internal.HandlerError{
|
||||||
|
StatusCode: http.StatusInternalServerError,
|
||||||
|
Err: err,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
log := hlog.FromRequest(req).With().Str("user", token.UserID).Str("device", token.DeviceID).Logger()
|
||||||
|
|
||||||
|
// Record the fact that we've recieved a request from this token
|
||||||
|
err = h.V2Store.TokensTable.MaybeUpdateLastSeen(token, time.Now())
|
||||||
|
if err != nil {
|
||||||
|
// Not fatal---log and continue.
|
||||||
|
log.Warn().Msg("Unable to update last seen timestamp")
|
||||||
|
}
|
||||||
|
|
||||||
|
connID := connIDFromToken(token)
|
||||||
// client thinks they have a connection
|
// client thinks they have a connection
|
||||||
if containsPos {
|
if containsPos {
|
||||||
// Lookup the connection
|
// Lookup the connection
|
||||||
conn = h.ConnMap.Conn(sync3.ConnID{
|
conn = h.ConnMap.Conn(connID)
|
||||||
DeviceID: deviceID,
|
|
||||||
})
|
|
||||||
if conn != nil {
|
if conn != nil {
|
||||||
log.Trace().Str("conn", conn.ConnID.String()).Msg("reusing conn")
|
log.Trace().Str("conn", conn.ConnID.String()).Msg("reusing conn")
|
||||||
return conn, nil
|
return conn, nil
|
||||||
@ -303,55 +333,22 @@ func (h *SyncLiveHandler) setupConnection(req *http.Request, syncReq *sync3.Requ
|
|||||||
return nil, internal.ExpiredSessionError()
|
return nil, internal.ExpiredSessionError()
|
||||||
}
|
}
|
||||||
|
|
||||||
// We're going to make a new connection
|
log.Trace().Msg("checking poller exists and is running")
|
||||||
// Ensure we have the v2 side of things hooked up
|
h.V3Pub.EnsurePolling(token.UserID, token.DeviceID)
|
||||||
v2device, err := h.V2Store.DevicesTable.InsertDevice(deviceID, accessToken)
|
log.Trace().Msg("poller exists and is running")
|
||||||
if err != nil {
|
|
||||||
log.Warn().Err(err).Str("device_id", deviceID).Msg("failed to insert v2 device")
|
|
||||||
return nil, &internal.HandlerError{
|
|
||||||
StatusCode: 500,
|
|
||||||
Err: err,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if v2device.UserID == "" {
|
|
||||||
v2device.UserID, _, err = h.V2.WhoAmI(accessToken)
|
|
||||||
if err != nil {
|
|
||||||
if err == sync2.HTTP401 {
|
|
||||||
return nil, &internal.HandlerError{
|
|
||||||
StatusCode: 401,
|
|
||||||
Err: fmt.Errorf("/whoami returned HTTP 401"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
log.Warn().Err(err).Str("device_id", deviceID).Msg("failed to get user ID from device ID")
|
|
||||||
return nil, &internal.HandlerError{
|
|
||||||
StatusCode: http.StatusBadGateway,
|
|
||||||
Err: err,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if err = h.V2Store.DevicesTable.UpdateUserIDForDevice(deviceID, v2device.UserID); err != nil {
|
|
||||||
log.Warn().Err(err).Str("device_id", deviceID).Msg("failed to persist user ID -> device ID mapping")
|
|
||||||
// non-fatal, we can still work without doing this
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Trace().Str("user", v2device.UserID).Msg("checking poller exists and is running")
|
|
||||||
h.V3Pub.EnsurePolling(v2device.UserID, v2device.DeviceID)
|
|
||||||
log.Trace().Str("user", v2device.UserID).Msg("poller exists and is running")
|
|
||||||
// this may take a while so if the client has given up (e.g timed out) by this point, just stop.
|
// this may take a while so if the client has given up (e.g timed out) by this point, just stop.
|
||||||
// We'll be quicker next time as the poller will already exist.
|
// We'll be quicker next time as the poller will already exist.
|
||||||
if req.Context().Err() != nil {
|
if req.Context().Err() != nil {
|
||||||
log.Warn().Str("user_id", v2device.UserID).Msg(
|
log.Warn().Msg("client gave up, not creating connection")
|
||||||
"client gave up, not creating connection",
|
|
||||||
)
|
|
||||||
return nil, &internal.HandlerError{
|
return nil, &internal.HandlerError{
|
||||||
StatusCode: 400,
|
StatusCode: 400,
|
||||||
Err: req.Context().Err(),
|
Err: req.Context().Err(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
userCache, err := h.userCache(v2device.UserID)
|
userCache, err := h.userCache(token.UserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warn().Err(err).Str("user_id", v2device.UserID).Msg("failed to load user cache")
|
log.Warn().Err(err).Msg("failed to load user cache")
|
||||||
return nil, &internal.HandlerError{
|
return nil, &internal.HandlerError{
|
||||||
StatusCode: 500,
|
StatusCode: 500,
|
||||||
Err: err,
|
Err: err,
|
||||||
@ -366,19 +363,53 @@ func (h *SyncLiveHandler) setupConnection(req *http.Request, syncReq *sync3.Requ
|
|||||||
// because we *either* do the existing check *or* make a new conn. It's important for CreateConn
|
// because we *either* do the existing check *or* make a new conn. It's important for CreateConn
|
||||||
// to check for an existing connection though, as it's possible for the client to call /sync
|
// to check for an existing connection though, as it's possible for the client to call /sync
|
||||||
// twice for a new connection.
|
// twice for a new connection.
|
||||||
conn, created := h.ConnMap.CreateConn(sync3.ConnID{
|
conn, created := h.ConnMap.CreateConn(connID, func() sync3.ConnHandler {
|
||||||
DeviceID: deviceID,
|
return NewConnState(token.UserID, token.DeviceID, userCache, h.GlobalCache, h.Extensions, h.Dispatcher, h.histVec, h.maxPendingEventUpdates)
|
||||||
}, func() sync3.ConnHandler {
|
|
||||||
return NewConnState(v2device.UserID, v2device.DeviceID, userCache, h.GlobalCache, h.Extensions, h.Dispatcher, h.histVec, h.maxPendingEventUpdates)
|
|
||||||
})
|
})
|
||||||
if created {
|
if created {
|
||||||
log.Info().Str("user", v2device.UserID).Str("conn_id", conn.ConnID.String()).Msg("created new connection")
|
log.Info().Msg("created new connection")
|
||||||
} else {
|
} else {
|
||||||
log.Info().Str("user", v2device.UserID).Str("conn_id", conn.ConnID.String()).Msg("using existing connection")
|
log.Info().Msg("using existing connection")
|
||||||
}
|
}
|
||||||
return conn, nil
|
return conn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func connIDFromToken(token *sync2.Token) sync3.ConnID {
|
||||||
|
return sync3.ConnID{
|
||||||
|
// TODO: change ConnID to be a (user, device) ID pair
|
||||||
|
DeviceID: token.DeviceID,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *SyncLiveHandler) identifyUnknownAccessToken(accessToken string) (*sync2.Token, *internal.HandlerError) {
|
||||||
|
// We don't recognise the given accessToken. Ask the homeserver who owns it.
|
||||||
|
userID, deviceID, err := h.V2.WhoAmI(accessToken)
|
||||||
|
if err != nil {
|
||||||
|
if err == sync2.HTTP401 {
|
||||||
|
return nil, &internal.HandlerError{
|
||||||
|
StatusCode: 401,
|
||||||
|
Err: fmt.Errorf("/whoami returned HTTP 401"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
log.Warn().Err(err).Msg("failed to get user ID from device ID")
|
||||||
|
return nil, &internal.HandlerError{
|
||||||
|
StatusCode: http.StatusBadGateway,
|
||||||
|
Err: err,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a brand-new row for this token.
|
||||||
|
token, err := h.V2Store.TokensTable.Insert(accessToken, userID, deviceID, time.Now())
|
||||||
|
if err != nil {
|
||||||
|
log.Warn().Err(err).Str("user", userID).Str("device", deviceID).Msg("failed to insert v2 token")
|
||||||
|
return nil, &internal.HandlerError{
|
||||||
|
StatusCode: 500,
|
||||||
|
Err: err,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return token, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (h *SyncLiveHandler) CacheForUser(userID string) *caches.UserCache {
|
func (h *SyncLiveHandler) CacheForUser(userID string) *caches.UserCache {
|
||||||
c, ok := h.userCaches.Load(userID)
|
c, ok := h.userCaches.Load(userID)
|
||||||
if ok {
|
if ok {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user