Use a CURSOR

This commit is contained in:
Kegan Dougal 2024-05-17 14:48:08 +01:00
parent b383ed0d82
commit b6f2f9d273
2 changed files with 127 additions and 33 deletions

View File

@ -3,11 +3,13 @@ package migrations
import ( import (
"context" "context"
"database/sql" "database/sql"
"fmt"
"strings"
"time" "time"
"github.com/fxamacker/cbor/v2" "github.com/fxamacker/cbor/v2"
"github.com/lib/pq"
"github.com/matrix-org/sliding-sync/internal" "github.com/matrix-org/sliding-sync/internal"
"github.com/matrix-org/sliding-sync/sqlutil"
"github.com/matrix-org/sliding-sync/state" "github.com/matrix-org/sliding-sync/state"
"github.com/pressly/goose/v3" "github.com/pressly/goose/v3"
) )
@ -52,13 +54,7 @@ func upDeviceListTable(ctx context.Context, tx *sql.Tx) error {
target_state SMALLINT NOT NULL, target_state SMALLINT NOT NULL,
bucket SMALLINT NOT NULL, bucket SMALLINT NOT NULL,
UNIQUE(user_id, device_id, target_user_id, bucket) UNIQUE(user_id, device_id, target_user_id, bucket)
); );`)
-- make an index so selecting all the rows is faster
CREATE INDEX IF NOT EXISTS syncv3_device_list_updates_bucket_idx ON syncv3_device_list_updates(user_id, device_id, bucket);
-- Set the fillfactor to 90%, to allow for HOT updates (e.g. we only
-- change the data, not anything indexed like the id)
ALTER TABLE syncv3_device_list_updates SET (fillfactor = 90);
`)
if err != nil { if err != nil {
return err return err
} }
@ -69,12 +65,12 @@ func upDeviceListTable(ctx context.Context, tx *sql.Tx) error {
} }
logger.Info().Int("count", count).Msg("transferring device list data for devices") logger.Info().Int("count", count).Msg("transferring device list data for devices")
// scan for existing CBOR (streaming as the CBOR can be large) and for each row: // scan for existing CBOR (streaming as the CBOR with cursors as it can be large)
rows, err := tx.Query(`SELECT user_id, device_id, data FROM syncv3_device_data`) _, err = tx.Exec(`DECLARE device_data_migration_cursor CURSOR FOR SELECT user_id, device_id, data FROM syncv3_device_data`)
if err != nil { if err != nil {
return err return err
} }
defer rows.Close() defer tx.Exec("CLOSE device_data_migration_cursor")
var userID string var userID string
var deviceID string var deviceID string
var data []byte var data []byte
@ -82,42 +78,73 @@ func upDeviceListTable(ctx context.Context, tx *sql.Tx) error {
updateFrequency := time.Second * 2 updateFrequency := time.Second * 2
lastUpdate := time.Now() lastUpdate := time.Now()
i := 0 i := 0
for rows.Next() { for {
// logging
i++ i++
if time.Since(lastUpdate) > updateFrequency { if time.Since(lastUpdate) > updateFrequency {
logger.Info().Msgf("%d/%d process device list data", i, count) logger.Info().Msgf("%d/%d process device list data", i, count)
lastUpdate = time.Now() lastUpdate = time.Now()
} }
// * deserialise the CBOR
if err := rows.Scan(&userID, &deviceID, &data); err != nil { if err := tx.QueryRow(
`FETCH NEXT FROM device_data_migration_cursor`,
).Scan(&userID, &deviceID, &data); err != nil {
if err == sql.ErrNoRows {
// End of rows.
break
}
return err return err
} }
// * deserialise the CBOR
result, err := deserialiseCBOR(data) result, err := deserialiseCBOR(data)
if err != nil { if err != nil {
return err return err
} }
// * transfer the device lists to the new device lists table // * transfer the device lists to the new device lists table
// uses a bulk copy that lib/pq supports var deviceListRows []state.DeviceListRow
stmt, err := tx.Prepare(pq.CopyIn("syncv3_device_list_updates", "user_id", "device_id", "target_user_id", "target_state", "bucket"))
if err != nil {
return err
}
for targetUser, targetState := range result.DeviceLists.New { for targetUser, targetState := range result.DeviceLists.New {
if _, err := stmt.Exec(userID, deviceID, targetUser, targetState, state.BucketNew); err != nil { deviceListRows = append(deviceListRows, state.DeviceListRow{
return err UserID: userID,
} DeviceID: deviceID,
TargetUserID: targetUser,
TargetState: targetState,
Bucket: state.BucketNew,
})
} }
for targetUser, targetState := range result.DeviceLists.Sent { for targetUser, targetState := range result.DeviceLists.Sent {
if _, err := stmt.Exec(userID, deviceID, targetUser, targetState, state.BucketSent); err != nil { deviceListRows = append(deviceListRows, state.DeviceListRow{
return err UserID: userID,
DeviceID: deviceID,
TargetUserID: targetUser,
TargetState: targetState,
Bucket: state.BucketSent,
})
}
chunks := sqlutil.Chunkify(5, state.MaxPostgresParameters, state.DeviceListChunker(deviceListRows))
for _, chunk := range chunks {
var placeholders []string
var vals []interface{}
listChunk := chunk.(state.DeviceListChunker)
for i, deviceListRow := range listChunk {
placeholders = append(placeholders, fmt.Sprintf("($%d,$%d,$%d,$%d,$%d)",
i*5+1,
i*5+2,
i*5+3,
i*5+4,
i*5+5,
))
vals = append(vals, deviceListRow.UserID, deviceListRow.DeviceID, deviceListRow.TargetUserID, deviceListRow.TargetState, deviceListRow.Bucket)
}
query := fmt.Sprintf(
`INSERT INTO syncv3_device_list_updates(user_id, device_id, target_user_id, target_state, bucket) VALUES %s`,
strings.Join(placeholders, ","),
)
_, err = tx.ExecContext(ctx, query, vals...)
if err != nil {
return fmt.Errorf("failed to bulk insert: %s", err)
} }
}
if _, err = stmt.Exec(); err != nil {
return err
}
if err = stmt.Close(); err != nil {
return err
} }
// * delete the device lists from the CBOR and update // * delete the device lists from the CBOR and update
@ -129,12 +156,12 @@ func upDeviceListTable(ctx context.Context, tx *sql.Tx) error {
if err != nil { if err != nil {
return err return err
} }
_, err = tx.Exec(`UPDATE syncv3_device_data SET data=$1 WHERE user_id=$2 AND device_id=$3`, data, userID, deviceID) _, err = tx.ExecContext(ctx, `UPDATE syncv3_device_data SET data=$1 WHERE user_id=$2 AND device_id=$3`, data, userID, deviceID)
if err != nil { if err != nil {
return err return err
} }
} }
return rows.Err() return nil
} }
func downDeviceListTable(ctx context.Context, tx *sql.Tx) error { func downDeviceListTable(ctx context.Context, tx *sql.Tx) error {

View File

@ -1,7 +1,74 @@
package migrations package migrations
import "testing" import (
"context"
"testing"
"github.com/fxamacker/cbor/v2"
)
func TestDeviceListTableMigration(t *testing.T) { func TestDeviceListTableMigration(t *testing.T) {
ctx := context.Background()
db, close := connectToDB(t)
defer close()
// Create the table in the old format (data = JSONB instead of BYTEA)
// and insert some data: we'll make sure that this data is preserved
// after migrating.
_, err := db.Exec(`CREATE TABLE IF NOT EXISTS syncv3_device_data (
user_id TEXT NOT NULL,
device_id TEXT NOT NULL,
data BYTEA NOT NULL,
UNIQUE(user_id, device_id)
);`)
if err != nil {
t.Fatalf("failed to create table: %s", err)
}
// insert old data
rowData := []OldDeviceData{
{
DeviceLists: OldDeviceLists{
New: map[string]int{"@bob:localhost": 2},
Sent: map[string]int{},
},
ChangedBits: 2,
OTKCounts: map[string]int{"bar": 42},
FallbackKeyTypes: []string{"narp"},
DeviceID: "ALICE",
UserID: "@alice:localhost",
},
{
DeviceLists: OldDeviceLists{
New: map[string]int{"@💣:localhost": 1, "@bomb:localhost": 2},
Sent: map[string]int{"@sent:localhost": 1},
},
OTKCounts: map[string]int{"foo": 100},
FallbackKeyTypes: []string{"yep"},
DeviceID: "BOB",
UserID: "@bob:localhost",
},
}
for _, data := range rowData {
blob, err := cbor.Marshal(data)
if err != nil {
t.Fatal(err)
}
_, err = db.ExecContext(ctx, `INSERT INTO syncv3_device_data (user_id, device_id, data) VALUES ($1, $2, $3)`, data.UserID, data.DeviceID, blob)
if err != nil {
t.Fatal(err)
}
}
// now migrate and ensure we didn't lose any data
tx, err := db.Begin()
if err != nil {
t.Fatal(err)
}
err = upDeviceListTable(ctx, tx)
if err != nil {
t.Fatal(err)
}
tx.Commit()
} }