mirror of
https://github.com/matrix-org/sliding-sync.git
synced 2025-03-10 13:37:11 +00:00
Merge branch 'main' of github.com:matrix-org/sliding-sync into s7evink/httptimeout
This commit is contained in:
commit
317a722286
@ -1,7 +1,6 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
@ -89,7 +88,6 @@ func defaulting(in, dft string) string {
|
||||
}
|
||||
|
||||
func main() {
|
||||
ctx := context.Background()
|
||||
fmt.Printf("Sync v3 [%s] (%s)\n", version, GitCommit)
|
||||
sync2.ProxyVersion = version
|
||||
syncv3.Version = fmt.Sprintf("%s (%s)", version, GitCommit)
|
||||
@ -196,11 +194,6 @@ func main() {
|
||||
}
|
||||
}
|
||||
|
||||
err := sync2.MigrateDeviceIDs(ctx, args[EnvServer], args[EnvDB], args[EnvSecret], true)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
maxConnsInt, err := strconv.Atoi(args[EnvMaxConns])
|
||||
if err != nil {
|
||||
panic("invalid value for " + EnvMaxConns + ": " + args[EnvMaxConns])
|
||||
|
@ -1,32 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/matrix-org/sliding-sync/sync2"
|
||||
"os"
|
||||
)
|
||||
|
||||
const (
|
||||
// Required fields
|
||||
EnvServer = "SYNCV3_SERVER"
|
||||
EnvDB = "SYNCV3_DB"
|
||||
EnvSecret = "SYNCV3_SECRET"
|
||||
|
||||
// Migration test only
|
||||
EnvMigrationCommit = "SYNCV3_TEST_MIGRATION_COMMIT"
|
||||
)
|
||||
|
||||
func main() {
|
||||
ctx := context.Background()
|
||||
args := map[string]string{
|
||||
EnvServer: os.Getenv(EnvServer),
|
||||
EnvDB: os.Getenv(EnvDB),
|
||||
EnvSecret: os.Getenv(EnvSecret),
|
||||
EnvMigrationCommit: os.Getenv(EnvMigrationCommit),
|
||||
}
|
||||
|
||||
err := sync2.MigrateDeviceIDs(ctx, args[EnvServer], args[EnvDB], args[EnvSecret], args[EnvMigrationCommit] != "")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
@ -155,19 +155,21 @@ func (m *RoomMetadata) IsSpace() bool {
|
||||
}
|
||||
|
||||
type Hero struct {
|
||||
ID string
|
||||
Name string
|
||||
Avatar string
|
||||
ID string `json:"user_id"`
|
||||
Name string `json:"displayname,omitempty"`
|
||||
Avatar string `json:"avatar_url,omitempty"`
|
||||
}
|
||||
|
||||
func CalculateRoomName(heroInfo *RoomMetadata, maxNumNamesPerRoom int) string {
|
||||
// CalculateRoomName calculates the room name. Returns the name and if the name was actually calculated
|
||||
// based on room heroes.
|
||||
func CalculateRoomName(heroInfo *RoomMetadata, maxNumNamesPerRoom int) (name string, calculated bool) {
|
||||
// If the room has an m.room.name state event with a non-empty name field, use the name given by that field.
|
||||
if heroInfo.NameEvent != "" {
|
||||
return heroInfo.NameEvent
|
||||
return heroInfo.NameEvent, false
|
||||
}
|
||||
// If the room has an m.room.canonical_alias state event with a valid alias field, use the alias given by that field as the name.
|
||||
if heroInfo.CanonicalAlias != "" {
|
||||
return heroInfo.CanonicalAlias
|
||||
return heroInfo.CanonicalAlias, false
|
||||
}
|
||||
// If none of the above conditions are met, a name should be composed based on the members of the room.
|
||||
disambiguatedNames := disambiguate(heroInfo.Heroes)
|
||||
@ -178,7 +180,7 @@ func CalculateRoomName(heroInfo *RoomMetadata, maxNumNamesPerRoom int) string {
|
||||
// the client should use the rules BELOW to indicate that the room was empty. For example, "Empty Room (was Alice)",
|
||||
// "Empty Room (was Alice and 1234 others)", or "Empty Room" if there are no heroes.
|
||||
if len(heroInfo.Heroes) == 0 && isAlone {
|
||||
return "Empty Room"
|
||||
return "Empty Room", false
|
||||
}
|
||||
|
||||
// If the number of m.heroes for the room are greater or equal to m.joined_member_count + m.invited_member_count - 1,
|
||||
@ -186,13 +188,13 @@ func CalculateRoomName(heroInfo *RoomMetadata, maxNumNamesPerRoom int) string {
|
||||
// and concatenating them.
|
||||
if len(heroInfo.Heroes) >= totalNumOtherUsers {
|
||||
if len(disambiguatedNames) == 1 {
|
||||
return disambiguatedNames[0]
|
||||
return disambiguatedNames[0], true
|
||||
}
|
||||
calculatedRoomName := strings.Join(disambiguatedNames[:len(disambiguatedNames)-1], ", ") + " and " + disambiguatedNames[len(disambiguatedNames)-1]
|
||||
if isAlone {
|
||||
return fmt.Sprintf("Empty Room (was %s)", calculatedRoomName)
|
||||
return fmt.Sprintf("Empty Room (was %s)", calculatedRoomName), true
|
||||
}
|
||||
return calculatedRoomName
|
||||
return calculatedRoomName, true
|
||||
}
|
||||
|
||||
// if we're here then len(heroes) < (joinedCount + invitedCount - 1)
|
||||
@ -208,13 +210,13 @@ func CalculateRoomName(heroInfo *RoomMetadata, maxNumNamesPerRoom int) string {
|
||||
// and m.joined_member_count + m.invited_member_count is greater than 1, the client should use the heroes to calculate
|
||||
// display names for the users (disambiguating them if required) and concatenating them alongside a count of the remaining users.
|
||||
if (heroInfo.JoinCount + heroInfo.InviteCount) > 1 {
|
||||
return calculatedRoomName
|
||||
return calculatedRoomName, true
|
||||
}
|
||||
|
||||
// If m.joined_member_count + m.invited_member_count is less than or equal to 1 (indicating the member is alone),
|
||||
// the client should use the rules above to indicate that the room was empty. For example, "Empty Room (was Alice)",
|
||||
// "Empty Room (was Alice and 1234 others)", or "Empty Room" if there are no heroes.
|
||||
return fmt.Sprintf("Empty Room (was %s)", calculatedRoomName)
|
||||
return fmt.Sprintf("Empty Room (was %s)", calculatedRoomName), true
|
||||
}
|
||||
|
||||
func disambiguate(heroes []Hero) []string {
|
||||
|
@ -11,7 +11,8 @@ func TestCalculateRoomName(t *testing.T) {
|
||||
invitedCount int
|
||||
maxNumNamesPerRoom int
|
||||
|
||||
wantRoomName string
|
||||
wantRoomName string
|
||||
wantCalculated bool
|
||||
}{
|
||||
// Room name takes precedence
|
||||
{
|
||||
@ -65,7 +66,8 @@ func TestCalculateRoomName(t *testing.T) {
|
||||
Name: "Bob",
|
||||
},
|
||||
},
|
||||
wantRoomName: "Alice, Bob and 3 others",
|
||||
wantRoomName: "Alice, Bob and 3 others",
|
||||
wantCalculated: true,
|
||||
},
|
||||
// Small group chat
|
||||
{
|
||||
@ -86,7 +88,8 @@ func TestCalculateRoomName(t *testing.T) {
|
||||
Name: "Charlie",
|
||||
},
|
||||
},
|
||||
wantRoomName: "Alice, Bob and Charlie",
|
||||
wantRoomName: "Alice, Bob and Charlie",
|
||||
wantCalculated: true,
|
||||
},
|
||||
// DM room
|
||||
{
|
||||
@ -99,7 +102,8 @@ func TestCalculateRoomName(t *testing.T) {
|
||||
Name: "Alice",
|
||||
},
|
||||
},
|
||||
wantRoomName: "Alice",
|
||||
wantRoomName: "Alice",
|
||||
wantCalculated: true,
|
||||
},
|
||||
// 3-way room
|
||||
{
|
||||
@ -116,7 +120,8 @@ func TestCalculateRoomName(t *testing.T) {
|
||||
Name: "Bob",
|
||||
},
|
||||
},
|
||||
wantRoomName: "Alice and Bob",
|
||||
wantRoomName: "Alice and Bob",
|
||||
wantCalculated: true,
|
||||
},
|
||||
// 3-way room, one person invited with no display name
|
||||
{
|
||||
@ -132,7 +137,8 @@ func TestCalculateRoomName(t *testing.T) {
|
||||
ID: "@bob:localhost",
|
||||
},
|
||||
},
|
||||
wantRoomName: "Alice and @bob:localhost",
|
||||
wantRoomName: "Alice and @bob:localhost",
|
||||
wantCalculated: true,
|
||||
},
|
||||
// 3-way room, no display names
|
||||
{
|
||||
@ -147,7 +153,8 @@ func TestCalculateRoomName(t *testing.T) {
|
||||
ID: "@bob:localhost",
|
||||
},
|
||||
},
|
||||
wantRoomName: "@alice:localhost and @bob:localhost",
|
||||
wantRoomName: "@alice:localhost and @bob:localhost",
|
||||
wantCalculated: true,
|
||||
},
|
||||
// disambiguation all
|
||||
{
|
||||
@ -168,7 +175,8 @@ func TestCalculateRoomName(t *testing.T) {
|
||||
Name: "Alice",
|
||||
},
|
||||
},
|
||||
wantRoomName: "Alice (@alice:localhost), Alice (@bob:localhost), Alice (@charlie:localhost) and 6 others",
|
||||
wantRoomName: "Alice (@alice:localhost), Alice (@bob:localhost), Alice (@charlie:localhost) and 6 others",
|
||||
wantCalculated: true,
|
||||
},
|
||||
// disambiguation some
|
||||
{
|
||||
@ -189,7 +197,8 @@ func TestCalculateRoomName(t *testing.T) {
|
||||
Name: "Alice",
|
||||
},
|
||||
},
|
||||
wantRoomName: "Alice (@alice:localhost), Bob, Alice (@charlie:localhost) and 6 others",
|
||||
wantRoomName: "Alice (@alice:localhost), Bob, Alice (@charlie:localhost) and 6 others",
|
||||
wantCalculated: true,
|
||||
},
|
||||
// disambiguation, faking user IDs as display names
|
||||
{
|
||||
@ -205,7 +214,8 @@ func TestCalculateRoomName(t *testing.T) {
|
||||
ID: "@alice:localhost",
|
||||
},
|
||||
},
|
||||
wantRoomName: "@alice:localhost (@evil:localhost) and @alice:localhost (@alice:localhost)",
|
||||
wantRoomName: "@alice:localhost (@evil:localhost) and @alice:localhost (@alice:localhost)",
|
||||
wantCalculated: true,
|
||||
},
|
||||
// left room
|
||||
{
|
||||
@ -222,7 +232,8 @@ func TestCalculateRoomName(t *testing.T) {
|
||||
Name: "Bob",
|
||||
},
|
||||
},
|
||||
wantRoomName: "Empty Room (was Alice and Bob)",
|
||||
wantRoomName: "Empty Room (was Alice and Bob)",
|
||||
wantCalculated: true,
|
||||
},
|
||||
// empty room
|
||||
{
|
||||
@ -235,7 +246,7 @@ func TestCalculateRoomName(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
gotName := CalculateRoomName(&RoomMetadata{
|
||||
gotName, gotCalculated := CalculateRoomName(&RoomMetadata{
|
||||
NameEvent: tc.roomName,
|
||||
CanonicalAlias: tc.canonicalAlias,
|
||||
Heroes: tc.heroes,
|
||||
@ -245,6 +256,9 @@ func TestCalculateRoomName(t *testing.T) {
|
||||
if gotName != tc.wantRoomName {
|
||||
t.Errorf("got %s want %s for test case: %+v", gotName, tc.wantRoomName, tc)
|
||||
}
|
||||
if gotCalculated != tc.wantCalculated {
|
||||
t.Errorf("got %v want %v for test case: %+v", gotCalculated, tc.wantCalculated, tc)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1,448 +0,0 @@
|
||||
package sync2
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/getsentry/sentry-go"
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/matrix-org/sliding-sync/sqlutil"
|
||||
)
|
||||
|
||||
// MigrateDeviceIDs performs a one-off DB migration from the old device ids (hash of
|
||||
// access token) to the new device ids (actual device ids from the homeserver). This is
|
||||
// not backwards compatible. If the migration has already taken place, this function is
|
||||
// a no-op.
|
||||
//
|
||||
// This code will be removed in a future version of the proxy.
|
||||
func MigrateDeviceIDs(ctx context.Context, destHomeserver, postgresURI, secret string, commit bool) error {
|
||||
whoamiClient := NewHTTPClient(5*time.Minute, 30*time.Minute, destHomeserver)
|
||||
db, err := sqlx.Open("postgres", postgresURI)
|
||||
if err != nil {
|
||||
sentry.CaptureException(err)
|
||||
logger.Panic().Err(err).Str("uri", postgresURI).Msg("failed to open SQL DB")
|
||||
}
|
||||
|
||||
// Ensure the new table exists.
|
||||
NewTokensTable(db, secret)
|
||||
|
||||
return sqlutil.WithTransaction(db, func(txn *sqlx.Tx) (err error) {
|
||||
migrated, err := isMigrated(txn)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if migrated {
|
||||
logger.Debug().Msg("MigrateDeviceIDs: migration has already taken place")
|
||||
return nil
|
||||
}
|
||||
logger.Info().Msgf("MigrateDeviceIDs: starting (commit=%t)", commit)
|
||||
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
elapsed := time.Since(start)
|
||||
logger.Debug().Msgf("MigrateDeviceIDs: took %s", elapsed)
|
||||
}()
|
||||
err = alterTables(txn)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
err = runMigration(ctx, txn, secret, whoamiClient)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = finish(txn)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
s := NewStore(postgresURI, secret)
|
||||
tokens, err := s.TokensTable.TokenForEachDevice(txn)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
logger.Debug().Msgf("Got %d tokens after migration", len(tokens))
|
||||
|
||||
if !commit {
|
||||
err = fmt.Errorf("MigrateDeviceIDs: migration succeeded without errors, but commit is false - rolling back anyway")
|
||||
} else {
|
||||
logger.Info().Msg("MigrateDeviceIDs: migration succeeded - committing")
|
||||
}
|
||||
return
|
||||
})
|
||||
}
|
||||
|
||||
func isMigrated(txn *sqlx.Tx) (bool, error) {
|
||||
// Keep this dead simple for now. This is a one-off migration, before version 1.0.
|
||||
// In the future we'll rip this out and tell people that it's their job to ensure
|
||||
// this migration has run before they upgrade beyond the rip-out point.
|
||||
|
||||
// We're going to detect if the migration has run by testing for the existence of
|
||||
// a column added by the migration. First, check that the table exists.
|
||||
var tableExists bool
|
||||
err := txn.QueryRow(`
|
||||
SELECT EXISTS(
|
||||
SELECT 1 FROM information_schema.columns
|
||||
WHERE table_name = 'syncv3_txns'
|
||||
);
|
||||
`).Scan(&tableExists)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("isMigrated: %s", err)
|
||||
}
|
||||
if !tableExists {
|
||||
// The proxy has never been run before and its tables have never been created.
|
||||
// We do not need to run the migration.
|
||||
logger.Debug().Msg("isMigrated: no syncv3_txns table, no migration needed")
|
||||
return true, nil
|
||||
}
|
||||
|
||||
var migrated bool
|
||||
err = txn.QueryRow(`
|
||||
SELECT EXISTS(
|
||||
SELECT 1 FROM information_schema.columns
|
||||
WHERE table_name = 'syncv3_txns' AND column_name = 'device_id'
|
||||
);
|
||||
`).Scan(&migrated)
|
||||
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("isMigrated: %s", err)
|
||||
}
|
||||
return migrated, nil
|
||||
}
|
||||
|
||||
func alterTables(txn *sqlx.Tx) (err error) {
|
||||
_, err = txn.Exec(`
|
||||
ALTER TABLE syncv3_sync2_devices
|
||||
DROP CONSTRAINT syncv3_sync2_devices_pkey;
|
||||
`)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
_, err = txn.Exec(`
|
||||
ALTER TABLE syncv3_to_device_messages
|
||||
ADD COLUMN user_id TEXT;
|
||||
`)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
_, err = txn.Exec(`
|
||||
ALTER TABLE syncv3_to_device_ack_pos
|
||||
DROP CONSTRAINT syncv3_to_device_ack_pos_pkey,
|
||||
ADD COLUMN user_id TEXT;
|
||||
`)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
_, err = txn.Exec(`
|
||||
ALTER TABLE syncv3_txns
|
||||
DROP CONSTRAINT syncv3_txns_user_id_event_id_key,
|
||||
ADD COLUMN device_id TEXT;
|
||||
`)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
type oldDevice struct {
|
||||
AccessToken string // not a DB row, but it's convenient to write to here
|
||||
AccessTokenHash string `db:"device_id"`
|
||||
UserID string `db:"user_id"`
|
||||
AccessTokenEncrypted string `db:"v2_token_encrypted"`
|
||||
Since string `db:"since"`
|
||||
}
|
||||
|
||||
func runMigration(ctx context.Context, txn *sqlx.Tx, secret string, whoamiClient Client) error {
|
||||
logger.Info().Msg("Loading old-style devices into memory")
|
||||
var devices []oldDevice
|
||||
err := txn.Select(
|
||||
&devices,
|
||||
`SELECT device_id, user_id, v2_token_encrypted, since FROM syncv3_sync2_devices;`,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("runMigration: failed to select devices: %s", err)
|
||||
}
|
||||
|
||||
logger.Info().Msgf("Got %d devices to migrate", len(devices))
|
||||
|
||||
hasher := sha256.New()
|
||||
hasher.Write([]byte(secret))
|
||||
key := hasher.Sum(nil)
|
||||
|
||||
// This migration runs sequentially, one device at a time. We have found this to be
|
||||
// quick enough in practice.
|
||||
numErrors := 0
|
||||
for i, device := range devices {
|
||||
device.AccessToken, err = decrypt(device.AccessTokenEncrypted, key)
|
||||
if err != nil {
|
||||
return fmt.Errorf("runMigration: failed to decrypt device: %s", err)
|
||||
}
|
||||
userID := device.UserID
|
||||
if userID == "" {
|
||||
userID = "<unknown user>"
|
||||
}
|
||||
logger.Info().Msgf(
|
||||
"%4d/%4d migrating device %s %s",
|
||||
i+1, len(devices), userID, device.AccessTokenHash,
|
||||
)
|
||||
err = migrateDevice(ctx, txn, whoamiClient, &device)
|
||||
if err != nil {
|
||||
logger.Err(err).Msgf("runMigration: failed to migrate device %s", device.AccessTokenHash)
|
||||
numErrors++
|
||||
}
|
||||
}
|
||||
if numErrors > 0 {
|
||||
return fmt.Errorf("runMigration: there were %d failures", numErrors)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func migrateDevice(ctx context.Context, txn *sqlx.Tx, whoamiClient Client, device *oldDevice) (err error) {
|
||||
gotUserID, gotDeviceID, err := whoamiClient.WhoAmI(ctx, device.AccessToken)
|
||||
if err == HTTP401 {
|
||||
userID := device.UserID
|
||||
if userID == "" {
|
||||
userID = "<unknown user>"
|
||||
}
|
||||
logger.Warn().Msgf(
|
||||
"migrateDevice: access token for %s %s has expired. Dropping device and metadata.",
|
||||
userID, device.AccessTokenHash,
|
||||
)
|
||||
return cleanupDevice(txn, device)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Sanity check the user ID from the HS matches our records
|
||||
if gotUserID != device.UserID {
|
||||
return fmt.Errorf(
|
||||
"/whoami response was for the wrong user. Queried for %s, but got response for %s",
|
||||
device.UserID, gotUserID,
|
||||
)
|
||||
}
|
||||
|
||||
err = exec(
|
||||
txn,
|
||||
`INSERT INTO syncv3_sync2_tokens(token_hash, token_encrypted, user_id, device_id, last_seen)
|
||||
VALUES ($1, $2, $3, $4, $5)`,
|
||||
expectOneRowAffected,
|
||||
device.AccessTokenHash, device.AccessTokenEncrypted, gotUserID, gotDeviceID, time.Now(),
|
||||
)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// For these first four tables:
|
||||
// - use the actual device ID instead of the access token hash, and
|
||||
// - ensure a user ID is set.
|
||||
err = exec(
|
||||
txn,
|
||||
`UPDATE syncv3_sync2_devices SET user_id = $1, device_id = $2 WHERE device_id = $3`,
|
||||
expectOneRowAffected,
|
||||
gotUserID, gotDeviceID, device.AccessTokenHash,
|
||||
)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
err = exec(
|
||||
txn,
|
||||
`UPDATE syncv3_to_device_messages SET user_id = $1, device_id = $2 WHERE device_id = $3`,
|
||||
expectAnyNumberOfRowsAffected,
|
||||
gotUserID, gotDeviceID, device.AccessTokenHash,
|
||||
)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
err = exec(
|
||||
txn,
|
||||
`UPDATE syncv3_to_device_ack_pos SET user_id = $1, device_id = $2 WHERE device_id = $3`,
|
||||
expectAtMostOneRowAffected,
|
||||
gotUserID, gotDeviceID, device.AccessTokenHash,
|
||||
)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
err = exec(
|
||||
txn,
|
||||
`UPDATE syncv3_device_data SET user_id = $1, device_id = $2 WHERE device_id = $3`,
|
||||
expectAtMostOneRowAffected,
|
||||
gotUserID, gotDeviceID, device.AccessTokenHash,
|
||||
)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Confusingly, the txns table used to store access token hashes under the user_id
|
||||
// column. Write the actual user ID to the user_id column, and the actual device ID
|
||||
// to the device_id column.
|
||||
err = exec(
|
||||
txn,
|
||||
`UPDATE syncv3_txns SET user_id = $1, device_id = $2 WHERE user_id = $3`,
|
||||
expectAnyNumberOfRowsAffected,
|
||||
gotUserID, gotDeviceID, device.AccessTokenHash,
|
||||
)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func cleanupDevice(txn *sqlx.Tx, device *oldDevice) (err error) {
|
||||
// The homeserver does not recognise this access token. Because we have no
|
||||
// record of the device_id from the homeserver, it will never be possible to
|
||||
// spot that a future refreshed access token belongs to the device we're
|
||||
// handling here. Therefore this device is not useful to the proxy.
|
||||
//
|
||||
// If we leave this device's rows in situ, we may end up with rows in
|
||||
// syncv3_to_device_messages, syncv3_to_device_ack_pos and syncv3_txns which have
|
||||
// null values for the new fields, which will mean we fail to impose the uniqueness
|
||||
// constraints at the end of the migration. Instead, drop those rows.
|
||||
err = exec(
|
||||
txn,
|
||||
`DELETE FROM syncv3_sync2_devices WHERE device_id = $1`,
|
||||
expectOneRowAffected,
|
||||
device.AccessTokenHash,
|
||||
)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
err = exec(
|
||||
txn,
|
||||
`DELETE FROM syncv3_to_device_messages WHERE device_id = $1`,
|
||||
expectAnyNumberOfRowsAffected,
|
||||
device.AccessTokenHash,
|
||||
)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
err = exec(
|
||||
txn,
|
||||
`DELETE FROM syncv3_to_device_ack_pos WHERE device_id = $1`,
|
||||
expectAtMostOneRowAffected,
|
||||
device.AccessTokenHash,
|
||||
)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
err = exec(
|
||||
txn,
|
||||
`DELETE FROM syncv3_device_data WHERE device_id = $1`,
|
||||
expectAtMostOneRowAffected,
|
||||
device.AccessTokenHash,
|
||||
)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
err = exec(
|
||||
txn,
|
||||
`DELETE FROM syncv3_txns WHERE user_id = $1`,
|
||||
expectAnyNumberOfRowsAffected,
|
||||
device.AccessTokenHash,
|
||||
)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func exec(txn *sqlx.Tx, query string, checkRowsAffected func(ra int64) bool, args ...any) error {
|
||||
res, err := txn.Exec(query, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ra, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !checkRowsAffected(ra) {
|
||||
return fmt.Errorf("query \"%s\" unexpectedly affected %d rows", query, ra)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func expectOneRowAffected(ra int64) bool { return ra == 1 }
|
||||
func expectAnyNumberOfRowsAffected(ra int64) bool { return true }
|
||||
func logRowsAffected(msg string) func(ra int64) bool {
|
||||
return func(ra int64) bool {
|
||||
logger.Info().Msgf(msg, ra)
|
||||
return true
|
||||
}
|
||||
}
|
||||
func expectAtMostOneRowAffected(ra int64) bool { return ra == 0 || ra == 1 }
|
||||
|
||||
func finish(txn *sqlx.Tx) (err error) {
|
||||
// OnExpiredToken used to delete from the devices and to-device tables, but not from
|
||||
// the to-device ack pos or the txn tables. Fix this up by deleting orphaned rows.
|
||||
err = exec(
|
||||
txn,
|
||||
`
|
||||
DELETE FROM syncv3_to_device_ack_pos
|
||||
WHERE device_id NOT IN (SELECT device_id FROM syncv3_sync2_devices)
|
||||
;`,
|
||||
logRowsAffected("Deleted %d stale rows from syncv3_to_device_ack_pos"),
|
||||
)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
err = exec(
|
||||
txn,
|
||||
`
|
||||
DELETE FROM syncv3_txns WHERE device_id IS NULL;
|
||||
`,
|
||||
logRowsAffected("Deleted %d stale rows from syncv3_txns"),
|
||||
)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
_, err = txn.Exec(`
|
||||
ALTER TABLE syncv3_sync2_devices
|
||||
DROP COLUMN v2_token_encrypted,
|
||||
ADD PRIMARY KEY (user_id, device_id);
|
||||
`)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
_, err = txn.Exec(`
|
||||
ALTER TABLE syncv3_to_device_messages
|
||||
ALTER COLUMN user_id SET NOT NULL;
|
||||
`)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
_, err = txn.Exec(`
|
||||
ALTER TABLE syncv3_to_device_ack_pos
|
||||
ALTER COLUMN user_id SET NOT NULL,
|
||||
ADD PRIMARY KEY (user_id, device_id);
|
||||
`)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
_, err = txn.Exec(`
|
||||
ALTER TABLE syncv3_txns
|
||||
ALTER COLUMN device_id SET NOT NULL,
|
||||
ADD UNIQUE(user_id, device_id, event_id);
|
||||
`)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
@ -1,50 +0,0 @@
|
||||
package sync2
|
||||
|
||||
import (
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/matrix-org/sliding-sync/sqlutil"
|
||||
"github.com/matrix-org/sliding-sync/state"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestConsideredMigratedOnFirstStartup(t *testing.T) {
|
||||
db, close := connectToDB(t)
|
||||
defer close()
|
||||
var migrated bool
|
||||
err := sqlutil.WithTransaction(db, func(txn *sqlx.Tx) (err error) {
|
||||
// Attempt to make this test independent of others by dropping the table whose
|
||||
// columns we probe.
|
||||
_, err = txn.Exec("DROP TABLE IF EXISTS syncv3_txns;")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
migrated, err = isMigrated(txn)
|
||||
return
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Error calling isMigrated: %s", err)
|
||||
}
|
||||
if !migrated {
|
||||
t.Fatalf("Expected a non-existent DB to be considered migrated, but it was not")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewSchemaIsConsideredMigrated(t *testing.T) {
|
||||
NewStore(postgresConnectionString, "my_secret")
|
||||
state.NewStorage(postgresConnectionString)
|
||||
|
||||
db, close := connectToDB(t)
|
||||
defer close()
|
||||
var migrated bool
|
||||
err := sqlutil.WithTransaction(db, func(txn *sqlx.Tx) (err error) {
|
||||
migrated, err = isMigrated(txn)
|
||||
return
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Error calling isMigrated: %s", err)
|
||||
}
|
||||
if !migrated {
|
||||
t.Fatalf("Expected a new DB to be considered migrated, but it was not")
|
||||
}
|
||||
}
|
@ -664,8 +664,9 @@ func (s *ConnState) getInitialRoomData(ctx context.Context, roomSub sync3.RoomSu
|
||||
}
|
||||
}
|
||||
|
||||
rooms[roomID] = sync3.Room{
|
||||
Name: internal.CalculateRoomName(metadata, 5), // TODO: customisable?
|
||||
roomName, calculated := internal.CalculateRoomName(metadata, 5) // TODO: customisable?
|
||||
room := sync3.Room{
|
||||
Name: roomName,
|
||||
AvatarChange: sync3.NewAvatarChange(internal.CalculateAvatar(metadata)),
|
||||
NotificationCount: int64(userRoomData.NotificationCount),
|
||||
HighlightCount: int64(userRoomData.HighlightCount),
|
||||
@ -679,6 +680,10 @@ func (s *ConnState) getInitialRoomData(ctx context.Context, roomSub sync3.RoomSu
|
||||
PrevBatch: userRoomData.RequestedLatestEvents.PrevBatch,
|
||||
Timestamp: maxTs,
|
||||
}
|
||||
if roomSub.IncludeHeroes() && calculated {
|
||||
room.Heroes = metadata.Heroes
|
||||
}
|
||||
rooms[roomID] = room
|
||||
}
|
||||
|
||||
if rsm.IsLazyLoading() {
|
||||
|
@ -259,7 +259,13 @@ func (s *connStateLive) processLiveUpdate(ctx context.Context, up caches.Update,
|
||||
if delta.RoomNameChanged {
|
||||
metadata := roomUpdate.GlobalRoomMetadata()
|
||||
metadata.RemoveHero(s.userID)
|
||||
thisRoom.Name = internal.CalculateRoomName(metadata, 5) // TODO: customisable?
|
||||
roomName, calculated := internal.CalculateRoomName(metadata, 5) // TODO: customisable?
|
||||
|
||||
thisRoom.Name = roomName
|
||||
|
||||
if calculated && s.shouldIncludeHeroes(roomUpdate.RoomID()) {
|
||||
thisRoom.Heroes = metadata.Heroes
|
||||
}
|
||||
}
|
||||
if delta.RoomAvatarChanged {
|
||||
metadata := roomUpdate.GlobalRoomMetadata()
|
||||
@ -438,3 +444,20 @@ func (s *connStateLive) resort(
|
||||
}
|
||||
return ops, hasUpdates
|
||||
}
|
||||
|
||||
// shouldIncludeHeroes returns whether the given roomID is in a list or direct
|
||||
// subscription which should return heroes.
|
||||
func (s *connStateLive) shouldIncludeHeroes(roomID string) bool {
|
||||
if s.roomSubscriptions[roomID].IncludeHeroes() {
|
||||
return true
|
||||
}
|
||||
roomIDsToLists := s.lists.ListsByVisibleRoomIDs(s.muxedReq.Lists)
|
||||
for _, listKey := range roomIDsToLists[roomID] {
|
||||
// check if this list should include heroes
|
||||
if !s.muxedReq.Lists[listKey].IncludeHeroes() {
|
||||
continue
|
||||
}
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
90
sync3/handler/connstate_live_test.go
Normal file
90
sync3/handler/connstate_live_test.go
Normal file
@ -0,0 +1,90 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/matrix-org/sliding-sync/internal"
|
||||
"github.com/matrix-org/sliding-sync/sync3"
|
||||
)
|
||||
|
||||
func Test_connStateLive_shouldIncludeHeroes(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
list := sync3.NewInternalRequestLists()
|
||||
|
||||
m1 := sync3.RoomConnMetadata{
|
||||
RoomMetadata: internal.RoomMetadata{
|
||||
RoomID: "!abc",
|
||||
},
|
||||
}
|
||||
m2 := sync3.RoomConnMetadata{
|
||||
RoomMetadata: internal.RoomMetadata{
|
||||
RoomID: "!def",
|
||||
},
|
||||
}
|
||||
list.SetRoom(m1)
|
||||
list.SetRoom(m2)
|
||||
|
||||
list.AssignList(ctx, "all_rooms", &sync3.RequestFilters{}, []string{sync3.SortByName}, false)
|
||||
list.AssignList(ctx, "visible_rooms", &sync3.RequestFilters{}, []string{sync3.SortByName}, false)
|
||||
|
||||
boolTrue := true
|
||||
tests := []struct {
|
||||
name string
|
||||
ConnState *ConnState
|
||||
roomID string
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "neither in subscription nor in list",
|
||||
roomID: "!abc",
|
||||
ConnState: &ConnState{
|
||||
muxedReq: &sync3.Request{},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "in room subscription",
|
||||
want: true,
|
||||
roomID: "!abc",
|
||||
ConnState: &ConnState{
|
||||
muxedReq: &sync3.Request{},
|
||||
roomSubscriptions: map[string]sync3.RoomSubscription{
|
||||
"!abc": {
|
||||
Heroes: &boolTrue,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "in list all_rooms",
|
||||
roomID: "!def",
|
||||
want: true,
|
||||
ConnState: &ConnState{
|
||||
muxedReq: &sync3.Request{
|
||||
Lists: map[string]sync3.RequestList{
|
||||
"all_rooms": {
|
||||
SlowGetAllRooms: &boolTrue,
|
||||
RoomSubscription: sync3.RoomSubscription{
|
||||
Heroes: &boolTrue,
|
||||
},
|
||||
},
|
||||
"visible_rooms": {
|
||||
SlowGetAllRooms: &boolTrue,
|
||||
},
|
||||
},
|
||||
},
|
||||
lists: list,
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s := &connStateLive{
|
||||
ConnState: tt.ConnState,
|
||||
}
|
||||
if got := s.shouldIncludeHeroes(tt.roomID); got != tt.want {
|
||||
t.Errorf("shouldIncludeHeroes() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
@ -70,8 +70,9 @@ func (s *InternalRequestLists) SetRoom(r RoomConnMetadata) (delta RoomDelta) {
|
||||
delta.RoomNameChanged = !existing.SameRoomName(&r.RoomMetadata)
|
||||
if delta.RoomNameChanged {
|
||||
// update the canonical name to allow room name sorting to continue to work
|
||||
roomName, _ := internal.CalculateRoomName(&r.RoomMetadata, 5)
|
||||
r.CanonicalisedName = strings.ToLower(
|
||||
strings.Trim(internal.CalculateRoomName(&r.RoomMetadata, 5), "#!():_@"),
|
||||
strings.Trim(roomName, "#!():_@"),
|
||||
)
|
||||
} else {
|
||||
// XXX: during TestConnectionTimeoutNotReset there is some situation where
|
||||
@ -109,8 +110,9 @@ func (s *InternalRequestLists) SetRoom(r RoomConnMetadata) (delta RoomDelta) {
|
||||
}
|
||||
} else {
|
||||
// set the canonical name to allow room name sorting to work
|
||||
roomName, _ := internal.CalculateRoomName(&r.RoomMetadata, 5)
|
||||
r.CanonicalisedName = strings.ToLower(
|
||||
strings.Trim(internal.CalculateRoomName(&r.RoomMetadata, 5), "#!():_@"),
|
||||
strings.Trim(roomName, "#!():_@"),
|
||||
)
|
||||
r.ResolvedAvatarURL = internal.CalculateAvatar(&r.RoomMetadata)
|
||||
// We'll automatically use the LastInterestedEventTimestamps provided by the
|
||||
|
@ -394,12 +394,17 @@ func (r *Request) ApplyDelta(nextReq *Request) (result *Request, delta *RequestD
|
||||
if bumpEventTypes == nil {
|
||||
bumpEventTypes = existingList.BumpEventTypes
|
||||
}
|
||||
heroes := nextList.Heroes
|
||||
if heroes == nil {
|
||||
heroes = existingList.Heroes
|
||||
}
|
||||
|
||||
calculatedLists[listKey] = RequestList{
|
||||
RoomSubscription: RoomSubscription{
|
||||
RequiredState: reqState,
|
||||
TimelineLimit: timelineLimit,
|
||||
IncludeOldRooms: includeOldRooms,
|
||||
Heroes: heroes,
|
||||
},
|
||||
Ranges: rooms,
|
||||
Sort: sort,
|
||||
@ -517,7 +522,8 @@ func (rf *RequestFilters) Include(r *RoomConnMetadata, finder RoomFinder) bool {
|
||||
if rf.IsInvite != nil && *rf.IsInvite != r.IsInvite {
|
||||
return false
|
||||
}
|
||||
if rf.RoomNameFilter != "" && !strings.Contains(strings.ToLower(internal.CalculateRoomName(&r.RoomMetadata, 5)), strings.ToLower(rf.RoomNameFilter)) {
|
||||
roomName, _ := internal.CalculateRoomName(&r.RoomMetadata, 5)
|
||||
if rf.RoomNameFilter != "" && !strings.Contains(strings.ToLower(roomName), strings.ToLower(rf.RoomNameFilter)) {
|
||||
return false
|
||||
}
|
||||
if len(rf.NotTags) > 0 {
|
||||
@ -563,6 +569,7 @@ type RoomSubscription struct {
|
||||
RequiredState [][2]string `json:"required_state"`
|
||||
TimelineLimit int64 `json:"timeline_limit"`
|
||||
IncludeOldRooms *RoomSubscription `json:"include_old_rooms"`
|
||||
Heroes *bool `json:"include_heroes"`
|
||||
}
|
||||
|
||||
func (rs RoomSubscription) RequiredStateChanged(other RoomSubscription) bool {
|
||||
@ -586,6 +593,10 @@ func (rs RoomSubscription) LazyLoadMembers() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (rs RoomSubscription) IncludeHeroes() bool {
|
||||
return rs.Heroes != nil && *rs.Heroes
|
||||
}
|
||||
|
||||
// Combine this subcription with another, returning a union of both as a copy.
|
||||
func (rs RoomSubscription) Combine(other RoomSubscription) RoomSubscription {
|
||||
return rs.combineRecursive(other, true)
|
||||
|
@ -11,6 +11,7 @@ import (
|
||||
type Room struct {
|
||||
Name string `json:"name,omitempty"`
|
||||
AvatarChange AvatarChange `json:"avatar,omitempty"`
|
||||
Heroes []internal.Hero `json:"heroes,omitempty"`
|
||||
RequiredState []json.RawMessage `json:"required_state,omitempty"`
|
||||
Timeline []json.RawMessage `json:"timeline,omitempty"`
|
||||
InviteState []json.RawMessage `json:"invite_state,omitempty"`
|
||||
|
@ -695,6 +695,7 @@ func (c *CSAPI) SlidingSyncUntilMembership(t *testing.T, pos string, roomID stri
|
||||
RoomSubscriptions: map[string]sync3.RoomSubscription{
|
||||
roomID: {
|
||||
TimelineLimit: 10,
|
||||
Heroes: &boolTrue,
|
||||
},
|
||||
},
|
||||
}, func(r *sync3.Response) error {
|
||||
@ -713,6 +714,7 @@ func (c *CSAPI) SlidingSyncUntilMembership(t *testing.T, pos string, roomID stri
|
||||
RoomSubscriptions: map[string]sync3.RoomSubscription{
|
||||
roomID: {
|
||||
TimelineLimit: 10,
|
||||
Heroes: &boolTrue,
|
||||
},
|
||||
},
|
||||
}, func(r *sync3.Response) error {
|
||||
|
@ -7,6 +7,7 @@ import (
|
||||
|
||||
"github.com/matrix-org/sliding-sync/sync3"
|
||||
"github.com/matrix-org/sliding-sync/testutils/m"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestRoomStateTransitions(t *testing.T) {
|
||||
@ -532,6 +533,150 @@ func TestRejectingInviteReturnsOneEvent(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// Test to check that room heroes are returned if the membership changes
|
||||
func TestHeroesOnMembershipChanges(t *testing.T) {
|
||||
alice := registerNewUser(t)
|
||||
bob := registerNewUser(t)
|
||||
charlie := registerNewUser(t)
|
||||
|
||||
t.Run("nameless room uses heroes to calculate roomname", func(t *testing.T) {
|
||||
// create a room without a name, to ensure we calculate the room name based on
|
||||
// room heroes
|
||||
roomID := alice.CreateRoom(t, map[string]interface{}{"preset": "public_chat"})
|
||||
|
||||
bob.JoinRoom(t, roomID, []string{})
|
||||
|
||||
res := alice.SlidingSyncUntilMembership(t, "", roomID, bob, "join")
|
||||
// we expect to see Bob as a hero
|
||||
if c := len(res.Rooms[roomID].Heroes); c > 1 {
|
||||
t.Errorf("expected 1 room hero, got %d", c)
|
||||
}
|
||||
if gotUserID := res.Rooms[roomID].Heroes[0].ID; gotUserID != bob.UserID {
|
||||
t.Errorf("expected userID %q, got %q", gotUserID, bob.UserID)
|
||||
}
|
||||
|
||||
// Now join with Charlie, the heroes and the room name should change
|
||||
charlie.JoinRoom(t, roomID, []string{})
|
||||
res = alice.SlidingSyncUntilMembership(t, res.Pos, roomID, charlie, "join")
|
||||
|
||||
// we expect to see Bob as a hero
|
||||
if c := len(res.Rooms[roomID].Heroes); c > 2 {
|
||||
t.Errorf("expected 2 room hero, got %d", c)
|
||||
}
|
||||
if gotUserID := res.Rooms[roomID].Heroes[0].ID; gotUserID != bob.UserID {
|
||||
t.Errorf("expected userID %q, got %q", gotUserID, bob.UserID)
|
||||
}
|
||||
if gotUserID := res.Rooms[roomID].Heroes[1].ID; gotUserID != charlie.UserID {
|
||||
t.Errorf("expected userID %q, got %q", gotUserID, charlie.UserID)
|
||||
}
|
||||
|
||||
// Send a message, the heroes shouldn't change
|
||||
msgEv := bob.SendEventSynced(t, roomID, Event{
|
||||
Type: "m.room.roomID",
|
||||
Content: map[string]interface{}{"body": "Hello world", "msgtype": "m.text"},
|
||||
})
|
||||
|
||||
res = alice.SlidingSyncUntilEventID(t, res.Pos, roomID, msgEv)
|
||||
if len(res.Rooms[roomID].Heroes) > 0 {
|
||||
t.Errorf("expected no change to room heros")
|
||||
}
|
||||
|
||||
// Now leave with Charlie, only Bob should be in the heroes list
|
||||
charlie.LeaveRoom(t, roomID)
|
||||
res = alice.SlidingSyncUntilMembership(t, res.Pos, roomID, charlie, "leave")
|
||||
if c := len(res.Rooms[roomID].Heroes); c > 1 {
|
||||
t.Errorf("expected 1 room hero, got %d", c)
|
||||
}
|
||||
if gotUserID := res.Rooms[roomID].Heroes[0].ID; gotUserID != bob.UserID {
|
||||
t.Errorf("expected userID %q, got %q", gotUserID, bob.UserID)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("named rooms don't have heroes", func(t *testing.T) {
|
||||
namedRoomID := alice.CreateRoom(t, map[string]interface{}{"preset": "public_chat", "name": "my room without heroes"})
|
||||
// this makes sure that even if bob is joined, we don't return any heroes
|
||||
bob.JoinRoom(t, namedRoomID, []string{})
|
||||
|
||||
res := alice.SlidingSyncUntilMembership(t, "", namedRoomID, bob, "join")
|
||||
if len(res.Rooms[namedRoomID].Heroes) > 0 {
|
||||
t.Errorf("expected no heroes, got %#v", res.Rooms[namedRoomID].Heroes)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("rooms with aliases don't have heroes", func(t *testing.T) {
|
||||
aliasRoomID := alice.CreateRoom(t, map[string]interface{}{"preset": "public_chat"})
|
||||
|
||||
alias := fmt.Sprintf("#%s-%d:%s", t.Name(), time.Now().Unix(), alice.Domain)
|
||||
alice.MustDoFunc(t, "PUT", []string{"_matrix", "client", "v3", "directory", "room", alias},
|
||||
WithJSONBody(t, map[string]any{"room_id": aliasRoomID}),
|
||||
)
|
||||
alice.SetState(t, aliasRoomID, "m.room.canonical_alias", "", map[string]any{
|
||||
"alias": alias,
|
||||
})
|
||||
|
||||
bob.JoinRoom(t, aliasRoomID, []string{})
|
||||
|
||||
res := alice.SlidingSyncUntilMembership(t, "", aliasRoomID, bob, "join")
|
||||
if len(res.Rooms[aliasRoomID].Heroes) > 0 {
|
||||
t.Errorf("expected no heroes, got %#v", res.Rooms[aliasRoomID].Heroes)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("can set heroes=true on room subscriptions", func(t *testing.T) {
|
||||
subRoomID := alice.CreateRoom(t, map[string]interface{}{"preset": "public_chat"})
|
||||
bob.JoinRoom(t, subRoomID, []string{})
|
||||
|
||||
res := alice.SlidingSyncUntilMembership(t, "", subRoomID, bob, "join")
|
||||
if c := len(res.Rooms[subRoomID].Heroes); c > 1 {
|
||||
t.Errorf("expected 1 room hero, got %d", c)
|
||||
}
|
||||
if gotUserID := res.Rooms[subRoomID].Heroes[0].ID; gotUserID != bob.UserID {
|
||||
t.Errorf("expected userID %q, got %q", gotUserID, bob.UserID)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("can set heroes=true in lists", func(t *testing.T) {
|
||||
listRoomID := alice.CreateRoom(t, map[string]interface{}{"preset": "public_chat"})
|
||||
bob.JoinRoom(t, listRoomID, []string{})
|
||||
|
||||
res := alice.SlidingSyncUntil(t, "", sync3.Request{
|
||||
Lists: map[string]sync3.RequestList{
|
||||
"all_rooms": {
|
||||
Ranges: sync3.SliceRanges{{0, 20}},
|
||||
RoomSubscription: sync3.RoomSubscription{
|
||||
Heroes: &boolTrue,
|
||||
TimelineLimit: 10,
|
||||
},
|
||||
},
|
||||
},
|
||||
}, func(response *sync3.Response) error {
|
||||
r, ok := response.Rooms[listRoomID]
|
||||
if !ok {
|
||||
return fmt.Errorf("room %q not in response", listRoomID)
|
||||
}
|
||||
// wait for bob to be joined
|
||||
for _, ev := range r.Timeline {
|
||||
if gjson.GetBytes(ev, "type").Str != "m.room.member" {
|
||||
continue
|
||||
}
|
||||
if gjson.GetBytes(ev, "state_key").Str != bob.UserID {
|
||||
continue
|
||||
}
|
||||
if gjson.GetBytes(ev, "content.membership").Str == "join" {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("%s is not joined to room %q", bob.UserID, listRoomID)
|
||||
})
|
||||
if c := len(res.Rooms[listRoomID].Heroes); c > 1 {
|
||||
t.Errorf("expected 1 room hero, got %d", c)
|
||||
}
|
||||
if gotUserID := res.Rooms[listRoomID].Heroes[0].ID; gotUserID != bob.UserID {
|
||||
t.Errorf("expected userID %q, got %q", gotUserID, bob.UserID)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// test invite/join counts update and are accurate
|
||||
func TestMemberCounts(t *testing.T) {
|
||||
alice := registerNewUser(t)
|
||||
|
Loading…
x
Reference in New Issue
Block a user