Use a CURSOR

This commit is contained in:
Kegan Dougal 2024-05-17 14:48:08 +01:00
parent b383ed0d82
commit b6f2f9d273
2 changed files with 127 additions and 33 deletions

View File

@ -3,11 +3,13 @@ package migrations
import (
"context"
"database/sql"
"fmt"
"strings"
"time"
"github.com/fxamacker/cbor/v2"
"github.com/lib/pq"
"github.com/matrix-org/sliding-sync/internal"
"github.com/matrix-org/sliding-sync/sqlutil"
"github.com/matrix-org/sliding-sync/state"
"github.com/pressly/goose/v3"
)
@ -52,13 +54,7 @@ func upDeviceListTable(ctx context.Context, tx *sql.Tx) error {
target_state SMALLINT NOT NULL,
bucket SMALLINT NOT NULL,
UNIQUE(user_id, device_id, target_user_id, bucket)
);
-- make an index so selecting all the rows is faster
CREATE INDEX IF NOT EXISTS syncv3_device_list_updates_bucket_idx ON syncv3_device_list_updates(user_id, device_id, bucket);
-- Set the fillfactor to 90%, to allow for HOT updates (e.g. we only
-- change the data, not anything indexed like the id)
ALTER TABLE syncv3_device_list_updates SET (fillfactor = 90);
`)
);`)
if err != nil {
return err
}
@ -69,12 +65,12 @@ func upDeviceListTable(ctx context.Context, tx *sql.Tx) error {
}
logger.Info().Int("count", count).Msg("transferring device list data for devices")
// scan for existing CBOR (streaming as the CBOR can be large) and for each row:
rows, err := tx.Query(`SELECT user_id, device_id, data FROM syncv3_device_data`)
// scan for existing CBOR (streaming as the CBOR with cursors as it can be large)
_, err = tx.Exec(`DECLARE device_data_migration_cursor CURSOR FOR SELECT user_id, device_id, data FROM syncv3_device_data`)
if err != nil {
return err
}
defer rows.Close()
defer tx.Exec("CLOSE device_data_migration_cursor")
var userID string
var deviceID string
var data []byte
@ -82,42 +78,73 @@ func upDeviceListTable(ctx context.Context, tx *sql.Tx) error {
updateFrequency := time.Second * 2
lastUpdate := time.Now()
i := 0
for rows.Next() {
for {
// logging
i++
if time.Since(lastUpdate) > updateFrequency {
logger.Info().Msgf("%d/%d process device list data", i, count)
lastUpdate = time.Now()
}
// * deserialise the CBOR
if err := rows.Scan(&userID, &deviceID, &data); err != nil {
if err := tx.QueryRow(
`FETCH NEXT FROM device_data_migration_cursor`,
).Scan(&userID, &deviceID, &data); err != nil {
if err == sql.ErrNoRows {
// End of rows.
break
}
return err
}
// * deserialise the CBOR
result, err := deserialiseCBOR(data)
if err != nil {
return err
}
// * transfer the device lists to the new device lists table
// uses a bulk copy that lib/pq supports
stmt, err := tx.Prepare(pq.CopyIn("syncv3_device_list_updates", "user_id", "device_id", "target_user_id", "target_state", "bucket"))
if err != nil {
return err
}
var deviceListRows []state.DeviceListRow
for targetUser, targetState := range result.DeviceLists.New {
if _, err := stmt.Exec(userID, deviceID, targetUser, targetState, state.BucketNew); err != nil {
return err
}
deviceListRows = append(deviceListRows, state.DeviceListRow{
UserID: userID,
DeviceID: deviceID,
TargetUserID: targetUser,
TargetState: targetState,
Bucket: state.BucketNew,
})
}
for targetUser, targetState := range result.DeviceLists.Sent {
if _, err := stmt.Exec(userID, deviceID, targetUser, targetState, state.BucketSent); err != nil {
return err
deviceListRows = append(deviceListRows, state.DeviceListRow{
UserID: userID,
DeviceID: deviceID,
TargetUserID: targetUser,
TargetState: targetState,
Bucket: state.BucketSent,
})
}
chunks := sqlutil.Chunkify(5, state.MaxPostgresParameters, state.DeviceListChunker(deviceListRows))
for _, chunk := range chunks {
var placeholders []string
var vals []interface{}
listChunk := chunk.(state.DeviceListChunker)
for i, deviceListRow := range listChunk {
placeholders = append(placeholders, fmt.Sprintf("($%d,$%d,$%d,$%d,$%d)",
i*5+1,
i*5+2,
i*5+3,
i*5+4,
i*5+5,
))
vals = append(vals, deviceListRow.UserID, deviceListRow.DeviceID, deviceListRow.TargetUserID, deviceListRow.TargetState, deviceListRow.Bucket)
}
query := fmt.Sprintf(
`INSERT INTO syncv3_device_list_updates(user_id, device_id, target_user_id, target_state, bucket) VALUES %s`,
strings.Join(placeholders, ","),
)
_, err = tx.ExecContext(ctx, query, vals...)
if err != nil {
return fmt.Errorf("failed to bulk insert: %s", err)
}
}
if _, err = stmt.Exec(); err != nil {
return err
}
if err = stmt.Close(); err != nil {
return err
}
// * delete the device lists from the CBOR and update
@ -129,12 +156,12 @@ func upDeviceListTable(ctx context.Context, tx *sql.Tx) error {
if err != nil {
return err
}
_, err = tx.Exec(`UPDATE syncv3_device_data SET data=$1 WHERE user_id=$2 AND device_id=$3`, data, userID, deviceID)
_, err = tx.ExecContext(ctx, `UPDATE syncv3_device_data SET data=$1 WHERE user_id=$2 AND device_id=$3`, data, userID, deviceID)
if err != nil {
return err
}
}
return rows.Err()
return nil
}
func downDeviceListTable(ctx context.Context, tx *sql.Tx) error {

View File

@ -1,7 +1,74 @@
package migrations
import "testing"
import (
"context"
"testing"
"github.com/fxamacker/cbor/v2"
)
func TestDeviceListTableMigration(t *testing.T) {
ctx := context.Background()
db, close := connectToDB(t)
defer close()
// Create the table in the old format (data = JSONB instead of BYTEA)
// and insert some data: we'll make sure that this data is preserved
// after migrating.
_, err := db.Exec(`CREATE TABLE IF NOT EXISTS syncv3_device_data (
user_id TEXT NOT NULL,
device_id TEXT NOT NULL,
data BYTEA NOT NULL,
UNIQUE(user_id, device_id)
);`)
if err != nil {
t.Fatalf("failed to create table: %s", err)
}
// insert old data
rowData := []OldDeviceData{
{
DeviceLists: OldDeviceLists{
New: map[string]int{"@bob:localhost": 2},
Sent: map[string]int{},
},
ChangedBits: 2,
OTKCounts: map[string]int{"bar": 42},
FallbackKeyTypes: []string{"narp"},
DeviceID: "ALICE",
UserID: "@alice:localhost",
},
{
DeviceLists: OldDeviceLists{
New: map[string]int{"@💣:localhost": 1, "@bomb:localhost": 2},
Sent: map[string]int{"@sent:localhost": 1},
},
OTKCounts: map[string]int{"foo": 100},
FallbackKeyTypes: []string{"yep"},
DeviceID: "BOB",
UserID: "@bob:localhost",
},
}
for _, data := range rowData {
blob, err := cbor.Marshal(data)
if err != nil {
t.Fatal(err)
}
_, err = db.ExecContext(ctx, `INSERT INTO syncv3_device_data (user_id, device_id, data) VALUES ($1, $2, $3)`, data.UserID, data.DeviceID, blob)
if err != nil {
t.Fatal(err)
}
}
// now migrate and ensure we didn't lose any data
tx, err := db.Begin()
if err != nil {
t.Fatal(err)
}
err = upDeviceListTable(ctx, tx)
if err != nil {
t.Fatal(err)
}
tx.Commit()
}