mirror of
https://github.com/matrix-org/sliding-sync.git
synced 2025-03-10 13:37:11 +00:00
Remember since tokens for each device
Tie it to the device so a since-less v3 sync still pulls in stored pollers
This commit is contained in:
parent
29b17b89e5
commit
b0146bb031
61
sessions.go
61
sessions.go
@ -1,61 +0,0 @@
|
||||
package syncv3
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
|
||||
"github.com/jmoiron/sqlx"
|
||||
_ "github.com/lib/pq"
|
||||
)
|
||||
|
||||
type Session struct {
|
||||
ID string `db:"session_id"`
|
||||
DeviceID string `db:"device_id"`
|
||||
LastToDeviceACK string `db:"last_to_device_ack"`
|
||||
}
|
||||
|
||||
type Sessions struct {
|
||||
db *sqlx.DB
|
||||
upsertSessionStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSessions(postgresURI string) *Sessions {
|
||||
db, err := sqlx.Open("postgres", postgresURI)
|
||||
if err != nil {
|
||||
log.Panic().Err(err).Str("uri", postgresURI).Msg("failed to open SQL DB")
|
||||
}
|
||||
// make sure tables are made
|
||||
db.MustExec(`
|
||||
CREATE SEQUENCE IF NOT EXISTS syncv3_session_id_seq;
|
||||
CREATE TABLE IF NOT EXISTS syncv3_sessions (
|
||||
session_id BIGINT PRIMARY KEY DEFAULT nextval('syncv3_session_id_seq'),
|
||||
-- user_id TEXT NOT NULL, (we don't know the user ID from the access token alone, at least in dendrite!)
|
||||
device_id TEXT NOT NULL,
|
||||
last_to_device_ack TEXT NOT NULL,
|
||||
CONSTRAINT syncv3_sessions_unique UNIQUE (device_id, session_id)
|
||||
);
|
||||
`)
|
||||
return &Sessions{
|
||||
db: db,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Sessions) NewSession(deviceID string) (*Session, error) {
|
||||
var id string
|
||||
err := s.db.QueryRow("INSERT INTO syncv3_sessions(device_id, last_to_device_ack) VALUES($1,'') RETURNING session_id", deviceID).Scan(&id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Session{
|
||||
ID: id,
|
||||
DeviceID: deviceID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Sessions) Session(sessionID, deviceID string) (*Session, error) {
|
||||
var result Session
|
||||
err := s.db.Get(&result, "SELECT * FROM syncv3_sessions WHERE session_id=$1 AND device_id=$2", sessionID, deviceID)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
return &result, nil
|
||||
}
|
36
sqlutil/sql.go
Normal file
36
sqlutil/sql.go
Normal file
@ -0,0 +1,36 @@
|
||||
package sqlutil
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/jmoiron/sqlx"
|
||||
)
|
||||
|
||||
// WithTransaction runs a block of code passing in an SQL transaction
|
||||
// If the code returns an error or panics then the transactions is rolled back
|
||||
// Otherwise the transaction is committed.
|
||||
func WithTransaction(db *sqlx.DB, fn func(txn *sqlx.Tx) error) (err error) {
|
||||
txn, err := db.Beginx()
|
||||
if err != nil {
|
||||
return fmt.Errorf("WithTransaction.Begin: %w", err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
panicErr := recover()
|
||||
if err == nil && panicErr != nil {
|
||||
err = fmt.Errorf("panic: %v", panicErr)
|
||||
}
|
||||
var txnErr error
|
||||
if err != nil {
|
||||
txnErr = txn.Rollback()
|
||||
} else {
|
||||
txnErr = txn.Commit()
|
||||
}
|
||||
if txnErr != nil && err == nil {
|
||||
err = fmt.Errorf("WithTransaction failed to commit/rollback: %w", txnErr)
|
||||
}
|
||||
}()
|
||||
|
||||
err = fn(txn)
|
||||
return
|
||||
}
|
@ -7,6 +7,7 @@ import (
|
||||
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/lib/pq"
|
||||
"github.com/matrix-org/sync-v3/sqlutil"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
@ -130,7 +131,7 @@ func (a *Accumulator) Initialise(roomID string, state []json.RawMessage) error {
|
||||
if len(state) == 0 {
|
||||
return nil
|
||||
}
|
||||
return WithTransaction(a.db, func(txn *sqlx.Tx) error {
|
||||
return sqlutil.WithTransaction(a.db, func(txn *sqlx.Tx) error {
|
||||
// Attempt to short-circuit. This has to be done inside a transaction to make sure
|
||||
// we don't race with multiple calls to Initialise with the same room ID.
|
||||
snapshotID, err := a.roomsTable.CurrentSnapshotID(txn, roomID)
|
||||
@ -211,7 +212,7 @@ func (a *Accumulator) Accumulate(roomID string, timeline []json.RawMessage) erro
|
||||
if len(timeline) == 0 {
|
||||
return nil
|
||||
}
|
||||
return WithTransaction(a.db, func(txn *sqlx.Tx) error {
|
||||
return sqlutil.WithTransaction(a.db, func(txn *sqlx.Tx) error {
|
||||
// Insert the events
|
||||
events := make([]Event, len(timeline))
|
||||
for i := range events {
|
||||
@ -321,32 +322,3 @@ func (a *Accumulator) Delta(roomID string, lastEventNID int64, limit int) (event
|
||||
}
|
||||
return eventsJSON, int64(events[len(events)-1].NID), nil
|
||||
}
|
||||
|
||||
// WithTransaction runs a block of code passing in an SQL transaction
|
||||
// If the code returns an error or panics then the transactions is rolled back
|
||||
// Otherwise the transaction is committed.
|
||||
func WithTransaction(db *sqlx.DB, fn func(txn *sqlx.Tx) error) (err error) {
|
||||
txn, err := db.Beginx()
|
||||
if err != nil {
|
||||
return fmt.Errorf("WithTransaction.Begin: %w", err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
panicErr := recover()
|
||||
if err == nil && panicErr != nil {
|
||||
err = fmt.Errorf("panic: %v", panicErr)
|
||||
}
|
||||
var txnErr error
|
||||
if err != nil {
|
||||
txnErr = txn.Rollback()
|
||||
} else {
|
||||
txnErr = txn.Commit()
|
||||
}
|
||||
if txnErr != nil && err == nil {
|
||||
err = fmt.Errorf("WithTransaction failed to commit/rollback: %w", txnErr)
|
||||
}
|
||||
}()
|
||||
|
||||
err = fn(txn)
|
||||
return
|
||||
}
|
||||
|
@ -6,6 +6,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/sync-v3/state"
|
||||
"github.com/matrix-org/sync-v3/sync3"
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
@ -15,17 +16,19 @@ type Poller struct {
|
||||
DeviceID string
|
||||
Client *Client
|
||||
Accumulator *state.Accumulator
|
||||
Sessions *sync3.Sessions
|
||||
// flag set to true when poll() returns due to expired access tokens
|
||||
Terminated bool
|
||||
logger zerolog.Logger
|
||||
}
|
||||
|
||||
func NewPoller(authHeader, deviceID string, client *Client, accumulator *state.Accumulator) *Poller {
|
||||
func NewPoller(authHeader, deviceID string, client *Client, accumulator *state.Accumulator, sessions *sync3.Sessions) *Poller {
|
||||
return &Poller{
|
||||
AuthorizationHeader: authHeader,
|
||||
DeviceID: deviceID,
|
||||
Client: client,
|
||||
Accumulator: accumulator,
|
||||
Sessions: sessions,
|
||||
Terminated: false,
|
||||
logger: zerolog.New(os.Stdout).With().Timestamp().Logger().With().Str("device", deviceID).Logger().Output(zerolog.ConsoleWriter{
|
||||
Out: os.Stderr,
|
||||
@ -37,7 +40,7 @@ func NewPoller(authHeader, deviceID string, client *Client, accumulator *state.A
|
||||
// Poll will block forever, repeatedly calling v2 sync. Do this in a goroutine.
|
||||
// Returns if the access token gets invalidated. Invokes the callback on first success.
|
||||
func (p *Poller) Poll(since string, callback func()) {
|
||||
p.logger.Info().Msg("v2 poll loop started")
|
||||
p.logger.Info().Str("since", since).Msg("v2 poll loop started")
|
||||
failCount := 0
|
||||
firstTime := true
|
||||
for {
|
||||
@ -63,6 +66,13 @@ func (p *Poller) Poll(since string, callback func()) {
|
||||
failCount = 0
|
||||
p.accumulate(resp)
|
||||
since = resp.NextBatch
|
||||
// persist the since token (TODO: this could get slow if we hammer the DB too much)
|
||||
err = p.Sessions.UpdateDeviceSince(p.DeviceID, since)
|
||||
if err != nil {
|
||||
// non-fatal
|
||||
p.logger.Warn().Str("since", since).Err(err).Msg("failed to persist new since value")
|
||||
}
|
||||
|
||||
if firstTime {
|
||||
firstTime = false
|
||||
callback()
|
||||
|
123
sync3/sessions.go
Normal file
123
sync3/sessions.go
Normal file
@ -0,0 +1,123 @@
|
||||
package sync3
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"os"
|
||||
|
||||
"github.com/jmoiron/sqlx"
|
||||
_ "github.com/lib/pq"
|
||||
"github.com/matrix-org/sync-v3/sqlutil"
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
var log = zerolog.New(os.Stdout).With().Timestamp().Logger().Output(zerolog.ConsoleWriter{
|
||||
Out: os.Stderr,
|
||||
TimeFormat: "15:04:05",
|
||||
})
|
||||
|
||||
// A Session represents a single device's sync stream. One device can have many sessions open at
|
||||
// once. Sessions are created when devices sync without a since token. Sessions are destroyed
|
||||
// after a configurable period of inactivity.
|
||||
type Session struct {
|
||||
ID string `db:"session_id"`
|
||||
DeviceID string `db:"device_id"`
|
||||
LastToDeviceACK string `db:"last_to_device_ack"`
|
||||
|
||||
Since string `db:"since"`
|
||||
}
|
||||
|
||||
type Sessions struct {
|
||||
db *sqlx.DB
|
||||
upsertSessionStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSessions(postgresURI string) *Sessions {
|
||||
db, err := sqlx.Open("postgres", postgresURI)
|
||||
if err != nil {
|
||||
log.Panic().Err(err).Str("uri", postgresURI).Msg("failed to open SQL DB")
|
||||
}
|
||||
// make sure tables are made
|
||||
db.MustExec(`
|
||||
CREATE SEQUENCE IF NOT EXISTS syncv3_session_id_seq;
|
||||
CREATE TABLE IF NOT EXISTS syncv3_sessions (
|
||||
session_id BIGINT PRIMARY KEY DEFAULT nextval('syncv3_session_id_seq'),
|
||||
device_id TEXT NOT NULL,
|
||||
last_to_device_ack TEXT NOT NULL,
|
||||
CONSTRAINT syncv3_sessions_unique UNIQUE (device_id, session_id)
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS syncv3_sessions_v2devices (
|
||||
-- user_id TEXT NOT NULL, (we don't know the user ID from the access token alone, at least in dendrite!)
|
||||
device_id TEXT PRIMARY KEY,
|
||||
since TEXT NOT NULL
|
||||
);
|
||||
`)
|
||||
return &Sessions{
|
||||
db: db,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Sessions) NewSession(deviceID string) (*Session, error) {
|
||||
var session *Session
|
||||
err := sqlutil.WithTransaction(s.db, func(txn *sqlx.Tx) error {
|
||||
// make a new session
|
||||
var id string
|
||||
err := txn.QueryRow("INSERT INTO syncv3_sessions(device_id, last_to_device_ack) VALUES($1,'') RETURNING session_id", deviceID).Scan(&id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// make sure there is a device entry for this device ID. If one already exists, don't clobber
|
||||
// the since value else we'll forget our position!
|
||||
result, err := txn.Exec(`
|
||||
INSERT INTO syncv3_sessions_v2devices(device_id, since) VALUES($1,$2)
|
||||
ON CONFLICT (device_id) DO NOTHING`,
|
||||
deviceID, "",
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// if we inserted a row that means it's a brand new device ergo there is no since token
|
||||
if ra, err := result.RowsAffected(); err == nil && ra == 1 {
|
||||
// we inserted a new row, no need to query the since value
|
||||
session = &Session{
|
||||
ID: id,
|
||||
DeviceID: deviceID,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Return the since value as we may start a new poller with this session.
|
||||
var since string
|
||||
err = txn.QueryRow("SELECT since FROM syncv3_sessions_v2devices WHERE device_id = $1", deviceID).Scan(&since)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
session = &Session{
|
||||
ID: id,
|
||||
DeviceID: deviceID,
|
||||
Since: since,
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return session, err
|
||||
}
|
||||
|
||||
func (s *Sessions) Session(sessionID, deviceID string) (*Session, error) {
|
||||
var result Session
|
||||
err := s.db.Get(&result,
|
||||
`SELECT session_id, device_id, last_to_device_ack, since FROM syncv3_sessions
|
||||
LEFT JOIN syncv3_sessions_v2devices
|
||||
ON syncv3_sessions.device_id = syncv3_sessions_v2devices.device_id
|
||||
WHERE session_id=$1 AND syncv3_sessions.device_id=$2`,
|
||||
sessionID, deviceID,
|
||||
)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
func (s *Sessions) UpdateDeviceSince(deviceID, since string) error {
|
||||
_, err := s.db.Exec(`UPDATE syncv3_sessions_v2devices SET since = $1 WHERE device_id = $2`, since, deviceID)
|
||||
return err
|
||||
}
|
@ -5,31 +5,29 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
// V3_S1_F2-3-4-5-6_V2_s111_222_333_444
|
||||
// "V3_" $SESSION "_" $FILTERS "_V2_" $V2TOKEN
|
||||
// V3_S1_F2-3-4-5-6
|
||||
// "V3_" $SESSION "_" $FILTERS
|
||||
type Token struct {
|
||||
SessionID string
|
||||
FilterIDs []string
|
||||
V2token string
|
||||
}
|
||||
|
||||
func (t Token) String() string {
|
||||
filters := strings.Join(t.FilterIDs, "-")
|
||||
return fmt.Sprintf("V3_S%s_F%s_V2_%s", t.SessionID, filters, t.V2token)
|
||||
return fmt.Sprintf("V3_S%s_F%s", t.SessionID, filters)
|
||||
}
|
||||
|
||||
func NewSyncToken(since string) (*Token, error) {
|
||||
segments := strings.SplitN(since, "_", 5)
|
||||
if len(segments) != 5 {
|
||||
segments := strings.SplitN(since, "_", 3)
|
||||
if len(segments) != 3 {
|
||||
return nil, fmt.Errorf("not a sync v3 token")
|
||||
}
|
||||
if segments[0] != "V3" || segments[3] != "V2" {
|
||||
if segments[0] != "V3" {
|
||||
return nil, fmt.Errorf("not a sync v3 token: %s", since)
|
||||
}
|
||||
filters := strings.TrimPrefix(segments[2], "F")
|
||||
return &Token{
|
||||
SessionID: strings.TrimPrefix(segments[1], "S"),
|
||||
FilterIDs: strings.Split(filters, "-"),
|
||||
V2token: segments[4],
|
||||
}, nil
|
||||
}
|
||||
|
91
v3.go
91
v3.go
@ -49,7 +49,7 @@ func RunSyncV3Server(destinationServer, bindAddr, postgresDBURI string) {
|
||||
},
|
||||
DestinationServer: destinationServer,
|
||||
},
|
||||
Sessions: NewSessions(postgresDBURI),
|
||||
Sessions: sync3.NewSessions(postgresDBURI),
|
||||
Accumulator: state.NewAccumulator(postgresDBURI),
|
||||
Pollers: make(map[string]*sync2.Poller),
|
||||
pollerMu: &sync.Mutex{},
|
||||
@ -67,9 +67,18 @@ func RunSyncV3Server(destinationServer, bindAddr, postgresDBURI string) {
|
||||
}
|
||||
}
|
||||
|
||||
type handlerError struct {
|
||||
StatusCode int
|
||||
err error
|
||||
}
|
||||
|
||||
func (e *handlerError) Error() string {
|
||||
return fmt.Sprintf("HTTP %d : %s", e.StatusCode, e.err.Error())
|
||||
}
|
||||
|
||||
type SyncV3Handler struct {
|
||||
V2 *sync2.Client
|
||||
Sessions *Sessions
|
||||
Sessions *sync3.Sessions
|
||||
Accumulator *state.Accumulator
|
||||
|
||||
pollerMu *sync.Mutex
|
||||
@ -77,16 +86,44 @@ type SyncV3Handler struct {
|
||||
}
|
||||
|
||||
func (h *SyncV3Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||
err := h.serve(w, req)
|
||||
if err != nil {
|
||||
w.WriteHeader(err.StatusCode)
|
||||
w.Write(asJSONError(err))
|
||||
}
|
||||
}
|
||||
|
||||
func (h *SyncV3Handler) serve(w http.ResponseWriter, req *http.Request) *handlerError {
|
||||
// Get or create a Session
|
||||
session, _, err := h.getOrCreateSession(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
log.Info().Str("session", session.ID).Str("device", session.DeviceID).Msg("recv /v3/sync")
|
||||
|
||||
// make sure we have a poller for this device
|
||||
h.ensurePolling(req.Header.Get("Authorization"), session)
|
||||
|
||||
// return data based on filters
|
||||
|
||||
w.WriteHeader(200)
|
||||
w.Write([]byte(sync3.Token{
|
||||
SessionID: session.ID,
|
||||
FilterIDs: []string{},
|
||||
}.String()))
|
||||
return nil
|
||||
}
|
||||
|
||||
// getOrCreateSession retrieves an existing session if ?since= is set, else makes a new session.
|
||||
// Returns a session or an error. Returns a token if and only if there is an existing session.
|
||||
func (h *SyncV3Handler) getOrCreateSession(req *http.Request) (*sync3.Session, *sync3.Token, *handlerError) {
|
||||
var session *sync3.Session
|
||||
var tokv3 *sync3.Token
|
||||
deviceID, err := deviceIDFromRequest(req)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("failed to get device ID from request")
|
||||
w.WriteHeader(400)
|
||||
w.Write(asJSONError(err))
|
||||
return
|
||||
return nil, nil, &handlerError{400, err}
|
||||
}
|
||||
// Get or create a Session
|
||||
var session *Session
|
||||
var tokv3 *sync3.Token
|
||||
sincev3 := req.URL.Query().Get("since")
|
||||
if sincev3 == "" {
|
||||
session, err = h.Sessions.NewSession(deviceID)
|
||||
@ -94,58 +131,36 @@ func (h *SyncV3Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||
tokv3, err = sync3.NewSyncToken(sincev3)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("failed to parse sync v3 token")
|
||||
w.WriteHeader(400)
|
||||
w.Write(asJSONError(err))
|
||||
return
|
||||
return nil, nil, &handlerError{400, err}
|
||||
}
|
||||
session, err = h.Sessions.Session(tokv3.SessionID, deviceID)
|
||||
}
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Str("device", deviceID).Msg("failed to ensure Session existed for device")
|
||||
w.WriteHeader(500)
|
||||
w.Write(asJSONError(err))
|
||||
return
|
||||
return nil, nil, &handlerError{500, err}
|
||||
}
|
||||
log.Info().Str("session", session.ID).Str("device", session.DeviceID).Msg("recv /v3/sync")
|
||||
|
||||
// map sync v3 token to sync v2 token
|
||||
var sincev2 string
|
||||
if tokv3 != nil {
|
||||
sincev2 = tokv3.V2token
|
||||
}
|
||||
|
||||
// make sure we have a poller for this device
|
||||
h.ensurePolling(req.Header.Get("Authorization"), session.DeviceID, sincev2)
|
||||
|
||||
// return data based on filters
|
||||
|
||||
w.WriteHeader(200)
|
||||
w.Write([]byte(sync3.Token{
|
||||
V2token: "v2tokengoeshere",
|
||||
SessionID: session.ID,
|
||||
FilterIDs: []string{},
|
||||
}.String()))
|
||||
return session, tokv3, nil
|
||||
}
|
||||
|
||||
// ensurePolling makes sure there is a poller for this device, making one if need be.
|
||||
// Blocks until at least 1 sync is done if and only if the poller was just created.
|
||||
// This ensures that calls to the database will return data.
|
||||
func (h *SyncV3Handler) ensurePolling(authHeader, deviceID, since string) {
|
||||
func (h *SyncV3Handler) ensurePolling(authHeader string, session *sync3.Session) {
|
||||
h.pollerMu.Lock()
|
||||
poller, ok := h.Pollers[deviceID]
|
||||
poller, ok := h.Pollers[session.DeviceID]
|
||||
// either no poller exists or it did but it died
|
||||
if ok && !poller.Terminated {
|
||||
h.pollerMu.Unlock()
|
||||
return
|
||||
}
|
||||
// replace the poller
|
||||
poller = sync2.NewPoller(authHeader, deviceID, h.V2, h.Accumulator)
|
||||
poller = sync2.NewPoller(authHeader, session.DeviceID, h.V2, h.Accumulator, h.Sessions)
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go poller.Poll(since, func() {
|
||||
go poller.Poll(session.Since, func() {
|
||||
wg.Done()
|
||||
})
|
||||
h.Pollers[deviceID] = poller
|
||||
h.Pollers[session.DeviceID] = poller
|
||||
h.pollerMu.Unlock()
|
||||
wg.Wait()
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user