mirror of
https://github.com/matrix-org/sliding-sync.git
synced 2025-03-10 13:37:11 +00:00
state: rewrite SELECT ... IN to be SELECT ... ANY
Using ANY allows us to give a single parameter containing many many entries, which bypasses the postgres parameter limit of 65535. Without this, large rooms like Matrix HQ which have current state >65535 events will not be stored correctly. Add torture test to the events table to assert that we can query >65535 events.
This commit is contained in:
parent
064095c899
commit
10f94336ba
@ -50,7 +50,7 @@ func (t *AccountDataTable) Insert(txn *sqlx.Tx, accDatas []AccountData) ([]Accou
|
||||
for _, ad := range keys {
|
||||
dedupedAccountData = append(dedupedAccountData, *ad)
|
||||
}
|
||||
chunks := sqlutil.Chunkify(4, 65535, AccountDataChunker(dedupedAccountData))
|
||||
chunks := sqlutil.Chunkify(4, MaxPostgresParameters, AccountDataChunker(dedupedAccountData))
|
||||
for _, chunk := range chunks {
|
||||
_, err := txn.NamedExec(`
|
||||
INSERT INTO syncv3_account_data (user_id, room_id, type, data)
|
||||
|
@ -6,6 +6,7 @@ import (
|
||||
"math"
|
||||
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/lib/pq"
|
||||
"github.com/matrix-org/sync-v3/internal"
|
||||
"github.com/matrix-org/sync-v3/sqlutil"
|
||||
"github.com/tidwall/gjson"
|
||||
@ -135,7 +136,7 @@ func (t *EventTable) Insert(txn *sqlx.Tx, events []Event) (int, error) {
|
||||
|
||||
events[i] = ev
|
||||
}
|
||||
chunks := sqlutil.Chunkify(6, 65535, EventChunker(events))
|
||||
chunks := sqlutil.Chunkify(6, MaxPostgresParameters, EventChunker(events))
|
||||
var rowsAffected int64
|
||||
for _, chunk := range chunks {
|
||||
result, err := txn.NamedExec(`
|
||||
@ -155,26 +156,15 @@ func (t *EventTable) Insert(txn *sqlx.Tx, events []Event) (int, error) {
|
||||
|
||||
// select events in a list of nids or ids, depending on the query. Provides flexibility to query on NID or ID, as well as
|
||||
// the ability to pull stripped events or normal events
|
||||
func (t *EventTable) selectIn(txn *sqlx.Tx, numWanted int, queryStr string, args ...interface{}) (events []Event, err error) {
|
||||
query, args, err := sqlx.In(
|
||||
queryStr, args...,
|
||||
)
|
||||
func (t *EventTable) selectAny(txn *sqlx.Tx, numWanted int, queryStr string, pqArray interface{}) (events []Event, err error) {
|
||||
if txn != nil {
|
||||
query = txn.Rebind(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = txn.Select(&events, query, args...)
|
||||
err = txn.Select(&events, queryStr, pqArray)
|
||||
} else {
|
||||
query = t.db.Rebind(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = t.db.Select(&events, query, args...)
|
||||
err = t.db.Select(&events, queryStr, pqArray)
|
||||
}
|
||||
if numWanted > 0 {
|
||||
if numWanted != len(events) {
|
||||
return nil, fmt.Errorf("Events table query %s got %d events wanted %d", queryStr, len(events), numWanted)
|
||||
return nil, fmt.Errorf("Events table query %s got %d events wanted %d. err=%s", queryStr, len(events), numWanted, err)
|
||||
}
|
||||
}
|
||||
return
|
||||
@ -185,9 +175,9 @@ func (t *EventTable) SelectByNIDs(txn *sqlx.Tx, verifyAll bool, nids []int64) (e
|
||||
if verifyAll {
|
||||
wanted = len(nids)
|
||||
}
|
||||
return t.selectIn(txn, wanted, `
|
||||
return t.selectAny(txn, wanted, `
|
||||
SELECT event_nid, event_id, event, event_type, state_key, room_id, before_state_snapshot_id, membership FROM syncv3_events
|
||||
WHERE event_nid IN (?) ORDER BY event_nid ASC;`, nids)
|
||||
WHERE event_nid = ANY ($1) ORDER BY event_nid ASC;`, pq.Int64Array(nids))
|
||||
}
|
||||
|
||||
func (t *EventTable) SelectByIDs(txn *sqlx.Tx, verifyAll bool, ids []string) (events []Event, err error) {
|
||||
@ -195,18 +185,15 @@ func (t *EventTable) SelectByIDs(txn *sqlx.Tx, verifyAll bool, ids []string) (ev
|
||||
if verifyAll {
|
||||
wanted = len(ids)
|
||||
}
|
||||
return t.selectIn(txn, wanted, `
|
||||
return t.selectAny(txn, wanted, `
|
||||
SELECT event_nid, event_id, event, event_type, state_key, room_id, before_state_snapshot_id, membership FROM syncv3_events
|
||||
WHERE event_id IN (?) ORDER BY event_nid ASC;`, ids)
|
||||
WHERE event_id = ANY ($1) ORDER BY event_nid ASC;`, pq.StringArray(ids))
|
||||
}
|
||||
|
||||
func (t *EventTable) SelectNIDsByIDs(txn *sqlx.Tx, ids []string) (nids []int64, err error) {
|
||||
query, args, err := sqlx.In("SELECT event_nid FROM syncv3_events WHERE event_id IN (?) ORDER BY event_nid ASC;", ids)
|
||||
query = txn.Rebind(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = txn.Select(&nids, query, args...)
|
||||
// Select NIDs using a single parameter which is a string array
|
||||
// https://stackoverflow.com/questions/52712022/what-is-the-most-performant-way-to-rewrite-a-large-in-clause
|
||||
err = txn.Select(&nids, "SELECT event_nid FROM syncv3_events WHERE event_id = ANY ($1) ORDER BY event_nid ASC;", pq.StringArray(ids))
|
||||
return
|
||||
}
|
||||
|
||||
@ -216,9 +203,9 @@ func (t *EventTable) SelectStrippedEventsByNIDs(txn *sqlx.Tx, verifyAll bool, ni
|
||||
wanted = len(nids)
|
||||
}
|
||||
// don't include the 'event' column
|
||||
return t.selectIn(txn, wanted, `
|
||||
return t.selectAny(txn, wanted, `
|
||||
SELECT event_nid, event_id, event_type, state_key, room_id, before_state_snapshot_id FROM syncv3_events
|
||||
WHERE event_nid IN (?) ORDER BY event_nid ASC;`, nids)
|
||||
WHERE event_nid = ANY ($1) ORDER BY event_nid ASC;`, pq.Int64Array(nids))
|
||||
}
|
||||
|
||||
func (t *EventTable) SelectStrippedEventsByIDs(txn *sqlx.Tx, verifyAll bool, ids []string) (StrippedEvents, error) {
|
||||
@ -227,9 +214,9 @@ func (t *EventTable) SelectStrippedEventsByIDs(txn *sqlx.Tx, verifyAll bool, ids
|
||||
wanted = len(ids)
|
||||
}
|
||||
// don't include the 'event' column
|
||||
return t.selectIn(txn, wanted, `
|
||||
return t.selectAny(txn, wanted, `
|
||||
SELECT event_nid, event_id, event_type, state_key, room_id, before_state_snapshot_id FROM syncv3_events
|
||||
WHERE event_id IN (?) ORDER BY event_nid ASC;`, ids)
|
||||
WHERE event_id = ANY ($1) ORDER BY event_nid ASC;`, pq.StringArray(ids))
|
||||
|
||||
}
|
||||
|
||||
|
@ -2,6 +2,7 @@ package state
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/jmoiron/sqlx"
|
||||
@ -642,3 +643,53 @@ func TestEventTableSelectEventsWithTypeStateKey(t *testing.T) {
|
||||
t.Fatalf("SelectEventsWithTypeStateKeyInRooms missed rooms: %v", wantRooms)
|
||||
}
|
||||
}
|
||||
|
||||
// Do a massive insert/select for event IDs (greater than postgres limit) and ensure it works.
|
||||
func TestTortureEventTable(t *testing.T) {
|
||||
db, err := sqlx.Open("postgres", postgresConnectionString)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to open SQL db: %s", err)
|
||||
}
|
||||
txn, err := db.Beginx()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to start txn: %s", err)
|
||||
}
|
||||
roomID := "!0:localhost"
|
||||
table := NewEventTable(db)
|
||||
// Insert a ton of events
|
||||
events := make([]Event, 10+MaxPostgresParameters)
|
||||
eventIDs := make([]string, len(events))
|
||||
for i := 0; i < len(events); i++ {
|
||||
events[i] = Event{
|
||||
ID: fmt.Sprintf("$%d", i),
|
||||
Type: "my_type",
|
||||
RoomID: roomID,
|
||||
JSON: []byte(fmt.Sprintf(`{"type":"my_type","content":{"data":%d}}`, i)),
|
||||
}
|
||||
eventIDs[i] = events[i].ID
|
||||
}
|
||||
n, err := table.Insert(txn, events)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to insert %d events: %s", len(events), err)
|
||||
}
|
||||
if n != len(events) {
|
||||
t.Fatalf("only inserted %d/%d events", n, len(events))
|
||||
}
|
||||
if err = txn.Commit(); err != nil {
|
||||
t.Fatalf("failed to commit insert")
|
||||
}
|
||||
|
||||
// Now do a massive select
|
||||
txn, err = db.Beginx()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to start txn: %s", err)
|
||||
}
|
||||
nids, err := table.SelectNIDsByIDs(txn, eventIDs)
|
||||
if err != nil {
|
||||
t.Fatalf("SelectNIDsByIDs: %s", err)
|
||||
}
|
||||
if len(nids) != len(eventIDs) {
|
||||
t.Fatalf("failed to retrieve nids for ids, got %d/%d", len(nids), len(eventIDs))
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -11,6 +11,9 @@ import (
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
// Max number of parameters in a single SQL command
|
||||
const MaxPostgresParameters = 65535
|
||||
|
||||
type Storage struct {
|
||||
accumulator *Accumulator
|
||||
EventsTable *EventTable
|
||||
|
@ -94,7 +94,7 @@ func (t *ToDeviceTable) InsertMessages(deviceID string, msgs []json.RawMessage)
|
||||
}
|
||||
}
|
||||
|
||||
chunks := sqlutil.Chunkify(4, 65535, ToDeviceRowChunker(rows))
|
||||
chunks := sqlutil.Chunkify(4, MaxPostgresParameters, ToDeviceRowChunker(rows))
|
||||
for _, chunk := range chunks {
|
||||
result, err := t.db.NamedQuery(`INSERT INTO syncv3_to_device_messages (device_id, message, event_type, sender)
|
||||
VALUES (:device_id, :message, :event_type, :sender) RETURNING position`, chunk)
|
||||
|
Loading…
x
Reference in New Issue
Block a user