Add v2 polling; chunk event insertion to support Matrix HQ

This commit is contained in:
Kegan Dougal 2021-06-04 13:02:28 +01:00
parent 502274b475
commit 159530bed1
6 changed files with 261 additions and 65 deletions

View File

@ -153,7 +153,7 @@ func (a *Accumulator) Initialise(roomID string, state []json.RawMessage) error {
}
numNew, err := a.eventsTable.Insert(txn, events)
if err != nil {
return err
return fmt.Errorf("failed to insert events: %w", err)
}
if numNew == 0 {
// we don't have a current snapshot for this room but yet no events are new,
@ -171,7 +171,7 @@ func (a *Accumulator) Initialise(roomID string, state []json.RawMessage) error {
}
nids, err := a.eventsTable.SelectNIDsByIDs(txn, eventIDs)
if err != nil {
return err
return fmt.Errorf("failed to select NIDs for inserted events: %w", err)
}
// Make a current snapshot
@ -181,7 +181,7 @@ func (a *Accumulator) Initialise(roomID string, state []json.RawMessage) error {
}
err = a.snapshotTable.Insert(txn, snapshot)
if err != nil {
return err
return fmt.Errorf("failed to insert snapshot: %w", err)
}
// Increment the ref counter
@ -199,14 +199,14 @@ func (a *Accumulator) Initialise(roomID string, state []json.RawMessage) error {
// received from the server.
//
// This function does several things:
// - It ensures all events are persisted in the database. This is shared amongst users.
// - If all events have been stored before, then it short circuits and returns.
// This is because we must have already processed this part of the timeline in order for the event
// to exist in the database, and the sync stream is already linearised for us.
// - Else it creates a new room state snapshot if the timeline contains state events (as this now represents the current state)
// - It checks if there are outstanding references for the previous snapshot, and if not, removes the old snapshot from the database.
// References are made when clients have synced up to a given snapshot (hence may paginate at that point).
// The server itself also holds a ref to the current state, which is then moved to the new current state.
// - It ensures all events are persisted in the database. This is shared amongst users.
// - If all events have been stored before, then it short circuits and returns.
// This is because we must have already processed this part of the timeline in order for the event
// to exist in the database, and the sync stream is already linearised for us.
// - Else it creates a new room state snapshot if the timeline contains state events (as this now represents the current state)
// - It checks if there are outstanding references for the previous snapshot, and if not, removes the old snapshot from the database.
// References are made when clients have synced up to a given snapshot (hence may paginate at that point).
// The server itself also holds a ref to the current state, which is then moved to the new current state.
func (a *Accumulator) Accumulate(roomID string, timeline []json.RawMessage) error {
if len(timeline) == 0 {
return nil

View File

@ -75,13 +75,21 @@ func (t *EventTable) Insert(txn *sqlx.Tx, events []Event) (int, error) {
}
events[i] = ev
}
result, err := txn.NamedExec(`INSERT INTO syncv3_events (event_id, room_id, event)
VALUES (:event_id, :room_id, :event) ON CONFLICT (event_id) DO NOTHING`, events)
if err != nil {
return 0, err
chunks := chunkify(3, 65535, events)
var rowsAffected int64
for _, chunk := range chunks {
result, err := txn.NamedExec(`INSERT INTO syncv3_events (event_id, room_id, event)
VALUES (:event_id, :room_id, :event) ON CONFLICT (event_id) DO NOTHING`, chunk)
if err != nil {
return 0, err
}
ra, err := result.RowsAffected()
if err != nil {
return 0, err
}
rowsAffected += ra
}
ra, err := result.RowsAffected()
return int(ra), err
return int(rowsAffected), nil
}
func (t *EventTable) SelectByNIDs(txn *sqlx.Tx, nids []int64) (events []Event, err error) {
@ -143,3 +151,30 @@ func (t *EventTable) SelectEventsBetween(txn *sqlx.Tx, roomID string, lowerExclu
)
return events, err
}
// chunkify will break up things to be inserted based on the number of params in the statement.
// It is required because postgres has a limit on the number of params in a single statement (65535).
// Inserting events using NamedExec involves 3n params (n=number of events), meaning it's easy to hit
// the limit in rooms like Matrix HQ. This function breaks up the events into chunks which can be
// batch inserted in multiple statements. Without this, you'll see errors like:
// "pq: got 95331 parameters but PostgreSQL only supports 65535 parameters"
func chunkify(numParamsPerStmt, maxParamsPerCall int, entries []Event) [][]Event {
// common case, most things are small
if (len(entries) * numParamsPerStmt) <= maxParamsPerCall {
return [][]Event{
entries,
}
}
var chunks [][]Event
// work out how many events can fit in a chunk
numEntriesPerChunk := (maxParamsPerCall / numParamsPerStmt)
for i := 0; i < len(entries); i += numEntriesPerChunk {
endIndex := i + numEntriesPerChunk
if endIndex > len(entries) {
endIndex = len(entries)
}
chunks = append(chunks, entries[i:endIndex])
}
return chunks
}

View File

@ -249,3 +249,71 @@ func TestEventTableSelectEventsBetween(t *testing.T) {
}
})
}
func TestChunkify(t *testing.T) {
// Make 100 dummy events
events := make([]Event, 100)
for i := 0; i < len(events); i++ {
events[i] = Event{
NID: i,
}
}
testCases := []struct {
name string
numParamsPerStmt int
maxParamsPerCall int
chunkSizes []int // length = number of chunks wanted, ints = events in that chunk
}{
{
name: "below chunk limit returns 1 chunk",
numParamsPerStmt: 3,
maxParamsPerCall: 400,
chunkSizes: []int{100},
},
{
name: "just above chunk limit returns 2 chunks",
numParamsPerStmt: 3,
maxParamsPerCall: 297,
chunkSizes: []int{99, 1},
},
{
name: "way above chunk limit returns many chunks",
numParamsPerStmt: 3,
maxParamsPerCall: 30,
chunkSizes: []int{10, 10, 10, 10, 10, 10, 10, 10, 10, 10},
},
{
name: "fractional division rounds down",
numParamsPerStmt: 3,
maxParamsPerCall: 298,
chunkSizes: []int{99, 1},
},
{
name: "fractional division rounds down",
numParamsPerStmt: 3,
maxParamsPerCall: 299,
chunkSizes: []int{99, 1},
},
}
for _, tc := range testCases {
testCase := tc
t.Run(testCase.name, func(t *testing.T) {
chunks := chunkify(testCase.numParamsPerStmt, testCase.maxParamsPerCall, events)
if len(chunks) != len(testCase.chunkSizes) {
t.Fatalf("got %d chunks, want %d", len(chunks), len(testCase.chunkSizes))
}
eventNID := 0
for i := 0; i < len(chunks); i++ {
if len(chunks[i]) != testCase.chunkSizes[i] {
t.Errorf("chunk %d got %d elements, want %d", i, len(chunks[i]), testCase.chunkSizes[i])
}
for j, ev := range chunks[i] {
if ev.NID != eventNID {
t.Errorf("chunk %d got wrong event in position %d: got NID %d want NID %d", i, j, ev.NID, eventNID)
}
eventNID += 1
}
}
})
}
}

