state: rewrite SELECT ... IN to be SELECT ... ANY

Using ANY allows us to give a single parameter containing many many
entries, which bypasses the postgres parameter limit of 65535. Without
this, large rooms like Matrix HQ which have current state >65535 events
will not be stored correctly.

Add torture test to the events table to assert that we can query >65535
events.
This commit is contained in:
Kegan Dougal 2022-02-22 17:38:18 +00:00
parent 064095c899
commit 10f94336ba
5 changed files with 73 additions and 32 deletions

View File

@ -50,7 +50,7 @@ func (t *AccountDataTable) Insert(txn *sqlx.Tx, accDatas []AccountData) ([]Accou
for _, ad := range keys {
dedupedAccountData = append(dedupedAccountData, *ad)
}
chunks := sqlutil.Chunkify(4, 65535, AccountDataChunker(dedupedAccountData))
chunks := sqlutil.Chunkify(4, MaxPostgresParameters, AccountDataChunker(dedupedAccountData))
for _, chunk := range chunks {
_, err := txn.NamedExec(`
INSERT INTO syncv3_account_data (user_id, room_id, type, data)

View File

@ -6,6 +6,7 @@ import (
"math"
"github.com/jmoiron/sqlx"
"github.com/lib/pq"
"github.com/matrix-org/sync-v3/internal"
"github.com/matrix-org/sync-v3/sqlutil"
"github.com/tidwall/gjson"
@ -135,7 +136,7 @@ func (t *EventTable) Insert(txn *sqlx.Tx, events []Event) (int, error) {
events[i] = ev
}
chunks := sqlutil.Chunkify(6, 65535, EventChunker(events))
chunks := sqlutil.Chunkify(6, MaxPostgresParameters, EventChunker(events))
var rowsAffected int64
for _, chunk := range chunks {
result, err := txn.NamedExec(`
@ -155,26 +156,15 @@ func (t *EventTable) Insert(txn *sqlx.Tx, events []Event) (int, error) {
// select events in a list of nids or ids, depending on the query. Provides flexibility to query on NID or ID, as well as
// the ability to pull stripped events or normal events
func (t *EventTable) selectIn(txn *sqlx.Tx, numWanted int, queryStr string, args ...interface{}) (events []Event, err error) {
query, args, err := sqlx.In(
queryStr, args...,
)
func (t *EventTable) selectAny(txn *sqlx.Tx, numWanted int, queryStr string, pqArray interface{}) (events []Event, err error) {
if txn != nil {
query = txn.Rebind(query)
if err != nil {
return nil, err
}
err = txn.Select(&events, query, args...)
err = txn.Select(&events, queryStr, pqArray)
} else {
query = t.db.Rebind(query)
if err != nil {
return nil, err
}
err = t.db.Select(&events, query, args...)
err = t.db.Select(&events, queryStr, pqArray)
}
if numWanted > 0 {
if numWanted != len(events) {
return nil, fmt.Errorf("Events table query %s got %d events wanted %d", queryStr, len(events), numWanted)
return nil, fmt.Errorf("Events table query %s got %d events wanted %d. err=%s", queryStr, len(events), numWanted, err)
}
}
return
@ -185,9 +175,9 @@ func (t *EventTable) SelectByNIDs(txn *sqlx.Tx, verifyAll bool, nids []int64) (e
if verifyAll {
wanted = len(nids)
}
return t.selectIn(txn, wanted, `
return t.selectAny(txn, wanted, `
SELECT event_nid, event_id, event, event_type, state_key, room_id, before_state_snapshot_id, membership FROM syncv3_events
WHERE event_nid IN (?) ORDER BY event_nid ASC;`, nids)
WHERE event_nid = ANY ($1) ORDER BY event_nid ASC;`, pq.Int64Array(nids))
}
func (t *EventTable) SelectByIDs(txn *sqlx.Tx, verifyAll bool, ids []string) (events []Event, err error) {
@ -195,18 +185,15 @@ func (t *EventTable) SelectByIDs(txn *sqlx.Tx, verifyAll bool, ids []string) (ev
if verifyAll {
wanted = len(ids)
}
return t.selectIn(txn, wanted, `
return t.selectAny(txn, wanted, `
SELECT event_nid, event_id, event, event_type, state_key, room_id, before_state_snapshot_id, membership FROM syncv3_events
WHERE event_id IN (?) ORDER BY event_nid ASC;`, ids)
WHERE event_id = ANY ($1) ORDER BY event_nid ASC;`, pq.StringArray(ids))
}
func (t *EventTable) SelectNIDsByIDs(txn *sqlx.Tx, ids []string) (nids []int64, err error) {
query, args, err := sqlx.In("SELECT event_nid FROM syncv3_events WHERE event_id IN (?) ORDER BY event_nid ASC;", ids)
query = txn.Rebind(query)
if err != nil {
return nil, err
}
err = txn.Select(&nids, query, args...)
// Select NIDs using a single parameter which is a string array
// https://stackoverflow.com/questions/52712022/what-is-the-most-performant-way-to-rewrite-a-large-in-clause
err = txn.Select(&nids, "SELECT event_nid FROM syncv3_events WHERE event_id = ANY ($1) ORDER BY event_nid ASC;", pq.StringArray(ids))
return
}
@ -216,9 +203,9 @@ func (t *EventTable) SelectStrippedEventsByNIDs(txn *sqlx.Tx, verifyAll bool, ni
wanted = len(nids)
}
// don't include the 'event' column
return t.selectIn(txn, wanted, `
return t.selectAny(txn, wanted, `
SELECT event_nid, event_id, event_type, state_key, room_id, before_state_snapshot_id FROM syncv3_events
WHERE event_nid IN (?) ORDER BY event_nid ASC;`, nids)
WHERE event_nid = ANY ($1) ORDER BY event_nid ASC;`, pq.Int64Array(nids))
}
func (t *EventTable) SelectStrippedEventsByIDs(txn *sqlx.Tx, verifyAll bool, ids []string) (StrippedEvents, error) {
@ -227,9 +214,9 @@ func (t *EventTable) SelectStrippedEventsByIDs(txn *sqlx.Tx, verifyAll bool, ids
wanted = len(ids)
}
// don't include the 'event' column
return t.selectIn(txn, wanted, `
return t.selectAny(txn, wanted, `
SELECT event_nid, event_id, event_type, state_key, room_id, before_state_snapshot_id FROM syncv3_events
WHERE event_id IN (?) ORDER BY event_nid ASC;`, ids)
WHERE event_id = ANY ($1) ORDER BY event_nid ASC;`, pq.StringArray(ids))
}

View File

@ -2,6 +2,7 @@ package state
import (
"bytes"
"fmt"
"testing"
"github.com/jmoiron/sqlx"
@ -642,3 +643,53 @@ func TestEventTableSelectEventsWithTypeStateKey(t *testing.T) {
t.Fatalf("SelectEventsWithTypeStateKeyInRooms missed rooms: %v", wantRooms)
}
}
// Do a massive insert/select for event IDs (greater than postgres limit) and ensure it works.
func TestTortureEventTable(t *testing.T) {
db, err := sqlx.Open("postgres", postgresConnectionString)
if err != nil {
t.Fatalf("failed to open SQL db: %s", err)
}
txn, err := db.Beginx()
if err != nil {
t.Fatalf("failed to start txn: %s", err)
}
roomID := "!0:localhost"
table := NewEventTable(db)
// Insert a ton of events
events := make([]Event, 10+MaxPostgresParameters)
eventIDs := make([]string, len(events))
for i := 0; i < len(events); i++ {
events[i] = Event{
ID: fmt.Sprintf("$%d", i),
Type: "my_type",
RoomID: roomID,
JSON: []byte(fmt.Sprintf(`{"type":"my_type","content":{"data":%d}}`, i)),
}
eventIDs[i] = events[i].ID
}
n, err := table.Insert(txn, events)
if err != nil {
t.Fatalf("failed to insert %d events: %s", len(events), err)
}
if n != len(events) {
t.Fatalf("only inserted %d/%d events", n, len(events))
}
if err = txn.Commit(); err != nil {
t.Fatalf("failed to commit insert")
}
// Now do a massive select
txn, err = db.Beginx()
if err != nil {
t.Fatalf("failed to start txn: %s", err)
}
nids, err := table.SelectNIDsByIDs(txn, eventIDs)
if err != nil {
t.Fatalf("SelectNIDsByIDs: %s", err)
}
if len(nids) != len(eventIDs) {
t.Fatalf("failed to retrieve nids for ids, got %d/%d", len(nids), len(eventIDs))
}
}

View File

@ -11,6 +11,9 @@ import (
"github.com/tidwall/gjson"
)
// Max number of parameters in a single SQL command
const MaxPostgresParameters = 65535
type Storage struct {
accumulator *Accumulator
EventsTable *EventTable

View File

@ -94,7 +94,7 @@ func (t *ToDeviceTable) InsertMessages(deviceID string, msgs []json.RawMessage)
}
}
chunks := sqlutil.Chunkify(4, 65535, ToDeviceRowChunker(rows))
chunks := sqlutil.Chunkify(4, MaxPostgresParameters, ToDeviceRowChunker(rows))
for _, chunk := range chunks {
result, err := t.db.NamedQuery(`INSERT INTO syncv3_to_device_messages (device_id, message, event_type, sender)
VALUES (:device_id, :message, :event_type, :sender) RETURNING position`, chunk)