mirror of
https://github.com/matrix-org/sliding-sync.git
synced 2025-03-10 13:37:11 +00:00
Add v2 polling; chunk event insertion to support Matrix HQ
This commit is contained in:
parent
502274b475
commit
159530bed1
@ -153,7 +153,7 @@ func (a *Accumulator) Initialise(roomID string, state []json.RawMessage) error {
|
||||
}
|
||||
numNew, err := a.eventsTable.Insert(txn, events)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("failed to insert events: %w", err)
|
||||
}
|
||||
if numNew == 0 {
|
||||
// we don't have a current snapshot for this room but yet no events are new,
|
||||
@ -171,7 +171,7 @@ func (a *Accumulator) Initialise(roomID string, state []json.RawMessage) error {
|
||||
}
|
||||
nids, err := a.eventsTable.SelectNIDsByIDs(txn, eventIDs)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("failed to select NIDs for inserted events: %w", err)
|
||||
}
|
||||
|
||||
// Make a current snapshot
|
||||
@ -181,7 +181,7 @@ func (a *Accumulator) Initialise(roomID string, state []json.RawMessage) error {
|
||||
}
|
||||
err = a.snapshotTable.Insert(txn, snapshot)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("failed to insert snapshot: %w", err)
|
||||
}
|
||||
|
||||
// Increment the ref counter
|
||||
@ -199,14 +199,14 @@ func (a *Accumulator) Initialise(roomID string, state []json.RawMessage) error {
|
||||
// received from the server.
|
||||
//
|
||||
// This function does several things:
|
||||
// - It ensures all events are persisted in the database. This is shared amongst users.
|
||||
// - If all events have been stored before, then it short circuits and returns.
|
||||
// This is because we must have already processed this part of the timeline in order for the event
|
||||
// to exist in the database, and the sync stream is already linearised for us.
|
||||
// - Else it creates a new room state snapshot if the timeline contains state events (as this now represents the current state)
|
||||
// - It checks if there are outstanding references for the previous snapshot, and if not, removes the old snapshot from the database.
|
||||
// References are made when clients have synced up to a given snapshot (hence may paginate at that point).
|
||||
// The server itself also holds a ref to the current state, which is then moved to the new current state.
|
||||
// - It ensures all events are persisted in the database. This is shared amongst users.
|
||||
// - If all events have been stored before, then it short circuits and returns.
|
||||
// This is because we must have already processed this part of the timeline in order for the event
|
||||
// to exist in the database, and the sync stream is already linearised for us.
|
||||
// - Else it creates a new room state snapshot if the timeline contains state events (as this now represents the current state)
|
||||
// - It checks if there are outstanding references for the previous snapshot, and if not, removes the old snapshot from the database.
|
||||
// References are made when clients have synced up to a given snapshot (hence may paginate at that point).
|
||||
// The server itself also holds a ref to the current state, which is then moved to the new current state.
|
||||
func (a *Accumulator) Accumulate(roomID string, timeline []json.RawMessage) error {
|
||||
if len(timeline) == 0 {
|
||||
return nil
|
||||
|
@ -75,13 +75,21 @@ func (t *EventTable) Insert(txn *sqlx.Tx, events []Event) (int, error) {
|
||||
}
|
||||
events[i] = ev
|
||||
}
|
||||
result, err := txn.NamedExec(`INSERT INTO syncv3_events (event_id, room_id, event)
|
||||
VALUES (:event_id, :room_id, :event) ON CONFLICT (event_id) DO NOTHING`, events)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
chunks := chunkify(3, 65535, events)
|
||||
var rowsAffected int64
|
||||
for _, chunk := range chunks {
|
||||
result, err := txn.NamedExec(`INSERT INTO syncv3_events (event_id, room_id, event)
|
||||
VALUES (:event_id, :room_id, :event) ON CONFLICT (event_id) DO NOTHING`, chunk)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
ra, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
rowsAffected += ra
|
||||
}
|
||||
ra, err := result.RowsAffected()
|
||||
return int(ra), err
|
||||
return int(rowsAffected), nil
|
||||
}
|
||||
|
||||
func (t *EventTable) SelectByNIDs(txn *sqlx.Tx, nids []int64) (events []Event, err error) {
|
||||
@ -143,3 +151,30 @@ func (t *EventTable) SelectEventsBetween(txn *sqlx.Tx, roomID string, lowerExclu
|
||||
)
|
||||
return events, err
|
||||
}
|
||||
|
||||
// chunkify will break up things to be inserted based on the number of params in the statement.
|
||||
// It is required because postgres has a limit on the number of params in a single statement (65535).
|
||||
// Inserting events using NamedExec involves 3n params (n=number of events), meaning it's easy to hit
|
||||
// the limit in rooms like Matrix HQ. This function breaks up the events into chunks which can be
|
||||
// batch inserted in multiple statements. Without this, you'll see errors like:
|
||||
// "pq: got 95331 parameters but PostgreSQL only supports 65535 parameters"
|
||||
func chunkify(numParamsPerStmt, maxParamsPerCall int, entries []Event) [][]Event {
|
||||
// common case, most things are small
|
||||
if (len(entries) * numParamsPerStmt) <= maxParamsPerCall {
|
||||
return [][]Event{
|
||||
entries,
|
||||
}
|
||||
}
|
||||
var chunks [][]Event
|
||||
// work out how many events can fit in a chunk
|
||||
numEntriesPerChunk := (maxParamsPerCall / numParamsPerStmt)
|
||||
for i := 0; i < len(entries); i += numEntriesPerChunk {
|
||||
endIndex := i + numEntriesPerChunk
|
||||
if endIndex > len(entries) {
|
||||
endIndex = len(entries)
|
||||
}
|
||||
chunks = append(chunks, entries[i:endIndex])
|
||||
}
|
||||
|
||||
return chunks
|
||||
}
|
||||
|
@ -249,3 +249,71 @@ func TestEventTableSelectEventsBetween(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestChunkify(t *testing.T) {
|
||||
// Make 100 dummy events
|
||||
events := make([]Event, 100)
|
||||
for i := 0; i < len(events); i++ {
|
||||
events[i] = Event{
|
||||
NID: i,
|
||||
}
|
||||
}
|
||||
testCases := []struct {
|
||||
name string
|
||||
numParamsPerStmt int
|
||||
maxParamsPerCall int
|
||||
chunkSizes []int // length = number of chunks wanted, ints = events in that chunk
|
||||
}{
|
||||
{
|
||||
name: "below chunk limit returns 1 chunk",
|
||||
numParamsPerStmt: 3,
|
||||
maxParamsPerCall: 400,
|
||||
chunkSizes: []int{100},
|
||||
},
|
||||
{
|
||||
name: "just above chunk limit returns 2 chunks",
|
||||
numParamsPerStmt: 3,
|
||||
maxParamsPerCall: 297,
|
||||
chunkSizes: []int{99, 1},
|
||||
},
|
||||
{
|
||||
name: "way above chunk limit returns many chunks",
|
||||
numParamsPerStmt: 3,
|
||||
maxParamsPerCall: 30,
|
||||
chunkSizes: []int{10, 10, 10, 10, 10, 10, 10, 10, 10, 10},
|
||||
},
|
||||
{
|
||||
name: "fractional division rounds down",
|
||||
numParamsPerStmt: 3,
|
||||
maxParamsPerCall: 298,
|
||||
chunkSizes: []int{99, 1},
|
||||
},
|
||||
{
|
||||
name: "fractional division rounds down",
|
||||
numParamsPerStmt: 3,
|
||||
maxParamsPerCall: 299,
|
||||
chunkSizes: []int{99, 1},
|
||||
},
|
||||
}
|
||||
for _, tc := range testCases {
|
||||
testCase := tc
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
chunks := chunkify(testCase.numParamsPerStmt, testCase.maxParamsPerCall, events)
|
||||
if len(chunks) != len(testCase.chunkSizes) {
|
||||
t.Fatalf("got %d chunks, want %d", len(chunks), len(testCase.chunkSizes))
|
||||
}
|
||||
eventNID := 0
|
||||
for i := 0; i < len(chunks); i++ {
|
||||
if len(chunks[i]) != testCase.chunkSizes[i] {
|
||||
t.Errorf("chunk %d got %d elements, want %d", i, len(chunks[i]), testCase.chunkSizes[i])
|
||||
}
|
||||
for j, ev := range chunks[i] {
|
||||
if ev.NID != eventNID {
|
||||
t.Errorf("chunk %d got wrong event in position %d: got NID %d want NID %d", i, j, ev.NID, eventNID)
|
||||
}
|
||||
eventNID += 1
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -1,4 +1,4 @@
|
||||
package syncv3
|
||||
package v2
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
@ -8,40 +8,44 @@ import (
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
type V2 struct {
|
||||
// Client represents a Sync v2 Client.
|
||||
// One client can be shared among many users.
|
||||
type Client struct {
|
||||
Client *http.Client
|
||||
DestinationServer string
|
||||
}
|
||||
|
||||
func (v *V2) DoSyncV2(authHeader, since string) (*SyncV2Response, error) {
|
||||
qps := ""
|
||||
// DoSyncV2 performs a sync v2 request. Returns the sync response and the response status code
|
||||
// or an error
|
||||
func (v *Client) DoSyncV2(authHeader, since string) (*SyncResponse, int, error) {
|
||||
qps := "?timeout=30000"
|
||||
if since != "" {
|
||||
qps = "?since=" + since
|
||||
qps += "&since=" + since
|
||||
}
|
||||
req, err := http.NewRequest(
|
||||
"GET", v.DestinationServer+"/_matrix/client/r0/sync"+qps, nil,
|
||||
)
|
||||
req.Header.Set("Authorization", authHeader)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("DoSyncV2: NewRequest failed: %w", err)
|
||||
return nil, 0, fmt.Errorf("DoSyncV2: NewRequest failed: %w", err)
|
||||
}
|
||||
res, err := v.Client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("DoSyncV2: request failed: %w", err)
|
||||
return nil, 0, fmt.Errorf("DoSyncV2: request failed: %w", err)
|
||||
}
|
||||
switch res.StatusCode {
|
||||
case 200:
|
||||
var svr SyncV2Response
|
||||
var svr SyncResponse
|
||||
if err := json.NewDecoder(res.Body).Decode(&svr); err != nil {
|
||||
return nil, fmt.Errorf("DoSyncV2: response body decode JSON failed: %w", err)
|
||||
return nil, 0, fmt.Errorf("DoSyncV2: response body decode JSON failed: %w", err)
|
||||
}
|
||||
return &svr, nil
|
||||
return &svr, 200, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("DoSyncV2: response returned %s", res.Status)
|
||||
return nil, res.StatusCode, fmt.Errorf("DoSyncV2: response returned %s", res.Status)
|
||||
}
|
||||
}
|
||||
|
||||
type SyncV2Response struct {
|
||||
type SyncResponse struct {
|
||||
NextBatch string `json:"next_batch"`
|
||||
AccountData struct {
|
||||
Events []gomatrixserverlib.ClientEvent `json:"events,omitempty"`
|
90
v2/poller.go
Normal file
90
v2/poller.go
Normal file
@ -0,0 +1,90 @@
|
||||
package v2
|
||||
|
||||
import (
|
||||
"math"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/sync-v3/state"
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
// Poller can automatically poll the sync v2 endpoint and accumulate the responses in the accumulator
|
||||
type Poller struct {
|
||||
AuthorizationHeader string
|
||||
DeviceID string
|
||||
Client *Client
|
||||
Accumulator *state.Accumulator
|
||||
// flag set to true when poll() returns due to expired access tokens
|
||||
Terminated bool
|
||||
logger zerolog.Logger
|
||||
}
|
||||
|
||||
func NewPoller(authHeader, deviceID string, client *Client, accumulator *state.Accumulator) *Poller {
|
||||
return &Poller{
|
||||
AuthorizationHeader: authHeader,
|
||||
DeviceID: deviceID,
|
||||
Client: client,
|
||||
Accumulator: accumulator,
|
||||
Terminated: false,
|
||||
logger: zerolog.New(os.Stdout).With().Timestamp().Logger().With().Str("device", deviceID).Logger().Output(zerolog.ConsoleWriter{
|
||||
Out: os.Stderr,
|
||||
TimeFormat: "15:04:05",
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
// Poll will block forever, repeatedly calling v2 sync. Do this in a goroutine.
|
||||
// Returns if the access token gets invalidated. Invokes the callback on first success.
|
||||
func (p *Poller) Poll(since string, callback func()) {
|
||||
p.logger.Info().Msg("v2 poll loop started")
|
||||
failCount := 0
|
||||
firstTime := true
|
||||
for {
|
||||
if failCount > 0 {
|
||||
waitTime := time.Duration(math.Pow(2, float64(failCount))) * time.Second
|
||||
p.logger.Warn().Str("duration", waitTime.String()).Msg("waiting before next poll")
|
||||
time.Sleep(waitTime)
|
||||
}
|
||||
p.logger.Info().Str("since", since).Msg("requesting data")
|
||||
resp, statusCode, err := p.Client.DoSyncV2(p.AuthorizationHeader, since)
|
||||
if err != nil {
|
||||
// check if temporary
|
||||
if statusCode != 401 {
|
||||
p.logger.Warn().Int("code", statusCode).Err(err).Msg("sync v2 poll returned temporary error")
|
||||
failCount += 1
|
||||
continue
|
||||
} else {
|
||||
p.logger.Warn().Msg("access token has been invalidated, terminating loop")
|
||||
p.Terminated = true
|
||||
return
|
||||
}
|
||||
}
|
||||
failCount = 0
|
||||
p.accumulate(resp)
|
||||
since = resp.NextBatch
|
||||
if firstTime {
|
||||
firstTime = false
|
||||
callback()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Poller) accumulate(res *SyncResponse) {
|
||||
if len(res.Rooms.Join) == 0 {
|
||||
return
|
||||
}
|
||||
for roomID, roomData := range res.Rooms.Join {
|
||||
if len(roomData.State.Events) > 0 {
|
||||
err := p.Accumulator.Initialise(roomID, roomData.State.Events)
|
||||
if err != nil {
|
||||
p.logger.Err(err).Str("room_id", roomID).Int("num_state_events", len(roomData.State.Events)).Msg("Accumulator.Initialise failed")
|
||||
}
|
||||
}
|
||||
err := p.Accumulator.Accumulate(roomID, roomData.Timeline.Events)
|
||||
if err != nil {
|
||||
p.logger.Err(err).Str("room_id", roomID).Int("num_timeline_events", len(roomData.Timeline.Events)).Msg("Accumulator.Accumulate failed")
|
||||
}
|
||||
}
|
||||
p.logger.Info().Int("num_rooms", len(res.Rooms.Join)).Msg("accumulated data")
|
||||
}
|
71
v3.go
71
v3.go
@ -8,11 +8,13 @@ import (
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/justinas/alice"
|
||||
"github.com/matrix-org/sync-v3/state"
|
||||
v2 "github.com/matrix-org/sync-v3/v2"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/hlog"
|
||||
)
|
||||
@ -40,7 +42,7 @@ func RunSyncV3Server(destinationServer, bindAddr, postgresDBURI string) {
|
||||
|
||||
// dependency inject all components together
|
||||
sh := &SyncV3Handler{
|
||||
V2: &V2{
|
||||
V2: &v2.Client{
|
||||
Client: &http.Client{
|
||||
Timeout: 120 * time.Second,
|
||||
},
|
||||
@ -48,6 +50,8 @@ func RunSyncV3Server(destinationServer, bindAddr, postgresDBURI string) {
|
||||
},
|
||||
Sessions: NewSessions(postgresDBURI),
|
||||
Accumulator: state.NewAccumulator(postgresDBURI),
|
||||
Pollers: make(map[string]*v2.Poller),
|
||||
pollerMu: &sync.Mutex{},
|
||||
}
|
||||
|
||||
// HTTP path routing
|
||||
@ -63,9 +67,12 @@ func RunSyncV3Server(destinationServer, bindAddr, postgresDBURI string) {
|
||||
}
|
||||
|
||||
type SyncV3Handler struct {
|
||||
V2 *V2
|
||||
V2 *v2.Client
|
||||
Sessions *Sessions
|
||||
Accumulator *state.Accumulator
|
||||
|
||||
pollerMu *sync.Mutex
|
||||
Pollers map[string]*v2.Poller // device_id -> poller
|
||||
}
|
||||
|
||||
func (h *SyncV3Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||
@ -106,48 +113,40 @@ func (h *SyncV3Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||
sincev2 = tokv3.v2token
|
||||
}
|
||||
|
||||
// query for the data
|
||||
v2res, err := h.V2.DoSyncV2(req.Header.Get("Authorization"), sincev2)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("DoSyncV2 failed")
|
||||
w.WriteHeader(502)
|
||||
w.Write(asJSONError(err))
|
||||
return
|
||||
}
|
||||
|
||||
h.accumulate(v2res)
|
||||
// make sure we have a poller for this device
|
||||
h.ensurePolling(req.Header.Get("Authorization"), session.DeviceID, sincev2)
|
||||
|
||||
// return data based on filters
|
||||
|
||||
v3res, err := json.Marshal(v2res)
|
||||
if err != nil {
|
||||
w.WriteHeader(500)
|
||||
w.Write(asJSONError(err))
|
||||
return
|
||||
}
|
||||
w.Header().Set("X-Matrix-Sync-V3", v3token{
|
||||
v2token: v2res.NextBatch,
|
||||
w.WriteHeader(200)
|
||||
w.Write([]byte(v3token{
|
||||
v2token: "v2tokengoeshere",
|
||||
sessionID: session.ID,
|
||||
filterIDs: []string{},
|
||||
}.String())
|
||||
w.WriteHeader(200)
|
||||
w.Write(v3res)
|
||||
}.String()))
|
||||
}
|
||||
|
||||
func (h *SyncV3Handler) accumulate(res *SyncV2Response) {
|
||||
for roomID, roomData := range res.Rooms.Join {
|
||||
if len(roomData.State.Events) > 0 {
|
||||
err := h.Accumulator.Initialise(roomID, roomData.State.Events)
|
||||
if err != nil {
|
||||
log.Err(err).Str("room_id", roomID).Int("num_state_events", len(roomData.State.Events)).Msg("Accumulator.Initialise failed")
|
||||
}
|
||||
}
|
||||
err := h.Accumulator.Accumulate(roomID, roomData.Timeline.Events)
|
||||
if err != nil {
|
||||
log.Err(err).Str("room_id", roomID).Int("num_timeline_events", len(roomData.Timeline.Events)).Msg("Accumulator.Accumulate failed")
|
||||
}
|
||||
// ensurePolling makes sure there is a poller for this device, making one if need be.
|
||||
// Blocks until at least 1 sync is done if and only if the poller was just created.
|
||||
// This ensures that calls to the database will return data.
|
||||
func (h *SyncV3Handler) ensurePolling(authHeader, deviceID, since string) {
|
||||
h.pollerMu.Lock()
|
||||
poller, ok := h.Pollers[deviceID]
|
||||
// either no poller exists or it did but it died
|
||||
if ok && !poller.Terminated {
|
||||
h.pollerMu.Unlock()
|
||||
return
|
||||
}
|
||||
log.Info().Int("num_rooms", len(res.Rooms.Join)).Msg("accumulated data")
|
||||
// replace the poller
|
||||
poller = v2.NewPoller(authHeader, deviceID, h.V2, h.Accumulator)
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go poller.Poll(since, func() {
|
||||
wg.Done()
|
||||
})
|
||||
h.Pollers[deviceID] = poller
|
||||
h.pollerMu.Unlock()
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
type jsonError struct {
|
||||
|
Loading…
x
Reference in New Issue
Block a user