View File

@ -1,4 +1,4 @@
package syncv3
package v2
import (
"encoding/json"
@ -8,40 +8,44 @@ import (
"github.com/matrix-org/gomatrixserverlib"
)
type V2 struct {
// Client represents a Sync v2 Client.
// One client can be shared among many users.
type Client struct {
Client *http.Client
DestinationServer string
}
func (v *V2) DoSyncV2(authHeader, since string) (*SyncV2Response, error) {
qps := ""
// DoSyncV2 performs a sync v2 request. Returns the sync response and the response status code
// or an error
func (v *Client) DoSyncV2(authHeader, since string) (*SyncResponse, int, error) {
qps := "?timeout=30000"
if since != "" {
qps = "?since=" + since
qps += "&since=" + since
}
req, err := http.NewRequest(
"GET", v.DestinationServer+"/_matrix/client/r0/sync"+qps, nil,
)
req.Header.Set("Authorization", authHeader)
if err != nil {
return nil, fmt.Errorf("DoSyncV2: NewRequest failed: %w", err)
return nil, 0, fmt.Errorf("DoSyncV2: NewRequest failed: %w", err)
}
res, err := v.Client.Do(req)
if err != nil {
return nil, fmt.Errorf("DoSyncV2: request failed: %w", err)
return nil, 0, fmt.Errorf("DoSyncV2: request failed: %w", err)
}
switch res.StatusCode {
case 200:
var svr SyncV2Response
var svr SyncResponse
if err := json.NewDecoder(res.Body).Decode(&svr); err != nil {
return nil, fmt.Errorf("DoSyncV2: response body decode JSON failed: %w", err)
return nil, 0, fmt.Errorf("DoSyncV2: response body decode JSON failed: %w", err)
}
return &svr, nil
return &svr, 200, nil
default:
return nil, fmt.Errorf("DoSyncV2: response returned %s", res.Status)
return nil, res.StatusCode, fmt.Errorf("DoSyncV2: response returned %s", res.Status)
}
}
type SyncV2Response struct {
type SyncResponse struct {
NextBatch string `json:"next_batch"`
AccountData struct {
Events []gomatrixserverlib.ClientEvent `json:"events,omitempty"`

90
v2/poller.go Normal file
View File

@ -0,0 +1,90 @@
package v2
import (
"math"
"os"
"time"
"github.com/matrix-org/sync-v3/state"
"github.com/rs/zerolog"
)
// Poller can automatically poll the sync v2 endpoint and accumulate the responses in the accumulator
type Poller struct {
AuthorizationHeader string
DeviceID string
Client *Client
Accumulator *state.Accumulator
// flag set to true when poll() returns due to expired access tokens
Terminated bool
logger zerolog.Logger
}
func NewPoller(authHeader, deviceID string, client *Client, accumulator *state.Accumulator) *Poller {
return &Poller{
AuthorizationHeader: authHeader,
DeviceID: deviceID,
Client: client,
Accumulator: accumulator,
Terminated: false,
logger: zerolog.New(os.Stdout).With().Timestamp().Logger().With().Str("device", deviceID).Logger().Output(zerolog.ConsoleWriter{
Out: os.Stderr,
TimeFormat: "15:04:05",
}),
}
}
// Poll will block forever, repeatedly calling v2 sync. Do this in a goroutine.
// Returns if the access token gets invalidated. Invokes the callback on first success.
func (p *Poller) Poll(since string, callback func()) {
p.logger.Info().Msg("v2 poll loop started")
failCount := 0
firstTime := true
for {
if failCount > 0 {
waitTime := time.Duration(math.Pow(2, float64(failCount))) * time.Second
p.logger.Warn().Str("duration", waitTime.String()).Msg("waiting before next poll")
time.Sleep(waitTime)
}
p.logger.Info().Str("since", since).Msg("requesting data")
resp, statusCode, err := p.Client.DoSyncV2(p.AuthorizationHeader, since)
if err != nil {
// check if temporary
if statusCode != 401 {
p.logger.Warn().Int("code", statusCode).Err(err).Msg("sync v2 poll returned temporary error")
failCount += 1
continue
} else {
p.logger.Warn().Msg("access token has been invalidated, terminating loop")
p.Terminated = true
return
}
}
failCount = 0
p.accumulate(resp)
since = resp.NextBatch
if firstTime {
firstTime = false
callback()
}
}
}
func (p *Poller) accumulate(res *SyncResponse) {
if len(res.Rooms.Join) == 0 {
return
}
for roomID, roomData := range res.Rooms.Join {
if len(roomData.State.Events) > 0 {
err := p.Accumulator.Initialise(roomID, roomData.State.Events)
if err != nil {
p.logger.Err(err).Str("room_id", roomID).Int("num_state_events", len(roomData.State.Events)).Msg("Accumulator.Initialise failed")
}
}
err := p.Accumulator.Accumulate(roomID, roomData.Timeline.Events)
if err != nil {
p.logger.Err(err).Str("room_id", roomID).Int("num_timeline_events", len(roomData.Timeline.Events)).Msg("Accumulator.Accumulate failed")
}
}
p.logger.Info().Int("num_rooms", len(res.Rooms.Join)).Msg("accumulated data")
}

71
v3.go
View File

@ -8,11 +8,13 @@ import (
"net/http"
"os"
"strings"
"sync"
"time"
"github.com/gorilla/mux"
"github.com/justinas/alice"
"github.com/matrix-org/sync-v3/state"
v2 "github.com/matrix-org/sync-v3/v2"
"github.com/rs/zerolog"
"github.com/rs/zerolog/hlog"
)
@ -40,7 +42,7 @@ func RunSyncV3Server(destinationServer, bindAddr, postgresDBURI string) {
// dependency inject all components together
sh := &SyncV3Handler{
V2: &V2{
V2: &v2.Client{
Client: &http.Client{
Timeout: 120 * time.Second,
},
@ -48,6 +50,8 @@ func RunSyncV3Server(destinationServer, bindAddr, postgresDBURI string) {
},
Sessions: NewSessions(postgresDBURI),
Accumulator: state.NewAccumulator(postgresDBURI),
Pollers: make(map[string]*v2.Poller),
pollerMu: &sync.Mutex{},
}
// HTTP path routing
@ -63,9 +67,12 @@ func RunSyncV3Server(destinationServer, bindAddr, postgresDBURI string) {
}
type SyncV3Handler struct {
V2 *V2
V2 *v2.Client
Sessions *Sessions
Accumulator *state.Accumulator
pollerMu *sync.Mutex
Pollers map[string]*v2.Poller // device_id -> poller
}
func (h *SyncV3Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
@ -106,48 +113,40 @@ func (h *SyncV3Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
sincev2 = tokv3.v2token
}
// query for the data
v2res, err := h.V2.DoSyncV2(req.Header.Get("Authorization"), sincev2)
if err != nil {
log.Warn().Err(err).Msg("DoSyncV2 failed")
w.WriteHeader(502)
w.Write(asJSONError(err))
return
}
h.accumulate(v2res)
// make sure we have a poller for this device
h.ensurePolling(req.Header.Get("Authorization"), session.DeviceID, sincev2)
// return data based on filters
v3res, err := json.Marshal(v2res)
if err != nil {
w.WriteHeader(500)
w.Write(asJSONError(err))
return
}
w.Header().Set("X-Matrix-Sync-V3", v3token{
v2token: v2res.NextBatch,
w.WriteHeader(200)
w.Write([]byte(v3token{
v2token: "v2tokengoeshere",
sessionID: session.ID,
filterIDs: []string{},
}.String())
w.WriteHeader(200)
w.Write(v3res)
}.String()))
}
func (h *SyncV3Handler) accumulate(res *SyncV2Response) {
for roomID, roomData := range res.Rooms.Join {
if len(roomData.State.Events) > 0 {
err := h.Accumulator.Initialise(roomID, roomData.State.Events)
if err != nil {
log.Err(err).Str("room_id", roomID).Int("num_state_events", len(roomData.State.Events)).Msg("Accumulator.Initialise failed")
}
}
err := h.Accumulator.Accumulate(roomID, roomData.Timeline.Events)
if err != nil {
log.Err(err).Str("room_id", roomID).Int("num_timeline_events", len(roomData.Timeline.Events)).Msg("Accumulator.Accumulate failed")
}
// ensurePolling makes sure there is a poller for this device, making one if need be.
// Blocks until at least 1 sync is done if and only if the poller was just created.
// This ensures that calls to the database will return data.
func (h *SyncV3Handler) ensurePolling(authHeader, deviceID, since string) {
h.pollerMu.Lock()
poller, ok := h.Pollers[deviceID]
// either no poller exists or it did but it died
if ok && !poller.Terminated {
h.pollerMu.Unlock()
return
}
log.Info().Int("num_rooms", len(res.Rooms.Join)).Msg("accumulated data")
// replace the poller
poller = v2.NewPoller(authHeader, deviceID, h.V2, h.Accumulator)
var wg sync.WaitGroup
wg.Add(1)
go poller.Poll(since, func() {
wg.Done()
})
h.Pollers[deviceID] = poller
h.pollerMu.Unlock()
wg.Wait()
}
type jsonError struct {