Merge remote-tracking branch 'origin/main' into dmr/extension-scoping-fix

This commit is contained in:
David Robertson 2023-08-15 17:41:56 +01:00
commit a51230d852
No known key found for this signature in database
GPG Key ID: 903ECE108A39DEDD
86 changed files with 6569 additions and 1150 deletions

View File

@ -26,7 +26,7 @@ jobs:
- name: Checkout
uses: actions/checkout@v3
- name: Set up QEMU
uses: docker/setup-qemu-action@v1
uses: docker/setup-qemu-action@v2
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v2
- name: Login to GitHub Containers

View File

@ -59,7 +59,7 @@ jobs:
- name: Test
run: |
set -euo pipefail
go test -count=1 -covermode=atomic -coverpkg ./... -p 1 -v -json $(go list ./... | grep -v tests-e2e) -coverprofile synccoverage.out 2>&1 | tee ./test-integration.log | gotestfmt
go test -count=1 -covermode=atomic -coverpkg ./... -p 1 -v -json $(go list ./... | grep -v tests-e2e) -coverprofile synccoverage.out 2>&1 | tee ./test-integration.log | gotestfmt -hide all
shell: bash
env:
POSTGRES_HOST: localhost
@ -88,6 +88,10 @@ jobs:
if-no-files-found: error
end_to_end:
runs-on: ubuntu-latest
strategy:
matrix:
# test with unlimited + 1 + 2 max db conns. If we end up double transacting in the tests anywhere, conn=1 tests will fail.
max_db_conns: [0,1,2]
services:
synapse:
# Custom image built from https://github.com/matrix-org/synapse/tree/v1.72.0/docker/complement with a dummy /complement/ca set
@ -97,6 +101,11 @@ jobs:
SERVER_NAME: synapse
ports:
- 8008:8008
# Set health checks to wait until synapse has started
options: >-
--health-interval 10s
--health-timeout 5s
--health-retries 5
# Label used to access the service container
postgres:
# Docker Hub image
@ -135,13 +144,14 @@ jobs:
- name: Run end-to-end tests
run: |
set -euo pipefail
./run-tests.sh -count=1 -v -json . 2>&1 | tee test-e2e-runner.log | gotestfmt
./run-tests.sh -count=1 -v -json . 2>&1 | tee test-e2e-runner.log | gotestfmt -hide all
working-directory: tests-e2e
shell: bash
env:
SYNCV3_DB: user=postgres dbname=syncv3 sslmode=disable password=postgres host=localhost
SYNCV3_SERVER: http://localhost:8008
SYNCV3_SECRET: itsasecret
SYNCV3_MAX_DB_CONN: ${{ matrix.max_db_conns }}
E2E_TEST_SERVER_STDOUT: test-e2e-server.log
- name: Upload test log
@ -164,11 +174,14 @@ jobs:
- uses: actions/checkout@v3
with:
repository: matrix-org/matrix-react-sdk
ref: "v3.71.0" # later versions break the SS E2E tests which need to be fixed :(
- uses: actions/setup-node@v3
with:
cache: 'yarn'
- name: Fetch layered build
run: scripts/ci/layered.sh
env:
JS_SDK_GITHUB_BASE_REF: "v25.0.0-rc.1"
- name: Copy config
run: cp element.io/develop/config.json config.json
working-directory: ./element-web
@ -195,3 +208,123 @@ jobs:
path: |
cypress/screenshots
cypress/videos
upgrade-test:
runs-on: ubuntu-latest
env:
PREV_VERSION: "v0.99.4"
# Service containers to run with `container-job`
services:
synapse:
# Custom image built from https://github.com/matrix-org/synapse/tree/v1.72.0/docker/complement with a dummy /complement/ca set
image: ghcr.io/matrix-org/synapse-service:v1.72.0
env:
SYNAPSE_COMPLEMENT_DATABASE: sqlite
SERVER_NAME: synapse
ports:
- 8008:8008
# Set health checks to wait until synapse has started
options: >-
--health-interval 10s
--health-timeout 5s
--health-retries 5
# Label used to access the service container
postgres:
# Docker Hub image
image: postgres:13-alpine
# Provide the password for postgres
env:
POSTGRES_USER: postgres
POSTGRES_PASSWORD: postgres
POSTGRES_DB: syncv3
ports:
# Maps tcp port 5432 on service container to the host
- 5432:5432
# Set health checks to wait until postgres has started
options: >-
--health-cmd pg_isready
--health-interval 10s
--health-timeout 5s
--health-retries 5
steps:
- name: Install Go
uses: actions/setup-go@v4
with:
go-version: 1.19
# E2E tests with ${{env.PREV_VERSION}}
- uses: actions/checkout@v3
with:
ref: ${{env.PREV_VERSION}}
- name: Set up gotestfmt
uses: GoTestTools/gotestfmt-action@v2
with:
# Note: constrained to `packages:read` only at the top of the file
token: ${{ secrets.GITHUB_TOKEN }}
- name: Build ${{env.PREV_VERSION}}
run: go build ./cmd/syncv3
- name: Run end-to-end tests
run: |
set -euo pipefail
./run-tests.sh -count=1 -v -json . 2>&1 | tee test-e2e-runner-${{env.PREV_VERSION}}.log | gotestfmt -hide all
working-directory: tests-e2e
shell: bash
env:
SYNCV3_DB: user=postgres dbname=syncv3 sslmode=disable password=postgres host=localhost
SYNCV3_SERVER: http://localhost:8008
SYNCV3_SECRET: itsasecret
SYNCV3_MAX_DB_CONN: ${{ matrix.max_db_conns }}
E2E_TEST_SERVER_STDOUT: test-e2e-server-${{env.PREV_VERSION}}.log
- name: Upload test log ${{env.PREV_VERSION}}
uses: actions/upload-artifact@v3
if: failure()
with:
name: E2E test logs upgrade ${{env.PREV_VERSION}}
path: |
./tests-e2e/test-e2e-runner-${{env.PREV_VERSION}}.log
./tests-e2e/test-e2e-server-${{env.PREV_VERSION}}.log
if-no-files-found: error
# E2E tests with current commit
- uses: actions/checkout@v3
- name: Set up gotestfmt
uses: GoTestTools/gotestfmt-action@v2
with:
# Note: constrained to `packages:read` only at the top of the file
token: ${{ secrets.GITHUB_TOKEN }}
- name: Build
run: go build ./cmd/syncv3
- name: Run end-to-end tests
run: |
set -euo pipefail
./run-tests.sh -count=1 -v -json . 2>&1 | tee test-e2e-runner.log | gotestfmt -hide all
working-directory: tests-e2e
shell: bash
env:
SYNCV3_DB: user=postgres dbname=syncv3 sslmode=disable password=postgres host=localhost
SYNCV3_SERVER: http://localhost:8008
SYNCV3_SECRET: itsasecret
SYNCV3_MAX_DB_CONN: ${{ matrix.max_db_conns }}
E2E_TEST_SERVER_STDOUT: test-e2e-server.log
- name: Upload test log
uses: actions/upload-artifact@v3
if: failure()
with:
name: E2E test logs upgrade
path: |
./tests-e2e/test-e2e-runner.log
./tests-e2e/test-e2e-server.log
if-no-files-found: error

View File

@ -2,7 +2,11 @@
Run a sliding sync proxy. An implementation of [MSC3575](https://github.com/matrix-org/matrix-doc/blob/kegan/sync-v3/proposals/3575-sync.md).
Proxy version to MSC API specification:
## Proxy version to MSC API specification
This describes which proxy versions implement which version of the API drafted
in MSC3575. See https://github.com/matrix-org/sliding-sync/releases for the
changes in the proxy itself.
- Version 0.1.x: [2022/04/01](https://github.com/matrix-org/matrix-spec-proposals/blob/615e8f5a7bfe4da813bc2db661ed0bd00bccac20/proposals/3575-sync.md)
- First release
@ -21,29 +25,102 @@ Proxy version to MSC API specification:
- Support for `errcode` when sessions expire.
- Version 0.99.1 [2023/01/20](https://github.com/matrix-org/matrix-spec-proposals/blob/b4b4e7ff306920d2c862c6ff4d245110f6fa5bc7/proposals/3575-sync.md)
- Preparing for major v1.x release: lists-as-keys support.
- Version 0.99.2 [2024/07/27](https://github.com/matrix-org/matrix-spec-proposals/blob/eab643cb3ca63b03537a260fa343e1fb2d1ee284/proposals/3575-sync.md)
- Version 0.99.2 [2023/03/31](https://github.com/matrix-org/matrix-spec-proposals/blob/eab643cb3ca63b03537a260fa343e1fb2d1ee284/proposals/3575-sync.md)
- Experimental support for `bump_event_types` when ordering rooms by recency.
- Support for opting in to extensions on a per-list and per-room basis.
- Sentry support.
- Version 0.99.3 [2023/05/23](https://github.com/matrix-org/matrix-spec-proposals/blob/4103ee768a4a3e1decee80c2987f50f4c6b3d539/proposals/3575-sync.md)
- Support for per-list `bump_event_types`.
- Support for [`conn_id`](https://github.com/matrix-org/matrix-spec-proposals/blob/4103ee768a4a3e1decee80c2987f50f4c6b3d539/proposals/3575-sync.md#concurrent-connections) for distinguishing multiple concurrent connections.
- Version 0.99.4 [2023/07/12](https://github.com/matrix-org/matrix-spec-proposals/blob/4103ee768a4a3e1decee80c2987f50f4c6b3d539/proposals/3575-sync.md)
- Support for `SYNCV3_MAX_DB_CONN`, and reduce the amount of concurrent connections required during normal operation.
- Add more metrics and logs. Reduce log spam.
- Improve performance when handling changed device lists.
- Responses will consume from the live buffer even when clients change their request parameters to more speedily send new events down.
- Bugfix: return `invited_count` correctly when it transitions to 0.
- Bugfix: fix a data corruption bug when 2 users join a federated room where the first user was invited to said room.
## Usage
### Setup
Requires Postgres 13+.
First, you must create a Postgres database and secret:
```bash
$ createdb syncv3
$ echo -n "$(openssl rand -hex 32)" > .secret # this MUST remain the same throughout the lifetime of the database created above.
```
Compiling from source and running:
The Sliding Sync proxy requires some environment variables set to function. They are described when the proxy is run with missing variables.
Here is a short description of each, as of writing:
```
SYNCV3_SERVER Required. The destination homeserver to talk to (CS API HTTPS URL) e.g 'https://matrix-client.matrix.org'
SYNCV3_DB Required. The postgres connection string: https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING
SYNCV3_SECRET Required. A secret to use to encrypt access tokens. Must remain the same for the lifetime of the database.
SYNCV3_BINDADDR Default: 0.0.0.0:8008. The interface and port to listen on.
SYNCV3_TLS_CERT Default: unset. Path to a certificate file to serve to HTTPS clients. Specifying this enables TLS on the bound address.
SYNCV3_TLS_KEY Default: unset. Path to a key file for the certificate. Must be provided along with the certificate file.
SYNCV3_PPROF Default: unset. The bind addr for pprof debugging e.g ':6060'. If not set, does not listen.
SYNCV3_PROM Default: unset. The bind addr for Prometheus metrics, which will be accessible at /metrics at this address.
SYNCV3_JAEGER_URL Default: unset. The Jaeger URL to send spans to e.g http://localhost:14268/api/traces - if unset does not send OTLP traces.
SYNCV3_SENTRY_DSN Default: unset. The Sentry DSN to report events to e.g https://sliding-sync@sentry.example.com/123 - if unset does not send sentry events.
SYNCV3_LOG_LEVEL Default: info. The level of verbosity for messages logged. Available values are trace, debug, info, warn, error and fatal
SYNCV3_MAX_DB_CONN Default: unset. Max database connections to use when communicating with postgres. Unset or 0 means no limit.
```
It is easiest to host the proxy on a separate hostname than the Matrix server, though it is possible to use the same hostname by forwarding the used endpoints.
In both cases, the path `https://example.com/.well-known/matrix/client` must return a JSON with at least the following contents:
```json
{
"m.server": {
"base_url": "https://example.com"
},
"m.homeserver": {
"base_url": "https://example.com"
},
"org.matrix.msc3575.proxy": {
"url": "https://syncv3.example.com"
}
}
```
#### Same hostname
The following nginx configuration can be used to pass the required endpoints to the sync proxy, running on local port 8009 (so as to not conflict with Synapse):
```nginx
location ~ ^/(client/|_matrix/client/unstable/org.matrix.msc3575/sync) {
proxy_pass http://localhost:8009;
proxy_set_header X-Forwarded-For $remote_addr;
proxy_set_header X-Forwarded-Proto $scheme;
proxy_set_header Host $host;
}
location ~ ^(\/_matrix|\/_synapse\/client) {
proxy_pass http://localhost:8008;
proxy_set_header X-Forwarded-For $remote_addr;
proxy_set_header X-Forwarded-Proto $scheme;
proxy_set_header Host $host;
}
location /.well-known/matrix/client {
add_header Access-Control-Allow-Origin *;
}
```
### Running
There are two ways to run the proxy:
- Compiling from source:
```
$ go build ./cmd/syncv3
$ SYNCV3_SECRET=$(cat .secret) SYNCV3_SERVER="https://matrix-client.matrix.org" SYNCV3_DB="user=$(whoami) dbname=syncv3 sslmode=disable" SYNCV3_BINDADDR=0.0.0.0:8008 ./syncv3
$ SYNCV3_SECRET=$(cat .secret) SYNCV3_SERVER="https://matrix-client.matrix.org" SYNCV3_DB="user=$(whoami) dbname=syncv3 sslmode=disable password='DATABASE_PASSWORD_HERE'" SYNCV3_BINDADDR=0.0.0.0:8008 ./syncv3
```
Using a Docker image:
- Using a Docker image:
```
docker run --rm -e "SYNCV3_SERVER=https://matrix-client.matrix.org" -e "SYNCV3_SECRET=$(cat .secret)" -e "SYNCV3_BINDADDR=:8008" -e "SYNCV3_DB=user=$(whoami) dbname=syncv3 sslmode=disable host=host.docker.internal" -p 8008:8008 ghcr.io/matrix-org/sliding-sync:latest
docker run --rm -e "SYNCV3_SERVER=https://matrix-client.matrix.org" -e "SYNCV3_SECRET=$(cat .secret)" -e "SYNCV3_BINDADDR=:8008" -e "SYNCV3_DB=user=$(whoami) dbname=syncv3 sslmode=disable host=host.docker.internal password='DATABASE_PASSWORD_HERE'" -p 8008:8008 ghcr.io/matrix-org/sliding-sync:latest
```
Optionally also set `SYNCV3_TLS_CERT=path/to/cert.pem` and `SYNCV3_TLS_KEY=path/to/key.pem` to listen on HTTPS instead of HTTP.
Make sure to tweak the `SYNCV3_DB` environment variable if the Postgres database isn't running on the host.
@ -157,4 +234,4 @@ Run end-to-end tests:
# to ghcr and pull the image.
docker run --rm -e "SYNAPSE_COMPLEMENT_DATABASE=sqlite" -e "SERVER_NAME=synapse" -p 8888:8008 ghcr.io/matrix-org/synapse-service:v1.72.0
(go build ./cmd/syncv3 && dropdb syncv3_test && createdb syncv3_test && cd tests-e2e && ./run-tests.sh -count=1 .)
```
```

View File

@ -1,27 +1,36 @@
package main
import (
"flag"
"fmt"
"log"
"net/http"
_ "net/http/pprof"
"os"
"os/signal"
"strconv"
"strings"
"syscall"
"time"
"github.com/getsentry/sentry-go"
sentryhttp "github.com/getsentry/sentry-go/http"
syncv3 "github.com/matrix-org/sliding-sync"
"github.com/matrix-org/sliding-sync/internal"
"github.com/matrix-org/sliding-sync/sync2"
"github.com/pressly/goose/v3"
"github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/rs/zerolog"
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
"net/http"
_ "net/http/pprof"
"os"
"os/signal"
"strings"
"syscall"
"time"
)
var GitCommit string
const version = "0.99.2"
const version = "0.99.5"
var (
flags = flag.NewFlagSet("goose", flag.ExitOnError)
)
const (
// Required fields
@ -39,6 +48,7 @@ const (
EnvJaeger = "SYNCV3_JAEGER_URL"
EnvSentryDsn = "SYNCV3_SENTRY_DSN"
EnvLogLevel = "SYNCV3_LOG_LEVEL"
EnvMaxConns = "SYNCV3_MAX_DB_CONN"
)
var helpMsg = fmt.Sprintf(`
@ -54,7 +64,8 @@ Environment var
%s Default: unset. The Jaeger URL to send spans to e.g http://localhost:14268/api/traces - if unset does not send OTLP traces.
%s Default: unset. The Sentry DSN to report events to e.g https://sliding-sync@sentry.example.com/123 - if unset does not send sentry events.
%s Default: info. The level of verbosity for messages logged. Available values are trace, debug, info, warn, error and fatal
`, EnvServer, EnvDB, EnvSecret, EnvBindAddr, EnvTLSCert, EnvTLSKey, EnvPPROF, EnvPrometheus, EnvJaeger, EnvSentryDsn, EnvLogLevel)
%s Default: unset. Max database connections to use when communicating with postgres. Unset or 0 means no limit.
`, EnvServer, EnvDB, EnvSecret, EnvBindAddr, EnvTLSCert, EnvTLSKey, EnvPPROF, EnvPrometheus, EnvJaeger, EnvSentryDsn, EnvLogLevel, EnvMaxConns)
func defaulting(in, dft string) string {
if in == "" {
@ -67,6 +78,12 @@ func main() {
fmt.Printf("Sync v3 [%s] (%s)\n", version, GitCommit)
sync2.ProxyVersion = version
syncv3.Version = fmt.Sprintf("%s (%s)", version, GitCommit)
if len(os.Args) > 1 && os.Args[1] == "migrate" {
executeMigrations()
return
}
args := map[string]string{
EnvServer: os.Getenv(EnvServer),
EnvDB: os.Getenv(EnvDB),
@ -80,6 +97,7 @@ func main() {
EnvJaeger: os.Getenv(EnvJaeger),
EnvSentryDsn: os.Getenv(EnvSentryDsn),
EnvLogLevel: os.Getenv(EnvLogLevel),
EnvMaxConns: defaulting(os.Getenv(EnvMaxConns), "0"),
}
requiredEnvVars := []string{EnvServer, EnvDB, EnvSecret, EnvBindAddr}
for _, requiredEnvVar := range requiredEnvVars {
@ -135,6 +153,8 @@ func main() {
}
}
fmt.Printf("Debug=%v LogLevel=%v MaxConns=%v\n", args[EnvDebug] == "1", args[EnvLogLevel], args[EnvMaxConns])
if args[EnvDebug] == "1" {
zerolog.SetGlobalLevel(zerolog.TraceLevel)
} else {
@ -161,8 +181,15 @@ func main() {
panic(err)
}
maxConnsInt, err := strconv.Atoi(args[EnvMaxConns])
if err != nil {
panic("invalid value for " + EnvMaxConns + ": " + args[EnvMaxConns])
}
h2, h3 := syncv3.Setup(args[EnvServer], args[EnvDB], args[EnvSecret], syncv3.Opts{
AddPrometheusMetrics: args[EnvPrometheus] != "",
AddPrometheusMetrics: args[EnvPrometheus] != "",
DBMaxConns: maxConnsInt,
DBConnMaxIdleTime: time.Hour,
MaxTransactionIDDelay: time.Second,
})
go h2.StartV2Pollers()
@ -203,3 +230,49 @@ func WaitForShutdown(sentryInUse bool) {
fmt.Printf("Exiting now")
}
func executeMigrations() {
envArgs := map[string]string{
EnvDB: os.Getenv(EnvDB),
}
requiredEnvVars := []string{EnvDB}
for _, requiredEnvVar := range requiredEnvVars {
if envArgs[requiredEnvVar] == "" {
fmt.Print(helpMsg)
fmt.Printf("\n%s is not set", requiredEnvVar)
fmt.Printf("\n%s must be set\n", strings.Join(requiredEnvVars, ", "))
os.Exit(1)
}
}
flags.Parse(os.Args[1:])
args := flags.Args()
if len(args) < 2 {
flags.Usage()
return
}
command := args[1]
db, err := goose.OpenDBWithDriver("postgres", envArgs[EnvDB])
if err != nil {
log.Fatalf("goose: failed to open DB: %v\n", err)
}
defer func() {
if err := db.Close(); err != nil {
log.Fatalf("goose: failed to close DB: %v\n", err)
}
}()
arguments := []string{}
if len(args) > 2 {
arguments = append(arguments, args[2:]...)
}
goose.SetBaseFS(syncv3.EmbedMigrations)
if err := goose.Run(command, db, "state/migrations", arguments...); err != nil {
log.Fatalf("goose %v: %v", command, err)
}
}

34
go.mod
View File

@ -7,44 +7,46 @@ require (
github.com/getsentry/sentry-go v0.20.0
github.com/go-logr/zerologr v1.2.3
github.com/gorilla/mux v1.8.0
github.com/hashicorp/golang-lru v0.5.4
github.com/jmoiron/sqlx v1.3.3
github.com/lib/pq v1.10.1
github.com/lib/pq v1.10.9
github.com/matrix-org/gomatrixserverlib v0.0.0-20230105074811-965b10ae73ab
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4
github.com/pressly/goose/v3 v3.14.0
github.com/prometheus/client_golang v1.13.0
github.com/rs/zerolog v1.29.0
github.com/tidwall/gjson v1.14.3
github.com/tidwall/sjson v1.2.5
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.39.0
go.opentelemetry.io/otel v1.13.0
go.opentelemetry.io/otel/exporters/jaeger v1.13.0
go.opentelemetry.io/otel/sdk v1.13.0
go.opentelemetry.io/otel/trace v1.13.0
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.42.0
go.opentelemetry.io/otel v1.16.0
go.opentelemetry.io/otel/exporters/jaeger v1.16.0
go.opentelemetry.io/otel/sdk v1.16.0
go.opentelemetry.io/otel/trace v1.16.0
)
require (
github.com/beorn7/perks v1.0.1 // indirect
github.com/cespare/xxhash/v2 v2.1.2 // indirect
github.com/felixge/httpsnoop v1.0.3 // indirect
github.com/go-logr/logr v1.2.3 // indirect
github.com/fxamacker/cbor/v2 v2.5.0 // indirect
github.com/go-logr/logr v1.2.4 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/golang/protobuf v1.5.2 // indirect
github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 // indirect
github.com/mattn/go-colorable v0.1.13 // indirect
github.com/mattn/go-isatty v0.0.17 // indirect
github.com/mattn/go-isatty v0.0.19 // indirect
github.com/matttproud/golang_protobuf_extensions v1.0.1 // indirect
github.com/prometheus/client_model v0.2.0 // indirect
github.com/prometheus/common v0.37.0 // indirect
github.com/prometheus/procfs v0.8.0 // indirect
github.com/prometheus/procfs v0.11.0 // indirect
github.com/rs/xid v1.4.0 // indirect
github.com/sirupsen/logrus v1.9.0 // indirect
github.com/sirupsen/logrus v1.9.3 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.0 // indirect
go.opentelemetry.io/otel/metric v0.36.0 // indirect
golang.org/x/crypto v0.7.0 // indirect
golang.org/x/sync v0.0.0-20220907140024-f12130a52804 // indirect
golang.org/x/sys v0.6.0 // indirect
golang.org/x/text v0.8.0 // indirect
github.com/x448/float16 v0.8.4 // indirect
go.opentelemetry.io/otel/metric v1.16.0 // indirect
golang.org/x/crypto v0.10.0 // indirect
golang.org/x/sync v0.2.0 // indirect
golang.org/x/sys v0.10.0 // indirect
golang.org/x/text v0.11.0 // indirect
google.golang.org/protobuf v1.29.1 // indirect
)

91
go.sum
View File

@ -57,12 +57,15 @@ github.com/coreos/go-systemd/v22 v22.3.3-0.20220203105225-a9a7ef127534/go.mod h1
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98=
github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c=
github.com/felixge/httpsnoop v1.0.3 h1:s/nj+GCswXYzN5v2DpNMuMQYe+0DDwt5WVCU6CWBdXk=
github.com/felixge/httpsnoop v1.0.3/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
github.com/fxamacker/cbor/v2 v2.5.0 h1:oHsG0V/Q6E/wqTS2O1Cozzsy69nqCiguo5Q1a1ADivE=
github.com/fxamacker/cbor/v2 v2.5.0/go.mod h1:TA1xS00nchWmaBnEIxPSE5oHLuJBAVvqrtAnWBwBCVo=
github.com/getsentry/sentry-go v0.20.0 h1:bwXW98iMRIWxn+4FgPW7vMrjmbym6HblXALmhjHmQaQ=
github.com/getsentry/sentry-go v0.20.0/go.mod h1:lc76E2QywIyW8WuBnwl8Lc4bkmQH4+w1gwTf25trprY=
github.com/go-errors/errors v1.4.2 h1:J6MZopCL4uSllY1OfXM374weqZFFItUbrImctkmUxIA=
@ -78,14 +81,14 @@ github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V
github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG1KdI/P7A=
github.com/go-logfmt/logfmt v0.5.1/go.mod h1:WYhtIu8zTZfxdn5+rREduYbwxfcBr/Vr6KEVveWlfTs=
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
github.com/go-logr/logr v1.2.3 h1:2DntVwHkVopvECVRSlL5PSo9eG+cAkDCuckLubN+rq0=
github.com/go-logr/logr v1.2.3/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ=
github.com/go-logr/logr v1.2.4/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
github.com/go-logr/zerologr v1.2.3 h1:up5N9vcH9Xck3jJkXzgyOxozT14R47IyDODz8LM1KSs=
github.com/go-logr/zerologr v1.2.3/go.mod h1:BxwGo7y5zgSHYR1BjbnHPyF/5ZjVKfKxAZANVu6E8Ho=
github.com/go-sql-driver/mysql v1.5.0 h1:ozyZYNQW3x3HtqT1jira07DN2PArx2v7/mN66gGcHOs=
github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg=
github.com/go-sql-driver/mysql v1.7.1 h1:lUIinVbN1DY0xBg0eMOzmmtGoHwWBbvnWubQUrtU8EI=
github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY=
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ=
@ -140,6 +143,7 @@ github.com/google/pprof v0.0.0-20200229191704-1ebb73c60ed3/go.mod h1:ZgVRPoUq/hf
github.com/google/pprof v0.0.0-20200430221834-fc25d7d30c6d/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM=
github.com/google/pprof v0.0.0-20200708004538-1a94d8640e99/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM=
github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI=
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg=
github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk=
github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI=
@ -147,8 +151,6 @@ github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB7
github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542 h1:2VTzZjLZBgl62/EtslCrtky5vbi9dd7HrQPQIx6wqiw=
github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8=
github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8=
github.com/hashicorp/golang-lru v0.5.4 h1:YDjusn29QI/Das2iO9M0BHnIbxPeyuCHsjMW+lJfyTc=
github.com/hashicorp/golang-lru v0.5.4/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4=
github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc=
github.com/jmoiron/sqlx v1.3.3 h1:j82X0bf7oQ27XeqxicSZsTU5suPwKElg3oyxNn43iTk=
github.com/jmoiron/sqlx v1.3.3/go.mod h1:2BljVx/86SuTyjE+aPYlHCTNvZrnJXghYGpNiXLBMCQ=
@ -161,6 +163,7 @@ github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1
github.com/jstemmer/go-junit-report v0.9.1/go.mod h1:Brl9GWCQeLvo8nXZwPNNblvFj/XSXhF0NWZEnDohbsk=
github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w=
github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM=
github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 h1:Z9n2FFNUXsshfwJMBgNA0RU6/i7WVaAegv3PtuIHPMs=
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
github.com/konsorten/go-windows-terminal-sequences v1.0.3/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
@ -169,8 +172,8 @@ github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORN
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
github.com/lib/pq v1.10.1 h1:6VXZrLU0jHBYyAqrSPa+MgPfnSvTPuMgK+k0o5kVFWo=
github.com/lib/pq v1.10.1/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 h1:ZtO5uywdd5dLDCud4r0r55eP4j9FuUNpl60Gmntcop4=
github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s=
github.com/matrix-org/gomatrixserverlib v0.0.0-20230105074811-965b10ae73ab h1:ChaQdT2mpxMm3GRXNOZzLDQ/wOnlKZ8o60LmZGOjdj8=
@ -182,8 +185,8 @@ github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxec
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94=
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
github.com/mattn/go-isatty v0.0.17 h1:BTarxUcIeDqL27Mc+vyvdWYSL28zpIhv3RoTdsLMPng=
github.com/mattn/go-isatty v0.0.17/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-sqlite3 v1.14.6 h1:dNPt6NO46WmLVt2DLNpwczCmdV5boIZ6g/tlDrlRUbg=
github.com/mattn/go-sqlite3 v1.14.6/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU=
github.com/matttproud/golang_protobuf_extensions v1.0.1 h1:4hp9jkHxhMHkqkrB3Ix0jegS5sx/RkqARlsWZ6pIwiU=
@ -203,6 +206,8 @@ github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/pressly/goose/v3 v3.14.0 h1:gNrFLLDF+fujdq394rcdYK3WPxp3VKWifTajlZwInJM=
github.com/pressly/goose/v3 v3.14.0/go.mod h1:uwSpREK867PbIsdE9GS6pRk1LUPB7gwMkmvk9/hbIMA=
github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw=
github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo=
github.com/prometheus/client_golang v1.7.1/go.mod h1:PY5Wy2awLA44sXw4AOSfFBetzPP4j5+D6mVACh+pe2M=
@ -226,8 +231,9 @@ github.com/prometheus/procfs v0.0.2/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsT
github.com/prometheus/procfs v0.1.3/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4OA4YeYWdaU=
github.com/prometheus/procfs v0.6.0/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA=
github.com/prometheus/procfs v0.7.3/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA=
github.com/prometheus/procfs v0.8.0 h1:ODq8ZFEaYeCaZOJlZZdJA2AbQR98dSHSM1KW/You5mo=
github.com/prometheus/procfs v0.8.0/go.mod h1:z7EfXMXOkbkqb9IINtpCn86r/to3BnA0uaxHdg830/4=
github.com/prometheus/procfs v0.11.0 h1:5EAgkfkMl659uZPbe9AS2N68a7Cc1TJbPEuGzFuRbyk=
github.com/prometheus/procfs v0.11.0/go.mod h1:nwNm2aOCAYw8uTR/9bWRREkZFxAUcWzPHWJq+XBB/FM=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
github.com/rs/xid v1.4.0 h1:qd7wPTDkN6KQx2VmMBLrpHkiyQwgFXRnkOLacUiaSNY=
github.com/rs/xid v1.4.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
@ -236,8 +242,8 @@ github.com/rs/zerolog v1.29.0/go.mod h1:NILgTygv/Uej1ra5XxGf82ZFSLk58MFGAUS2o6us
github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo=
github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE=
github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88=
github.com/sirupsen/logrus v1.9.0 h1:trlNQbNUG3OdDrDil03MCb1H2o9nJ1x4/5LYw7byDE0=
github.com/sirupsen/logrus v1.9.0/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c=
@ -245,7 +251,7 @@ github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXf
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8=
github.com/stretchr/testify v1.8.3 h1:RP3t2pwF7cMEbC1dqtB6poj3niw/9gnV4Cjg5oW5gtY=
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/gjson v1.14.3 h1:9jvXn7olKEHU1S9vwoMGliaT8jq1vJ7IH/n9zD9Dnlw=
github.com/tidwall/gjson v1.14.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
@ -255,6 +261,8 @@ github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=
github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg=
github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
@ -264,18 +272,18 @@ go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8=
go.opencensus.io v0.22.2/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=
go.opencensus.io v0.22.3/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=
go.opencensus.io v0.22.4/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.39.0 h1:vFEBG7SieZJzvnRWQ81jxpuEqe6J8Ex+hgc9CqOTzHc=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.39.0/go.mod h1:9rgTcOKdIhDOC0IcAu8a+R+FChqSUBihKpM1lVNi6T0=
go.opentelemetry.io/otel v1.13.0 h1:1ZAKnNQKwBBxFtww/GwxNUyTf0AxkZzrukO8MeXqe4Y=
go.opentelemetry.io/otel v1.13.0/go.mod h1:FH3RtdZCzRkJYFTCsAKDy9l/XYjMdNv6QrkFFB8DvVg=
go.opentelemetry.io/otel/exporters/jaeger v1.13.0 h1:VAMoGujbVV8Q0JNM/cEbhzUIWWBxnEqH45HP9iBKN04=
go.opentelemetry.io/otel/exporters/jaeger v1.13.0/go.mod h1:fHwbmle6mBFJA1p2ZIhilvffCdq/dM5UTIiCOmEjS+w=
go.opentelemetry.io/otel/metric v0.36.0 h1:t0lgGI+L68QWt3QtOIlqM9gXoxqxWLhZ3R/e5oOAY0Q=
go.opentelemetry.io/otel/metric v0.36.0/go.mod h1:wKVw57sd2HdSZAzyfOM9gTqqE8v7CbqWsYL6AyrH9qk=
go.opentelemetry.io/otel/sdk v1.13.0 h1:BHib5g8MvdqS65yo2vV1s6Le42Hm6rrw08qU6yz5JaM=
go.opentelemetry.io/otel/sdk v1.13.0/go.mod h1:YLKPx5+6Vx/o1TCUYYs+bpymtkmazOMT6zoRrC7AQ7I=
go.opentelemetry.io/otel/trace v1.13.0 h1:CBgRZ6ntv+Amuj1jDsMhZtlAPT6gbyIRdaIzFhfBSdY=
go.opentelemetry.io/otel/trace v1.13.0/go.mod h1:muCvmmO9KKpvuXSf3KKAXXB2ygNYHQ+ZfI5X08d3tds=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.42.0 h1:pginetY7+onl4qN1vl0xW/V/v6OBZ0vVdH+esuJgvmM=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.42.0/go.mod h1:XiYsayHc36K3EByOO6nbAXnAWbrUxdjUROCEeeROOH8=
go.opentelemetry.io/otel v1.16.0 h1:Z7GVAX/UkAXPKsy94IU+i6thsQS4nb7LviLpnaNeW8s=
go.opentelemetry.io/otel v1.16.0/go.mod h1:vl0h9NUa1D5s1nv3A5vZOYWn8av4K8Ml6JDeHrT/bx4=
go.opentelemetry.io/otel/exporters/jaeger v1.16.0 h1:YhxxmXZ011C0aDZKoNw+juVWAmEfv/0W2XBOv9aHTaA=
go.opentelemetry.io/otel/exporters/jaeger v1.16.0/go.mod h1:grYbBo/5afWlPpdPZYhyn78Bk04hnvxn2+hvxQhKIQM=
go.opentelemetry.io/otel/metric v1.16.0 h1:RbrpwVG1Hfv85LgnZ7+txXioPDoh6EdbZHo26Q3hqOo=
go.opentelemetry.io/otel/metric v1.16.0/go.mod h1:QE47cpOmkwipPiefDwo2wDzwJrlfxxNYodqc4xnGCo4=
go.opentelemetry.io/otel/sdk v1.16.0 h1:Z1Ok1YsijYL0CSJpHt4cS3wDDh7p572grzNrBMiMWgE=
go.opentelemetry.io/otel/sdk v1.16.0/go.mod h1:tMsIuKXuuIWPBAOrH+eHtvhTL+SntFtXF9QD68aP6p4=
go.opentelemetry.io/otel/trace v1.16.0 h1:8JRpaObFoW0pxuVPapkgH8UhHQj+bJW8jJsCZEu5MQs=
go.opentelemetry.io/otel/trace v1.16.0/go.mod h1:Yt9vYq1SdNz3xdjZZK7wcXv1qv2pwLkqr2QVwea0ef0=
go.uber.org/goleak v1.1.10 h1:z+mqJhf6ss6BSfSM671tgKyZBFPTTJM+HLxnhPC3wu0=
go.uber.org/goleak v1.1.10/go.mod h1:8a7PlsEVH3e/a/GLqe5IIrQx6GzcnRmZEufDUTk4A7A=
golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
@ -284,8 +292,8 @@ golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8U
golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.7.0 h1:AvwMYaRytfdeVt3u6mLaxYtErKYjxA2OXjJ1HHq6t3A=
golang.org/x/crypto v0.7.0/go.mod h1:pYwdfH91IfpZVANVyUOhSIPZaFoJGxTFbZhFTx+dXZU=
golang.org/x/crypto v0.10.0 h1:LKqV2xt9+kDzSTfOhx4FrkEBcMrAgHSYgzywV9zcGmM=
golang.org/x/crypto v0.10.0/go.mod h1:o4eNf7Ede1fv+hwOwZsTHl9EsPFO6q6ZvYR8vYfY45I=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8=
@ -318,7 +326,7 @@ golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzB
golang.org/x/mod v0.1.1-0.20191107180719-034126e5016b/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.8.0 h1:LUYupSeNrTNCGzR/hVBk2NHZO4hXcVaW1k4Qx7rjPx8=
golang.org/x/mod v0.12.0 h1:rmsUpXtvNzj340zd98LZ4KntptpfRHwpFOHG188oHXc=
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
@ -351,7 +359,7 @@ golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwY
golang.org/x/net v0.0.0-20210525063256-abc453219eb5/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk=
golang.org/x/net v0.0.0-20220225172249-27dd8689420f/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk=
golang.org/x/net v0.8.0 h1:Zrh2ngAOFYneWTAIAPethzeaQLuHwhuBkuV6ZiRnUaQ=
golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
@ -370,8 +378,8 @@ golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220907140024-f12130a52804 h1:0SH2R3f1b1VmIMG7BXbEZCBUu2dKmHschSmjqGUrW8A=
golang.org/x/sync v0.0.0-20220907140024-f12130a52804/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.2.0 h1:PUR+T4wwASmuSTYdKjYHI5TD22Wy5ogLU5qZCOLxBrI=
golang.org/x/sync v0.2.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
@ -415,8 +423,9 @@ golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20220114195835-da31bd327af9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0 h1:MVltZSvRTcU2ljQOhs94SXPftV6DCNnZViHeQps87pQ=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.10.0 h1:SqMFp9UcQJZa+pmYuAKjd9xq1f0j5rLcDIk0mj4qAsA=
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
@ -426,8 +435,8 @@ golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.8.0 h1:57P1ETyNKtuIjB4SRd15iJxuhj8Gc416Y78H3qgMh68=
golang.org/x/text v0.8.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
golang.org/x/text v0.11.0 h1:LAntKIrcmeSKERyiOh0XMV39LXS8IE9UL2yP7+f5ij4=
golang.org/x/text v0.11.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
@ -473,7 +482,7 @@ golang.org/x/tools v0.0.0-20200729194436-6467de6f59a7/go.mod h1:njjCfa9FT2d7l9Bc
golang.org/x/tools v0.0.0-20200804011535-6c149bb5ef0d/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA=
golang.org/x/tools v0.0.0-20200825202427-b303f430e36d/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA=
golang.org/x/tools v0.0.0-20210112230658-8b4aab62c064/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
golang.org/x/tools v0.6.0 h1:BOw41kyTf3PuCW1pVQf8+Cyg8pMlkYB1oo9iJ6D/lKM=
golang.org/x/tools v0.10.0 h1:tvDr/iQoUqNdohiYm0LmmKcBk+q86lb9EprIUFhHHGg=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
@ -577,6 +586,16 @@ honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWh
honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg=
honnef.co/go/tools v0.0.1-2020.1.3/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k=
honnef.co/go/tools v0.0.1-2020.1.4/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k=
lukechampine.com/uint128 v1.3.0 h1:cDdUVfRwDUDovz610ABgFD17nXD4/uDgVHl2sC3+sbo=
modernc.org/cc/v3 v3.41.0 h1:QoR1Sn3YWlmA1T4vLaKZfawdVtSiGx8H+cEojbC7v1Q=
modernc.org/ccgo/v3 v3.16.14 h1:af6KNtFgsVmnDYrWk3PQCS9XT6BXe7o3ZFJKkIKvXNQ=
modernc.org/libc v1.24.1 h1:uvJSeCKL/AgzBo2yYIPPTy82v21KgGnizcGYfBHaNuM=
modernc.org/mathutil v1.6.0 h1:fRe9+AmYlaej+64JsEEhoWuAYBkOtQiMEU7n/XgfYi4=
modernc.org/memory v1.6.0 h1:i6mzavxrE9a30whzMfwf7XWVODx2r5OYXvU46cirX7o=
modernc.org/opt v0.1.3 h1:3XOZf2yznlhC+ibLltsDGzABUGVx8J6pnFMS3E4dcq4=
modernc.org/sqlite v1.24.0 h1:EsClRIWHGhLTCX44p+Ri/JLD+vFGo0QGjasg2/F9TlI=
modernc.org/strutil v1.1.3 h1:fNMm+oJklMGYfU9Ylcywl0CO5O6nTfaowNsh2wpPjzY=
modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y=
rsc.io/binaryregexp v0.2.0/go.mod h1:qTv7/COck+e2FymRvadv62gMdZztPaShugOCi3I+8D8=
rsc.io/quote/v3 v3.1.0/go.mod h1:yEA65RcK8LyAZtP9Kv3t0HmxON59tX3rD+tICJqUlj0=
rsc.io/sampler v1.3.0/go.mod h1:T1hPZKmBbMNahiBKFy5HrXp6adAjACjK9JXDnKaTXpA=

View File

@ -2,6 +2,8 @@ package internal
import (
"context"
"fmt"
"github.com/getsentry/sentry-go"
"github.com/rs/zerolog"
@ -16,6 +18,9 @@ var (
// logging metadata for a single request
type data struct {
userID string
deviceID string
bufferSummary string
connID string
since int64
next int64
numRooms int
@ -24,6 +29,9 @@ type data struct {
numGlobalAccountData int
numChangedDevices int
numLeftDevices int
numLists int
roomSubs int
roomUnsubs int
}
// prepare a request context so it can contain syncv3 info
@ -37,13 +45,14 @@ func RequestContext(ctx context.Context) context.Context {
}
// add the user ID to this request context. Need to have called RequestContext first.
func SetRequestContextUserID(ctx context.Context, userID string) {
func SetRequestContextUserID(ctx context.Context, userID, deviceID string) {
d := ctx.Value(ctxData)
if d == nil {
return
}
da := d.(*data)
da.userID = userID
da.deviceID = deviceID
if hub := sentry.GetHubFromContext(ctx); hub != nil {
sentry.ConfigureScope(func(scope *sentry.Scope) {
scope.SetUser(sentry.User{Username: userID})
@ -51,9 +60,18 @@ func SetRequestContextUserID(ctx context.Context, userID string) {
}
}
func SetConnBufferInfo(ctx context.Context, bufferLen, nextLen, bufferCap int) {
d := ctx.Value(ctxData)
if d == nil {
return
}
da := d.(*data)
da.bufferSummary = fmt.Sprintf("%d/%d/%d", bufferLen, nextLen, bufferCap)
}
func SetRequestContextResponseInfo(
ctx context.Context, since, next int64, numRooms int, txnID string, numToDeviceEvents, numGlobalAccountData int,
numChangedDevices, numLeftDevices int,
numChangedDevices, numLeftDevices int, connID string, numLists int, roomSubs, roomUnsubs int,
) {
d := ctx.Value(ctxData)
if d == nil {
@ -68,6 +86,10 @@ func SetRequestContextResponseInfo(
da.numGlobalAccountData = numGlobalAccountData
da.numChangedDevices = numChangedDevices
da.numLeftDevices = numLeftDevices
da.connID = connID
da.numLists = numLists
da.roomSubs = roomSubs
da.roomUnsubs = roomUnsubs
}
func DecorateLogger(ctx context.Context, l *zerolog.Event) *zerolog.Event {
@ -79,6 +101,9 @@ func DecorateLogger(ctx context.Context, l *zerolog.Event) *zerolog.Event {
if da.userID != "" {
l = l.Str("u", da.userID)
}
if da.deviceID != "" {
l = l.Str("dev", da.deviceID)
}
if da.since >= 0 {
l = l.Int64("p", da.since)
}
@ -103,5 +128,19 @@ func DecorateLogger(ctx context.Context, l *zerolog.Event) *zerolog.Event {
if da.numLeftDevices > 0 {
l = l.Int("dl-l", da.numLeftDevices)
}
if da.bufferSummary != "" {
l = l.Str("b", da.bufferSummary)
}
if da.roomSubs > 0 {
l = l.Int("sub", da.roomSubs)
}
if da.roomUnsubs > 0 {
l = l.Int("usub", da.roomUnsubs)
}
if da.numLists > 0 {
l = l.Int("l", da.numLists)
}
// always log the connection ID so we know when it isn't set
l = l.Str("c", da.connID)
return l
}

View File

@ -23,7 +23,7 @@ func isBitSet(n int, bit int) bool {
type DeviceData struct {
// Contains the latest device_one_time_keys_count values.
// Set whenever this field arrives down the v2 poller, and it replaces what was previously there.
OTKCounts map[string]int `json:"otk"`
OTKCounts MapStringInt `json:"otk"`
// Contains the latest device_unused_fallback_key_types value
// Set whenever this field arrives down the v2 poller, and it replaces what was previously there.
FallbackKeyTypes []string `json:"fallback"`

View File

@ -1,5 +1,10 @@
package internal
import (
"database/sql/driver"
"encoding/json"
)
const (
DeviceListChanged = 1
DeviceListLeft = 2
@ -7,8 +12,19 @@ const (
type DeviceLists struct {
// map user_id -> DeviceList enum
New map[string]int `json:"n"`
Sent map[string]int `json:"s"`
New MapStringInt `json:"n"`
Sent MapStringInt `json:"s"`
}
type MapStringInt map[string]int
// Value implements driver.Valuer
func (dl MapStringInt) Value() (driver.Value, error) {
if len(dl) == 0 {
return "{}", nil
}
v, err := json.Marshal(dl)
return v, err
}
func (dl DeviceLists) Combine(newer DeviceLists) DeviceLists {

View File

@ -71,10 +71,18 @@ func ExpiredSessionError() *HandlerError {
// Which then produces:
//
// assertion failed: list is not empty
func Assert(msg string, expr bool) {
//
// An optional debugContext map can be provided. If it is present and sentry is configured,
// it is added as context to the sentry events generated for failed assertions.
func Assert(msg string, expr bool, debugContext ...map[string]interface{}) {
assert(msg, expr)
if !expr {
sentry.CaptureException(fmt.Errorf("assertion failed: %s", msg))
sentry.WithScope(func(scope *sentry.Scope) {
if len(debugContext) > 0 {
scope.SetContext(SentryCtxKey, debugContext[0])
}
sentry.CaptureException(fmt.Errorf("assertion failed: %s", msg))
})
}
}

67
internal/pool.go Normal file
View File

@ -0,0 +1,67 @@
package internal
type WorkerPool struct {
N int
ch chan func()
}
// Create a new worker pool of size N. Up to N work can be done concurrently.
// The size of N depends on the expected frequency of work and contention for
// shared resources. Large values of N allow more frequent work at the cost of
// more contention for shared resources like cpu, memory and fds. Small values
// of N allow less frequent work but control the amount of shared resource contention.
// Ideally this value will be derived from whatever shared resource constraints you
// are hitting up against, rather than set to a fixed value. For example, if you have
// a database connection limit of 100, then setting N to some fraction of the limit is
// preferred to setting this to an arbitrary number < 100. If more than N work is requested,
// eventually WorkerPool.Queue will block until some work is done.
//
// The larger N is, the larger the up front memory costs are due to the implementation of WorkerPool.
func NewWorkerPool(n int) *WorkerPool {
return &WorkerPool{
N: n,
// If we have N workers, we can process N work concurrently.
// If we have >N work, we need to apply backpressure to stop us
// making more and more work which takes up more and more memory.
// By setting the channel size to N, we ensure that backpressure is
// being applied on the producer, stopping it from creating more work,
// and hence bounding memory consumption. Work is still being produced
// upstream on the homeserver, but we will consume it when we're ready
// rather than gobble it all at once.
//
// Note: we aren't forced to set this to N, it just serves as a useful
// metric which scales on the number of workers. The amount of in-flight
// work is N, so it makes sense to allow up to N work to be queued up before
// applying backpressure. If the channel buffer is < N then the channel can
// become the bottleneck in the case where we have lots of instantaneous work
// to do. If the channel buffer is too large, we needlessly consume memory as
// make() will allocate a backing array of whatever size you give it up front (sad face)
ch: make(chan func(), n),
}
}
// Start the workers. Only call this once.
func (wp *WorkerPool) Start() {
for i := 0; i < wp.N; i++ {
go wp.worker()
}
}
// Stop the worker pool. Only really useful for tests as a worker pool should be started once
// and persist for the lifetime of the process, else it causes needless goroutine churn.
// Only call this once.
func (wp *WorkerPool) Stop() {
close(wp.ch)
}
// Queue some work on the pool. May or may not block until some work is processed.
func (wp *WorkerPool) Queue(fn func()) {
wp.ch <- fn
}
// worker impl
func (wp *WorkerPool) worker() {
for fn := range wp.ch {
fn()
}
}

186
internal/pool_test.go Normal file
View File

@ -0,0 +1,186 @@
package internal
import (
"sync"
"testing"
"time"
)
// Test basic functions of WorkerPool
func TestWorkerPool(t *testing.T) {
wp := NewWorkerPool(2)
wp.Start()
defer wp.Stop()
// we should process this concurrently as N=2 so it should take 1s not 2s
var wg sync.WaitGroup
wg.Add(2)
start := time.Now()
wp.Queue(func() {
time.Sleep(time.Second)
wg.Done()
})
wp.Queue(func() {
time.Sleep(time.Second)
wg.Done()
})
wg.Wait()
took := time.Since(start)
if took > 2*time.Second {
t.Fatalf("took %v for queued work, it should have been faster than 2s", took)
}
}
func TestWorkerPoolDoesWorkPriorToStart(t *testing.T) {
wp := NewWorkerPool(2)
// return channel to use to see when work is done
ch := make(chan int, 2)
wp.Queue(func() {
ch <- 1
})
wp.Queue(func() {
ch <- 2
})
// the work should not be done yet
time.Sleep(100 * time.Millisecond)
if len(ch) > 0 {
t.Fatalf("Queued work was done before Start()")
}
// the work should be starting now
wp.Start()
defer wp.Stop()
sum := 0
for {
select {
case <-time.After(time.Second):
t.Fatalf("timed out waiting for work to be done")
case val := <-ch:
sum += val
}
if sum == 3 { // 2 + 1
break
}
}
}
type workerState struct {
id int
state int // not running, queued, running, finished
unblock *sync.WaitGroup // decrement to unblock this worker
}
func TestWorkerPoolBackpressure(t *testing.T) {
// this test assumes backpressure starts at n*2+1 due to a chan buffer of size n, and n in-flight work.
n := 2
wp := NewWorkerPool(n)
wp.Start()
defer wp.Stop()
var mu sync.Mutex
stateNotRunning := 0
stateQueued := 1
stateRunning := 2
stateFinished := 3
size := (2 * n) + 1
running := make([]*workerState, size)
go func() {
// we test backpressure by scheduling (n*2)+1 work and ensuring that we see the following running states:
// [2,2,1,1,0] <-- 2 running, 2 queued, 1 blocked <-- THIS IS BACKPRESSURE
// [3,2,2,1,1] <-- 1 finished, 2 running, 2 queued
// [3,3,2,2,1] <-- 2 finished, 2 running , 1 queued
// [3,3,3,2,2] <-- 3 finished, 2 running
for i := 0; i < size; i++ {
// set initial state of this piece of work
wg := &sync.WaitGroup{}
wg.Add(1)
state := &workerState{
id: i,
state: stateNotRunning,
unblock: wg,
}
mu.Lock()
running[i] = state
mu.Unlock()
// queue the work on the pool. The final piece of work will block here and remain in
// stateNotRunning and not transition to stateQueued until the first piece of work is done.
wp.Queue(func() {
mu.Lock()
if running[state.id].state != stateQueued {
// we ran work in the worker faster than the code underneath .Queue, so let it catch up
mu.Unlock()
time.Sleep(10 * time.Millisecond)
mu.Lock()
}
running[state.id].state = stateRunning
mu.Unlock()
running[state.id].unblock.Wait()
mu.Lock()
running[state.id].state = stateFinished
mu.Unlock()
})
// mark this work as queued
mu.Lock()
running[i].state = stateQueued
mu.Unlock()
}
}()
// wait for the workers to be doing work and assert the states of each task
time.Sleep(time.Second)
assertStates(t, &mu, running, []int{
stateRunning, stateRunning, stateQueued, stateQueued, stateNotRunning,
})
// now let the first task complete
running[0].unblock.Done()
// wait for the pool to grab more work
time.Sleep(100 * time.Millisecond)
// assert new states
assertStates(t, &mu, running, []int{
stateFinished, stateRunning, stateRunning, stateQueued, stateQueued,
})
// now let the second task complete
running[1].unblock.Done()
// wait for the pool to grab more work
time.Sleep(100 * time.Millisecond)
// assert new states
assertStates(t, &mu, running, []int{
stateFinished, stateFinished, stateRunning, stateRunning, stateQueued,
})
// now let the third task complete
running[2].unblock.Done()
// wait for the pool to grab more work
time.Sleep(100 * time.Millisecond)
// assert new states
assertStates(t, &mu, running, []int{
stateFinished, stateFinished, stateFinished, stateRunning, stateRunning,
})
}
func assertStates(t *testing.T, mu *sync.Mutex, running []*workerState, wantStates []int) {
t.Helper()
mu.Lock()
defer mu.Unlock()
if len(running) != len(wantStates) {
t.Fatalf("assertStates: bad wantStates length, got %d want %d", len(wantStates), len(running))
}
for i := range running {
state := running[i]
wantVal := wantStates[i]
if state.state != wantVal {
t.Errorf("work[%d] got state %d want %d", i, state.state, wantVal)
}
}
}

View File

@ -13,7 +13,12 @@ type EventMetadata struct {
Timestamp uint64
}
// RoomMetadata holds room-scoped data. It is primarily used in two places:
// RoomMetadata holds room-scoped data.
// TODO: This is a lie: we sometimes remove a user U from the list of heroes
// when calculating the sync response for that user U. Grep for `RemoveHero`.
//
// It is primarily used in two places:
//
// - in the caches.GlobalCache, to hold the latest version of data that is consistent
// between all users in the room; and
// - in the sync3.RoomConnMetadata struct, to hold the version of data last seen by
@ -25,6 +30,7 @@ type RoomMetadata struct {
RoomID string
Heroes []Hero
NameEvent string // the content of m.room.name, NOT the calculated name
AvatarEvent string // the content of m.room.avatar, NOT the resolved avatar
CanonicalAlias string
JoinCount int
InviteCount int
@ -54,6 +60,32 @@ func NewRoomMetadata(roomID string) *RoomMetadata {
}
}
// CopyHeroes returns a version of the current RoomMetadata whose Heroes field is
// a brand-new copy of the original Heroes. The return value's Heroes field can be
// safely modified by the caller, but it is NOT safe for the caller to modify any other
// fields.
func (m *RoomMetadata) CopyHeroes() *RoomMetadata {
newMetadata := *m
// XXX: We're doing this because we end up calling RemoveHero() to omit the
// currently-sycning user in various places. But this seems smelly. The set of
// heroes in the room is a global, room-scoped fact: it is a property of the room
// state and nothing else, and all users see the same set of heroes.
//
// I think the data model would be cleaner if we made the hero-reading functions
// aware of the currently syncing user, in order to ignore them without having to
// change the underlying data.
//
// copy the heroes or else we may modify the same slice which would be bad :(
newMetadata.Heroes = make([]Hero, len(m.Heroes))
copy(newMetadata.Heroes, m.Heroes)
// ⚠️ NB: there are other pointer fields (e.g. PredecessorRoomID *string) or
// and pointer-backed fields (e.g. LatestEventsByType map[string]EventMetadata)
// which are not deepcopied here.
return &newMetadata
}
// SameRoomName checks if the fields relevant for room names have changed between the two metadatas.
// Returns true if there are no changes.
func (m *RoomMetadata) SameRoomName(other *RoomMetadata) bool {
@ -62,7 +94,13 @@ func (m *RoomMetadata) SameRoomName(other *RoomMetadata) bool {
m.CanonicalAlias == other.CanonicalAlias &&
m.JoinCount == other.JoinCount &&
m.InviteCount == other.InviteCount &&
sameHeroes(m.Heroes, other.Heroes))
sameHeroNames(m.Heroes, other.Heroes))
}
// SameRoomAvatar checks if the fields relevant for room avatars have changed between the two metadatas.
// Returns true if there are no changes.
func (m *RoomMetadata) SameRoomAvatar(other *RoomMetadata) bool {
return m.AvatarEvent == other.AvatarEvent && sameHeroAvatars(m.Heroes, other.Heroes)
}
func (m *RoomMetadata) SameJoinCount(other *RoomMetadata) bool {
@ -73,7 +111,7 @@ func (m *RoomMetadata) SameInviteCount(other *RoomMetadata) bool {
return m.InviteCount == other.InviteCount
}
func sameHeroes(a, b []Hero) bool {
func sameHeroNames(a, b []Hero) bool {
if len(a) != len(b) {
return false
}
@ -88,6 +126,21 @@ func sameHeroes(a, b []Hero) bool {
return true
}
func sameHeroAvatars(a, b []Hero) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if a[i].ID != b[i].ID {
return false
}
if a[i].Avatar != b[i].Avatar {
return false
}
}
return true
}
func (m *RoomMetadata) RemoveHero(userID string) {
for i, h := range m.Heroes {
if h.ID == userID {
@ -102,8 +155,9 @@ func (m *RoomMetadata) IsSpace() bool {
}
type Hero struct {
ID string
Name string
ID string
Name string
Avatar string
}
func CalculateRoomName(heroInfo *RoomMetadata, maxNumNamesPerRoom int) string {
@ -190,3 +244,18 @@ func disambiguate(heroes []Hero) []string {
}
return disambiguatedNames
}
const noAvatar = ""
// CalculateAvatar computes the avatar for the room, based on the global room metadata.
// Assumption: metadata.RemoveHero has been called to remove the user who is syncing
// from the list of heroes.
func CalculateAvatar(metadata *RoomMetadata) string {
if metadata.AvatarEvent != "" {
return metadata.AvatarEvent
}
if len(metadata.Heroes) == 1 {
return metadata.Heroes[0].Avatar
}
return noAvatar
}

View File

@ -247,3 +247,45 @@ func TestCalculateRoomName(t *testing.T) {
}
}
}
func TestCopyHeroes(t *testing.T) {
const alice = "@alice:test"
const bob = "@bob:test"
const chris = "@chris:test"
m1 := RoomMetadata{Heroes: []Hero{
{ID: alice},
{ID: bob},
{ID: chris},
}}
m2 := m1.CopyHeroes()
// Uncomment this to see why CopyHeroes is necessary!
//m2 := m1
t.Logf("Compare heroes:\n\tm1=%v\n\tm2=%v", m1.Heroes, m2.Heroes)
t.Log("Remove chris from m1")
m1.RemoveHero(chris)
t.Logf("Compare heroes:\n\tm1=%v\n\tm2=%v", m1.Heroes, m2.Heroes)
assertSliceIDs(t, "m1.Heroes", m1.Heroes, []string{alice, bob})
assertSliceIDs(t, "m2.Heroes", m2.Heroes, []string{alice, bob, chris})
t.Log("Remove alice from m1")
m1.RemoveHero(alice)
t.Logf("Compare heroes:\n\tm1=%v\n\tm2=%v", m1.Heroes, m2.Heroes)
assertSliceIDs(t, "m1.Heroes", m1.Heroes, []string{bob})
assertSliceIDs(t, "m2.Heroes", m2.Heroes, []string{alice, bob, chris})
}
func assertSliceIDs(t *testing.T, desc string, h []Hero, ids []string) {
if len(h) != len(ids) {
t.Errorf("%s has length %d, expected %d", desc, len(h), len(ids))
}
for index, id := range ids {
if h[index].ID != id {
t.Errorf("%s[%d] ID is %s, expected %s", desc, index, h[index].ID, id)
}
}
}

View File

@ -41,12 +41,15 @@ type V2Accumulate struct {
func (*V2Accumulate) Type() string { return "V2Accumulate" }
// V2TransactionID is emitted by a poller when it sees an event with a transaction ID.
// V2TransactionID is emitted by a poller when it sees an event with a transaction ID,
// or when it is certain that no other poller will see a transaction ID for this event
// (the "all-clear").
type V2TransactionID struct {
EventID string
UserID string
RoomID string
UserID string // of the sender
DeviceID string
TransactionID string
TransactionID string // Note: an empty transaction ID represents the all-clear.
NID int64
}
@ -70,8 +73,9 @@ type V2AccountData struct {
func (*V2AccountData) Type() string { return "V2AccountData" }
type V2LeaveRoom struct {
UserID string
RoomID string
UserID string
RoomID string
LeaveEvent json.RawMessage
}
func (*V2LeaveRoom) Type() string { return "V2LeaveRoom" }
@ -91,9 +95,7 @@ type V2InitialSyncComplete struct {
func (*V2InitialSyncComplete) Type() string { return "V2InitialSyncComplete" }
type V2DeviceData struct {
UserID string
DeviceID string
Pos int64
UserIDToDeviceIDs map[string][]string
}
func (*V2DeviceData) Type() string { return "V2DeviceData" }

View File

@ -207,8 +207,9 @@ func (a *Accumulator) Initialise(roomID string, state []json.RawMessage) (Initia
IsState: true,
}
}
if err := ensureFieldsSet(events); err != nil {
return fmt.Errorf("events malformed: %s", err)
events = filterAndEnsureFieldsSet(events)
if len(events) == 0 {
return fmt.Errorf("failed to insert events, all events were filtered out: %w", err)
}
eventIDToNID, err := a.eventsTable.Insert(txn, events, false)
if err != nil {
@ -292,35 +293,51 @@ func (a *Accumulator) Initialise(roomID string, state []json.RawMessage) (Initia
// to exist in the database, and the sync stream is already linearised for us.
// - Else it creates a new room state snapshot if the timeline contains state events (as this now represents the current state)
// - It adds entries to the membership log for membership events.
func (a *Accumulator) Accumulate(txn *sqlx.Tx, roomID string, prevBatch string, timeline []json.RawMessage) (numNew int, timelineNIDs []int64, err error) {
// Insert the events. Check for duplicates which can happen in the real world when joining
// Matrix HQ on Synapse.
dedupedEvents := make([]Event, 0, len(timeline))
seenEvents := make(map[string]struct{})
for i := range timeline {
e := Event{
JSON: timeline[i],
RoomID: roomID,
}
if err := e.ensureFieldsSetOnEvent(); err != nil {
return 0, nil, fmt.Errorf("event malformed: %s", err)
}
if _, ok := seenEvents[e.ID]; ok {
logger.Warn().Str("event_id", e.ID).Str("room_id", roomID).Msg(
"Accumulator.Accumulate: seen the same event ID twice, ignoring",
)
continue
}
if i == 0 && prevBatch != "" {
// tag the first timeline event with the prev batch token
e.PrevBatch = sql.NullString{
String: prevBatch,
Valid: true,
}
}
dedupedEvents = append(dedupedEvents, e)
seenEvents[e.ID] = struct{}{}
func (a *Accumulator) Accumulate(txn *sqlx.Tx, userID, roomID string, prevBatch string, timeline []json.RawMessage) (numNew int, timelineNIDs []int64, err error) {
// The first stage of accumulating events is mostly around validation around what the upstream HS sends us. For accumulation to work correctly
// we expect:
// - there to be no duplicate events
// - if there are new events, they are always new.
// Both of these assumptions can be false for different reasons
dedupedEvents, err := a.filterAndParseTimelineEvents(txn, roomID, timeline, prevBatch)
if err != nil {
err = fmt.Errorf("filterTimelineEvents: %w", err)
return
}
if len(dedupedEvents) == 0 {
return 0, nil, err // nothing to do
}
// Given a timeline of [E1, E2, S3, E4, S5, S6, E7] (E=message event, S=state event)
// And a prior state snapshot of SNAP0 then the BEFORE snapshot IDs are grouped as:
// E1,E2,S3 => SNAP0
// E4, S5 => (SNAP0 + S3)
// S6 => (SNAP0 + S3 + S5)
// E7 => (SNAP0 + S3 + S5 + S6)
// We can track this by loading the current snapshot ID (after snapshot) then rolling forward
// the timeline until we hit a state event, at which point we make a new snapshot but critically
// do NOT assign the new state event in the snapshot so as to represent the state before the event.
snapID, err := a.roomsTable.CurrentAfterSnapshotID(txn, roomID)
if err != nil {
return 0, nil, err
}
// if we have just got a leave event for the polling user, and there is no snapshot for this room already, then
// we do NOT want to add this event to the events table, nor do we want to make a room snapshot. This is because
// this leave event is an invite rejection, rather than a normal event. Invite rejections cannot be processed in
// a normal way because we lack room state (no create event, PLs, etc). If we were to process the invite rejection,
// the room state would just be a single event: this leave event, which is wrong.
if len(dedupedEvents) == 1 &&
dedupedEvents[0].Type == "m.room.member" &&
(dedupedEvents[0].Membership == "leave" || dedupedEvents[0].Membership == "_leave") &&
dedupedEvents[0].StateKey == userID &&
snapID == 0 {
logger.Info().Str("event_id", dedupedEvents[0].ID).Str("room_id", roomID).Str("user_id", userID).Err(err).Msg(
"Accumulator: skipping processing of leave event, as no snapshot exists",
)
return 0, nil, nil
}
eventIDToNID, err := a.eventsTable.Insert(txn, dedupedEvents, false)
if err != nil {
return 0, nil, err
@ -352,19 +369,6 @@ func (a *Accumulator) Accumulate(txn *sqlx.Tx, roomID string, prevBatch string,
}
}
// Given a timeline of [E1, E2, S3, E4, S5, S6, E7] (E=message event, S=state event)
// And a prior state snapshot of SNAP0 then the BEFORE snapshot IDs are grouped as:
// E1,E2,S3 => SNAP0
// E4, S5 => (SNAP0 + S3)
// S6 => (SNAP0 + S3 + S5)
// E7 => (SNAP0 + S3 + S5 + S6)
// We can track this by loading the current snapshot ID (after snapshot) then rolling forward
// the timeline until we hit a state event, at which point we make a new snapshot but critically
// do NOT assign the new state event in the snapshot so as to represent the state before the event.
snapID, err := a.roomsTable.CurrentAfterSnapshotID(txn, roomID)
if err != nil {
return 0, nil, err
}
for _, ev := range newEvents {
var replacesNID int64
// the snapshot ID we assign to this event is unaffected by whether /this/ event is state or not,
@ -413,24 +417,90 @@ func (a *Accumulator) Accumulate(txn *sqlx.Tx, roomID string, prevBatch string,
return numNew, timelineNIDs, nil
}
// Delta returns a list of events of at most `limit` for the room not including `lastEventNID`.
// Returns the latest NID of the last event (most recent)
func (a *Accumulator) Delta(roomID string, lastEventNID int64, limit int) (eventsJSON []json.RawMessage, latest int64, err error) {
txn, err := a.db.Beginx()
// filterAndParseTimelineEvents takes a raw timeline array from sync v2 and applies sanity to it:
// - removes duplicate events: this is just a bug which has been seen on Synapse on matrix.org
// - removes old events: this is an edge case when joining rooms over federation, see https://github.com/matrix-org/sliding-sync/issues/192
// - parses it and returns Event structs.
// - check which events are unknown. If all events are known, filter them all out.
func (a *Accumulator) filterAndParseTimelineEvents(txn *sqlx.Tx, roomID string, timeline []json.RawMessage, prevBatch string) ([]Event, error) {
// Check for duplicates which can happen in the real world when joining
// Matrix HQ on Synapse, as well as when you join rooms for the first time over federation.
dedupedEvents := make([]Event, 0, len(timeline))
seenEvents := make(map[string]struct{})
for i := range timeline {
e := Event{
JSON: timeline[i],
RoomID: roomID,
}
if err := e.ensureFieldsSetOnEvent(); err != nil {
logger.Warn().Str("event_id", e.ID).Str("room_id", roomID).Err(err).Msg(
"Accumulator.filterAndParseTimelineEvents: failed to parse event, ignoring",
)
continue
}
if _, ok := seenEvents[e.ID]; ok {
logger.Warn().Str("event_id", e.ID).Str("room_id", roomID).Msg(
"Accumulator.filterAndParseTimelineEvents: seen the same event ID twice, ignoring",
)
continue
}
if i == 0 && prevBatch != "" {
// tag the first timeline event with the prev batch token
e.PrevBatch = sql.NullString{
String: prevBatch,
Valid: true,
}
}
dedupedEvents = append(dedupedEvents, e)
seenEvents[e.ID] = struct{}{}
}
// if we only have a single timeline event we cannot determine if it is old or not, as we rely on already seen events
// being after (higher index) than it.
if len(dedupedEvents) <= 1 {
return dedupedEvents, nil
}
// Figure out which of these events are unseen and hence brand new live events.
// In some cases, we may have unseen OLD events - see https://github.com/matrix-org/sliding-sync/issues/192
// in which case we need to drop those events.
dedupedEventIDs := make([]string, 0, len(seenEvents))
for evID := range seenEvents {
dedupedEventIDs = append(dedupedEventIDs, evID)
}
unknownEventIDs, err := a.eventsTable.SelectUnknownEventIDs(txn, dedupedEventIDs)
if err != nil {
return nil, 0, err
return nil, fmt.Errorf("filterAndParseTimelineEvents: failed to SelectUnknownEventIDs: %w", err)
}
defer txn.Commit()
events, err := a.eventsTable.SelectEventsBetween(txn, roomID, lastEventNID, EventsEnd, limit)
if err != nil {
return nil, 0, err
if len(unknownEventIDs) == 0 {
// every event has been seen already, no work to do
return nil, nil
}
if len(events) == 0 {
return nil, lastEventNID, nil
// In the happy case, we expect to see timeline arrays like this: (SEEN=S, UNSEEN=U)
// [S,S,U,U] -> want last 2
// [U,U,U] -> want all
// In the backfill edge case, we might see:
// [U,S,S,S] -> want none
// [U,S,S,U] -> want last 1
// We should never see scenarios like:
// [U,S,S,U,S,S] <- we should only see 1 contiguous block of seen events.
// If we do, we'll just ignore all unseen events less than the highest seen event.
// The algorithm starts at the end and just looks for the first S event, returning the subslice after that S event (which may be [])
seenIndex := -1
for i := len(dedupedEvents) - 1; i >= 0; i-- {
_, unseen := unknownEventIDs[dedupedEvents[i].ID]
if !unseen {
seenIndex = i
break
}
}
eventsJSON = make([]json.RawMessage, len(events))
for i := range events {
eventsJSON[i] = events[i].JSON
}
return eventsJSON, int64(events[len(events)-1].NID), nil
// seenIndex can be -1 if all are unseen, or len-1 if all are seen, either way if we +1 this slices correctly:
// no seen events s[A,B,C] => s[-1+1:] => [A,B,C]
// C is seen event s[A,B,C] => s[2+1:] => []
// B is seen event s[A,B,C] => s[1+1:] => [C]
// A is seen event s[A,B,C] => s[0+1:] => [B,C]
return dedupedEvents[seenIndex+1:], nil
}

View File

@ -11,10 +11,13 @@ import (
"github.com/jmoiron/sqlx"
"github.com/matrix-org/sliding-sync/sqlutil"
"github.com/matrix-org/sliding-sync/sync2"
"github.com/matrix-org/sliding-sync/testutils"
"github.com/tidwall/gjson"
)
var (
userID = "@me:localhost"
)
func TestAccumulatorInitialise(t *testing.T) {
roomID := "!TestAccumulatorInitialise:localhost"
roomEvents := []json.RawMessage{
@ -119,7 +122,7 @@ func TestAccumulatorAccumulate(t *testing.T) {
var numNew int
var latestNIDs []int64
err = sqlutil.WithTransaction(accumulator.db, func(txn *sqlx.Tx) error {
numNew, latestNIDs, err = accumulator.Accumulate(txn, roomID, "", newEvents)
numNew, latestNIDs, err = accumulator.Accumulate(txn, userID, roomID, "", newEvents)
return err
})
if err != nil {
@ -193,7 +196,7 @@ func TestAccumulatorAccumulate(t *testing.T) {
// subsequent calls do nothing and are not an error
err = sqlutil.WithTransaction(accumulator.db, func(txn *sqlx.Tx) error {
_, _, err = accumulator.Accumulate(txn, roomID, "", newEvents)
_, _, err = accumulator.Accumulate(txn, userID, roomID, "", newEvents)
return err
})
if err != nil {
@ -201,59 +204,6 @@ func TestAccumulatorAccumulate(t *testing.T) {
}
}
func TestAccumulatorDelta(t *testing.T) {
roomID := "!TestAccumulatorDelta:localhost"
db, close := connectToDB(t)
defer close()
accumulator := NewAccumulator(db)
_, err := accumulator.Initialise(roomID, nil)
if err != nil {
t.Fatalf("failed to Initialise accumulator: %s", err)
}
roomEvents := []json.RawMessage{
[]byte(`{"event_id":"aD", "type":"m.room.create", "state_key":"", "content":{"creator":"@TestAccumulatorDelta:localhost"}}`),
[]byte(`{"event_id":"aE", "type":"m.room.member", "state_key":"@TestAccumulatorDelta:localhost", "content":{"membership":"join"}}`),
[]byte(`{"event_id":"aF", "type":"m.room.join_rules", "state_key":"", "content":{"join_rule":"public"}}`),
[]byte(`{"event_id":"aG", "type":"m.room.message","content":{"body":"Hello World","msgtype":"m.text"}}`),
[]byte(`{"event_id":"aH", "type":"m.room.join_rules", "state_key":"", "content":{"join_rule":"public"}}`),
[]byte(`{"event_id":"aI", "type":"m.room.history_visibility", "state_key":"", "content":{"visibility":"public"}}`),
}
err = sqlutil.WithTransaction(accumulator.db, func(txn *sqlx.Tx) error {
_, _, err = accumulator.Accumulate(txn, roomID, "", roomEvents)
return err
})
if err != nil {
t.Fatalf("failed to Accumulate: %s", err)
}
// Draw the create event, tests limits
events, position, err := accumulator.Delta(roomID, EventsStart, 1)
if err != nil {
t.Fatalf("failed to Delta: %s", err)
}
if len(events) != 1 {
t.Fatalf("failed to get events from Delta, got %d want 1", len(events))
}
if gjson.GetBytes(events[0], "event_id").Str != gjson.GetBytes(roomEvents[0], "event_id").Str {
t.Fatalf("failed to draw first event, got %s want %s", string(events[0]), string(roomEvents[0]))
}
if position == 0 {
t.Errorf("Delta returned zero position")
}
// Draw up to the end
events, position, err = accumulator.Delta(roomID, position, 1000)
if err != nil {
t.Fatalf("failed to Delta: %s", err)
}
if len(events) != len(roomEvents)-1 {
t.Fatalf("failed to get events from Delta, got %d want %d", len(events), len(roomEvents)-1)
}
if position == 0 {
t.Errorf("Delta returned zero position")
}
}
func TestAccumulatorMembershipLogs(t *testing.T) {
roomID := "!TestAccumulatorMembershipLogs:localhost"
db, close := connectToDB(t)
@ -282,7 +232,7 @@ func TestAccumulatorMembershipLogs(t *testing.T) {
[]byte(`{"event_id":"` + roomEventIDs[7] + `", "type":"m.room.member", "state_key":"@me:localhost","unsigned":{"prev_content":{"membership":"join", "displayname":"Me"}}, "content":{"membership":"leave"}}`),
}
err = sqlutil.WithTransaction(accumulator.db, func(txn *sqlx.Tx) error {
_, _, err = accumulator.Accumulate(txn, roomID, "", roomEvents)
_, _, err = accumulator.Accumulate(txn, userID, roomID, "", roomEvents)
return err
})
if err != nil {
@ -409,7 +359,7 @@ func TestAccumulatorDupeEvents(t *testing.T) {
}
err = sqlutil.WithTransaction(accumulator.db, func(txn *sqlx.Tx) error {
_, _, err = accumulator.Accumulate(txn, roomID, "", joinRoom.Timeline.Events)
_, _, err = accumulator.Accumulate(txn, userID, roomID, "", joinRoom.Timeline.Events)
return err
})
if err != nil {
@ -417,86 +367,6 @@ func TestAccumulatorDupeEvents(t *testing.T) {
}
}
// Regression test for corrupt state snapshots.
// This seems to have happened in the wild, whereby the snapshot exhibited 2 things:
// - A message event having a event_replaces_nid. This should be impossible as messages are not state.
// - Duplicate events in the state snapshot.
//
// We can reproduce a message event having a event_replaces_nid by doing the following:
// - Create a room with initial state A,C
// - Accumulate events D, A, B(msg). This should be impossible because we already got A initially but whatever, roll with it, blame state resets or something.
// - This leads to A,B being processed and D ignored if you just take the newest results.
//
// This can then be tested by:
// - Query the current room snapshot. This will include B(msg) when it shouldn't.
func TestAccumulatorMisorderedGraceful(t *testing.T) {
alice := "@alice:localhost"
bob := "@bob:localhost"
eventA := testutils.NewStateEvent(t, "m.room.member", alice, alice, map[string]interface{}{"membership": "join"})
eventC := testutils.NewStateEvent(t, "m.room.create", "", alice, map[string]interface{}{})
eventD := testutils.NewStateEvent(
t, "m.room.member", bob, "join", map[string]interface{}{"membership": "join"},
)
eventBMsg := testutils.NewEvent(
t, "m.room.message", bob, map[string]interface{}{"body": "hello"},
)
t.Logf("A=member-alice, B=msg, C=create, D=member-bob")
db, close := connectToDB(t)
defer close()
accumulator := NewAccumulator(db)
roomID := "!TestAccumulatorStateReset:localhost"
// Create a room with initial state A,C
_, err := accumulator.Initialise(roomID, []json.RawMessage{
eventA, eventC,
})
if err != nil {
t.Fatalf("failed to Initialise accumulator: %s", err)
}
// Accumulate events D, A, B(msg).
err = sqlutil.WithTransaction(accumulator.db, func(txn *sqlx.Tx) error {
_, _, err = accumulator.Accumulate(txn, roomID, "", []json.RawMessage{eventD, eventA, eventBMsg})
return err
})
if err != nil {
t.Fatalf("failed to Accumulate: %s", err)
}
eventIDs := []string{
gjson.GetBytes(eventA, "event_id").Str,
gjson.GetBytes(eventBMsg, "event_id").Str,
gjson.GetBytes(eventC, "event_id").Str,
gjson.GetBytes(eventD, "event_id").Str,
}
t.Logf("Events A,B,C,D: %v", eventIDs)
txn := accumulator.db.MustBeginTx(context.Background(), nil)
idsToNIDs, err := accumulator.eventsTable.SelectNIDsByIDs(txn, eventIDs)
if err != nil {
t.Fatalf("Failed to SelectNIDsByIDs: %s", err)
}
if len(idsToNIDs) != len(eventIDs) {
t.Errorf("SelectNIDsByIDs: asked for %v got %v", eventIDs, idsToNIDs)
}
t.Logf("Events: %v", idsToNIDs)
wantEventNIDs := []int64{
idsToNIDs[eventIDs[0]], idsToNIDs[eventIDs[2]], idsToNIDs[eventIDs[3]],
}
sort.Slice(wantEventNIDs, func(i, j int) bool {
return wantEventNIDs[i] < wantEventNIDs[j]
})
// Query the current room snapshot
gotSnapshotEvents := currentSnapshotNIDs(t, accumulator.snapshotTable, roomID)
if len(gotSnapshotEvents) != len(wantEventNIDs) { // events A,C,D
t.Errorf("corrupt snapshot, got %v want %v", gotSnapshotEvents, wantEventNIDs)
}
if !reflect.DeepEqual(wantEventNIDs, gotSnapshotEvents) {
t.Errorf("got %v want %v", gotSnapshotEvents, wantEventNIDs)
}
}
// Regression test for corrupt state snapshots.
// This seems to have happened in the wild, whereby the snapshot exhibited 2 things:
// - A message event having a event_replaces_nid. This should be impossible as messages are not state.
@ -689,7 +559,7 @@ func TestAccumulatorConcurrency(t *testing.T) {
defer wg.Done()
subset := newEvents[:(i + 1)] // i=0 => [1], i=1 => [1,2], etc
err := sqlutil.WithTransaction(accumulator.db, func(txn *sqlx.Tx) error {
numNew, _, err := accumulator.Accumulate(txn, roomID, "", subset)
numNew, _, err := accumulator.Accumulate(txn, userID, roomID, "", subset)
totalNumNew += numNew
return err
})

View File

@ -1,10 +1,11 @@
package state
import (
"bytes"
"database/sql"
"encoding/json"
"reflect"
"github.com/fxamacker/cbor/v2"
"github.com/getsentry/sentry-go"
"github.com/jmoiron/sqlx"
"github.com/matrix-org/sliding-sync/internal"
"github.com/matrix-org/sliding-sync/sqlutil"
@ -25,14 +26,15 @@ type DeviceDataTable struct {
func NewDeviceDataTable(db *sqlx.DB) *DeviceDataTable {
db.MustExec(`
CREATE SEQUENCE IF NOT EXISTS syncv3_device_data_seq;
CREATE TABLE IF NOT EXISTS syncv3_device_data (
id BIGINT PRIMARY KEY NOT NULL DEFAULT nextval('syncv3_device_data_seq'),
user_id TEXT NOT NULL,
device_id TEXT NOT NULL,
data BYTEA NOT NULL,
UNIQUE(user_id, device_id)
);
-- Set the fillfactor to 90%, to allow for HOT updates (e.g. we only
-- change the data, not anything indexed like the id)
ALTER TABLE syncv3_device_data SET (fillfactor = 90);
`)
return &DeviceDataTable{
db: db,
@ -44,7 +46,7 @@ func NewDeviceDataTable(db *sqlx.DB) *DeviceDataTable {
func (t *DeviceDataTable) Select(userID, deviceID string, swap bool) (result *internal.DeviceData, err error) {
err = sqlutil.WithTransaction(t.db, func(txn *sqlx.Tx) error {
var row DeviceDataRow
err = t.db.Get(&row, `SELECT data FROM syncv3_device_data WHERE user_id=$1 AND device_id=$2`, userID, deviceID)
err = txn.Get(&row, `SELECT data FROM syncv3_device_data WHERE user_id=$1 AND device_id=$2`, userID, deviceID)
if err != nil {
if err == sql.ErrNoRows {
// if there is no device data for this user, it's not an error.
@ -53,7 +55,7 @@ func (t *DeviceDataTable) Select(userID, deviceID string, swap bool) (result *in
return err
}
// unmarshal to swap
if err = json.Unmarshal(row.Data, &result); err != nil {
if err = cbor.Unmarshal(row.Data, &result); err != nil {
return err
}
result.UserID = userID
@ -67,18 +69,19 @@ func (t *DeviceDataTable) Select(userID, deviceID string, swap bool) (result *in
writeBack.DeviceLists.New = make(map[string]int)
writeBack.ChangedBits = 0
// re-marshal and write
data, err := json.Marshal(writeBack)
if err != nil {
return err
}
if bytes.Equal(data, row.Data) {
if reflect.DeepEqual(result, &writeBack) {
// The update to the DB would be a no-op; don't bother with it.
// This helps reduce write usage and the contention on the unique index for
// the device_data table.
return nil
}
_, err = t.db.Exec(`UPDATE syncv3_device_data SET data=$1 WHERE user_id=$2 AND device_id=$3`, data, userID, deviceID)
// re-marshal and write
data, err := cbor.Marshal(writeBack)
if err != nil {
return err
}
_, err = txn.Exec(`UPDATE syncv3_device_data SET data=$1 WHERE user_id=$2 AND device_id=$3`, data, userID, deviceID)
return err
})
return
@ -90,18 +93,18 @@ func (t *DeviceDataTable) DeleteDevice(userID, deviceID string) error {
}
// Upsert combines what is in the database for this user|device with the partial entry `dd`
func (t *DeviceDataTable) Upsert(dd *internal.DeviceData) (pos int64, err error) {
func (t *DeviceDataTable) Upsert(dd *internal.DeviceData) (err error) {
err = sqlutil.WithTransaction(t.db, func(txn *sqlx.Tx) error {
// select what already exists
var row DeviceDataRow
err = t.db.Get(&row, `SELECT data FROM syncv3_device_data WHERE user_id=$1 AND device_id=$2`, dd.UserID, dd.DeviceID)
err = txn.Get(&row, `SELECT data FROM syncv3_device_data WHERE user_id=$1 AND device_id=$2`, dd.UserID, dd.DeviceID)
if err != nil && err != sql.ErrNoRows {
return err
}
// unmarshal and combine
var tempDD internal.DeviceData
if len(row.Data) > 0 {
if err = json.Unmarshal(row.Data, &tempDD); err != nil {
if err = cbor.Unmarshal(row.Data, &tempDD); err != nil {
return err
}
}
@ -115,16 +118,19 @@ func (t *DeviceDataTable) Upsert(dd *internal.DeviceData) (pos int64, err error)
}
tempDD.DeviceLists = tempDD.DeviceLists.Combine(dd.DeviceLists)
data, err := json.Marshal(tempDD)
data, err := cbor.Marshal(tempDD)
if err != nil {
return err
}
err = t.db.QueryRow(
_, err = txn.Exec(
`INSERT INTO syncv3_device_data(user_id, device_id, data) VALUES($1,$2,$3)
ON CONFLICT (user_id, device_id) DO UPDATE SET data=$3, id=nextval('syncv3_device_data_seq') RETURNING id`,
ON CONFLICT (user_id, device_id) DO UPDATE SET data=$3`,
dd.UserID, dd.DeviceID, data,
).Scan(&pos)
)
return err
})
if err != nil && err != sql.ErrNoRows {
sentry.CaptureException(err)
}
return
}

View File

@ -10,7 +10,7 @@ import (
func assertVal(t *testing.T, msg string, got, want interface{}) {
t.Helper()
if !reflect.DeepEqual(got, want) {
t.Errorf("%s: got %v want %v", msg, got, want)
t.Errorf("%s: got\n%#v want\n%#v", msg, got, want)
}
}
@ -21,6 +21,7 @@ func assertDeviceData(t *testing.T, g, w internal.DeviceData) {
assertVal(t, "FallbackKeyTypes", g.FallbackKeyTypes, w.FallbackKeyTypes)
assertVal(t, "OTKCounts", g.OTKCounts, w.OTKCounts)
assertVal(t, "ChangedBits", g.ChangedBits, w.ChangedBits)
assertVal(t, "DeviceLists", g.DeviceLists, w.DeviceLists)
}
func TestDeviceDataTableSwaps(t *testing.T) {
@ -59,12 +60,12 @@ func TestDeviceDataTableSwaps(t *testing.T) {
UserID: userID,
DeviceID: deviceID,
DeviceLists: internal.DeviceLists{
New: internal.ToDeviceListChangesMap([]string{"bob"}, nil),
New: internal.ToDeviceListChangesMap([]string{"💣"}, nil),
},
},
}
for _, dd := range deltas {
_, err := table.Upsert(&dd)
err := table.Upsert(&dd)
assertNoError(t, err)
}
@ -76,7 +77,8 @@ func TestDeviceDataTableSwaps(t *testing.T) {
},
FallbackKeyTypes: []string{"foobar"},
DeviceLists: internal.DeviceLists{
New: internal.ToDeviceListChangesMap([]string{"alice", "bob"}, nil),
New: internal.ToDeviceListChangesMap([]string{"alice", "💣"}, nil),
Sent: map[string]int{},
},
}
want.SetFallbackKeysChanged()
@ -87,38 +89,43 @@ func TestDeviceDataTableSwaps(t *testing.T) {
assertNoError(t, err)
assertDeviceData(t, *got, want)
}
// now swap-er-roo
// now swap-er-roo, at this point we still expect the "old" data,
// as it is the first time we swap
got, err := table.Select(userID, deviceID, true)
assertNoError(t, err)
want2 := want
want2.DeviceLists = internal.DeviceLists{
Sent: internal.ToDeviceListChangesMap([]string{"alice"}, nil),
New: nil,
}
assertDeviceData(t, *got, want2)
assertDeviceData(t, *got, want)
// changed bits were reset when we swapped
want2 := want
want2.DeviceLists = internal.DeviceLists{
Sent: internal.ToDeviceListChangesMap([]string{"alice", "💣"}, nil),
New: map[string]int{},
}
want2.ChangedBits = 0
want.ChangedBits = 0
// this is permanent, read-only views show this too
got, err = table.Select(userID, deviceID, false)
// this is permanent, read-only views show this too.
// Since we have swapped previously, we now expect New to be empty
// and Sent to be set. Swap again to clear Sent.
got, err = table.Select(userID, deviceID, true)
assertNoError(t, err)
assertDeviceData(t, *got, want2)
// another swap causes sent to be cleared out
got, err = table.Select(userID, deviceID, true)
// We now expect empty DeviceLists, as we swapped twice.
got, err = table.Select(userID, deviceID, false)
assertNoError(t, err)
want3 := want2
want3.DeviceLists = internal.DeviceLists{
Sent: nil,
New: nil,
Sent: map[string]int{},
New: map[string]int{},
}
assertDeviceData(t, *got, want3)
// get back the original state
//err = table.DeleteDevice(userID, deviceID)
assertNoError(t, err)
for _, dd := range deltas {
_, err = table.Upsert(&dd)
err = table.Upsert(&dd)
assertNoError(t, err)
}
want.SetFallbackKeysChanged()
@ -128,13 +135,14 @@ func TestDeviceDataTableSwaps(t *testing.T) {
assertDeviceData(t, *got, want)
// swap once then add once so both sent and new are populated
// Moves Alice and Bob to Sent
_, err = table.Select(userID, deviceID, true)
assertNoError(t, err)
_, err = table.Upsert(&internal.DeviceData{
err = table.Upsert(&internal.DeviceData{
UserID: userID,
DeviceID: deviceID,
DeviceLists: internal.DeviceLists{
New: internal.ToDeviceListChangesMap([]string{"bob"}, []string{"charlie"}),
New: internal.ToDeviceListChangesMap([]string{"💣"}, []string{"charlie"}),
},
})
assertNoError(t, err)
@ -143,15 +151,17 @@ func TestDeviceDataTableSwaps(t *testing.T) {
want4 := want
want4.DeviceLists = internal.DeviceLists{
Sent: internal.ToDeviceListChangesMap([]string{"alice"}, nil),
New: internal.ToDeviceListChangesMap([]string{"bob"}, []string{"charlie"}),
Sent: internal.ToDeviceListChangesMap([]string{"alice", "💣"}, nil),
New: internal.ToDeviceListChangesMap([]string{"💣"}, []string{"charlie"}),
}
// Without swapping, we expect Alice and Bob in Sent, and Bob and Charlie in New
got, err = table.Select(userID, deviceID, false)
assertNoError(t, err)
assertDeviceData(t, *got, want4)
// another append then consume
_, err = table.Upsert(&internal.DeviceData{
// This results in dave to be added to New
err = table.Upsert(&internal.DeviceData{
UserID: userID,
DeviceID: deviceID,
DeviceLists: internal.DeviceLists{
@ -163,8 +173,18 @@ func TestDeviceDataTableSwaps(t *testing.T) {
assertNoError(t, err)
want5 := want4
want5.DeviceLists = internal.DeviceLists{
Sent: internal.ToDeviceListChangesMap([]string{"bob", "dave"}, []string{"charlie", "dave"}),
New: nil,
Sent: internal.ToDeviceListChangesMap([]string{"alice", "💣"}, nil),
New: internal.ToDeviceListChangesMap([]string{"💣"}, []string{"charlie", "dave"}),
}
assertDeviceData(t, *got, want5)
// Swapping again clears New
got, err = table.Select(userID, deviceID, true)
assertNoError(t, err)
want5 = want4
want5.DeviceLists = internal.DeviceLists{
Sent: internal.ToDeviceListChangesMap([]string{"💣"}, []string{"charlie", "dave"}),
New: map[string]int{},
}
assertDeviceData(t, *got, want5)
@ -190,11 +210,13 @@ func TestDeviceDataTableBitset(t *testing.T) {
"foo": 100,
"bar": 92,
},
DeviceLists: internal.DeviceLists{New: map[string]int{}, Sent: map[string]int{}},
}
fallbakKeyUpdate := internal.DeviceData{
UserID: userID,
DeviceID: deviceID,
FallbackKeyTypes: []string{"foo", "bar"},
DeviceLists: internal.DeviceLists{New: map[string]int{}, Sent: map[string]int{}},
}
bothUpdate := internal.DeviceData{
UserID: userID,
@ -203,9 +225,10 @@ func TestDeviceDataTableBitset(t *testing.T) {
OTKCounts: map[string]int{
"both": 100,
},
DeviceLists: internal.DeviceLists{New: map[string]int{}, Sent: map[string]int{}},
}
_, err := table.Upsert(&otkUpdate)
err := table.Upsert(&otkUpdate)
assertNoError(t, err)
got, err := table.Select(userID, deviceID, true)
assertNoError(t, err)
@ -217,7 +240,7 @@ func TestDeviceDataTableBitset(t *testing.T) {
otkUpdate.ChangedBits = 0
assertDeviceData(t, *got, otkUpdate)
// now same for fallback keys, but we won't swap them so it should return those diffs
_, err = table.Upsert(&fallbakKeyUpdate)
err = table.Upsert(&fallbakKeyUpdate)
assertNoError(t, err)
fallbakKeyUpdate.OTKCounts = otkUpdate.OTKCounts
got, err = table.Select(userID, deviceID, false)
@ -229,7 +252,7 @@ func TestDeviceDataTableBitset(t *testing.T) {
fallbakKeyUpdate.SetFallbackKeysChanged()
assertDeviceData(t, *got, fallbakKeyUpdate)
// updating both works
_, err = table.Upsert(&bothUpdate)
err = table.Upsert(&bothUpdate)
assertNoError(t, err)
got, err = table.Select(userID, deviceID, true)
assertNoError(t, err)

View File

@ -57,7 +57,7 @@ func (ev *Event) ensureFieldsSetOnEvent() error {
}
if ev.Type == "" {
typeResult := evJSON.Get("type")
if !typeResult.Exists() || typeResult.Str == "" {
if !typeResult.Exists() || typeResult.Type != gjson.String { // empty strings for 'type' are valid apparently
return fmt.Errorf("event JSON missing type key")
}
ev.Type = typeResult.Str
@ -153,7 +153,7 @@ func (t *EventTable) SelectHighestNID() (highest int64, err error) {
// we insert new events A and B in that order, then NID(A) < NID(B).
func (t *EventTable) Insert(txn *sqlx.Tx, events []Event, checkFields bool) (map[string]int64, error) {
if checkFields {
ensureFieldsSet(events)
events = filterAndEnsureFieldsSet(events)
}
result := make(map[string]int64)
for i := range events {
@ -317,10 +317,10 @@ func (t *EventTable) LatestEventInRooms(txn *sqlx.Tx, roomIDs []string, highestN
return
}
func (t *EventTable) LatestEventNIDInRooms(roomIDs []string, highestNID int64) (roomToNID map[string]int64, err error) {
func (t *EventTable) LatestEventNIDInRooms(txn *sqlx.Tx, roomIDs []string, highestNID int64) (roomToNID map[string]int64, err error) {
// the position (event nid) may be for a random different room, so we need to find the highest nid <= this position for this room
var events []Event
err = t.db.Select(
err = txn.Select(
&events,
`SELECT event_nid, room_id FROM syncv3_events
WHERE event_nid IN (SELECT max(event_nid) FROM syncv3_events WHERE event_nid <= $1 AND room_id = ANY($2) GROUP BY room_id)`,
@ -336,14 +336,6 @@ func (t *EventTable) LatestEventNIDInRooms(roomIDs []string, highestNID int64) (
return
}
func (t *EventTable) SelectEventsBetween(txn *sqlx.Tx, roomID string, lowerExclusive, upperInclusive int64, limit int) ([]Event, error) {
var events []Event
err := txn.Select(&events, `SELECT event_nid, event FROM syncv3_events WHERE event_nid > $1 AND event_nid <= $2 AND room_id = $3 ORDER BY event_nid ASC LIMIT $4`,
lowerExclusive, upperInclusive, roomID, limit,
)
return events, err
}
func (t *EventTable) SelectLatestEventsBetween(txn *sqlx.Tx, roomID string, lowerExclusive, upperInclusive int64, limit int) ([]Event, error) {
var events []Event
// do not pull in events which were in the v2 state block
@ -438,8 +430,8 @@ func (t *EventTable) SelectClosestPrevBatchByID(roomID string, eventID string) (
// Select the closest prev batch token for the provided event NID. Returns the empty string if there
// is no closest.
func (t *EventTable) SelectClosestPrevBatch(roomID string, eventNID int64) (prevBatch string, err error) {
err = t.db.QueryRow(
func (t *EventTable) SelectClosestPrevBatch(txn *sqlx.Tx, roomID string, eventNID int64) (prevBatch string, err error) {
err = txn.QueryRow(
`SELECT prev_batch FROM syncv3_events WHERE prev_batch IS NOT NULL AND room_id=$1 AND event_nid >= $2 LIMIT 1`, roomID, eventNID,
).Scan(&prevBatch)
if err == sql.ErrNoRows {
@ -457,14 +449,18 @@ func (c EventChunker) Subslice(i, j int) sqlutil.Chunker {
return c[i:j]
}
func ensureFieldsSet(events []Event) error {
func filterAndEnsureFieldsSet(events []Event) []Event {
result := make([]Event, 0, len(events))
// ensure fields are set
for i := range events {
ev := events[i]
ev := &events[i]
if err := ev.ensureFieldsSetOnEvent(); err != nil {
return err
logger.Warn().Str("event_id", ev.ID).Err(err).Msg(
"filterAndEnsureFieldsSet: failed to parse event, ignoring",
)
continue
}
events[i] = ev
result = append(result, *ev)
}
return nil
return result
}

View File

@ -297,125 +297,6 @@ func TestEventTableDupeInsert(t *testing.T) {
}
}
func TestEventTableSelectEventsBetween(t *testing.T) {
db, close := connectToDB(t)
defer close()
txn, err := db.Beginx()
if err != nil {
t.Fatalf("failed to start txn: %s", err)
}
table := NewEventTable(db)
searchRoomID := "!0TestEventTableSelectEventsBetween:localhost"
eventIDs := []string{
"100TestEventTableSelectEventsBetween",
"101TestEventTableSelectEventsBetween",
"102TestEventTableSelectEventsBetween",
"103TestEventTableSelectEventsBetween",
"104TestEventTableSelectEventsBetween",
}
events := []Event{
{
JSON: []byte(`{"event_id":"` + eventIDs[0] + `","type": "T1", "state_key":"S1", "room_id":"` + searchRoomID + `"}`),
},
{
JSON: []byte(`{"event_id":"` + eventIDs[1] + `","type": "T2", "state_key":"S2", "room_id":"` + searchRoomID + `"}`),
},
{
JSON: []byte(`{"event_id":"` + eventIDs[2] + `","type": "T3", "state_key":"", "room_id":"` + searchRoomID + `"}`),
},
{
// different room
JSON: []byte(`{"event_id":"` + eventIDs[3] + `","type": "T4", "state_key":"", "room_id":"!1TestEventTableSelectEventsBetween:localhost"}`),
},
{
JSON: []byte(`{"event_id":"` + eventIDs[4] + `","type": "T5", "state_key":"", "room_id":"` + searchRoomID + `"}`),
},
}
idToNID, err := table.Insert(txn, events, true)
if err != nil {
t.Fatalf("Insert failed: %s", err)
}
if len(idToNID) != len(events) {
t.Fatalf("failed to insert events: got %d want %d", len(idToNID), len(events))
}
txn.Commit()
t.Run("subgroup", func(t *testing.T) {
t.Run("selecting multiple events known lower bound", func(t *testing.T) {
t.Parallel()
txn2, err := db.Beginx()
if err != nil {
t.Fatalf("failed to start txn: %s", err)
}
defer txn2.Rollback()
events, err := table.SelectByIDs(txn2, true, []string{eventIDs[0]})
if err != nil || len(events) == 0 {
t.Fatalf("failed to extract event for lower bound: %s", err)
}
events, err = table.SelectEventsBetween(txn2, searchRoomID, int64(events[0].NID), EventsEnd, 1000)
if err != nil {
t.Fatalf("failed to SelectEventsBetween: %s", err)
}
// 3 as 1 is from a different room
if len(events) != 3 {
t.Fatalf("wanted 3 events, got %d", len(events))
}
})
t.Run("selecting multiple events known lower and upper bound", func(t *testing.T) {
t.Parallel()
txn3, err := db.Beginx()
if err != nil {
t.Fatalf("failed to start txn: %s", err)
}
defer txn3.Rollback()
events, err := table.SelectByIDs(txn3, true, []string{eventIDs[0], eventIDs[2]})
if err != nil || len(events) == 0 {
t.Fatalf("failed to extract event for lower/upper bound: %s", err)
}
events, err = table.SelectEventsBetween(txn3, searchRoomID, int64(events[0].NID), int64(events[1].NID), 1000)
if err != nil {
t.Fatalf("failed to SelectEventsBetween: %s", err)
}
// eventIDs[1] and eventIDs[2]
if len(events) != 2 {
t.Fatalf("wanted 2 events, got %d", len(events))
}
})
t.Run("selecting multiple events unknown bounds (all events)", func(t *testing.T) {
t.Parallel()
txn4, err := db.Beginx()
if err != nil {
t.Fatalf("failed to start txn: %s", err)
}
defer txn4.Rollback()
gotEvents, err := table.SelectEventsBetween(txn4, searchRoomID, EventsStart, EventsEnd, 1000)
if err != nil {
t.Fatalf("failed to SelectEventsBetween: %s", err)
}
// one less as one event is for a different room
if len(gotEvents) != (len(events) - 1) {
t.Fatalf("wanted %d events, got %d", len(events)-1, len(gotEvents))
}
})
t.Run("selecting multiple events hitting the limit", func(t *testing.T) {
t.Parallel()
txn5, err := db.Beginx()
if err != nil {
t.Fatalf("failed to start txn: %s", err)
}
defer txn5.Rollback()
limit := 2
gotEvents, err := table.SelectEventsBetween(txn5, searchRoomID, EventsStart, EventsEnd, limit)
if err != nil {
t.Fatalf("failed to SelectEventsBetween: %s", err)
}
if len(gotEvents) != limit {
t.Fatalf("wanted %d events, got %d", limit, len(gotEvents))
}
})
})
}
func TestEventTableMembershipDetection(t *testing.T) {
db, close := connectToDB(t)
defer close()
@ -778,10 +659,14 @@ func TestEventTablePrevBatch(t *testing.T) {
}
assertPrevBatch := func(roomID string, index int, wantPrevBatch string) {
gotPrevBatch, err := table.SelectClosestPrevBatch(roomID, int64(idToNID[events[index].ID]))
if err != nil {
t.Fatalf("failed to SelectClosestPrevBatch: %s", err)
}
var gotPrevBatch string
_ = sqlutil.WithTransaction(db, func(txn *sqlx.Tx) error {
gotPrevBatch, err = table.SelectClosestPrevBatch(txn, roomID, int64(idToNID[events[index].ID]))
if err != nil {
t.Fatalf("failed to SelectClosestPrevBatch: %s", err)
}
return nil
})
if wantPrevBatch != "" {
if gotPrevBatch == "" || gotPrevBatch != wantPrevBatch {
t.Fatalf("SelectClosestPrevBatch: got %v want %v", gotPrevBatch, wantPrevBatch)
@ -947,7 +832,11 @@ func TestLatestEventNIDInRooms(t *testing.T) {
},
}
for _, tc := range testCases {
gotRoomToNID, err := table.LatestEventNIDInRooms(tc.roomIDs, int64(tc.highestNID))
var gotRoomToNID map[string]int64
err = sqlutil.WithTransaction(table.db, func(txn *sqlx.Tx) error {
gotRoomToNID, err = table.LatestEventNIDInRooms(txn, tc.roomIDs, int64(tc.highestNID))
return err
})
assertNoError(t, err)
want := make(map[string]int64) // map event IDs to nids
for roomID, eventID := range tc.wantMap {

View File

@ -0,0 +1,11 @@
-- +goose Up
-- +goose StatementBegin
ALTER TABLE IF EXISTS syncv3_device_data DROP COLUMN IF EXISTS id;
DROP SEQUENCE IF EXISTS syncv3_device_data_seq;
-- +goose StatementEnd
-- +goose Down
-- +goose StatementBegin
CREATE SEQUENCE IF NOT EXISTS syncv3_device_data_seq;
ALTER TABLE syncv3_device_data ADD COLUMN IF NOT EXISTS id BIGINT PRIMARY KEY NOT NULL DEFAULT nextval('syncv3_device_data_seq') ;
-- +goose StatementEnd

View File

@ -0,0 +1,97 @@
package migrations
import (
"context"
"database/sql"
"errors"
"strings"
"github.com/matrix-org/sliding-sync/sync2"
"github.com/pressly/goose/v3"
)
func init() {
goose.AddMigrationContext(upJSONB, downJSONB)
}
func upJSONB(ctx context.Context, tx *sql.Tx) error {
// check if we even need to do anything
var dataType string
err := tx.QueryRow("select data_type from information_schema.columns where table_name = 'syncv3_device_data' AND column_name = 'data'").Scan(&dataType)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
// The table/column doesn't exist in is likely going to be created soon with the
// correct schema
return nil
}
return err
}
if strings.ToLower(dataType) == "jsonb" {
return nil
}
_, err = tx.ExecContext(ctx, "ALTER TABLE IF EXISTS syncv3_device_data ADD COLUMN IF NOT EXISTS dataj JSONB;")
if err != nil {
return err
}
rows, err := tx.Query("SELECT user_id, device_id, data FROM syncv3_device_data")
if err != nil {
return err
}
defer rows.Close()
// abusing PollerID here
var deviceData sync2.PollerID
var data []byte
// map from PollerID -> deviceData
deviceDatas := make(map[sync2.PollerID][]byte)
for rows.Next() {
if err = rows.Scan(&deviceData.UserID, &deviceData.DeviceID, &data); err != nil {
return err
}
deviceDatas[deviceData] = data
}
for dd, d := range deviceDatas {
_, err = tx.ExecContext(ctx, "UPDATE syncv3_device_data SET dataj = $1 WHERE user_id = $2 AND device_id = $3;", d, dd.UserID, dd.DeviceID)
if err != nil {
return err
}
}
if rows.Err() != nil {
return rows.Err()
}
_, err = tx.ExecContext(ctx, "ALTER TABLE IF EXISTS syncv3_device_data DROP COLUMN IF EXISTS data;")
if err != nil {
return err
}
_, err = tx.ExecContext(ctx, "ALTER TABLE IF EXISTS syncv3_device_data RENAME COLUMN dataj TO data;")
if err != nil {
return err
}
return nil
}
func downJSONB(ctx context.Context, tx *sql.Tx) error {
_, err := tx.ExecContext(ctx, "ALTER TABLE IF EXISTS syncv3_device_data ADD COLUMN IF NOT EXISTS datab BYTEA;")
if err != nil {
return err
}
_, err = tx.ExecContext(ctx, "UPDATE syncv3_device_data SET datab = (data::TEXT)::BYTEA;")
if err != nil {
return err
}
_, err = tx.ExecContext(ctx, "ALTER TABLE IF EXISTS syncv3_device_data DROP COLUMN IF EXISTS data;")
if err != nil {
return err
}
_, err = tx.ExecContext(ctx, "ALTER TABLE IF EXISTS syncv3_device_data RENAME COLUMN datab TO data;")
if err != nil {
return err
}
return nil
}

View File

@ -0,0 +1,80 @@
package migrations
import (
"context"
"encoding/json"
"testing"
"github.com/jmoiron/sqlx"
_ "github.com/lib/pq"
"github.com/matrix-org/sliding-sync/internal"
"github.com/matrix-org/sliding-sync/testutils"
)
var postgresConnectionString = "user=xxxxx dbname=syncv3_test sslmode=disable"
func connectToDB(t *testing.T) (*sqlx.DB, func()) {
postgresConnectionString = testutils.PrepareDBConnectionString()
db, err := sqlx.Open("postgres", postgresConnectionString)
if err != nil {
t.Fatalf("failed to open SQL db: %s", err)
}
return db, func() {
db.Close()
}
}
func TestJSONBMigration(t *testing.T) {
ctx := context.Background()
db, close := connectToDB(t)
defer close()
// Create the table in the old format (data = BYTEA instead of JSONB)
_, err := db.Exec(`CREATE TABLE syncv3_device_data (
user_id TEXT NOT NULL,
device_id TEXT NOT NULL,
data BYTEA NOT NULL,
UNIQUE(user_id, device_id)
);`)
if err != nil {
t.Fatal(err)
}
tx, err := db.Begin()
if err != nil {
t.Fatal(err)
}
defer tx.Commit()
// insert some "invalid" data
dd := internal.DeviceData{
DeviceLists: internal.DeviceLists{
New: map[string]int{"@💣:localhost": 1},
Sent: map[string]int{},
},
OTKCounts: map[string]int{},
FallbackKeyTypes: []string{},
}
data, err := json.Marshal(dd)
if err != nil {
t.Fatal(err)
}
_, err = tx.ExecContext(ctx, `INSERT INTO syncv3_device_data (user_id, device_id, data) VALUES ($1, $2, $3)`, "bob", "bobDev", data)
if err != nil {
t.Fatal(err)
}
// validate that invalid data can be migrated upwards
err = upJSONB(ctx, tx)
if err != nil {
t.Fatal(err)
}
// and downgrade again
err = downJSONB(ctx, tx)
if err != nil {
t.Fatal(err)
}
}

View File

@ -0,0 +1,142 @@
package migrations
import (
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"strings"
"github.com/fxamacker/cbor/v2"
"github.com/matrix-org/sliding-sync/internal"
"github.com/matrix-org/sliding-sync/sync2"
"github.com/pressly/goose/v3"
)
func init() {
goose.AddMigrationContext(upCborDeviceData, downCborDeviceData)
}
func upCborDeviceData(ctx context.Context, tx *sql.Tx) error {
// check if we even need to do anything
var dataType string
err := tx.QueryRow("select data_type from information_schema.columns where table_name = 'syncv3_device_data' AND column_name = 'data'").Scan(&dataType)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
// The table/column doesn't exist in is likely going to be created soon with the
// correct schema
return nil
}
return err
}
if strings.ToLower(dataType) == "bytea" {
return nil
}
_, err = tx.ExecContext(ctx, "ALTER TABLE IF EXISTS syncv3_device_data ADD COLUMN IF NOT EXISTS datab BYTEA;")
if err != nil {
return err
}
rows, err := tx.Query("SELECT user_id, device_id, data FROM syncv3_device_data")
if err != nil {
return err
}
defer rows.Close()
// abusing PollerID here
var deviceData sync2.PollerID
var data []byte
// map from PollerID -> deviceData
deviceDatas := make(map[sync2.PollerID][]byte)
for rows.Next() {
if err = rows.Scan(&deviceData.UserID, &deviceData.DeviceID, &data); err != nil {
return err
}
deviceDatas[deviceData] = data
}
for dd, jsonBytes := range deviceDatas {
var data internal.DeviceData
if err := json.Unmarshal(jsonBytes, &data); err != nil {
return fmt.Errorf("failed to unmarshal JSON: %v -> %v", string(jsonBytes), err)
}
cborBytes, err := cbor.Marshal(data)
if err != nil {
return fmt.Errorf("failed to marshal as CBOR: %v", err)
}
_, err = tx.ExecContext(ctx, "UPDATE syncv3_device_data SET datab = $1 WHERE user_id = $2 AND device_id = $3;", cborBytes, dd.UserID, dd.DeviceID)
if err != nil {
return err
}
}
if rows.Err() != nil {
return rows.Err()
}
_, err = tx.ExecContext(ctx, "ALTER TABLE IF EXISTS syncv3_device_data DROP COLUMN IF EXISTS data;")
if err != nil {
return err
}
_, err = tx.ExecContext(ctx, "ALTER TABLE IF EXISTS syncv3_device_data RENAME COLUMN datab TO data;")
if err != nil {
return err
}
return nil
}
func downCborDeviceData(ctx context.Context, tx *sql.Tx) error {
_, err := tx.ExecContext(ctx, "ALTER TABLE IF EXISTS syncv3_device_data ADD COLUMN IF NOT EXISTS dataj JSONB;")
if err != nil {
return err
}
rows, err := tx.Query("SELECT user_id, device_id, data FROM syncv3_device_data")
if err != nil {
return err
}
defer rows.Close()
// abusing PollerID here
var deviceData sync2.PollerID
var data []byte
// map from PollerID -> deviceData
deviceDatas := make(map[sync2.PollerID][]byte)
for rows.Next() {
if err = rows.Scan(&deviceData.UserID, &deviceData.DeviceID, &data); err != nil {
return err
}
deviceDatas[deviceData] = data
}
for dd, cborBytes := range deviceDatas {
var data internal.DeviceData
if err := cbor.Unmarshal(cborBytes, &data); err != nil {
return fmt.Errorf("failed to unmarshal CBOR: %v", err)
}
jsonBytes, err := json.Marshal(data)
if err != nil {
return fmt.Errorf("failed to marshal as JSON: %v", err)
}
_, err = tx.ExecContext(ctx, "UPDATE syncv3_device_data SET dataj = $1 WHERE user_id = $2 AND device_id = $3;", jsonBytes, dd.UserID, dd.DeviceID)
if err != nil {
return err
}
}
_, err = tx.ExecContext(ctx, "ALTER TABLE IF EXISTS syncv3_device_data DROP COLUMN IF EXISTS data;")
if err != nil {
return err
}
_, err = tx.ExecContext(ctx, "ALTER TABLE IF EXISTS syncv3_device_data RENAME COLUMN dataj TO data;")
if err != nil {
return err
}
return nil
}

View File

@ -0,0 +1,121 @@
package migrations
import (
"context"
"encoding/json"
"reflect"
"testing"
_ "github.com/lib/pq"
"github.com/matrix-org/sliding-sync/internal"
"github.com/matrix-org/sliding-sync/state"
)
func TestCBORBMigration(t *testing.T) {
ctx := context.Background()
db, close := connectToDB(t)
defer close()
// Create the table in the old format (data = JSONB instead of BYTEA)
// and insert some data: we'll make sure that this data is preserved
// after migrating.
_, err := db.Exec(`CREATE TABLE syncv3_device_data (
user_id TEXT NOT NULL,
device_id TEXT NOT NULL,
data JSONB NOT NULL,
UNIQUE(user_id, device_id)
);`)
if err != nil {
t.Fatal(err)
}
rowData := []internal.DeviceData{
{
DeviceLists: internal.DeviceLists{
New: map[string]int{"@bob:localhost": 2},
Sent: map[string]int{},
},
ChangedBits: 2,
OTKCounts: map[string]int{"bar": 42},
FallbackKeyTypes: []string{"narp"},
DeviceID: "ALICE",
UserID: "@alice:localhost",
},
{
DeviceLists: internal.DeviceLists{
New: map[string]int{"@💣:localhost": 1, "@bomb:localhost": 2},
Sent: map[string]int{"@sent:localhost": 1},
},
OTKCounts: map[string]int{"foo": 100},
FallbackKeyTypes: []string{"yep"},
DeviceID: "BOB",
UserID: "@bob:localhost",
},
}
tx, err := db.Begin()
if err != nil {
t.Fatal(err)
}
for _, dd := range rowData {
data, err := json.Marshal(dd)
if err != nil {
t.Fatal(err)
}
_, err = tx.ExecContext(ctx, `INSERT INTO syncv3_device_data (user_id, device_id, data) VALUES ($1, $2, $3)`, dd.UserID, dd.DeviceID, data)
if err != nil {
t.Fatal(err)
}
}
// validate that invalid data can be migrated upwards
err = upCborDeviceData(ctx, tx)
if err != nil {
t.Fatal(err)
}
tx.Commit()
// ensure we can now select it
table := state.NewDeviceDataTable(db)
for _, want := range rowData {
got, err := table.Select(want.UserID, want.DeviceID, false)
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(*got, want) {
t.Fatalf("got %+v\nwant %+v", *got, want)
}
}
// and downgrade again
tx, err = db.Begin()
if err != nil {
t.Fatal(err)
}
err = downCborDeviceData(ctx, tx)
if err != nil {
t.Fatal(err)
}
// ensure it is what we originally inserted
for _, want := range rowData {
var got internal.DeviceData
var gotBytes []byte
err = tx.QueryRow(`SELECT data FROM syncv3_device_data WHERE user_id=$1 AND device_id=$2`, want.UserID, want.DeviceID).Scan(&gotBytes)
if err != nil {
t.Fatal(err)
}
if err = json.Unmarshal(gotBytes, &got); err != nil {
t.Fatal(err)
}
got.DeviceID = want.DeviceID
got.UserID = want.UserID
if !reflect.DeepEqual(got, want) {
t.Fatalf("got %+v\nwant %+v", got, want)
}
}
tx.Commit()
}

View File

@ -0,0 +1,56 @@
# Database migrations in the Sliding Sync Proxy
Database migrations are using https://github.com/pressly/goose, with an integrated `migrate` command in the `syncv3` binary.
All commands below require the `SYNCV3_DB` environment variable to determine which database to use:
```bash
$ export SYNCV3_DB="user=postgres dbname=syncv3 sslmode=disable password=yourpassword"
```
## Upgrading
It is sufficient to run the proxy itself, upgrading is done automatically. If you still have the need to upgrade manually, you can
use one of the following commands to upgrade:
```bash
# Check which versions have been applied
$ ./syncv3 migrate status
# Execute all existing migrations
$ ./syncv3 migrate up
# Upgrade to 20230802121023 (which is 20230802121023_device_data_jsonb.go - migrating from bytea to jsonb)
$ ./sync3 migrate up-to 20230728114555
# Upgrade by one
$ ./syncv3 migrate up-by-one
```
## Downgrading
If you wish to downgrade, executing one of the following commands. If you downgrade, make sure you also start
the required older version of the proxy, as otherwise the schema will automatically be upgraded again:
```bash
# Check which versions have been applied
$ ./syncv3 migrate status
# Undo the latest migration
$ ./syncv3 migrate down
# Downgrade to 20230728114555 (which is 20230728114555_device_data_drop_id.sql - dropping the id column)
$ ./sync3 migrate down-to 20230728114555
```
## Creating new migrations
Migrations can either be created as plain SQL or as Go functions.
```bash
# Create a new SQL migration with the name "mymigration"
$ ./syncv3 migrate create mymigration sql
# Same as above, but as Go functions
$ ./syncv3 migrate create mymigration go
```

View File

@ -6,7 +6,6 @@ import (
"fmt"
"os"
"strings"
"time"
"github.com/getsentry/sentry-go"
@ -38,6 +37,22 @@ type LatestEvents struct {
LatestNID int64
}
// DiscardIgnoredMessages modifies the struct in-place, replacing the Timeline with
// a copy that has all ignored events omitted. The order of timelines is preserved.
func (e *LatestEvents) DiscardIgnoredMessages(shouldIgnore func(sender string) bool) {
// A little bit sad to be effectively doing a copy here---most of the time there
// won't be any messages to ignore (and the timeline is likely short). But that copy
// is unlikely to be a bottleneck.
newTimeline := make([]json.RawMessage, 0, len(e.Timeline))
for _, ev := range e.Timeline {
parsed := gjson.ParseBytes(ev)
if parsed.Get("state_key").Exists() || !shouldIgnore(parsed.Get("sender").Str) {
newTimeline = append(newTimeline, ev)
}
}
e.Timeline = newTimeline
}
type Storage struct {
Accumulator *Accumulator
EventsTable *EventTable
@ -58,9 +73,10 @@ func NewStorage(postgresURI string) *Storage {
// TODO: if we panic(), will sentry have a chance to flush the event?
logger.Panic().Err(err).Str("uri", postgresURI).Msg("failed to open SQL DB")
}
db.SetMaxOpenConns(100)
db.SetMaxIdleConns(80)
db.SetConnMaxLifetime(time.Hour)
return NewStorageWithDB(db)
}
func NewStorageWithDB(db *sqlx.DB) *Storage {
acc := &Accumulator{
db: db,
roomsTable: NewRoomsTable(db),
@ -178,26 +194,12 @@ func (s *Storage) GlobalSnapshot() (ss StartupSnapshot, err error) {
// Extract hero info for all rooms. Requires a prepared snapshot in order to be called.
func (s *Storage) MetadataForAllRooms(txn *sqlx.Tx, tempTableName string, result map[string]internal.RoomMetadata) error {
// Select the invited member counts
rows, err := txn.Query(`
SELECT room_id, count(state_key) FROM syncv3_events INNER JOIN ` + tempTableName + ` ON membership_nid=event_nid
WHERE (membership='_invite' OR membership = 'invite') AND event_type='m.room.member' GROUP BY room_id`)
if err != nil {
return err
}
defer rows.Close()
for rows.Next() {
var roomID string
var inviteCount int
if err := rows.Scan(&roomID, &inviteCount); err != nil {
return err
}
loadMetadata := func(roomID string) internal.RoomMetadata {
metadata, ok := result[roomID]
if !ok {
metadata = *internal.NewRoomMetadata(roomID)
}
metadata.InviteCount = inviteCount
result[roomID] = metadata
return metadata
}
// work out latest timestamps
@ -206,10 +208,8 @@ func (s *Storage) MetadataForAllRooms(txn *sqlx.Tx, tempTableName string, result
return err
}
for _, ev := range events {
metadata, ok := result[ev.RoomID]
if !ok {
metadata = *internal.NewRoomMetadata(ev.RoomID)
}
metadata := loadMetadata(ev.RoomID)
// For a given room, we'll see many events (one for each event type in the
// room's state). We need to pick the largest of these events' timestamps here.
ts := gjson.ParseBytes(ev.JSON).Get("origin_server_ts").Uint()
@ -232,68 +232,32 @@ func (s *Storage) MetadataForAllRooms(txn *sqlx.Tx, tempTableName string, result
// Select the name / canonical alias for all rooms
roomIDToStateEvents, err := s.currentNotMembershipStateEventsInAllRooms(txn, []string{
"m.room.name", "m.room.canonical_alias",
"m.room.name", "m.room.canonical_alias", "m.room.avatar",
})
if err != nil {
return fmt.Errorf("failed to load state events for all rooms: %s", err)
}
for roomID, stateEvents := range roomIDToStateEvents {
metadata := result[roomID]
metadata := loadMetadata(roomID)
for _, ev := range stateEvents {
if ev.Type == "m.room.name" && ev.StateKey == "" {
metadata.NameEvent = gjson.ParseBytes(ev.JSON).Get("content.name").Str
} else if ev.Type == "m.room.canonical_alias" && ev.StateKey == "" {
metadata.CanonicalAlias = gjson.ParseBytes(ev.JSON).Get("content.alias").Str
} else if ev.Type == "m.room.avatar" && ev.StateKey == "" {
metadata.AvatarEvent = gjson.ParseBytes(ev.JSON).Get("content.url").Str
}
}
result[roomID] = metadata
}
// Select the most recent members for each room to serve as Heroes. The spec is ambiguous here:
// "This should be the first 5 members of the room, ordered by stream ordering, which are joined or invited."
// Unclear if this is the first 5 *most recent* (backwards) or forwards. For now we'll use the most recent
// ones, and select 6 of them so we can always use 5 no matter who is requesting the room name.
rows, err = txn.Query(`
SELECT rf.* FROM (
SELECT room_id, event, rank() OVER (
PARTITION BY room_id ORDER BY event_nid DESC
) FROM syncv3_events INNER JOIN ` + tempTableName + ` ON membership_nid=event_nid WHERE (
membership='join' OR membership='invite' OR membership='_join'
) AND event_type='m.room.member'
) rf WHERE rank <= 6;`)
if err != nil {
return fmt.Errorf("failed to query heroes: %s", err)
}
defer rows.Close()
seen := map[string]bool{}
for rows.Next() {
var roomID string
var event json.RawMessage
var rank int
if err := rows.Scan(&roomID, &event, &rank); err != nil {
return err
}
ev := gjson.ParseBytes(event)
targetUser := ev.Get("state_key").Str
key := roomID + " " + targetUser
if seen[key] {
continue
}
seen[key] = true
metadata := result[roomID]
metadata.Heroes = append(metadata.Heroes, internal.Hero{
ID: targetUser,
Name: ev.Get("content.displayname").Str,
})
result[roomID] = metadata
}
roomInfos, err := s.Accumulator.roomsTable.SelectRoomInfos(txn)
if err != nil {
return fmt.Errorf("failed to select room infos: %s", err)
}
var spaceRoomIDs []string
for _, info := range roomInfos {
metadata := result[info.ID]
metadata := loadMetadata(info.ID)
metadata.Encrypted = info.IsEncrypted
metadata.UpgradedRoomID = info.UpgradedRoomID
metadata.PredecessorRoomID = info.PredecessorRoomID
@ -310,7 +274,13 @@ func (s *Storage) MetadataForAllRooms(txn *sqlx.Tx, tempTableName string, result
return fmt.Errorf("failed to select space children: %s", err)
}
for roomID, relations := range spaceRoomToRelations {
metadata := result[roomID]
if _, exists := result[roomID]; !exists {
// this can happen when you join a space (so it populates the spaces table) then leave the space,
// so there are no joined members in the space so result doesn't include the room. In this case,
// we don't want to have a stub metadata with just the space children, so skip it.
continue
}
metadata := loadMetadata(roomID)
metadata.ChildSpaceRooms = make(map[string]struct{}, len(relations))
for _, r := range relations {
// For now we only honour child state events, but we store all the mappings just in case.
@ -353,12 +323,12 @@ func (s *Storage) currentNotMembershipStateEventsInAllRooms(txn *sqlx.Tx, eventT
return result, nil
}
func (s *Storage) Accumulate(roomID, prevBatch string, timeline []json.RawMessage) (numNew int, timelineNIDs []int64, err error) {
func (s *Storage) Accumulate(userID, roomID, prevBatch string, timeline []json.RawMessage) (numNew int, timelineNIDs []int64, err error) {
if len(timeline) == 0 {
return 0, nil, nil
}
err = sqlutil.WithTransaction(s.Accumulator.db, func(txn *sqlx.Tx) error {
numNew, timelineNIDs, err = s.Accumulator.Accumulate(txn, roomID, prevBatch, timeline)
numNew, timelineNIDs, err = s.Accumulator.Accumulate(txn, userID, roomID, prevBatch, timeline)
return err
})
return
@ -545,7 +515,7 @@ func (s *Storage) RoomStateAfterEventPosition(ctx context.Context, roomIDs []str
if err != nil {
return fmt.Errorf("failed to form sql query: %s", err)
}
rows, err := s.Accumulator.db.Query(s.Accumulator.db.Rebind(query), args...)
rows, err := txn.Query(txn.Rebind(query), args...)
if err != nil {
return fmt.Errorf("failed to execute query: %s", err)
}
@ -630,7 +600,7 @@ func (s *Storage) LatestEventsInRooms(userID string, roomIDs []string, to int64,
}
if earliestEventNID != 0 {
// the oldest event needs a prev batch token, so find one now
prevBatch, err := s.EventsTable.SelectClosestPrevBatch(roomID, earliestEventNID)
prevBatch, err := s.EventsTable.SelectClosestPrevBatch(txn, roomID, earliestEventNID)
if err != nil {
return fmt.Errorf("failed to select prev_batch for room %s : %s", roomID, err)
}
@ -810,32 +780,121 @@ func (s *Storage) RoomMembershipDelta(roomID string, from, to int64, limit int)
}
// Extract all rooms with joined members, and include the joined user list. Requires a prepared snapshot in order to be called.
func (s *Storage) AllJoinedMembers(txn *sqlx.Tx, tempTableName string) (result map[string][]string, metadata map[string]internal.RoomMetadata, err error) {
// Populates the join/invite count and heroes for the returned metadata.
func (s *Storage) AllJoinedMembers(txn *sqlx.Tx, tempTableName string) (joinedMembers map[string][]string, metadata map[string]internal.RoomMetadata, err error) {
// Select the most recent members for each room to serve as Heroes. The spec is ambiguous here:
// "This should be the first 5 members of the room, ordered by stream ordering, which are joined or invited."
// Unclear if this is the first 5 *most recent* (backwards) or forwards. For now we'll use the most recent
// ones, and select 6 of them so we can always use 5 no matter who is requesting the room name.
rows, err := txn.Query(
`SELECT room_id, state_key from ` + tempTableName + ` INNER JOIN syncv3_events on membership_nid = event_nid WHERE membership='join' OR membership='_join' ORDER BY event_nid ASC`,
`SELECT membership_nid, room_id, state_key, membership from ` + tempTableName + ` INNER JOIN syncv3_events
on membership_nid = event_nid WHERE membership='join' OR membership='_join' OR membership='invite' OR membership='_invite' ORDER BY event_nid ASC`,
)
if err != nil {
return nil, nil, err
}
defer rows.Close()
result = make(map[string][]string)
joinedMembers = make(map[string][]string)
inviteCounts := make(map[string]int)
heroNIDs := make(map[string]*circularSlice)
var stateKey string
var membership string
var roomID string
var joinedUserID string
var nid int64
for rows.Next() {
if err := rows.Scan(&roomID, &joinedUserID); err != nil {
if err := rows.Scan(&nid, &roomID, &stateKey, &membership); err != nil {
return nil, nil, err
}
users := result[roomID]
users = append(users, joinedUserID)
result[roomID] = users
heroes := heroNIDs[roomID]
if heroes == nil {
heroes = &circularSlice{max: 6}
heroNIDs[roomID] = heroes
}
switch membership {
case "join":
fallthrough
case "_join":
users := joinedMembers[roomID]
users = append(users, stateKey)
joinedMembers[roomID] = users
heroes.append(nid)
case "invite":
fallthrough
case "_invite":
inviteCounts[roomID] = inviteCounts[roomID] + 1
heroes.append(nid)
}
}
// now select the membership events for the heroes
var allHeroNIDs []int64
for _, nids := range heroNIDs {
allHeroNIDs = append(allHeroNIDs, nids.vals...)
}
heroEvents, err := s.EventsTable.SelectByNIDs(txn, true, allHeroNIDs)
if err != nil {
return nil, nil, err
}
heroes := make(map[string][]internal.Hero)
// loop backwards so the most recent hero is first in the hero list
for i := len(heroEvents) - 1; i >= 0; i-- {
ev := heroEvents[i]
evJSON := gjson.ParseBytes(ev.JSON)
roomHeroes := heroes[ev.RoomID]
roomHeroes = append(roomHeroes, internal.Hero{
ID: ev.StateKey,
Name: evJSON.Get("content.displayname").Str,
Avatar: evJSON.Get("content.avatar_url").Str,
})
heroes[ev.RoomID] = roomHeroes
}
metadata = make(map[string]internal.RoomMetadata)
for roomID, joinedMembers := range result {
for roomID, members := range joinedMembers {
m := internal.NewRoomMetadata(roomID)
m.JoinCount = len(joinedMembers)
m.JoinCount = len(members)
m.InviteCount = inviteCounts[roomID]
m.Heroes = heroes[roomID]
metadata[roomID] = *m
}
return result, metadata, nil
return joinedMembers, metadata, nil
}
func (s *Storage) LatestEventNIDInRooms(roomIDs []string, highestNID int64) (roomToNID map[string]int64, err error) {
roomToNID = make(map[string]int64)
err = sqlutil.WithTransaction(s.Accumulator.db, func(txn *sqlx.Tx) error {
// Pull out the latest nids for all the rooms. If they are < highestNID then use them, else we need to query the
// events table (slow) for the latest nid in this room which is < highestNID.
fastRoomToLatestNIDs, err := s.Accumulator.roomsTable.LatestNIDs(txn, roomIDs)
if err != nil {
return err
}
var slowRooms []string
for _, roomID := range roomIDs {
nid := fastRoomToLatestNIDs[roomID]
if nid > 0 && nid <= highestNID {
roomToNID[roomID] = nid
} else {
// we need to do a slow query for this
slowRooms = append(slowRooms, roomID)
}
}
if len(slowRooms) == 0 {
return nil // no work to do
}
logger.Warn().Int("slow_rooms", len(slowRooms)).Msg("LatestEventNIDInRooms: pos value provided is far behind the database copy, performance degraded")
slowRoomToLatestNIDs, err := s.EventsTable.LatestEventNIDInRooms(txn, slowRooms, highestNID)
if err != nil {
return err
}
for roomID, nid := range slowRoomToLatestNIDs {
roomToNID[roomID] = nid
}
return nil
})
return roomToNID, err
}
// Returns a map from joined room IDs to EventMetadata, which is nil iff a non-nil error
@ -895,3 +954,27 @@ func (s *Storage) Teardown() {
panic("Storage.Teardown: " + err.Error())
}
}
// circularSlice is a slice which can be appended to which will wraparound at `max`.
// Mostly useful for lazily calculating heroes. The values returned aren't sorted.
type circularSlice struct {
i int
vals []int64
max int
}
func (s *circularSlice) append(val int64) {
if len(s.vals) < s.max {
// populate up to max
s.vals = append(s.vals, val)
s.i++
return
}
// wraparound
if s.i == s.max {
s.i = 0
}
// replace this entry
s.vals[s.i] = val
s.i++
}

View File

@ -4,6 +4,7 @@ import (
"bytes"
"context"
"encoding/json"
"fmt"
"reflect"
"sort"
"testing"
@ -30,7 +31,7 @@ func TestStorageRoomStateBeforeAndAfterEventPosition(t *testing.T) {
testutils.NewStateEvent(t, "m.room.join_rules", "", alice, map[string]interface{}{"join_rule": "invite"}),
testutils.NewStateEvent(t, "m.room.member", bob, alice, map[string]interface{}{"membership": "invite"}),
}
_, latestNIDs, err := store.Accumulate(roomID, "", events)
_, latestNIDs, err := store.Accumulate(userID, roomID, "", events)
if err != nil {
t.Fatalf("Accumulate returned error: %s", err)
}
@ -160,7 +161,7 @@ func TestStorageJoinedRoomsAfterPosition(t *testing.T) {
var latestNIDs []int64
var err error
for roomID, eventMap := range roomIDToEventMap {
_, latestNIDs, err = store.Accumulate(roomID, "", eventMap)
_, latestNIDs, err = store.Accumulate(userID, roomID, "", eventMap)
if err != nil {
t.Fatalf("Accumulate on %s failed: %s", roomID, err)
}
@ -210,18 +211,19 @@ func TestStorageJoinedRoomsAfterPosition(t *testing.T) {
}
}
newMetadata := func(roomID string, joinCount int) internal.RoomMetadata {
newMetadata := func(roomID string, joinCount, inviteCount int) internal.RoomMetadata {
m := internal.NewRoomMetadata(roomID)
m.JoinCount = joinCount
m.InviteCount = inviteCount
return *m
}
// also test MetadataForAllRooms
roomIDToMetadata := map[string]internal.RoomMetadata{
joinedRoomID: newMetadata(joinedRoomID, 1),
invitedRoomID: newMetadata(invitedRoomID, 1),
banRoomID: newMetadata(banRoomID, 1),
bobJoinedRoomID: newMetadata(bobJoinedRoomID, 2),
joinedRoomID: newMetadata(joinedRoomID, 1, 0),
invitedRoomID: newMetadata(invitedRoomID, 1, 1),
banRoomID: newMetadata(banRoomID, 1, 0),
bobJoinedRoomID: newMetadata(bobJoinedRoomID, 2, 0),
}
tempTableName, err := store.PrepareSnapshot(txn)
@ -349,7 +351,7 @@ func TestVisibleEventNIDsBetween(t *testing.T) {
},
}
for _, tl := range timelineInjections {
numNew, _, err := store.Accumulate(tl.RoomID, "", tl.Events)
numNew, _, err := store.Accumulate(userID, tl.RoomID, "", tl.Events)
if err != nil {
t.Fatalf("Accumulate on %s failed: %s", tl.RoomID, err)
}
@ -452,7 +454,7 @@ func TestVisibleEventNIDsBetween(t *testing.T) {
t.Fatalf("LatestEventNID: %s", err)
}
for _, tl := range timelineInjections {
numNew, _, err := store.Accumulate(tl.RoomID, "", tl.Events)
numNew, _, err := store.Accumulate(userID, tl.RoomID, "", tl.Events)
if err != nil {
t.Fatalf("Accumulate on %s failed: %s", tl.RoomID, err)
}
@ -532,7 +534,7 @@ func TestStorageLatestEventsInRoomsPrevBatch(t *testing.T) {
}
eventIDs := []string{}
for _, timeline := range timelines {
_, _, err = store.Accumulate(roomID, timeline.prevBatch, timeline.timeline)
_, _, err = store.Accumulate(userID, roomID, timeline.prevBatch, timeline.timeline)
if err != nil {
t.Fatalf("failed to accumulate: %s", err)
}
@ -566,10 +568,15 @@ func TestStorageLatestEventsInRoomsPrevBatch(t *testing.T) {
wantPrevBatch := wantPrevBatches[i]
eventNID := idsToNIDs[eventIDs[i]]
// closest batch to the last event in the chunk (latest nid) is always the next prev batch token
pb, err := store.EventsTable.SelectClosestPrevBatch(roomID, eventNID)
if err != nil {
t.Fatalf("failed to SelectClosestPrevBatch: %s", err)
}
var pb string
_ = sqlutil.WithTransaction(store.DB, func(txn *sqlx.Tx) (err error) {
pb, err = store.EventsTable.SelectClosestPrevBatch(txn, roomID, eventNID)
if err != nil {
t.Fatalf("failed to SelectClosestPrevBatch: %s", err)
}
return nil
})
if pb != wantPrevBatch {
t.Fatalf("SelectClosestPrevBatch: got %v want %v", pb, wantPrevBatch)
}
@ -682,6 +689,189 @@ func TestGlobalSnapshot(t *testing.T) {
}
}
func TestAllJoinedMembers(t *testing.T) {
assertNoError(t, cleanDB(t))
store := NewStorage(postgresConnectionString)
defer store.Teardown()
alice := "@alice:localhost"
bob := "@bob:localhost"
charlie := "@charlie:localhost"
doris := "@doris:localhost"
eve := "@eve:localhost"
frank := "@frank:localhost"
// Alice is always the creator and the inviter for simplicity's sake
testCases := []struct {
Name string
InitMemberships [][2]string
AccumulateMemberships [][2]string
RoomID string // tests set this dynamically
WantJoined []string
WantInvited []string
}{
{
Name: "basic joined users",
InitMemberships: [][2]string{{alice, "join"}},
AccumulateMemberships: [][2]string{{bob, "join"}},
WantJoined: []string{alice, bob},
},
{
Name: "basic invited users",
InitMemberships: [][2]string{{alice, "join"}, {charlie, "invite"}},
AccumulateMemberships: [][2]string{{bob, "invite"}},
WantJoined: []string{alice},
WantInvited: []string{bob, charlie},
},
{
Name: "many join/leaves, use latest",
InitMemberships: [][2]string{{alice, "join"}, {charlie, "join"}, {frank, "join"}},
AccumulateMemberships: [][2]string{{bob, "join"}, {charlie, "leave"}, {frank, "leave"}, {charlie, "join"}, {eve, "join"}},
WantJoined: []string{alice, bob, charlie, eve},
},
{
Name: "many invites, use latest",
InitMemberships: [][2]string{{alice, "join"}, {doris, "join"}},
AccumulateMemberships: [][2]string{{doris, "leave"}, {charlie, "invite"}, {doris, "invite"}},
WantJoined: []string{alice},
WantInvited: []string{charlie, doris},
},
{
Name: "invite and rejection in accumulate",
InitMemberships: [][2]string{{alice, "join"}},
AccumulateMemberships: [][2]string{{frank, "invite"}, {frank, "leave"}},
WantJoined: []string{alice},
},
{
Name: "invite in initial, rejection in accumulate",
InitMemberships: [][2]string{{alice, "join"}, {frank, "invite"}},
AccumulateMemberships: [][2]string{{frank, "leave"}},
WantJoined: []string{alice},
},
}
serialise := func(memberships [][2]string) []json.RawMessage {
var result []json.RawMessage
for _, userWithMembership := range memberships {
target := userWithMembership[0]
sender := userWithMembership[0]
membership := userWithMembership[1]
if membership == "invite" {
// Alice is always the inviter
sender = alice
}
result = append(result, testutils.NewStateEvent(t, "m.room.member", target, sender, map[string]interface{}{
"membership": membership,
}))
}
return result
}
for i, tc := range testCases {
roomID := fmt.Sprintf("!TestAllJoinedMembers_%d:localhost", i)
_, err := store.Initialise(roomID, append([]json.RawMessage{
testutils.NewStateEvent(t, "m.room.create", "", alice, map[string]interface{}{
"creator": alice, // alice is always the creator
}),
}, serialise(tc.InitMemberships)...))
assertNoError(t, err)
_, _, err = store.Accumulate(userID, roomID, "foo", serialise(tc.AccumulateMemberships))
assertNoError(t, err)
testCases[i].RoomID = roomID // remember this for later
}
// should get all joined members correctly
var joinedMembers map[string][]string
// should set join/invite counts correctly
var roomMetadatas map[string]internal.RoomMetadata
err := sqlutil.WithTransaction(store.DB, func(txn *sqlx.Tx) error {
tableName, err := store.PrepareSnapshot(txn)
if err != nil {
return err
}
joinedMembers, roomMetadatas, err = store.AllJoinedMembers(txn, tableName)
return err
})
assertNoError(t, err)
for _, tc := range testCases {
roomID := tc.RoomID
if roomID == "" {
t.Fatalf("test case has no room id set: %+v", tc)
}
// make sure joined members match
sort.Strings(joinedMembers[roomID])
sort.Strings(tc.WantJoined)
if !reflect.DeepEqual(joinedMembers[roomID], tc.WantJoined) {
t.Errorf("%v: got joined members %v want %v", tc.Name, joinedMembers[roomID], tc.WantJoined)
}
// make sure join/invite counts match
wantJoined := len(tc.WantJoined)
wantInvited := len(tc.WantInvited)
metadata, ok := roomMetadatas[roomID]
if !ok {
t.Fatalf("no room metadata for room %v", roomID)
}
if metadata.InviteCount != wantInvited {
t.Errorf("%v: got invite count %d want %d", tc.Name, metadata.InviteCount, wantInvited)
}
if metadata.JoinCount != wantJoined {
t.Errorf("%v: got join count %d want %d", tc.Name, metadata.JoinCount, wantJoined)
}
}
}
func TestCircularSlice(t *testing.T) {
testCases := []struct {
name string
max int
appends []int64
want []int64 // these get sorted in the test
}{
{
name: "wraparound",
max: 5,
appends: []int64{9, 8, 7, 6, 5, 4, 3, 2},
want: []int64{2, 3, 4, 5, 6},
},
{
name: "exact",
max: 5,
appends: []int64{9, 8, 7, 6, 5},
want: []int64{5, 6, 7, 8, 9},
},
{
name: "unfilled",
max: 5,
appends: []int64{9, 8, 7},
want: []int64{7, 8, 9},
},
{
name: "wraparound x2",
max: 5,
appends: []int64{9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 10},
want: []int64{0, 1, 2, 3, 10},
},
}
for _, tc := range testCases {
cs := &circularSlice{
max: tc.max,
}
for _, val := range tc.appends {
cs.append(val)
}
sort.Slice(cs.vals, func(i, j int) bool {
return cs.vals[i] < cs.vals[j]
})
if !reflect.DeepEqual(cs.vals, tc.want) {
t.Errorf("%s: got %v want %v", tc.name, cs.vals, tc.want)
}
}
}
func cleanDB(t *testing.T) error {
// make a fresh DB which is unpolluted from other tests
db, close := connectToDB(t)

View File

@ -7,6 +7,7 @@ import (
"io/ioutil"
"net/http"
"net/url"
"time"
"github.com/matrix-org/gomatrixserverlib"
"github.com/tidwall/gjson"
@ -69,7 +70,15 @@ func (v *HTTPClient) DoSyncV2(ctx context.Context, accessToken, since string, is
if err != nil {
return nil, 0, fmt.Errorf("DoSyncV2: NewRequest failed: %w", err)
}
res, err := v.Client.Do(req)
var res *http.Response
if isFirst {
longTimeoutClient := &http.Client{
Timeout: 30 * time.Minute,
}
res, err = longTimeoutClient.Do(req)
} else {
res, err = v.Client.Do(req)
}
if err != nil {
return nil, 0, fmt.Errorf("DoSyncV2: request failed: %w", err)
}

View File

@ -0,0 +1,90 @@
package sync2
import (
"sync"
"time"
"github.com/matrix-org/sliding-sync/pubsub"
)
// This struct remembers user+device IDs to notify for then periodically
// emits them all to the caller. Use to rate limit the frequency of device list
// updates.
type DeviceDataTicker struct {
// data structures to periodically notify downstream about device data updates
// The ticker controls the frequency of updates. The done channel is used to stop ticking
// and clean up the goroutine. The notify map contains the values to notify for.
ticker *time.Ticker
done chan struct{}
notifyMap *sync.Map // map of PollerID to bools, unwrapped when notifying
fn func(payload *pubsub.V2DeviceData)
}
// Create a new device data ticker, which batches calls to Remember and invokes a callback every
// d duration. If d is 0, no batching is performed and the callback is invoked synchronously, which
// is useful for testing.
func NewDeviceDataTicker(d time.Duration) *DeviceDataTicker {
ddt := &DeviceDataTicker{
done: make(chan struct{}),
notifyMap: &sync.Map{},
}
if d != 0 {
ddt.ticker = time.NewTicker(d)
}
return ddt
}
// Stop ticking.
func (t *DeviceDataTicker) Stop() {
if t.ticker != nil {
t.ticker.Stop()
}
close(t.done)
}
// Set the function which should be called when the tick happens.
func (t *DeviceDataTicker) SetCallback(fn func(payload *pubsub.V2DeviceData)) {
t.fn = fn
}
// Remember this user/device ID, and emit it later on.
func (t *DeviceDataTicker) Remember(pid PollerID) {
t.notifyMap.Store(pid, true)
if t.ticker == nil {
t.emitUpdate()
}
}
func (t *DeviceDataTicker) emitUpdate() {
var p pubsub.V2DeviceData
p.UserIDToDeviceIDs = make(map[string][]string)
// populate the pubsub payload
t.notifyMap.Range(func(key, value any) bool {
pid := key.(PollerID)
devices := p.UserIDToDeviceIDs[pid.UserID]
devices = append(devices, pid.DeviceID)
p.UserIDToDeviceIDs[pid.UserID] = devices
// clear the map of this value
t.notifyMap.Delete(key)
return true // keep enumerating
})
// notify if we have entries
if len(p.UserIDToDeviceIDs) > 0 {
t.fn(&p)
}
}
// Blocks forever, ticking until Stop() is called.
func (t *DeviceDataTicker) Run() {
if t.ticker == nil {
return
}
for {
select {
case <-t.done:
return
case <-t.ticker.C:
t.emitUpdate()
}
}
}

View File

@ -0,0 +1,125 @@
package sync2
import (
"reflect"
"sort"
"sync"
"testing"
"time"
"github.com/matrix-org/sliding-sync/pubsub"
)
func TestDeviceTickerBasic(t *testing.T) {
duration := time.Millisecond
ticker := NewDeviceDataTicker(duration)
var payloads []*pubsub.V2DeviceData
ticker.SetCallback(func(payload *pubsub.V2DeviceData) {
payloads = append(payloads, payload)
})
var wg sync.WaitGroup
wg.Add(1)
go func() {
t.Log("starting the ticker")
ticker.Run()
wg.Done()
}()
time.Sleep(duration * 2) // wait until the ticker is consuming
t.Log("remembering a poller")
ticker.Remember(PollerID{
UserID: "a",
DeviceID: "b",
})
time.Sleep(duration * 2)
if len(payloads) != 1 {
t.Fatalf("expected 1 callback, got %d", len(payloads))
}
want := map[string][]string{
"a": {"b"},
}
assertPayloadEqual(t, payloads[0].UserIDToDeviceIDs, want)
// check stopping works
payloads = []*pubsub.V2DeviceData{}
ticker.Stop()
wg.Wait()
time.Sleep(duration * 2)
if len(payloads) != 0 {
t.Fatalf("got extra payloads: %+v", payloads)
}
}
func TestDeviceTickerBatchesCorrectly(t *testing.T) {
duration := 100 * time.Millisecond
ticker := NewDeviceDataTicker(duration)
var payloads []*pubsub.V2DeviceData
ticker.SetCallback(func(payload *pubsub.V2DeviceData) {
payloads = append(payloads, payload)
})
go ticker.Run()
defer ticker.Stop()
ticker.Remember(PollerID{
UserID: "a",
DeviceID: "b",
})
ticker.Remember(PollerID{
UserID: "a",
DeviceID: "bb", // different device, same user
})
ticker.Remember(PollerID{
UserID: "a",
DeviceID: "b", // dupe poller ID
})
ticker.Remember(PollerID{
UserID: "x",
DeviceID: "y", // new device and user
})
time.Sleep(duration * 2)
if len(payloads) != 1 {
t.Fatalf("expected 1 callback, got %d", len(payloads))
}
want := map[string][]string{
"a": {"b", "bb"},
"x": {"y"},
}
assertPayloadEqual(t, payloads[0].UserIDToDeviceIDs, want)
}
func TestDeviceTickerForgetsAfterEmitting(t *testing.T) {
duration := time.Millisecond
ticker := NewDeviceDataTicker(duration)
var payloads []*pubsub.V2DeviceData
ticker.SetCallback(func(payload *pubsub.V2DeviceData) {
payloads = append(payloads, payload)
})
ticker.Remember(PollerID{
UserID: "a",
DeviceID: "b",
})
go ticker.Run()
defer ticker.Stop()
ticker.Remember(PollerID{
UserID: "a",
DeviceID: "b",
})
time.Sleep(10 * duration)
if len(payloads) != 1 {
t.Fatalf("got %d payloads, want 1", len(payloads))
}
}
func assertPayloadEqual(t *testing.T, got, want map[string][]string) {
t.Helper()
if len(got) != len(want) {
t.Fatalf("got %+v\nwant %+v\n", got, want)
}
for userID, wantDeviceIDs := range want {
gotDeviceIDs := got[userID]
sort.Strings(wantDeviceIDs)
sort.Strings(gotDeviceIDs)
if !reflect.DeepEqual(gotDeviceIDs, wantDeviceIDs) {
t.Errorf("user %v got devices %v want %v", userID, gotDeviceIDs, wantDeviceIDs)
}
}
}

View File

@ -32,8 +32,8 @@ func NewDevicesTable(db *sqlx.DB) *DevicesTable {
// InsertDevice creates a new devices row with a blank since token if no such row
// exists. Otherwise, it does nothing.
func (t *DevicesTable) InsertDevice(userID, deviceID string) error {
_, err := t.db.Exec(
func (t *DevicesTable) InsertDevice(txn *sqlx.Tx, userID, deviceID string) error {
_, err := txn.Exec(
` INSERT INTO syncv3_sync2_devices(user_id, device_id, since) VALUES($1,$2,$3)
ON CONFLICT (user_id, device_id) DO NOTHING`,
userID, deviceID, "",

View File

@ -2,6 +2,7 @@ package sync2
import (
"github.com/jmoiron/sqlx"
"github.com/matrix-org/sliding-sync/sqlutil"
"os"
"sort"
"testing"
@ -41,18 +42,25 @@ func TestDevicesTableSinceColumn(t *testing.T) {
aliceSecret1 := "mysecret1"
aliceSecret2 := "mysecret2"
t.Log("Insert two tokens for Alice.")
aliceToken, err := tokens.Insert(aliceSecret1, alice, aliceDevice, time.Now())
if err != nil {
t.Fatalf("Failed to Insert token: %s", err)
}
aliceToken2, err := tokens.Insert(aliceSecret2, alice, aliceDevice, time.Now())
if err != nil {
t.Fatalf("Failed to Insert token: %s", err)
}
var aliceToken, aliceToken2 *Token
_ = sqlutil.WithTransaction(db, func(txn *sqlx.Tx) (err error) {
t.Log("Insert two tokens for Alice.")
aliceToken, err = tokens.Insert(txn, aliceSecret1, alice, aliceDevice, time.Now())
if err != nil {
t.Fatalf("Failed to Insert token: %s", err)
}
aliceToken2, err = tokens.Insert(txn, aliceSecret2, alice, aliceDevice, time.Now())
if err != nil {
t.Fatalf("Failed to Insert token: %s", err)
}
t.Log("Add a devices row for Alice")
err = devices.InsertDevice(alice, aliceDevice)
t.Log("Add a devices row for Alice")
err = devices.InsertDevice(txn, alice, aliceDevice)
if err != nil {
t.Fatalf("Failed to Insert device: %s", err)
}
return nil
})
t.Log("Pretend we're about to start a poller. Fetch Alice's token along with the since value tracked by the devices table.")
accessToken, since, err := tokens.GetTokenAndSince(alice, aliceDevice, aliceToken.AccessTokenHash)
@ -104,39 +112,49 @@ func TestTokenForEachDevice(t *testing.T) {
chris := "chris"
chrisDevice := "chris_desktop"
t.Log("Add a device for Alice, Bob and Chris.")
err := devices.InsertDevice(alice, aliceDevice)
if err != nil {
t.Fatalf("InsertDevice returned error: %s", err)
}
err = devices.InsertDevice(bob, bobDevice)
if err != nil {
t.Fatalf("InsertDevice returned error: %s", err)
}
err = devices.InsertDevice(chris, chrisDevice)
if err != nil {
t.Fatalf("InsertDevice returned error: %s", err)
}
_ = sqlutil.WithTransaction(db, func(txn *sqlx.Tx) error {
t.Log("Add a device for Alice, Bob and Chris.")
err := devices.InsertDevice(txn, alice, aliceDevice)
if err != nil {
t.Fatalf("InsertDevice returned error: %s", err)
}
err = devices.InsertDevice(txn, bob, bobDevice)
if err != nil {
t.Fatalf("InsertDevice returned error: %s", err)
}
err = devices.InsertDevice(txn, chris, chrisDevice)
if err != nil {
t.Fatalf("InsertDevice returned error: %s", err)
}
return nil
})
t.Log("Mark Alice's device with a since token.")
sinceValue := "s-1-2-3-4"
devices.UpdateDeviceSince(alice, aliceDevice, sinceValue)
err := devices.UpdateDeviceSince(alice, aliceDevice, sinceValue)
if err != nil {
t.Fatalf("UpdateDeviceSince returned error: %s", err)
}
t.Log("Insert 2 tokens for Alice, one for Bob and none for Chris.")
aliceLastSeen1 := time.Now()
_, err = tokens.Insert("alice_secret", alice, aliceDevice, aliceLastSeen1)
if err != nil {
t.Fatalf("Failed to Insert token: %s", err)
}
aliceLastSeen2 := aliceLastSeen1.Add(1 * time.Minute)
aliceToken2, err := tokens.Insert("alice_secret2", alice, aliceDevice, aliceLastSeen2)
if err != nil {
t.Fatalf("Failed to Insert token: %s", err)
}
bobToken, err := tokens.Insert("bob_secret", bob, bobDevice, time.Time{})
if err != nil {
t.Fatalf("Failed to Insert token: %s", err)
}
var aliceToken2, bobToken *Token
_ = sqlutil.WithTransaction(db, func(txn *sqlx.Tx) error {
t.Log("Insert 2 tokens for Alice, one for Bob and none for Chris.")
aliceLastSeen1 := time.Now()
_, err = tokens.Insert(txn, "alice_secret", alice, aliceDevice, aliceLastSeen1)
if err != nil {
t.Fatalf("Failed to Insert token: %s", err)
}
aliceLastSeen2 := aliceLastSeen1.Add(1 * time.Minute)
aliceToken2, err = tokens.Insert(txn, "alice_secret2", alice, aliceDevice, aliceLastSeen2)
if err != nil {
t.Fatalf("Failed to Insert token: %s", err)
}
bobToken, err = tokens.Insert(txn, "bob_secret", bob, bobDevice, time.Time{})
if err != nil {
t.Fatalf("Failed to Insert token: %s", err)
}
return nil
})
t.Log("Fetch a token for every device")
gotTokens, err := tokens.TokenForEachDevice(nil)

View File

@ -4,11 +4,13 @@ import (
"context"
"encoding/json"
"fmt"
"github.com/jmoiron/sqlx"
"github.com/matrix-org/sliding-sync/sqlutil"
"hash/fnv"
"os"
"sync"
"time"
"github.com/jmoiron/sqlx"
"github.com/matrix-org/sliding-sync/sqlutil"
"github.com/getsentry/sentry-go"
@ -30,38 +32,48 @@ var logger = zerolog.New(os.Stdout).With().Timestamp().Logger().Output(zerolog.C
// processing v2 data (as a sync2.V2DataReceiver) and publishing updates (pubsub.Payload to V2Listeners);
// and receiving and processing EnsurePolling events.
type Handler struct {
pMap *sync2.PollerMap
v2Store *sync2.Storage
Store *state.Storage
v2Pub pubsub.Notifier
v3Sub *pubsub.V3Sub
client sync2.Client
unreadMap map[string]struct {
pMap sync2.IPollerMap
v2Store *sync2.Storage
Store *state.Storage
v2Pub pubsub.Notifier
v3Sub *pubsub.V3Sub
// user_id|room_id|event_type => fnv_hash(last_event_bytes)
accountDataMap *sync.Map
unreadMap map[string]struct {
Highlight int
Notif int
}
// room_id => fnv_hash([typing user ids])
typingMap map[string]uint64
// room_id -> PollerID, stores which Poller is allowed to update typing notifications
typingHandler map[string]sync2.PollerID
typingMu *sync.Mutex
PendingTxnIDs *sync2.PendingTransactionIDs
deviceDataTicker *sync2.DeviceDataTicker
e2eeWorkerPool *internal.WorkerPool
numPollers prometheus.Gauge
subSystem string
}
func NewHandler(
connStr string, pMap *sync2.PollerMap, v2Store *sync2.Storage, store *state.Storage, client sync2.Client,
pub pubsub.Notifier, sub pubsub.Listener, enablePrometheus bool,
pMap sync2.IPollerMap, v2Store *sync2.Storage, store *state.Storage,
pub pubsub.Notifier, sub pubsub.Listener, enablePrometheus bool, deviceDataUpdateDuration time.Duration,
) (*Handler, error) {
h := &Handler{
pMap: pMap,
v2Store: v2Store,
client: client,
Store: store,
subSystem: "poller",
unreadMap: make(map[string]struct {
Highlight int
Notif int
}),
typingMap: make(map[string]uint64),
accountDataMap: &sync.Map{},
typingMu: &sync.Mutex{},
typingHandler: make(map[string]sync2.PollerID),
PendingTxnIDs: sync2.NewPendingTransactionIDs(pMap.DeviceIDs),
deviceDataTicker: sync2.NewDeviceDataTicker(deviceDataUpdateDuration),
e2eeWorkerPool: internal.NewWorkerPool(500), // TODO: assign as fraction of db max conns, not hardcoded
}
if enablePrometheus {
@ -86,6 +98,9 @@ func (h *Handler) Listen() {
sentry.CaptureException(err)
}
}()
h.e2eeWorkerPool.Start()
h.deviceDataTicker.SetCallback(h.OnBulkDeviceDataUpdate)
go h.deviceDataTicker.Run()
}
func (h *Handler) Teardown() {
@ -95,6 +110,7 @@ func (h *Handler) Teardown() {
h.Store.Teardown()
h.v2Store.Teardown()
h.pMap.Terminate()
h.deviceDataTicker.Stop()
if h.numPollers != nil {
prometheus.Unregister(h.numPollers)
}
@ -156,7 +172,15 @@ func (h *Handler) updateMetrics() {
h.numPollers.Set(float64(h.pMap.NumPollers()))
}
func (h *Handler) OnTerminated(ctx context.Context, userID, deviceID string) {
func (h *Handler) OnTerminated(ctx context.Context, pollerID sync2.PollerID) {
// Check if this device is handling any typing notifications, of so, remove it
h.typingMu.Lock()
defer h.typingMu.Unlock()
for roomID, devID := range h.typingHandler {
if devID == pollerID {
delete(h.typingHandler, roomID)
}
}
h.updateMetrics()
}
@ -193,42 +217,62 @@ func (h *Handler) UpdateDeviceSince(ctx context.Context, userID, deviceID, since
}
func (h *Handler) OnE2EEData(ctx context.Context, userID, deviceID string, otkCounts map[string]int, fallbackKeyTypes []string, deviceListChanges map[string]int) {
// some of these fields may be set
partialDD := internal.DeviceData{
UserID: userID,
DeviceID: deviceID,
OTKCounts: otkCounts,
FallbackKeyTypes: fallbackKeyTypes,
DeviceLists: internal.DeviceLists{
New: deviceListChanges,
},
}
nextPos, err := h.Store.DeviceDataTable.Upsert(&partialDD)
if err != nil {
logger.Err(err).Str("user", userID).Msg("failed to upsert device data")
internal.GetSentryHubFromContextOrDefault(ctx).CaptureException(err)
return
}
h.v2Pub.Notify(pubsub.ChanV2, &pubsub.V2DeviceData{
UserID: userID,
DeviceID: deviceID,
Pos: nextPos,
var wg sync.WaitGroup
wg.Add(1)
h.e2eeWorkerPool.Queue(func() {
defer wg.Done()
// some of these fields may be set
partialDD := internal.DeviceData{
UserID: userID,
DeviceID: deviceID,
OTKCounts: otkCounts,
FallbackKeyTypes: fallbackKeyTypes,
DeviceLists: internal.DeviceLists{
New: deviceListChanges,
},
}
err := h.Store.DeviceDataTable.Upsert(&partialDD)
if err != nil {
logger.Err(err).Str("user", userID).Msg("failed to upsert device data")
internal.GetSentryHubFromContextOrDefault(ctx).CaptureException(err)
return
}
// remember this to notify on pubsub later
h.deviceDataTicker.Remember(sync2.PollerID{
UserID: userID,
DeviceID: deviceID,
})
})
wg.Wait()
}
// Called periodically by deviceDataTicker, contains many updates
func (h *Handler) OnBulkDeviceDataUpdate(payload *pubsub.V2DeviceData) {
h.v2Pub.Notify(pubsub.ChanV2, payload)
}
func (h *Handler) Accumulate(ctx context.Context, userID, deviceID, roomID, prevBatch string, timeline []json.RawMessage) {
// Remember any transaction IDs that may be unique to this user
eventIDsWithTxns := make([]string, 0, len(timeline)) // in timeline order
eventIDToTxnID := make(map[string]string, len(timeline)) // event_id -> txn_id
// Also remember events which were sent by this user but lack a transaction ID.
eventIDsLackingTxns := make([]string, 0, len(timeline))
for _, e := range timeline {
txnID := gjson.GetBytes(e, "unsigned.transaction_id")
if !txnID.Exists() {
parsed := gjson.ParseBytes(e)
eventID := parsed.Get("event_id").Str
if txnID := parsed.Get("unsigned.transaction_id"); txnID.Exists() {
eventIDsWithTxns = append(eventIDsWithTxns, eventID)
eventIDToTxnID[eventID] = txnID.Str
continue
}
eventID := gjson.GetBytes(e, "event_id").Str
eventIDsWithTxns = append(eventIDsWithTxns, eventID)
eventIDToTxnID[eventID] = txnID.Str
if sender := parsed.Get("sender"); sender.Str == userID {
eventIDsLackingTxns = append(eventIDsLackingTxns, eventID)
}
}
if len(eventIDToTxnID) > 0 {
// persist the txn IDs
err := h.Store.TransactionsTable.Insert(userID, deviceID, eventIDToTxnID)
@ -239,62 +283,69 @@ func (h *Handler) Accumulate(ctx context.Context, userID, deviceID, roomID, prev
}
// Insert new events
numNew, latestNIDs, err := h.Store.Accumulate(roomID, prevBatch, timeline)
numNew, latestNIDs, err := h.Store.Accumulate(userID, roomID, prevBatch, timeline)
if err != nil {
logger.Err(err).Int("timeline", len(timeline)).Str("room", roomID).Msg("V2: failed to accumulate room")
internal.GetSentryHubFromContextOrDefault(ctx).CaptureException(err)
return
}
if numNew == 0 {
// no new events
return
}
h.v2Pub.Notify(pubsub.ChanV2, &pubsub.V2Accumulate{
RoomID: roomID,
PrevBatch: prevBatch,
EventNIDs: latestNIDs,
})
if len(eventIDToTxnID) > 0 {
// We've updated the database. Now tell any pubsub listeners what we learned.
if numNew != 0 {
h.v2Pub.Notify(pubsub.ChanV2, &pubsub.V2Accumulate{
RoomID: roomID,
PrevBatch: prevBatch,
EventNIDs: latestNIDs,
})
}
if len(eventIDToTxnID) > 0 || len(eventIDsLackingTxns) > 0 {
// The call to h.Store.Accumulate above only tells us about new events' NIDS;
// for existing events we need to requery the database to fetch them.
// Rather than try to reuse work, keep things simple and just fetch NIDs for
// all events with txnIDs.
var nidsByIDs map[string]int64
eventIDsToFetch := append(eventIDsWithTxns, eventIDsLackingTxns...)
err = sqlutil.WithTransaction(h.Store.DB, func(txn *sqlx.Tx) error {
nidsByIDs, err = h.Store.EventsTable.SelectNIDsByIDs(txn, eventIDsWithTxns)
nidsByIDs, err = h.Store.EventsTable.SelectNIDsByIDs(txn, eventIDsToFetch)
return err
})
if err != nil {
logger.Err(err).
Int("timeline", len(timeline)).
Int("num_transaction_ids", len(eventIDsWithTxns)).
Int("num_missing_transaction_ids", len(eventIDsLackingTxns)).
Str("room", roomID).
Msg("V2: failed to fetch nids for events with transaction_ids")
Msg("V2: failed to fetch nids for event transaction_id handling")
internal.GetSentryHubFromContextOrDefault(ctx).CaptureException(err)
return
}
for _, eventID := range eventIDsWithTxns {
for eventID, nid := range nidsByIDs {
txnID, ok := eventIDToTxnID[eventID]
if !ok {
continue
if ok {
h.PendingTxnIDs.SeenTxnID(eventID)
h.v2Pub.Notify(pubsub.ChanV2, &pubsub.V2TransactionID{
EventID: eventID,
RoomID: roomID,
UserID: userID,
DeviceID: deviceID,
TransactionID: txnID,
NID: nid,
})
} else {
allClear, _ := h.PendingTxnIDs.MissingTxnID(eventID, userID, deviceID)
if allClear {
h.v2Pub.Notify(pubsub.ChanV2, &pubsub.V2TransactionID{
EventID: eventID,
RoomID: roomID,
UserID: userID,
DeviceID: deviceID,
TransactionID: "",
NID: nid,
})
}
}
nid, ok := nidsByIDs[eventID]
if !ok {
errMsg := "V2: failed to fetch NID for txnID"
logger.Error().Str("user", userID).Str("device", deviceID).Msg(errMsg)
internal.GetSentryHubFromContextOrDefault(ctx).CaptureException(fmt.Errorf("errMsg"))
continue
}
h.v2Pub.Notify(pubsub.ChanV2, &pubsub.V2TransactionID{
EventID: eventID,
UserID: userID,
DeviceID: deviceID,
TransactionID: txnID,
NID: nid,
})
}
}
}
@ -315,13 +366,20 @@ func (h *Handler) Initialise(ctx context.Context, roomID string, state []json.Ra
return res.PrependTimelineEvents
}
func (h *Handler) SetTyping(ctx context.Context, roomID string, ephEvent json.RawMessage) {
next := typingHash(ephEvent)
existing := h.typingMap[roomID]
if existing == next {
func (h *Handler) SetTyping(ctx context.Context, pollerID sync2.PollerID, roomID string, ephEvent json.RawMessage) {
h.typingMu.Lock()
defer h.typingMu.Unlock()
existingDevice := h.typingHandler[roomID]
isPollerAssigned := existingDevice.DeviceID != "" && existingDevice.UserID != ""
if isPollerAssigned && existingDevice != pollerID {
// A different device is already handling typing notifications for this room
return
} else if !isPollerAssigned {
// We're the first to call SetTyping, assign our pollerID
h.typingHandler[roomID] = pollerID
}
h.typingMap[roomID] = next
// we don't persist this for long term storage as typing notifs are inherently ephemeral.
// So rather than maintaining them forever, they will naturally expire when we terminate.
h.v2Pub.Notify(pubsub.ChanV2, &pubsub.V2Typing{
@ -398,7 +456,28 @@ func (h *Handler) UpdateUnreadCounts(ctx context.Context, roomID, userID string,
}
func (h *Handler) OnAccountData(ctx context.Context, userID, roomID string, events []json.RawMessage) {
data, err := h.Store.InsertAccountData(userID, roomID, events)
// duplicate suppression for multiple devices on the same account.
// We suppress by remembering the last bytes for a given account data, and if they match we ignore.
dedupedEvents := make([]json.RawMessage, 0, len(events))
for i := range events {
evType := gjson.GetBytes(events[i], "type").Str
key := fmt.Sprintf("%s|%s|%s", userID, roomID, evType)
thisHash := fnvHash(events[i])
last, _ := h.accountDataMap.Load(key)
if last != nil {
lastHash := last.(uint64)
if lastHash == thisHash {
continue // skip this event
}
}
dedupedEvents = append(dedupedEvents, events[i])
h.accountDataMap.Store(key, thisHash)
}
if len(dedupedEvents) == 0 {
return
}
data, err := h.Store.InsertAccountData(userID, roomID, dedupedEvents)
if err != nil {
logger.Err(err).Str("user", userID).Str("room", roomID).Msg("failed to update account data")
sentry.CaptureException(err)
@ -428,16 +507,24 @@ func (h *Handler) OnInvite(ctx context.Context, userID, roomID string, inviteSta
})
}
func (h *Handler) OnLeftRoom(ctx context.Context, userID, roomID string) {
func (h *Handler) OnLeftRoom(ctx context.Context, userID, roomID string, leaveEv json.RawMessage) {
// remove any invites for this user if they are rejecting an invite
err := h.Store.InvitesTable.RemoveInvite(userID, roomID)
if err != nil {
logger.Err(err).Str("user", userID).Str("room", roomID).Msg("failed to retire invite")
internal.GetSentryHubFromContextOrDefault(ctx).CaptureException(err)
}
// Remove room from the typing deviceHandler map, this ensures we always
// have a device handling typing notifications for a given room.
h.typingMu.Lock()
defer h.typingMu.Unlock()
delete(h.typingHandler, roomID)
h.v2Pub.Notify(pubsub.ChanV2, &pubsub.V2LeaveRoom{
UserID: userID,
RoomID: roomID,
UserID: userID,
RoomID: roomID,
LeaveEvent: leaveEv,
})
}
@ -471,10 +558,8 @@ func (h *Handler) EnsurePolling(p *pubsub.V3EnsurePolling) {
}()
}
func typingHash(ephEvent json.RawMessage) uint64 {
func fnvHash(event json.RawMessage) uint64 {
h := fnv.New64a()
for _, userID := range gjson.ParseBytes(ephEvent).Get("content.user_ids").Array() {
h.Write([]byte(userID.Str))
}
h.Write(event)
return h.Sum64()
}

View File

@ -0,0 +1,224 @@
package handler2_test
import (
"context"
"encoding/json"
"os"
"reflect"
"sync"
"testing"
"time"
"github.com/jmoiron/sqlx"
"github.com/matrix-org/sliding-sync/sqlutil"
"github.com/matrix-org/sliding-sync/pubsub"
"github.com/matrix-org/sliding-sync/state"
"github.com/matrix-org/sliding-sync/sync2"
"github.com/matrix-org/sliding-sync/sync2/handler2"
"github.com/matrix-org/sliding-sync/testutils"
"github.com/rs/zerolog"
)
var postgresURI string
func TestMain(m *testing.M) {
postgresURI = testutils.PrepareDBConnectionString()
exitCode := m.Run()
os.Exit(exitCode)
}
type pollInfo struct {
pid sync2.PollerID
accessToken string
v2since string
isStartup bool
}
type mockPollerMap struct {
calls []pollInfo
}
func (p *mockPollerMap) NumPollers() int {
return 0
}
func (p *mockPollerMap) Terminate() {}
func (p *mockPollerMap) DeviceIDs(userID string) []string {
return nil
}
func (p *mockPollerMap) EnsurePolling(pid sync2.PollerID, accessToken, v2since string, isStartup bool, logger zerolog.Logger) {
p.calls = append(p.calls, pollInfo{
pid: pid,
accessToken: accessToken,
v2since: v2since,
isStartup: isStartup,
})
}
func (p *mockPollerMap) assertCallExists(t *testing.T, pi pollInfo) {
for _, c := range p.calls {
if reflect.DeepEqual(pi, c) {
return
}
}
t.Fatalf("assertCallExists: did not find %+v", pi)
}
type mockPub struct {
calls []pubsub.Payload
mu *sync.Mutex
waiters map[string][]chan struct{}
}
func newMockPub() *mockPub {
return &mockPub{
mu: &sync.Mutex{},
waiters: make(map[string][]chan struct{}),
}
}
// Notify chanName that there is a new payload p. Return an error if we failed to send the notification.
func (p *mockPub) Notify(chanName string, payload pubsub.Payload) error {
p.calls = append(p.calls, payload)
p.mu.Lock()
for _, ch := range p.waiters[payload.Type()] {
close(ch)
}
p.waiters[payload.Type()] = nil // don't re-notify for 2nd+ payload
p.mu.Unlock()
return nil
}
func (p *mockPub) WaitForPayloadType(t string) chan struct{} {
ch := make(chan struct{})
p.mu.Lock()
p.waiters[t] = append(p.waiters[t], ch)
p.mu.Unlock()
return ch
}
func (p *mockPub) DoWait(t *testing.T, errMsg string, ch chan struct{}, wantTimeOut bool) {
select {
case <-ch:
if wantTimeOut {
t.Fatalf("expected to timeout, but received on channel")
}
return
case <-time.After(time.Second):
if !wantTimeOut {
t.Fatalf("DoWait: timed out waiting: %s", errMsg)
}
}
}
// Close is called when we should stop listening.
func (p *mockPub) Close() error { return nil }
type mockSub struct{}
// Begin listening on this channel with this callback starting from this position. Blocks until Close() is called.
func (s *mockSub) Listen(chanName string, fn func(p pubsub.Payload)) error { return nil }
// Close the listener. No more callbacks should fire.
func (s *mockSub) Close() error { return nil }
func assertNoError(t *testing.T, err error) {
t.Helper()
if err == nil {
return
}
t.Fatalf("assertNoError: %v", err)
}
// Test that if you call EnsurePolling you get back V2InitialSyncComplete down pubsub and the poller
// map is called correctly
func TestHandlerFreshEnsurePolling(t *testing.T) {
store := state.NewStorage(postgresURI)
v2Store := sync2.NewStore(postgresURI, "secret")
pMap := &mockPollerMap{}
pub := newMockPub()
sub := &mockSub{}
h, err := handler2.NewHandler(pMap, v2Store, store, pub, sub, false, time.Minute)
assertNoError(t, err)
alice := "@alice:localhost"
deviceID := "ALICE"
token := "aliceToken"
var tok *sync2.Token
sqlutil.WithTransaction(v2Store.DB, func(txn *sqlx.Tx) error {
// the device and token needs to already exist prior to EnsurePolling
err = v2Store.DevicesTable.InsertDevice(txn, alice, deviceID)
assertNoError(t, err)
tok, err = v2Store.TokensTable.Insert(txn, token, alice, deviceID, time.Now())
assertNoError(t, err)
return nil
})
payloadInitialSyncComplete := pubsub.V2InitialSyncComplete{
UserID: alice,
DeviceID: deviceID,
}
ch := pub.WaitForPayloadType(payloadInitialSyncComplete.Type())
// ask the handler to start polling
h.EnsurePolling(&pubsub.V3EnsurePolling{
UserID: alice,
DeviceID: deviceID,
AccessTokenHash: tok.AccessTokenHash,
})
pub.DoWait(t, "didn't see V2InitialSyncComplete", ch, false)
// make sure we polled with the token i.e it did a db hit
pMap.assertCallExists(t, pollInfo{
pid: sync2.PollerID{
UserID: alice,
DeviceID: deviceID,
},
accessToken: token,
v2since: "",
isStartup: false,
})
}
func TestSetTypingConcurrently(t *testing.T) {
store := state.NewStorage(postgresURI)
v2Store := sync2.NewStore(postgresURI, "secret")
pMap := &mockPollerMap{}
pub := newMockPub()
sub := &mockSub{}
h, err := handler2.NewHandler(pMap, v2Store, store, pub, sub, false, time.Minute)
assertNoError(t, err)
ctx := context.Background()
roomID := "!typing:localhost"
typingType := pubsub.V2Typing{}
// startSignal is used to synchronize calling SetTyping
startSignal := make(chan struct{})
// Call SetTyping twice, this may happen with pollers for the same user
go func() {
<-startSignal
h.SetTyping(ctx, sync2.PollerID{UserID: "@alice", DeviceID: "aliceDevice"}, roomID, json.RawMessage(`{"content":{"user_ids":["@alice:localhost"]}}`))
}()
go func() {
<-startSignal
h.SetTyping(ctx, sync2.PollerID{UserID: "@bob", DeviceID: "bobDevice"}, roomID, json.RawMessage(`{"content":{"user_ids":["@alice:localhost"]}}`))
}()
close(startSignal)
// Wait for the event to be published
ch := pub.WaitForPayloadType(typingType.Type())
pub.DoWait(t, "didn't see V2Typing", ch, false)
ch = pub.WaitForPayloadType(typingType.Type())
// Wait again, but this time we expect to timeout.
pub.DoWait(t, "saw unexpected V2Typing", ch, true)
// We expect only one call to Notify, as the hashes should match
if gotCalls := len(pub.calls); gotCalls != 1 {
t.Fatalf("expected only one call to notify, got %d", gotCalls)
}
}

View File

@ -4,12 +4,13 @@ import (
"context"
"encoding/json"
"fmt"
"github.com/getsentry/sentry-go"
"runtime/debug"
"sync"
"sync/atomic"
"time"
"github.com/getsentry/sentry-go"
"github.com/matrix-org/sliding-sync/internal"
"github.com/prometheus/client_golang/prometheus"
"github.com/rs/zerolog"
@ -21,8 +22,12 @@ type PollerID struct {
DeviceID string
}
// alias time.Sleep so tests can monkey patch it out
// alias time.Sleep/time.Since so tests can monkey patch it out
var timeSleep = time.Sleep
var timeSince = time.Since
// log at most once every duration. Always logs before terminating.
var logInterval = 30 * time.Second
// V2DataReceiver is the receiver for all the v2 sync data the poller gets
type V2DataReceiver interface {
@ -34,7 +39,7 @@ type V2DataReceiver interface {
// If given a state delta from an incremental sync, returns the slice of all state events unknown to the DB.
Initialise(ctx context.Context, roomID string, state []json.RawMessage) []json.RawMessage // snapshot ID?
// SetTyping indicates which users are typing.
SetTyping(ctx context.Context, roomID string, ephEvent json.RawMessage)
SetTyping(ctx context.Context, pollerID PollerID, roomID string, ephEvent json.RawMessage)
// Sent when there is a new receipt
OnReceipt(ctx context.Context, userID, roomID, ephEventType string, ephEvent json.RawMessage)
// AddToDeviceMessages adds this chunk of to_device messages. Preserve the ordering.
@ -46,25 +51,35 @@ type V2DataReceiver interface {
// Sent when there is a room in the `invite` section of the v2 response.
OnInvite(ctx context.Context, userID, roomID string, inviteState []json.RawMessage) // invitestate in db
// Sent when there is a room in the `leave` section of the v2 response.
OnLeftRoom(ctx context.Context, userID, roomID string)
OnLeftRoom(ctx context.Context, userID, roomID string, leaveEvent json.RawMessage)
// Sent when there is a _change_ in E2EE data, not all the time
OnE2EEData(ctx context.Context, userID, deviceID string, otkCounts map[string]int, fallbackKeyTypes []string, deviceListChanges map[string]int)
// Sent when the poll loop terminates
OnTerminated(ctx context.Context, userID, deviceID string)
OnTerminated(ctx context.Context, pollerID PollerID)
// Sent when the token gets a 401 response
OnExpiredToken(ctx context.Context, accessTokenHash, userID, deviceID string)
}
type IPollerMap interface {
EnsurePolling(pid PollerID, accessToken, v2since string, isStartup bool, logger zerolog.Logger)
NumPollers() int
Terminate()
DeviceIDs(userID string) []string
}
// PollerMap is a map of device ID to Poller
type PollerMap struct {
v2Client Client
callbacks V2DataReceiver
pollerMu *sync.Mutex
Pollers map[PollerID]*poller
executor chan func()
executorRunning bool
processHistogramVec *prometheus.HistogramVec
timelineSizeHistogramVec *prometheus.HistogramVec
v2Client Client
callbacks V2DataReceiver
pollerMu *sync.Mutex
Pollers map[PollerID]*poller
executor chan func()
executorRunning bool
processHistogramVec *prometheus.HistogramVec
timelineSizeHistogramVec *prometheus.HistogramVec
gappyStateSizeVec *prometheus.HistogramVec
numOutstandingSyncReqsGauge prometheus.Gauge
totalNumPollsCounter prometheus.Counter
}
// NewPollerMap makes a new PollerMap. Guarantees that the V2DataReceiver will be called on the same
@ -115,7 +130,28 @@ func NewPollerMap(v2Client Client, enablePrometheus bool) *PollerMap {
Buckets: []float64{0.0, 1.0, 2.0, 5.0, 10.0, 20.0, 50.0},
}, []string{"limited"})
prometheus.MustRegister(pm.timelineSizeHistogramVec)
pm.gappyStateSizeVec = prometheus.NewHistogramVec(prometheus.HistogramOpts{
Namespace: "sliding_sync",
Subsystem: "poller",
Name: "gappy_state_size",
Help: "Number of events in a state block during a sync v2 gappy sync",
Buckets: []float64{1.0, 10.0, 100.0, 1000.0, 10000.0},
}, nil)
prometheus.MustRegister(pm.gappyStateSizeVec)
pm.totalNumPollsCounter = prometheus.NewCounter(prometheus.CounterOpts{
Namespace: "sliding_sync",
Subsystem: "poller",
Name: "total_num_polls",
Help: "Total number of poll loops iterated.",
})
prometheus.MustRegister(pm.totalNumPollsCounter)
pm.numOutstandingSyncReqsGauge = prometheus.NewGauge(prometheus.GaugeOpts{
Namespace: "sliding_sync",
Subsystem: "poller",
Name: "num_outstanding_sync_v2_reqs",
Help: "Number of sync v2 requests that have yet to return a response.",
})
prometheus.MustRegister(pm.numOutstandingSyncReqsGauge)
}
return pm
}
@ -137,6 +173,15 @@ func (h *PollerMap) Terminate() {
if h.timelineSizeHistogramVec != nil {
prometheus.Unregister(h.timelineSizeHistogramVec)
}
if h.gappyStateSizeVec != nil {
prometheus.Unregister(h.gappyStateSizeVec)
}
if h.totalNumPollsCounter != nil {
prometheus.Unregister(h.totalNumPollsCounter)
}
if h.numOutstandingSyncReqsGauge != nil {
prometheus.Unregister(h.numOutstandingSyncReqsGauge)
}
close(h.executor)
}
@ -151,6 +196,20 @@ func (h *PollerMap) NumPollers() (count int) {
return
}
// DeviceIDs returns the slice of all devices currently being polled for by this user.
// The return value is brand-new and is fully owned by the caller.
func (h *PollerMap) DeviceIDs(userID string) []string {
h.pollerMu.Lock()
defer h.pollerMu.Unlock()
var devices []string
for _, p := range h.Pollers {
if !p.terminated.Load() && p.userID == userID {
devices = append(devices, p.deviceID)
}
}
return devices
}
// EnsurePolling makes sure there is a poller for this device, making one if need be.
// Blocks until at least 1 sync is done if and only if the poller was just created.
// This ensures that calls to the database will return data.
@ -196,6 +255,9 @@ func (h *PollerMap) EnsurePolling(pid PollerID, accessToken, v2since string, isS
poller = newPoller(pid, accessToken, h.v2Client, h, logger, !needToWait && !isStartup)
poller.processHistogramVec = h.processHistogramVec
poller.timelineSizeVec = h.timelineSizeHistogramVec
poller.gappyStateSizeVec = h.gappyStateSizeVec
poller.numOutstandingSyncReqs = h.numOutstandingSyncReqsGauge
poller.totalNumPolls = h.totalNumPollsCounter
go poller.Poll(v2since)
h.Pollers[pid] = poller
@ -235,11 +297,11 @@ func (h *PollerMap) Initialise(ctx context.Context, roomID string, state []json.
wg.Wait()
return
}
func (h *PollerMap) SetTyping(ctx context.Context, roomID string, ephEvent json.RawMessage) {
func (h *PollerMap) SetTyping(ctx context.Context, pollerID PollerID, roomID string, ephEvent json.RawMessage) {
var wg sync.WaitGroup
wg.Add(1)
h.executor <- func() {
h.callbacks.SetTyping(ctx, roomID, ephEvent)
h.callbacks.SetTyping(ctx, pollerID, roomID, ephEvent)
wg.Done()
}
wg.Wait()
@ -254,11 +316,11 @@ func (h *PollerMap) OnInvite(ctx context.Context, userID, roomID string, inviteS
wg.Wait()
}
func (h *PollerMap) OnLeftRoom(ctx context.Context, userID, roomID string) {
func (h *PollerMap) OnLeftRoom(ctx context.Context, userID, roomID string, leaveEvent json.RawMessage) {
var wg sync.WaitGroup
wg.Add(1)
h.executor <- func() {
h.callbacks.OnLeftRoom(ctx, userID, roomID)
h.callbacks.OnLeftRoom(ctx, userID, roomID, leaveEvent)
wg.Done()
}
wg.Wait()
@ -270,8 +332,8 @@ func (h *PollerMap) AddToDeviceMessages(ctx context.Context, userID, deviceID st
h.callbacks.AddToDeviceMessages(ctx, userID, deviceID, msgs)
}
func (h *PollerMap) OnTerminated(ctx context.Context, userID, deviceID string) {
h.callbacks.OnTerminated(ctx, userID, deviceID)
func (h *PollerMap) OnTerminated(ctx context.Context, pollerID PollerID) {
h.callbacks.OnTerminated(ctx, pollerID)
}
func (h *PollerMap) OnExpiredToken(ctx context.Context, accessTokenHash, userID, deviceID string) {
@ -335,9 +397,24 @@ type poller struct {
terminated *atomic.Bool
wg *sync.WaitGroup
pollHistogramVec *prometheus.HistogramVec
processHistogramVec *prometheus.HistogramVec
timelineSizeVec *prometheus.HistogramVec
// stats about poll response data, for logging purposes
lastLogged time.Time
totalStateCalls int
totalTimelineCalls int
totalReceipts int
totalTyping int
totalInvites int
totalDeviceEvents int
totalAccountData int
totalChangedDeviceLists int
totalLeftDeviceLists int
pollHistogramVec *prometheus.HistogramVec
processHistogramVec *prometheus.HistogramVec
timelineSizeVec *prometheus.HistogramVec
gappyStateSizeVec *prometheus.HistogramVec
numOutstandingSyncReqs prometheus.Gauge
totalNumPolls prometheus.Counter
}
func newPoller(pid PollerID, accessToken string, client Client, receiver V2DataReceiver, logger zerolog.Logger, initialToDeviceOnly bool) *poller {
@ -366,9 +443,10 @@ func (p *poller) Terminate() {
}
type pollLoopState struct {
firstTime bool
failCount int
since string
firstTime bool
failCount int
since string
lastStoredSince time.Time // The time we last stored the since token in the database
}
// Poll will block forever, repeatedly calling v2 sync. Do this in a goroutine.
@ -392,16 +470,21 @@ func (p *poller) Poll(since string) {
defer func() {
panicErr := recover()
if panicErr != nil {
logger.Error().Str("user", p.userID).Str("device", p.deviceID).Msg(string(debug.Stack()))
logger.Error().Str("user", p.userID).Str("device", p.deviceID).Msgf("%s. Traceback:\n%s", panicErr, debug.Stack())
internal.GetSentryHubFromContextOrDefault(ctx).RecoverWithContext(ctx, panicErr)
}
p.receiver.OnTerminated(ctx, p.userID, p.deviceID)
p.receiver.OnTerminated(ctx, PollerID{
UserID: p.userID,
DeviceID: p.deviceID,
})
}()
state := pollLoopState{
firstTime: true,
failCount: 0,
since: since,
// Setting time.Time{} results in the first poll loop to immediately store the since token.
lastStoredSince: time.Time{},
}
for !p.terminated.Load() {
ctx, task := internal.StartTask(ctx, "Poll")
@ -411,6 +494,7 @@ func (p *poller) Poll(since string) {
break
}
}
p.maybeLogStats(true)
// always unblock EnsurePolling else we can end up head-of-line blocking other pollers!
if state.firstTime {
state.firstTime = false
@ -422,6 +506,9 @@ func (p *poller) Poll(since string) {
// s (which is assumed to be non-nil). Returns a non-nil error iff the poller loop
// should halt.
func (p *poller) poll(ctx context.Context, s *pollLoopState) error {
if p.totalNumPolls != nil {
p.totalNumPolls.Inc()
}
if s.failCount > 0 {
// don't backoff when doing v2 syncs because the response is only in the cache for a short
// period of time (on massive accounts on matrix.org) such that if you wait 2,4,8min between
@ -435,15 +522,22 @@ func (p *poller) poll(ctx context.Context, s *pollLoopState) error {
}
start := time.Now()
spanCtx, region := internal.StartSpan(ctx, "DoSyncV2")
if p.numOutstandingSyncReqs != nil {
p.numOutstandingSyncReqs.Inc()
}
resp, statusCode, err := p.client.DoSyncV2(spanCtx, p.accessToken, s.since, s.firstTime, p.initialToDeviceOnly)
if p.numOutstandingSyncReqs != nil {
p.numOutstandingSyncReqs.Dec()
}
region.End()
p.trackRequestDuration(time.Since(start), s.since == "", s.firstTime)
p.trackRequestDuration(timeSince(start), s.since == "", s.firstTime)
if p.terminated.Load() {
return fmt.Errorf("poller terminated")
}
if err != nil {
// check if temporary
if statusCode != 401 {
isFatal := statusCode == 401 || statusCode == 403
if !isFatal {
p.logger.Warn().Int("code", statusCode).Err(err).Msg("Poller: sync v2 poll returned temporary error")
s.failCount += 1
return nil
@ -472,14 +566,19 @@ func (p *poller) poll(ctx context.Context, s *pollLoopState) error {
wasFirst := s.firstTime
s.since = resp.NextBatch
// persist the since token (TODO: this could get slow if we hammer the DB too much)
p.receiver.UpdateDeviceSince(ctx, p.userID, p.deviceID, s.since)
// Persist the since token if it either was more than one minute ago since we
// last stored it OR the response contains to-device messages
if timeSince(s.lastStoredSince) > time.Minute || len(resp.ToDevice.Events) > 0 {
p.receiver.UpdateDeviceSince(ctx, p.userID, p.deviceID, s.since)
s.lastStoredSince = time.Now()
}
if s.firstTime {
s.firstTime = false
p.wg.Done()
}
p.trackProcessDuration(time.Since(start), wasInitial, wasFirst)
p.trackProcessDuration(timeSince(start), wasInitial, wasFirst)
p.maybeLogStats(false)
return nil
}
@ -518,6 +617,7 @@ func (p *poller) parseToDeviceMessages(ctx context.Context, res *SyncResponse) {
if len(res.ToDevice.Events) == 0 {
return
}
p.totalDeviceEvents += len(res.ToDevice.Events)
p.receiver.AddToDeviceMessages(ctx, p.userID, p.deviceID, res.ToDevice.Events)
}
@ -556,6 +656,8 @@ func (p *poller) parseE2EEData(ctx context.Context, res *SyncResponse) {
deviceListChanges := internal.ToDeviceListChangesMap(res.DeviceLists.Changed, res.DeviceLists.Left)
if deviceListChanges != nil || changedFallbackTypes != nil || changedOTKCounts != nil {
p.totalChangedDeviceLists += len(res.DeviceLists.Changed)
p.totalLeftDeviceLists += len(res.DeviceLists.Left)
p.receiver.OnE2EEData(ctx, p.userID, p.deviceID, changedOTKCounts, changedFallbackTypes, deviceListChanges)
}
}
@ -566,6 +668,7 @@ func (p *poller) parseGlobalAccountData(ctx context.Context, res *SyncResponse)
if len(res.AccountData.Events) == 0 {
return
}
p.totalAccountData += len(res.AccountData.Events)
p.receiver.OnAccountData(ctx, p.userID, AccountDataGlobalRoom, res.AccountData.Events)
}
@ -596,6 +699,7 @@ func (p *poller) parseRoomsResponse(ctx context.Context, res *SyncResponse) {
})
hub.CaptureMessage(warnMsg)
})
p.trackGappyStateSize(len(prependStateEvents))
roomData.Timeline.Events = append(prependStateEvents, roomData.Timeline.Events...)
}
}
@ -605,7 +709,7 @@ func (p *poller) parseRoomsResponse(ctx context.Context, res *SyncResponse) {
switch ephEventType {
case "m.typing":
typingCalls++
p.receiver.SetTyping(ctx, roomID, ephEvent)
p.receiver.SetTyping(ctx, PollerID{UserID: p.userID, DeviceID: p.deviceID}, roomID, ephEvent)
case "m.receipt":
receiptCalls++
p.receiver.OnReceipt(ctx, p.userID, roomID, ephEventType, ephEvent)
@ -630,27 +734,60 @@ func (p *poller) parseRoomsResponse(ctx context.Context, res *SyncResponse) {
}
}
for roomID, roomData := range res.Rooms.Leave {
// TODO: do we care about state?
if len(roomData.Timeline.Events) > 0 {
p.trackTimelineSize(len(roomData.Timeline.Events), roomData.Timeline.Limited)
p.receiver.Accumulate(ctx, p.userID, p.deviceID, roomID, roomData.Timeline.PrevBatch, roomData.Timeline.Events)
}
p.receiver.OnLeftRoom(ctx, p.userID, roomID)
// Pass the leave event directly to OnLeftRoom. We need to do this _in addition_ to calling Accumulate to handle
// the case where a user rejects an invite (there will be no room state, but the user still expects to see the leave event).
var leaveEvent json.RawMessage
for _, ev := range roomData.Timeline.Events {
leaveEv := gjson.ParseBytes(ev)
if leaveEv.Get("content.membership").Str == "leave" && leaveEv.Get("state_key").Str == p.userID {
leaveEvent = ev
break
}
}
if leaveEvent != nil {
p.receiver.OnLeftRoom(ctx, p.userID, roomID, leaveEvent)
}
}
for roomID, roomData := range res.Rooms.Invite {
p.receiver.OnInvite(ctx, p.userID, roomID, roomData.InviteState.Events)
}
var l *zerolog.Event
if len(res.Rooms.Invite) > 0 || len(res.Rooms.Join) > 0 {
l = p.logger.Info()
} else {
l = p.logger.Debug()
p.totalReceipts += receiptCalls
p.totalStateCalls += stateCalls
p.totalTimelineCalls += timelineCalls
p.totalTyping += typingCalls
p.totalInvites += len(res.Rooms.Invite)
}
func (p *poller) maybeLogStats(force bool) {
if !force && timeSince(p.lastLogged) < logInterval {
// only log at most once every logInterval
return
}
l.Ints(
"rooms [invite,join,leave]", []int{len(res.Rooms.Invite), len(res.Rooms.Join), len(res.Rooms.Leave)},
p.lastLogged = time.Now()
p.logger.Info().Ints(
"rooms [timeline,state,typing,receipts,invites]", []int{
p.totalTimelineCalls, p.totalStateCalls, p.totalTyping, p.totalReceipts, p.totalInvites,
},
).Ints(
"storage [states,timelines,typing,receipts]", []int{stateCalls, timelineCalls, typingCalls, receiptCalls},
).Int("to_device", len(res.ToDevice.Events)).Msg("Poller: accumulated data")
"device [events,changed,left,account]", []int{
p.totalDeviceEvents, p.totalChangedDeviceLists, p.totalLeftDeviceLists, p.totalAccountData,
},
).Msg("Poller: accumulated data")
p.totalAccountData = 0
p.totalChangedDeviceLists = 0
p.totalDeviceEvents = 0
p.totalInvites = 0
p.totalLeftDeviceLists = 0
p.totalReceipts = 0
p.totalStateCalls = 0
p.totalTimelineCalls = 0
p.totalTyping = 0
}
func (p *poller) trackTimelineSize(size int, limited bool) {
@ -663,3 +800,10 @@ func (p *poller) trackTimelineSize(size int, limited bool) {
}
p.timelineSizeVec.WithLabelValues(label).Observe(float64(size))
}
func (p *poller) trackGappyStateSize(size int) {
if p.gappyStateSizeVec == nil {
return
}
p.gappyStateSizeVec.WithLabelValues().Observe(float64(size))
}

View File

@ -277,6 +277,9 @@ func TestPollerPollFromExisting(t *testing.T) {
json.RawMessage(`{"event":10}`),
},
}
toDeviceResponses := [][]json.RawMessage{
{}, {}, {}, {json.RawMessage(`{}`)},
}
hasPolledSuccessfully := make(chan struct{})
accumulator, client := newMocks(func(authHeader, since string) (*SyncResponse, int, error) {
if since == "" {
@ -295,6 +298,10 @@ func TestPollerPollFromExisting(t *testing.T) {
var joinResp SyncV2JoinResponse
joinResp.Timeline.Events = roomTimelineResponses[sinceInt]
return &SyncResponse{
// Add in dummy toDevice messages, so the poller actually persists the since token. (Which
// it only does for the first poll, after 1min (this test doesn't run that long) OR there are
// ToDevice messages in the response)
ToDevice: EventsResponse{Events: toDeviceResponses[sinceInt]},
NextBatch: fmt.Sprintf("%d", sinceInt+1),
Rooms: struct {
Join map[string]SyncV2JoinResponse `json:"join"`
@ -336,6 +343,121 @@ func TestPollerPollFromExisting(t *testing.T) {
}
}
// Check that the since token in the database
// 1. is updated if it is the first iteration of poll
// 2. is NOT updated for random events
// 3. is updated if the syncV2 response contains ToDevice messages
// 4. is updated if at least 1min has passed since we last stored a token
func TestPollerPollUpdateDeviceSincePeriodically(t *testing.T) {
pid := PollerID{UserID: "@alice:localhost", DeviceID: "FOOBAR"}
syncResponses := make(chan *SyncResponse, 1)
syncCalledWithSince := make(chan string)
accumulator, client := newMocks(func(authHeader, since string) (*SyncResponse, int, error) {
if since != "" {
syncCalledWithSince <- since
}
return <-syncResponses, 200, nil
})
accumulator.updateSinceCalled = make(chan struct{}, 1)
poller := newPoller(pid, "Authorization: hello world", client, accumulator, zerolog.New(os.Stderr), false)
defer poller.Terminate()
go func() {
poller.Poll("0")
}()
hasPolledSuccessfully := make(chan struct{})
go func() {
poller.WaitUntilInitialSync()
close(hasPolledSuccessfully)
}()
// 1. Initial poll updates the database
next := "1"
syncResponses <- &SyncResponse{NextBatch: next}
mustEqualSince(t, <-syncCalledWithSince, "0")
select {
case <-hasPolledSuccessfully:
break
case <-time.After(time.Second):
t.Errorf("WaitUntilInitialSync failed to fire")
}
// Also check that UpdateDeviceSince was called
select {
case <-accumulator.updateSinceCalled:
case <-time.After(time.Millisecond * 100): // give the Poller some time to process the response
t.Fatalf("did not receive call to UpdateDeviceSince in time")
}
if got := accumulator.pollerIDToSince[pid]; got != next {
t.Fatalf("expected since to be updated to %s, but got %s", next, got)
}
// The since token used by calls to doSyncV2
wantSinceFromSync := next
// 2. Second request updates the state but NOT the database
syncResponses <- &SyncResponse{NextBatch: "2"}
mustEqualSince(t, <-syncCalledWithSince, wantSinceFromSync)
select {
case <-accumulator.updateSinceCalled:
t.Fatalf("unexpected call to UpdateDeviceSince")
case <-time.After(time.Millisecond * 100):
}
if got := accumulator.pollerIDToSince[pid]; got != next {
t.Fatalf("expected since to be updated to %s, but got %s", next, got)
}
// 3. Sync response contains a toDevice message and should be stored in the database
wantSinceFromSync = "2"
next = "3"
syncResponses <- &SyncResponse{
NextBatch: next,
ToDevice: EventsResponse{Events: []json.RawMessage{{}}},
}
mustEqualSince(t, <-syncCalledWithSince, wantSinceFromSync)
select {
case <-accumulator.updateSinceCalled:
case <-time.After(time.Millisecond * 100):
t.Fatalf("did not receive call to UpdateDeviceSince in time")
}
if got := accumulator.pollerIDToSince[pid]; got != next {
t.Fatalf("expected since to be updated to %s, but got %s", wantSinceFromSync, got)
}
wantSinceFromSync = next
// 4. ... some time has passed, this triggers the 1min limit
timeSince = func(d time.Time) time.Duration {
return time.Minute * 2
}
next = "10"
syncResponses <- &SyncResponse{NextBatch: next}
mustEqualSince(t, <-syncCalledWithSince, wantSinceFromSync)
select {
case <-accumulator.updateSinceCalled:
case <-time.After(time.Millisecond * 100):
t.Fatalf("did not receive call to UpdateDeviceSince in time")
}
if got := accumulator.pollerIDToSince[pid]; got != next {
t.Fatalf("expected since to be updated to %s, but got %s", wantSinceFromSync, got)
}
}
func mustEqualSince(t *testing.T, gotSince, expectedSince string) {
t.Helper()
if gotSince != expectedSince {
t.Fatalf("client.DoSyncV2 using unexpected since token: %s, want %s", gotSince, expectedSince)
}
}
// Tests that the poller backs off in 2,4,8,etc second increments to a variety of errors
func TestPollerBackoff(t *testing.T) {
deviceID := "FOOBAR"
@ -453,11 +575,12 @@ func (c *mockClient) WhoAmI(authHeader string) (string, string, error) {
}
type mockDataReceiver struct {
states map[string][]json.RawMessage
timelines map[string][]json.RawMessage
pollerIDToSince map[PollerID]string
incomingProcess chan struct{}
unblockProcess chan struct{}
states map[string][]json.RawMessage
timelines map[string][]json.RawMessage
pollerIDToSince map[PollerID]string
incomingProcess chan struct{}
unblockProcess chan struct{}
updateSinceCalled chan struct{}
}
func (a *mockDataReceiver) Accumulate(ctx context.Context, userID, deviceID, roomID, prevBatch string, timeline []json.RawMessage) {
@ -475,10 +598,13 @@ func (a *mockDataReceiver) Initialise(ctx context.Context, roomID string, state
// timeline. Untested here---return nil for now.
return nil
}
func (a *mockDataReceiver) SetTyping(ctx context.Context, roomID string, ephEvent json.RawMessage) {
func (a *mockDataReceiver) SetTyping(ctx context.Context, pollerID PollerID, roomID string, ephEvent json.RawMessage) {
}
func (s *mockDataReceiver) UpdateDeviceSince(ctx context.Context, userID, deviceID, since string) {
s.pollerIDToSince[PollerID{UserID: userID, DeviceID: deviceID}] = since
if s.updateSinceCalled != nil {
s.updateSinceCalled <- struct{}{}
}
}
func (s *mockDataReceiver) AddToDeviceMessages(ctx context.Context, userID, deviceID string, msgs []json.RawMessage) {
}
@ -491,10 +617,11 @@ func (s *mockDataReceiver) OnReceipt(ctx context.Context, userID, roomID, ephEve
}
func (s *mockDataReceiver) OnInvite(ctx context.Context, userID, roomID string, inviteState []json.RawMessage) {
}
func (s *mockDataReceiver) OnLeftRoom(ctx context.Context, userID, roomID string) {}
func (s *mockDataReceiver) OnLeftRoom(ctx context.Context, userID, roomID string, leaveEvent json.RawMessage) {
}
func (s *mockDataReceiver) OnE2EEData(ctx context.Context, userID, deviceID string, otkCounts map[string]int, fallbackKeyTypes []string, deviceListChanges map[string]int) {
}
func (s *mockDataReceiver) OnTerminated(ctx context.Context, userID, deviceID string) {}
func (s *mockDataReceiver) OnTerminated(ctx context.Context, pollerID PollerID) {}
func (s *mockDataReceiver) OnExpiredToken(ctx context.Context, accessTokenHash, userID, deviceID string) {
}

View File

@ -2,7 +2,6 @@ package sync2
import (
"os"
"time"
"github.com/getsentry/sentry-go"
"github.com/jmoiron/sqlx"
@ -27,9 +26,10 @@ func NewStore(postgresURI, secret string) *Storage {
// TODO: if we panic(), will sentry have a chance to flush the event?
logger.Panic().Err(err).Str("uri", postgresURI).Msg("failed to open SQL DB")
}
db.SetMaxOpenConns(100)
db.SetMaxIdleConns(80)
db.SetConnMaxLifetime(time.Hour)
return NewStoreWithDB(db, secret)
}
func NewStoreWithDB(db *sqlx.DB, secret string) *Storage {
return &Storage{
DevicesTable: NewDevicesTable(db),
TokensTable: NewTokensTable(db, secret),

View File

@ -171,10 +171,10 @@ func (t *TokensTable) TokenForEachDevice(txn *sqlx.Tx) (tokens []TokenForPoller,
}
// Insert a new token into the table.
func (t *TokensTable) Insert(plaintextToken, userID, deviceID string, lastSeen time.Time) (*Token, error) {
func (t *TokensTable) Insert(txn *sqlx.Tx, plaintextToken, userID, deviceID string, lastSeen time.Time) (*Token, error) {
hashedToken := hashToken(plaintextToken)
encToken := t.encrypt(plaintextToken)
_, err := t.db.Exec(
_, err := txn.Exec(
`INSERT INTO syncv3_sync2_tokens(token_hash, token_encrypted, user_id, device_id, last_seen)
VALUES ($1, $2, $3, $4, $5)
ON CONFLICT (token_hash) DO NOTHING;`,

View File

@ -1,6 +1,8 @@
package sync2
import (
"github.com/jmoiron/sqlx"
"github.com/matrix-org/sliding-sync/sqlutil"
"testing"
"time"
)
@ -26,27 +28,31 @@ func TestTokensTable(t *testing.T) {
aliceSecret1 := "mysecret1"
aliceToken1FirstSeen := time.Now()
// Test a single token
t.Log("Insert a new token from Alice.")
aliceToken, err := tokens.Insert(aliceSecret1, alice, aliceDevice, aliceToken1FirstSeen)
if err != nil {
t.Fatalf("Failed to Insert token: %s", err)
}
var aliceToken, reinsertedToken *Token
_ = sqlutil.WithTransaction(db, func(txn *sqlx.Tx) (err error) {
// Test a single token
t.Log("Insert a new token from Alice.")
aliceToken, err = tokens.Insert(txn, aliceSecret1, alice, aliceDevice, aliceToken1FirstSeen)
if err != nil {
t.Fatalf("Failed to Insert token: %s", err)
}
t.Log("The returned Token struct should have been populated correctly.")
assertEqualTokens(t, tokens, aliceToken, aliceSecret1, alice, aliceDevice, aliceToken1FirstSeen)
t.Log("The returned Token struct should have been populated correctly.")
assertEqualTokens(t, tokens, aliceToken, aliceSecret1, alice, aliceDevice, aliceToken1FirstSeen)
t.Log("Reinsert the same token.")
reinsertedToken, err := tokens.Insert(aliceSecret1, alice, aliceDevice, aliceToken1FirstSeen)
if err != nil {
t.Fatalf("Failed to Insert token: %s", err)
}
t.Log("Reinsert the same token.")
reinsertedToken, err = tokens.Insert(txn, aliceSecret1, alice, aliceDevice, aliceToken1FirstSeen)
if err != nil {
t.Fatalf("Failed to Insert token: %s", err)
}
return nil
})
t.Log("This should yield an equal Token struct.")
assertEqualTokens(t, tokens, reinsertedToken, aliceSecret1, alice, aliceDevice, aliceToken1FirstSeen)
t.Log("Try to mark Alice's token as being used after an hour.")
err = tokens.MaybeUpdateLastSeen(aliceToken, aliceToken1FirstSeen.Add(time.Hour))
err := tokens.MaybeUpdateLastSeen(aliceToken, aliceToken1FirstSeen.Add(time.Hour))
if err != nil {
t.Fatalf("Failed to update last seen: %s", err)
}
@ -74,17 +80,20 @@ func TestTokensTable(t *testing.T) {
}
assertEqualTokens(t, tokens, fetchedToken, aliceSecret1, alice, aliceDevice, aliceToken1LastSeen)
// Test a second token for Alice
t.Log("Insert a second token for Alice.")
aliceSecret2 := "mysecret2"
aliceToken2FirstSeen := aliceToken1LastSeen.Add(time.Minute)
aliceToken2, err := tokens.Insert(aliceSecret2, alice, aliceDevice, aliceToken2FirstSeen)
if err != nil {
t.Fatalf("Failed to Insert token: %s", err)
}
_ = sqlutil.WithTransaction(db, func(txn *sqlx.Tx) error {
// Test a second token for Alice
t.Log("Insert a second token for Alice.")
aliceSecret2 := "mysecret2"
aliceToken2FirstSeen := aliceToken1LastSeen.Add(time.Minute)
aliceToken2, err := tokens.Insert(txn, aliceSecret2, alice, aliceDevice, aliceToken2FirstSeen)
if err != nil {
t.Fatalf("Failed to Insert token: %s", err)
}
t.Log("The returned Token struct should have been populated correctly.")
assertEqualTokens(t, tokens, aliceToken2, aliceSecret2, alice, aliceDevice, aliceToken2FirstSeen)
t.Log("The returned Token struct should have been populated correctly.")
assertEqualTokens(t, tokens, aliceToken2, aliceSecret2, alice, aliceDevice, aliceToken2FirstSeen)
return nil
})
}
func TestDeletingTokens(t *testing.T) {
@ -94,11 +103,15 @@ func TestDeletingTokens(t *testing.T) {
t.Log("Insert a new token from Alice.")
accessToken := "mytoken"
token, err := tokens.Insert(accessToken, "@bob:builders.com", "device", time.Time{})
if err != nil {
t.Fatalf("Failed to Insert token: %s", err)
}
var token *Token
err := sqlutil.WithTransaction(db, func(txn *sqlx.Tx) (err error) {
token, err = tokens.Insert(txn, accessToken, "@bob:builders.com", "device", time.Time{})
if err != nil {
t.Fatalf("Failed to Insert token: %s", err)
}
return nil
})
t.Log("We should be able to fetch this token without error.")
_, err = tokens.Token(accessToken)
if err != nil {

View File

@ -1,38 +1,121 @@
package sync2
import (
"fmt"
"sync"
"time"
"github.com/ReneKroon/ttlcache/v2"
)
type TransactionIDCache struct {
cache *ttlcache.Cache
type loaderFunc func(userID string) (deviceIDs []string)
// PendingTransactionIDs is (conceptually) a map from event IDs to a list of device IDs.
// Its keys are the IDs of event we've seen which a) lack a transaction ID, and b) were
// sent by one of the users we are polling for. The values are the list of the sender's
// devices whose pollers are yet to see a transaction ID.
//
// If another poller sees the same event
//
// - with a transaction ID, it emits a V2TransactionID payload with that ID and
// removes the event ID from this map.
//
// - without a transaction ID, it removes the polling device ID from the values
// list. If the device ID list is now empty, the poller emits an "all clear"
// V2TransactionID payload.
//
// This is a best-effort affair to ensure that the rest of the proxy can wait for
// transaction IDs to appear before transmitting an event down /sync to its sender.
//
// It's possible that we add an entry to this map and then the list of remaining
// device IDs becomes out of date, either due to a new device creation or an
// existing device expiring. We choose not to handle this case, because it is relatively
// rare.
//
// To avoid the map growing without bound, we use a ttlcache and drop entries
// after a short period of time.
type PendingTransactionIDs struct {
// mu guards the pending field. See MissingTxnID for rationale.
mu sync.Mutex
pending *ttlcache.Cache
// loader should provide the list of device IDs
loader loaderFunc
}
func NewTransactionIDCache() *TransactionIDCache {
func NewPendingTransactionIDs(loader loaderFunc) *PendingTransactionIDs {
c := ttlcache.NewCache()
c.SetTTL(5 * time.Minute) // keep transaction IDs for 5 minutes before forgetting about them
c.SkipTTLExtensionOnHit(true) // we don't care how many times they ask for the item, 5min is the limit.
return &TransactionIDCache{
cache: c,
return &PendingTransactionIDs{
mu: sync.Mutex{},
pending: c,
loader: loader,
}
}
// Store a new transaction ID received via v2 /sync
func (c *TransactionIDCache) Store(userID, eventID, txnID string) {
c.cache.Set(cacheKey(userID, eventID), txnID)
}
// MissingTxnID should be called to report that this device ID did not see a
// transaction ID for this event ID. Returns true if this is the first time we know
// for sure that we'll never see a txn ID for this event.
func (c *PendingTransactionIDs) MissingTxnID(eventID, userID, myDeviceID string) (bool, error) {
// While ttlcache is threadsafe, it does not provide a way to atomically update
// (get+set) a value, which means we are still open to races. For example:
//
// - We have three pollers A, B, C.
// - Poller A sees an event without txn id and calls MissingTxnID.
// - `c.pending.Get()` fails, so we load up all device IDs: [A, B, C].
// - Then `c.pending.Set()` with [B, C].
// - Poller B sees the same event, also missing txn ID and calls MissingTxnID.
// - Poller C does the same concurrently.
//
// If the Get+Set isn't atomic, then we might do e.g.
// - B gets [B, C] and prepares to write [C].
// - C gets [B, C] and prepares to write [B].
// - Last writer wins. Either way, we never write [] and so never return true
// (the all-clear signal.)
//
// This wouldn't be the end of the world (the API process has a maximum delay, and
// the ttlcache will expire the entry), but it would still be nice to avoid it.
c.mu.Lock()
defer c.mu.Unlock()
// Get a transaction ID previously stored.
func (c *TransactionIDCache) Get(userID, eventID string) string {
val, _ := c.cache.Get(cacheKey(userID, eventID))
if val != nil {
return val.(string)
data, err := c.pending.Get(eventID)
if err == ttlcache.ErrNotFound {
data = c.loader(userID)
} else if err != nil {
return false, fmt.Errorf("PendingTransactionIDs: failed to get device ids: %w", err)
}
return ""
deviceIDs, ok := data.([]string)
if !ok {
return false, fmt.Errorf("PendingTransactionIDs: failed to cast device IDs")
}
deviceIDs, changed := removeDevice(myDeviceID, deviceIDs)
if changed {
err = c.pending.Set(eventID, deviceIDs)
if err != nil {
return false, fmt.Errorf("PendingTransactionIDs: failed to set device IDs: %w", err)
}
}
return changed && len(deviceIDs) == 0, nil
}
func cacheKey(userID, eventID string) string {
return userID + " " + eventID
// SeenTxnID should be called to report that this device saw a transaction ID
// for this event.
func (c *PendingTransactionIDs) SeenTxnID(eventID string) error {
c.mu.Lock()
defer c.mu.Unlock()
return c.pending.Set(eventID, []string{})
}
// removeDevice takes a device ID slice and returns a device ID slice with one
// particular string removed. Assumes that the given slice has no duplicates.
// Does not modify the given slice in situ.
func removeDevice(device string, devices []string) ([]string, bool) {
for i, otherDevice := range devices {
if otherDevice == device {
return append(devices[:i], devices[i+1:]...), true
}
}
return devices, false
}

View File

@ -2,54 +2,107 @@ package sync2
import "testing"
func TestTransactionIDCache(t *testing.T) {
alice := "@alice:localhost"
bob := "@bob:localhost"
eventA := "$a:localhost"
eventB := "$b:localhost"
eventC := "$c:localhost"
txn1 := "1"
txn2 := "2"
cache := NewTransactionIDCache()
cache.Store(alice, eventA, txn1)
cache.Store(bob, eventB, txn1) // different users can use same txn ID
cache.Store(alice, eventC, txn2)
testCases := []struct {
eventID string
userID string
want string
}{
{
eventID: eventA,
userID: alice,
want: txn1,
},
{
eventID: eventB,
userID: bob,
want: txn1,
},
{
eventID: eventC,
userID: alice,
want: txn2,
},
{
eventID: "$invalid",
userID: alice,
want: "",
},
{
eventID: eventA,
userID: "@invalid",
want: "",
},
func TestPendingTransactionIDs(t *testing.T) {
pollingDevicesByUser := map[string][]string{
"alice": {"A1", "A2"},
"bob": {"B1"},
"chris": {},
"delia": {"D1", "D2", "D3", "D4"},
"enid": {"E1", "E2"},
}
for _, tc := range testCases {
txnID := cache.Get(tc.userID, tc.eventID)
if txnID != tc.want {
t.Errorf("%+v: got %v want %v", tc, txnID, tc.want)
mockLoad := func(userID string) (deviceIDs []string) {
devices, ok := pollingDevicesByUser[userID]
if !ok {
t.Fatalf("Mock didn't have devices for %s", userID)
}
newDevices := make([]string, len(devices))
copy(newDevices, devices)
return newDevices
}
pending := NewPendingTransactionIDs(mockLoad)
// Alice.
// We're tracking two of Alice's devices.
allClear, err := pending.MissingTxnID("event1", "alice", "A1")
assertNoError(t, err)
assertAllClear(t, allClear, false) // waiting on A2
// If for some reason the poller sees the same event for the same device, we should
// still be waiting for A2.
allClear, err = pending.MissingTxnID("event1", "alice", "A1")
assertNoError(t, err)
assertAllClear(t, allClear, false)
// If for some reason Alice spun up a new device, we are still going to be waiting
// for A2.
allClear, err = pending.MissingTxnID("event1", "alice", "A_unknown_device")
assertNoError(t, err)
assertAllClear(t, allClear, false)
// If A2 sees the event without a txnID, we should emit the all clear signal.
allClear, err = pending.MissingTxnID("event1", "alice", "A2")
assertNoError(t, err)
assertAllClear(t, allClear, true)
// If for some reason A2 sees the event a second time, we shouldn't re-emit the
// all clear signal.
allClear, err = pending.MissingTxnID("event1", "alice", "A2")
assertNoError(t, err)
assertAllClear(t, allClear, false)
// Bob.
// We're only tracking one device for Bob
allClear, err = pending.MissingTxnID("event2", "bob", "B1")
assertNoError(t, err)
assertAllClear(t, allClear, true) // not waiting on any devices
// Chris.
// We're not tracking any devices for Chris. A MissingTxnID call for him shouldn't
// cause anything to explode.
allClear, err = pending.MissingTxnID("event3", "chris", "C_unknown_device")
assertNoError(t, err)
// Delia.
// Delia is tracking four devices.
allClear, err = pending.MissingTxnID("event4", "delia", "D1")
assertNoError(t, err)
assertAllClear(t, allClear, false) // waiting on D2, D3 and D4
// One of Delia's devices, say D2, sees a txn ID for event 4.
err = pending.SeenTxnID("event4")
assertNoError(t, err)
// The other devices see the event. Neither should emit all clear.
allClear, err = pending.MissingTxnID("event4", "delia", "D3")
assertNoError(t, err)
assertAllClear(t, allClear, false)
allClear, err = pending.MissingTxnID("event4", "delia", "D4")
assertNoError(t, err)
assertAllClear(t, allClear, false)
// Enid.
// Enid has two devices. Her first poller (E1) is lucky and sees the transaction ID.
err = pending.SeenTxnID("event5")
assertNoError(t, err)
// Her second poller misses the transaction ID, but this shouldn't cause an all clear.
allClear, err = pending.MissingTxnID("event4", "delia", "E2")
assertNoError(t, err)
assertAllClear(t, allClear, false)
}
func assertAllClear(t *testing.T, got bool, want bool) {
t.Helper()
if got != want {
t.Errorf("Expected allClear=%t, got %t", want, got)
}
}
func assertNoError(t *testing.T, err error) {
t.Helper()
if err != nil {
t.Fatalf("got error: %s", err)
}
}

42
sync3/avatar.go Normal file
View File

@ -0,0 +1,42 @@
package sync3
import (
"bytes"
"encoding/json"
)
// An AvatarChange represents a change to a room's avatar. There are three cases:
// - an empty string represents no change, and should be omitted when JSON-serialised;
// - the sentinel `<no-avatar>` represents a room that has never had an avatar,
// or a room whose avatar has been removed. It is JSON-serialised as null.
// - All other strings represent the current avatar of the room and JSON-serialise as
// normal.
type AvatarChange string
const DeletedAvatar = AvatarChange("<no-avatar>")
const UnchangedAvatar AvatarChange = ""
// NewAvatarChange interprets an optional avatar string as an AvatarChange.
func NewAvatarChange(avatar string) AvatarChange {
if avatar == "" {
return DeletedAvatar
}
return AvatarChange(avatar)
}
func (a AvatarChange) MarshalJSON() ([]byte, error) {
if a == DeletedAvatar {
return []byte(`null`), nil
} else {
return json.Marshal(string(a))
}
}
// Note: the unmarshalling is only used in tests.
func (a *AvatarChange) UnmarshalJSON(data []byte) error {
if bytes.Equal(data, []byte("null")) {
*a = DeletedAvatar
return nil
}
return json.Unmarshal(data, (*string)(a))
}

View File

@ -21,6 +21,15 @@ type EventData struct {
Content gjson.Result
Timestamp uint64
Sender string
// TransactionID is the unsigned.transaction_id field in the event as stored in the
// syncv3_events table, or the empty string if there is no such field.
//
// We may see the event on poller A without a transaction_id, and then later on
// poller B with a transaction_id. If this happens, we make a temporary note of the
// transaction_id in the syncv3_txns table, but do not edit the persisted event.
// This means that this field is not authoritative; we only include it here as a
// hint to avoid unnecessary waits for V2TransactionID payloads.
TransactionID string
// the number of joined users in this room. Use this value and don't try to work it out as you
// may get it wrong due to Synapse sending duplicate join events(!) This value has them de-duped
@ -118,13 +127,7 @@ func (c *GlobalCache) copyRoom(roomID string) *internal.RoomMetadata {
logger.Warn().Str("room", roomID).Msg("GlobalCache.LoadRoom: no metadata for this room, returning stub")
return internal.NewRoomMetadata(roomID)
}
srCopy := *sr
// copy the heroes or else we may modify the same slice which would be bad :(
srCopy.Heroes = make([]internal.Hero, len(sr.Heroes))
for i := range sr.Heroes {
srCopy.Heroes[i] = sr.Heroes[i]
}
return &srCopy
return sr.CopyHeroes()
}
// LoadJoinedRooms loads all current joined room metadata for the user given, together
@ -158,7 +161,7 @@ func (c *GlobalCache) LoadJoinedRooms(ctx context.Context, userID string) (
i++
}
latestNIDs, err = c.store.EventsTable.LatestEventNIDInRooms(roomIDs, initialLoadPosition)
latestNIDs, err = c.store.LatestEventNIDInRooms(roomIDs, initialLoadPosition)
if err != nil {
return 0, nil, nil, nil, err
}
@ -239,8 +242,13 @@ func (c *GlobalCache) Startup(roomIDToMetadata map[string]internal.RoomMetadata)
sort.Strings(roomIDs)
for _, roomID := range roomIDs {
metadata := roomIDToMetadata[roomID]
internal.Assert("room ID is set", metadata.RoomID != "")
internal.Assert("last message timestamp exists", metadata.LastMessageTimestamp > 1)
debugContext := map[string]interface{}{
"room_id": roomID,
"metadata.RoomID": metadata.RoomID,
"metadata.LastMessageTimeStamp": metadata.LastMessageTimestamp,
}
internal.Assert("room ID is set", metadata.RoomID != "", debugContext)
internal.Assert("last message timestamp exists", metadata.LastMessageTimestamp > 1, debugContext)
c.roomIDToMetadata[roomID] = &metadata
}
return nil
@ -285,6 +293,10 @@ func (c *GlobalCache) OnNewEvent(
if ed.StateKey != nil && *ed.StateKey == "" {
metadata.NameEvent = ed.Content.Get("name").Str
}
case "m.room.avatar":
if ed.StateKey != nil && *ed.StateKey == "" {
metadata.AvatarEvent = ed.Content.Get("url").Str
}
case "m.room.encryption":
if ed.StateKey != nil && *ed.StateKey == "" {
metadata.Encrypted = true
@ -349,14 +361,16 @@ func (c *GlobalCache) OnNewEvent(
for i := range metadata.Heroes {
if metadata.Heroes[i].ID == *ed.StateKey {
metadata.Heroes[i].Name = ed.Content.Get("displayname").Str
metadata.Heroes[i].Avatar = ed.Content.Get("avatar_url").Str
found = true
break
}
}
if !found {
metadata.Heroes = append(metadata.Heroes, internal.Hero{
ID: *ed.StateKey,
Name: ed.Content.Get("displayname").Str,
ID: *ed.StateKey,
Name: ed.Content.Get("displayname").Str,
Avatar: ed.Content.Get("avatar_url").Str,
})
}
}

View File

@ -38,12 +38,12 @@ func TestGlobalCacheLoadState(t *testing.T) {
testutils.NewStateEvent(t, "m.room.name", "", alice, map[string]interface{}{"name": "The Room Name"}),
testutils.NewStateEvent(t, "m.room.name", "", alice, map[string]interface{}{"name": "The Updated Room Name"}),
}
_, _, err := store.Accumulate(roomID2, "", eventsRoom2)
_, _, err := store.Accumulate(alice, roomID2, "", eventsRoom2)
if err != nil {
t.Fatalf("Accumulate: %s", err)
}
_, latestNIDs, err := store.Accumulate(roomID, "", events)
_, latestNIDs, err := store.Accumulate(alice, roomID, "", events)
if err != nil {
t.Fatalf("Accumulate: %s", err)
}

View File

@ -38,15 +38,6 @@ func (u *InviteUpdate) Type() string {
return fmt.Sprintf("InviteUpdate[%s]", u.RoomID())
}
// LeftRoomUpdate corresponds to a key-value pair from a v2 sync's `leave` section.
type LeftRoomUpdate struct {
RoomUpdate
}
func (u *LeftRoomUpdate) Type() string {
return fmt.Sprintf("LeftRoomUpdate[%s]", u.RoomID())
}
// TypingEdu corresponds to a typing EDU in the `ephemeral` section of a joined room's v2 sync resposne.
type TypingUpdate struct {
RoomUpdate

View File

@ -46,10 +46,18 @@ type UserRoomData struct {
// The zero value of this safe to use (0 latest nid, no prev batch, no timeline).
RequestedLatestEvents state.LatestEvents
// TODO: should Canonicalised really be in RoomConMetadata? It's only set in SetRoom AFAICS
// TODO: should CanonicalisedName really be in RoomConMetadata? It's only set in SetRoom AFAICS
CanonicalisedName string // stripped leading symbols like #, all in lower case
// Set of spaces this room is a part of, from the perspective of this user. This is NOT global room data
// as the set of spaces may be different for different users.
// ResolvedAvatarURL is the avatar that should be displayed to this user to
// represent this room. The empty string means that this room has no avatar.
// Avatars set in m.room.avatar take precedence; if this is missing and the room is
// a DM with one other user joined or invited, we fall back to that user's
// avatar (if any) as specified in their membership event in that room.
ResolvedAvatarURL string
Spaces map[string]struct{}
// Map of tag to order float.
// See https://spec.matrix.org/latest/client-server-api/#room-tagging
@ -73,6 +81,7 @@ type InviteData struct {
Heroes []internal.Hero
InviteEvent *EventData
NameEvent string // the content of m.room.name, NOT the calculated name
AvatarEvent string // the content of m.room.avatar, NOT the calculated avatar
CanonicalAlias string
LastMessageTimestamp uint64
Encrypted bool
@ -108,12 +117,15 @@ func NewInviteData(ctx context.Context, userID, roomID string, inviteState []jso
id.IsDM = j.Get("is_direct").Bool()
} else if target == j.Get("sender").Str {
id.Heroes = append(id.Heroes, internal.Hero{
ID: target,
Name: j.Get("content.displayname").Str,
ID: target,
Name: j.Get("content.displayname").Str,
Avatar: j.Get("content.avatar_url").Str,
})
}
case "m.room.name":
id.NameEvent = j.Get("content.name").Str
case "m.room.avatar":
id.AvatarEvent = j.Get("content.url").Str
case "m.room.canonical_alias":
id.CanonicalAlias = j.Get("content.alias").Str
case "m.room.encryption":
@ -147,6 +159,7 @@ func (i *InviteData) RoomMetadata() *internal.RoomMetadata {
metadata := internal.NewRoomMetadata(i.roomID)
metadata.Heroes = i.Heroes
metadata.NameEvent = i.NameEvent
metadata.AvatarEvent = i.AvatarEvent
metadata.CanonicalAlias = i.CanonicalAlias
metadata.InviteCount = 1
metadata.JoinCount = 1
@ -178,18 +191,22 @@ type UserCache struct {
store *state.Storage
globalCache *GlobalCache
txnIDs TransactionIDFetcher
ignoredUsers map[string]struct{}
ignoredUsersMu *sync.RWMutex
}
func NewUserCache(userID string, globalCache *GlobalCache, store *state.Storage, txnIDs TransactionIDFetcher) *UserCache {
uc := &UserCache{
UserID: userID,
roomToDataMu: &sync.RWMutex{},
roomToData: make(map[string]UserRoomData),
listeners: make(map[int]UserCacheListener),
listenersMu: &sync.RWMutex{},
store: store,
globalCache: globalCache,
txnIDs: txnIDs,
UserID: userID,
roomToDataMu: &sync.RWMutex{},
roomToData: make(map[string]UserRoomData),
listeners: make(map[int]UserCacheListener),
listenersMu: &sync.RWMutex{},
store: store,
globalCache: globalCache,
txnIDs: txnIDs,
ignoredUsers: make(map[string]struct{}),
ignoredUsersMu: &sync.RWMutex{},
}
return uc
}
@ -212,7 +229,7 @@ func (c *UserCache) Unsubscribe(id int) {
// OnRegistered is called after the sync3.Dispatcher has successfully registered this
// cache to receive updates. We use this to run some final initialisation logic that
// is sensitive to race conditions; confusingly, most of the initialisation is driven
// externally by sync3.SyncLiveHandler.userCache. It's importatn that we don't spend too
// externally by sync3.SyncLiveHandler.userCaches. It's important that we don't spend too
// long inside this function, because it is called within a global lock on the
// sync3.Dispatcher (see sync3.Dispatcher.Register).
func (c *UserCache) OnRegistered(ctx context.Context) error {
@ -309,6 +326,7 @@ func (c *UserCache) LazyLoadTimelines(ctx context.Context, loadPos int64, roomID
urd = NewUserRoomData()
}
if latestEvents != nil {
latestEvents.DiscardIgnoredMessages(c.ShouldIgnore)
urd.RequestedLatestEvents = *latestEvents
}
result[requestedRoomID] = urd
@ -328,7 +346,10 @@ func (c *UserCache) LoadRoomData(roomID string) UserRoomData {
}
type roomUpdateCache struct {
roomID string
roomID string
// globalRoomData is a snapshot of the global metadata for this room immediately
// after this update. It is a copy, specific to the given user whose Heroes
// field can be freely modified.
globalRoomData *internal.RoomMetadata
userRoomData *UserRoomData
}
@ -393,8 +414,16 @@ func (c *UserCache) AnnotateWithTransactionIDs(ctx context.Context, userID strin
i int
})
for roomID, events := range roomIDToEvents {
for i, ev := range events {
evID := gjson.GetBytes(ev, "event_id").Str
for i, evJSON := range events {
ev := gjson.ParseBytes(evJSON)
evID := ev.Get("event_id").Str
sender := ev.Get("sender").Str
if sender != userID {
// don't ask for txn IDs for events which weren't sent by us.
// If we do, we'll needlessly hit the database, increasing latencies when
// catching up from the live buffer.
continue
}
eventIDs = append(eventIDs, evID)
eventIDToEvent[evID] = struct {
roomID string
@ -405,6 +434,10 @@ func (c *UserCache) AnnotateWithTransactionIDs(ctx context.Context, userID strin
}
}
}
if len(eventIDs) == 0 {
// don't do any work if we have no events
return roomIDToEvents
}
eventIDToTxnID := c.txnIDs.TransactionIDForEvents(userID, deviceID, eventIDs)
for eventID, txnID := range eventIDToTxnID {
data, ok := eventIDToEvent[eventID]
@ -578,7 +611,7 @@ func (c *UserCache) OnInvite(ctx context.Context, roomID string, inviteStateEven
c.emitOnRoomUpdate(ctx, up)
}
func (c *UserCache) OnLeftRoom(ctx context.Context, roomID string) {
func (c *UserCache) OnLeftRoom(ctx context.Context, roomID string, leaveEvent json.RawMessage) {
urd := c.LoadRoomData(roomID)
urd.IsInvite = false
urd.HasLeft = true
@ -588,7 +621,10 @@ func (c *UserCache) OnLeftRoom(ctx context.Context, roomID string) {
c.roomToData[roomID] = urd
c.roomToDataMu.Unlock()
up := &LeftRoomUpdate{
ev := gjson.ParseBytes(leaveEvent)
stateKey := ev.Get("state_key").Str
up := &RoomEventUpdate{
RoomUpdate: &roomUpdateCache{
roomID: roomID,
// do NOT pull from the global cache as it is a snapshot of the room at the point of
@ -596,6 +632,18 @@ func (c *UserCache) OnLeftRoom(ctx context.Context, roomID string) {
globalRoomData: internal.NewRoomMetadata(roomID),
userRoomData: &urd,
},
EventData: &EventData{
Event: leaveEvent,
RoomID: roomID,
EventType: ev.Get("type").Str,
StateKey: &stateKey,
Content: ev.Get("content"),
Timestamp: ev.Get("origin_server_ts").Uint(),
Sender: ev.Get("sender").Str,
// if this is an invite rejection we need to make sure we tell the client, and not
// skip it because of the lack of a NID (this event may not be in the events table)
AlwaysProcess: true,
},
}
c.emitOnRoomUpdate(ctx, up)
}
@ -608,7 +656,8 @@ func (c *UserCache) OnAccountData(ctx context.Context, datas []state.AccountData
up := roomUpdates[d.RoomID]
up = append(up, d)
roomUpdates[d.RoomID] = up
if d.Type == "m.direct" {
switch d.Type {
case "m.direct":
dmRoomSet := make(map[string]struct{})
// pull out rooms and mark them as DMs
content := gjson.ParseBytes(d.Data).Get("content")
@ -633,7 +682,7 @@ func (c *UserCache) OnAccountData(ctx context.Context, datas []state.AccountData
c.roomToData[dmRoomID] = u
}
c.roomToDataMu.Unlock()
} else if d.Type == "m.tag" {
case "m.tag":
content := gjson.ParseBytes(d.Data).Get("content.tags")
if tagUpdates[d.RoomID] == nil {
tagUpdates[d.RoomID] = make(map[string]float64)
@ -642,6 +691,22 @@ func (c *UserCache) OnAccountData(ctx context.Context, datas []state.AccountData
tagUpdates[d.RoomID][k.Str] = v.Get("order").Float()
return true
})
case "m.ignored_user_list":
if d.RoomID != state.AccountDataGlobalRoom {
continue
}
content := gjson.ParseBytes(d.Data).Get("content.ignored_users")
if !content.IsObject() {
continue
}
ignoredUsers := make(map[string]struct{})
content.ForEach(func(k, v gjson.Result) bool {
ignoredUsers[k.Str] = struct{}{}
return true
})
c.ignoredUsersMu.Lock()
c.ignoredUsers = ignoredUsers
c.ignoredUsersMu.Unlock()
}
}
if len(tagUpdates) > 0 {
@ -674,3 +739,10 @@ func (c *UserCache) OnAccountData(ctx context.Context, datas []state.AccountData
}
}
func (u *UserCache) ShouldIgnore(userID string) bool {
u.ignoredUsersMu.RLock()
defer u.ignoredUsersMu.RUnlock()
_, ignored := u.ignoredUsers[userID]
return ignored
}

View File

@ -83,8 +83,8 @@ func TestAnnotateWithTransactionIDs(t *testing.T) {
data: tc.eventIDToTxnIDs,
}
uc := caches.NewUserCache(userID, nil, nil, fetcher)
got := uc.AnnotateWithTransactionIDs(context.Background(), userID, "DEVICE", convertIDToEventStub(tc.roomIDToEvents))
want := convertIDTxnToEventStub(tc.wantRoomIDToEvents)
got := uc.AnnotateWithTransactionIDs(context.Background(), userID, "DEVICE", convertIDToEventStub(userID, tc.roomIDToEvents))
want := convertIDTxnToEventStub(userID, tc.wantRoomIDToEvents)
if !reflect.DeepEqual(got, want) {
t.Errorf("%s : got %v want %v", tc.name, js(got), js(want))
}
@ -96,27 +96,27 @@ func js(in interface{}) string {
return string(b)
}
func convertIDToEventStub(roomToEventIDs map[string][]string) map[string][]json.RawMessage {
func convertIDToEventStub(sender string, roomToEventIDs map[string][]string) map[string][]json.RawMessage {
result := make(map[string][]json.RawMessage)
for roomID, eventIDs := range roomToEventIDs {
events := make([]json.RawMessage, len(eventIDs))
for i := range eventIDs {
events[i] = json.RawMessage(fmt.Sprintf(`{"event_id":"%s","type":"x"}`, eventIDs[i]))
events[i] = json.RawMessage(fmt.Sprintf(`{"event_id":"%s","type":"x","sender":"%s"}`, eventIDs[i], sender))
}
result[roomID] = events
}
return result
}
func convertIDTxnToEventStub(roomToEventIDs map[string][][2]string) map[string][]json.RawMessage {
func convertIDTxnToEventStub(sender string, roomToEventIDs map[string][][2]string) map[string][]json.RawMessage {
result := make(map[string][]json.RawMessage)
for roomID, eventIDs := range roomToEventIDs {
events := make([]json.RawMessage, len(eventIDs))
for i := range eventIDs {
if eventIDs[i][1] == "" {
events[i] = json.RawMessage(fmt.Sprintf(`{"event_id":"%s","type":"x"}`, eventIDs[i][0]))
events[i] = json.RawMessage(fmt.Sprintf(`{"event_id":"%s","type":"x","sender":"%s"}`, eventIDs[i][0], sender))
} else {
events[i] = json.RawMessage(fmt.Sprintf(`{"event_id":"%s","type":"x","unsigned":{"transaction_id":"%s"}}`, eventIDs[i][0], eventIDs[i][1]))
events[i] = json.RawMessage(fmt.Sprintf(`{"event_id":"%s","type":"x","sender":"%s","unsigned":{"transaction_id":"%s"}}`, eventIDs[i][0], sender, eventIDs[i][1]))
}
}
result[roomID] = events

View File

@ -14,7 +14,7 @@ import (
// The amount of time to artificially wait if the server detects spamming clients. This time will
// be added to responses when the server detects the same request being sent over and over e.g
// /sync?pos=5 then /sync?pos=5 over and over. Likewise /sync without a ?pos=.
var SpamProtectionInterval = time.Second
var SpamProtectionInterval = 10 * time.Millisecond
type ConnID struct {
UserID string
@ -30,8 +30,9 @@ type ConnHandler interface {
// Callback which is allowed to block as long as the context is active. Return the response
// to send back or an error. Errors of type *internal.HandlerError are inspected for the correct
// status code to send back.
OnIncomingRequest(ctx context.Context, cid ConnID, req *Request, isInitial bool) (*Response, error)
OnIncomingRequest(ctx context.Context, cid ConnID, req *Request, isInitial bool, start time.Time) (*Response, error)
OnUpdate(ctx context.Context, update caches.Update)
PublishEventsUpTo(roomID string, nid int64)
Destroy()
Alive() bool
}
@ -88,7 +89,7 @@ func (c *Conn) OnUpdate(ctx context.Context, update caches.Update) {
// upwards but will NOT be logged to Sentry (neither here nor by the caller). Errors
// should be reported to Sentry as close as possible to the point of creating the error,
// to provide the best possible Sentry traceback.
func (c *Conn) tryRequest(ctx context.Context, req *Request) (res *Response, err error) {
func (c *Conn) tryRequest(ctx context.Context, req *Request, start time.Time) (res *Response, err error) {
// TODO: include useful information from the request in the sentry hub/context
// Might be better done in the caller though?
defer func() {
@ -116,7 +117,7 @@ func (c *Conn) tryRequest(ctx context.Context, req *Request) (res *Response, err
ctx, task := internal.StartTask(ctx, taskType)
defer task.End()
internal.Logf(ctx, "connstate", "starting user=%v device=%v pos=%v", c.UserID, c.ConnID.DeviceID, req.pos)
return c.handler.OnIncomingRequest(ctx, c.ConnID, req, req.pos == 0)
return c.handler.OnIncomingRequest(ctx, c.ConnID, req, req.pos == 0, start)
}
func (c *Conn) isOutstanding(pos int64) bool {
@ -132,7 +133,7 @@ func (c *Conn) isOutstanding(pos int64) bool {
// If an error is returned, it will be logged by the caller and transmitted to the
// client. It will NOT be reported to Sentry---this should happen as close as possible
// to the creation of the error (or else Sentry cannot provide a meaningful traceback.)
func (c *Conn) OnIncomingRequest(ctx context.Context, req *Request) (resp *Response, herr *internal.HandlerError) {
func (c *Conn) OnIncomingRequest(ctx context.Context, req *Request, start time.Time) (resp *Response, herr *internal.HandlerError) {
c.cancelOutstandingRequestMu.Lock()
if c.cancelOutstandingRequest != nil {
c.cancelOutstandingRequest()
@ -217,7 +218,7 @@ func (c *Conn) OnIncomingRequest(ctx context.Context, req *Request) (resp *Respo
req.SetTimeoutMSecs(1)
}
resp, err := c.tryRequest(ctx, req)
resp, err := c.tryRequest(ctx, req, start)
if err != nil {
herr, ok := err.(*internal.HandlerError)
if !ok {

View File

@ -16,7 +16,7 @@ type connHandlerMock struct {
fn func(ctx context.Context, cid ConnID, req *Request, isInitial bool) (*Response, error)
}
func (c *connHandlerMock) OnIncomingRequest(ctx context.Context, cid ConnID, req *Request, init bool) (*Response, error) {
func (c *connHandlerMock) OnIncomingRequest(ctx context.Context, cid ConnID, req *Request, init bool, start time.Time) (*Response, error) {
return c.fn(ctx, cid, req, init)
}
func (c *connHandlerMock) UserID() string {
@ -25,6 +25,7 @@ func (c *connHandlerMock) UserID() string {
func (c *connHandlerMock) Destroy() {}
func (c *connHandlerMock) Alive() bool { return true }
func (c *connHandlerMock) OnUpdate(ctx context.Context, update caches.Update) {}
func (c *connHandlerMock) PublishEventsUpTo(roomID string, nid int64) {}
// Test that Conn can send and receive requests based on positions
func TestConn(t *testing.T) {
@ -47,7 +48,7 @@ func TestConn(t *testing.T) {
// initial request
resp, err := c.OnIncomingRequest(ctx, &Request{
pos: 0,
})
}, time.Now())
assertNoError(t, err)
assertPos(t, resp.Pos, 1)
assertInt(t, resp.Lists["a"].Count, 101)
@ -55,14 +56,14 @@ func TestConn(t *testing.T) {
// happy case, pos=1
resp, err = c.OnIncomingRequest(ctx, &Request{
pos: 1,
})
}, time.Now())
assertPos(t, resp.Pos, 2)
assertInt(t, resp.Lists["a"].Count, 102)
assertNoError(t, err)
// bogus position returns a 400
_, err = c.OnIncomingRequest(ctx, &Request{
pos: 31415,
})
}, time.Now())
if err == nil {
t.Fatalf("expected error, got none")
}
@ -106,7 +107,7 @@ func TestConnBlocking(t *testing.T) {
Sort: []string{"hi"},
},
},
})
}, time.Now())
}()
go func() {
defer wg.Done()
@ -118,7 +119,7 @@ func TestConnBlocking(t *testing.T) {
Sort: []string{"hi2"},
},
},
})
}, time.Now())
}()
go func() {
wg.Wait()
@ -148,18 +149,18 @@ func TestConnRetries(t *testing.T) {
},
}}, nil
}})
resp, err := c.OnIncomingRequest(ctx, &Request{})
resp, err := c.OnIncomingRequest(ctx, &Request{}, time.Now())
assertPos(t, resp.Pos, 1)
assertInt(t, resp.Lists["a"].Count, 20)
assertInt(t, callCount, 1)
assertNoError(t, err)
resp, err = c.OnIncomingRequest(ctx, &Request{pos: 1})
resp, err = c.OnIncomingRequest(ctx, &Request{pos: 1}, time.Now())
assertPos(t, resp.Pos, 2)
assertInt(t, resp.Lists["a"].Count, 20)
assertInt(t, callCount, 2)
assertNoError(t, err)
// retry! Shouldn't invoke handler again
resp, err = c.OnIncomingRequest(ctx, &Request{pos: 1})
resp, err = c.OnIncomingRequest(ctx, &Request{pos: 1}, time.Now())
assertPos(t, resp.Pos, 2)
assertInt(t, resp.Lists["a"].Count, 20)
assertInt(t, callCount, 2) // this doesn't increment
@ -170,7 +171,7 @@ func TestConnRetries(t *testing.T) {
"a": {
Sort: []string{SortByName},
},
}})
}}, time.Now())
assertPos(t, resp.Pos, 2)
assertInt(t, resp.Lists["a"].Count, 20)
assertInt(t, callCount, 3) // this doesn't increment
@ -191,25 +192,25 @@ func TestConnBufferRes(t *testing.T) {
},
}}, nil
}})
resp, err := c.OnIncomingRequest(ctx, &Request{})
resp, err := c.OnIncomingRequest(ctx, &Request{}, time.Now())
assertNoError(t, err)
assertPos(t, resp.Pos, 1)
assertInt(t, resp.Lists["a"].Count, 1)
assertInt(t, callCount, 1)
resp, err = c.OnIncomingRequest(ctx, &Request{pos: 1})
resp, err = c.OnIncomingRequest(ctx, &Request{pos: 1}, time.Now())
assertNoError(t, err)
assertPos(t, resp.Pos, 2)
assertInt(t, resp.Lists["a"].Count, 2)
assertInt(t, callCount, 2)
// retry with modified request data that shouldn't prompt data to be returned.
// should invoke handler again!
resp, err = c.OnIncomingRequest(ctx, &Request{pos: 1, UnsubscribeRooms: []string{"a"}})
resp, err = c.OnIncomingRequest(ctx, &Request{pos: 1, UnsubscribeRooms: []string{"a"}}, time.Now())
assertNoError(t, err)
assertPos(t, resp.Pos, 2)
assertInt(t, resp.Lists["a"].Count, 2)
assertInt(t, callCount, 3) // this DOES increment, the response is buffered and not returned yet.
// retry with same request body, so should NOT invoke handler again and return buffered response
resp, err = c.OnIncomingRequest(ctx, &Request{pos: 2, UnsubscribeRooms: []string{"a"}})
resp, err = c.OnIncomingRequest(ctx, &Request{pos: 2, UnsubscribeRooms: []string{"a"}}, time.Now())
assertNoError(t, err)
assertPos(t, resp.Pos, 3)
assertInt(t, resp.Lists["a"].Count, 3)
@ -228,7 +229,7 @@ func TestConnErrors(t *testing.T) {
// random errors = 500
errCh <- errors.New("oops")
_, herr := c.OnIncomingRequest(ctx, &Request{})
_, herr := c.OnIncomingRequest(ctx, &Request{}, time.Now())
if herr.StatusCode != 500 {
t.Fatalf("random errors should be status 500, got %d", herr.StatusCode)
}
@ -237,7 +238,7 @@ func TestConnErrors(t *testing.T) {
StatusCode: 400,
Err: errors.New("no way!"),
}
_, herr = c.OnIncomingRequest(ctx, &Request{})
_, herr = c.OnIncomingRequest(ctx, &Request{}, time.Now())
if herr.StatusCode != 400 {
t.Fatalf("expected status 400, got %d", herr.StatusCode)
}
@ -258,7 +259,7 @@ func TestConnErrorsNoCache(t *testing.T) {
}
}})
// errors should not be cached
resp, herr := c.OnIncomingRequest(ctx, &Request{})
resp, herr := c.OnIncomingRequest(ctx, &Request{}, time.Now())
if herr != nil {
t.Fatalf("expected no error, got %+v", herr)
}
@ -267,12 +268,12 @@ func TestConnErrorsNoCache(t *testing.T) {
StatusCode: 400,
Err: errors.New("no way!"),
}
_, herr = c.OnIncomingRequest(ctx, &Request{pos: resp.PosInt()})
_, herr = c.OnIncomingRequest(ctx, &Request{pos: resp.PosInt()}, time.Now())
if herr.StatusCode != 400 {
t.Fatalf("expected status 400, got %d", herr.StatusCode)
}
// but doing the exact same request should now work
_, herr = c.OnIncomingRequest(ctx, &Request{pos: resp.PosInt()})
_, herr = c.OnIncomingRequest(ctx, &Request{pos: resp.PosInt()}, time.Now())
if herr != nil {
t.Fatalf("expected no error, got %+v", herr)
}
@ -361,7 +362,7 @@ func TestConnBufferRememberInflight(t *testing.T) {
var err *internal.HandlerError
for i, step := range steps {
t.Logf("Executing step %d", i)
resp, err = c.OnIncomingRequest(ctx, step.req)
resp, err = c.OnIncomingRequest(ctx, step.req, time.Now())
if !step.wantErr {
assertNoError(t, err)
}

View File

@ -5,6 +5,7 @@ import (
"time"
"github.com/ReneKroon/ttlcache/v2"
"github.com/prometheus/client_golang/prometheus"
)
// ConnMap stores a collection of Conns.
@ -15,10 +16,15 @@ type ConnMap struct {
userIDToConn map[string][]*Conn
connIDToConn map[string]*Conn
numConns prometheus.Gauge
// counters for reasons why connections have expired
expiryTimedOutCounter prometheus.Counter
expiryBufferFullCounter prometheus.Counter
mu *sync.Mutex
}
func NewConnMap() *ConnMap {
func NewConnMap(enablePrometheus bool) *ConnMap {
cm := &ConnMap{
userIDToConn: make(map[string][]*Conn),
connIDToConn: make(map[string]*Conn),
@ -27,17 +33,61 @@ func NewConnMap() *ConnMap {
}
cm.cache.SetTTL(30 * time.Minute) // TODO: customisable
cm.cache.SetExpirationCallback(cm.closeConnExpires)
if enablePrometheus {
cm.expiryTimedOutCounter = prometheus.NewCounter(prometheus.CounterOpts{
Namespace: "sliding_sync",
Subsystem: "api",
Name: "expiry_conn_timed_out",
Help: "Counter of expired API connections due to reaching TTL limit",
})
prometheus.MustRegister(cm.expiryTimedOutCounter)
cm.expiryBufferFullCounter = prometheus.NewCounter(prometheus.CounterOpts{
Namespace: "sliding_sync",
Subsystem: "api",
Name: "expiry_conn_buffer_full",
Help: "Counter of expired API connections due to reaching buffer update limit",
})
prometheus.MustRegister(cm.expiryBufferFullCounter)
cm.numConns = prometheus.NewGauge(prometheus.GaugeOpts{
Namespace: "sliding_sync",
Subsystem: "api",
Name: "num_active_conns",
Help: "Number of active sliding sync connections.",
})
prometheus.MustRegister(cm.numConns)
}
return cm
}
func (m *ConnMap) Teardown() {
m.cache.Close()
if m.numConns != nil {
prometheus.Unregister(m.numConns)
}
if m.expiryBufferFullCounter != nil {
prometheus.Unregister(m.expiryBufferFullCounter)
}
if m.expiryTimedOutCounter != nil {
prometheus.Unregister(m.expiryTimedOutCounter)
}
}
func (m *ConnMap) Len() int {
// UpdateMetrics recalculates the number of active connections. Do this when you think there is a change.
func (m *ConnMap) UpdateMetrics() {
m.mu.Lock()
defer m.mu.Unlock()
return len(m.connIDToConn)
m.updateMetrics(len(m.connIDToConn))
}
// updateMetrics is like UpdateMetrics but doesn't touch connIDToConn and hence doesn't need to lock. We use this internally
// when we need to update the metric and already have the lock held, as calling UpdateMetrics would deadlock.
func (m *ConnMap) updateMetrics(numConns int) {
if m.numConns == nil {
return
}
m.numConns.Set(float64(numConns))
}
// Conns return all connections for this user|device
@ -55,6 +105,14 @@ func (m *ConnMap) Conns(userID, deviceID string) []*Conn {
// Conn returns a connection with this ConnID. Returns nil if no connection exists.
func (m *ConnMap) Conn(cid ConnID) *Conn {
m.mu.Lock()
defer m.mu.Unlock()
return m.getConn(cid)
}
// getConn returns a connection with this ConnID. Returns nil if no connection exists. Expires connections if the buffer is full.
// Must hold mu.
func (m *ConnMap) getConn(cid ConnID) *Conn {
cint, _ := m.cache.Get(cid.String())
if cint == nil {
return nil
@ -64,8 +122,11 @@ func (m *ConnMap) Conn(cid ConnID) *Conn {
return conn
}
// e.g buffer exceeded, close it and remove it from the cache
logger.Trace().Str("conn", cid.String()).Msg("closing connection due to dead connection (buffer full)")
logger.Info().Str("conn", cid.String()).Msg("closing connection due to dead connection (buffer full)")
m.closeConn(conn)
if m.expiryBufferFullCounter != nil {
m.expiryBufferFullCounter.Inc()
}
return nil
}
@ -74,7 +135,7 @@ func (m *ConnMap) CreateConn(cid ConnID, newConnHandler func() ConnHandler) (*Co
// atomically check if a conn exists already and nuke it if it exists
m.mu.Lock()
defer m.mu.Unlock()
conn := m.Conn(cid)
conn := m.getConn(cid)
if conn != nil {
// tear down this connection and fallthrough
isSpamming := conn.lastPos <= 1
@ -92,6 +153,7 @@ func (m *ConnMap) CreateConn(cid ConnID, newConnHandler func() ConnHandler) (*Co
m.cache.Set(cid.String(), conn)
m.connIDToConn[cid.String()] = conn
m.userIDToConn[cid.UserID] = append(m.userIDToConn[cid.UserID], conn)
m.updateMetrics(len(m.connIDToConn))
return conn, true
}
@ -121,7 +183,10 @@ func (m *ConnMap) closeConnExpires(connID string, value interface{}) {
m.mu.Lock()
defer m.mu.Unlock()
conn := value.(*Conn)
logger.Trace().Str("conn", connID).Msg("closing connection due to expired TTL in cache")
logger.Info().Str("conn", connID).Msg("closing connection due to expired TTL in cache")
if m.expiryTimedOutCounter != nil {
m.expiryTimedOutCounter.Inc()
}
m.closeConn(conn)
}
@ -147,4 +212,14 @@ func (m *ConnMap) closeConn(conn *Conn) {
m.userIDToConn[conn.UserID] = conns
// remove user cache listeners etc
h.Destroy()
m.updateMetrics(len(m.connIDToConn))
}
func (m *ConnMap) ClearUpdateQueues(userID, roomID string, nid int64) {
m.mu.Lock()
defer m.mu.Unlock()
for _, conn := range m.userIDToConn[userID] {
conn.handler.PublishEventsUpTo(roomID, nid)
}
}

View File

@ -87,14 +87,15 @@ func (d *Dispatcher) newEventData(event json.RawMessage, roomID string, latestPo
eventType := ev.Get("type").Str
return &caches.EventData{
Event: event,
RoomID: roomID,
EventType: eventType,
StateKey: stateKey,
Content: ev.Get("content"),
NID: latestPos,
Timestamp: ev.Get("origin_server_ts").Uint(),
Sender: ev.Get("sender").Str,
Event: event,
RoomID: roomID,
EventType: eventType,
StateKey: stateKey,
Content: ev.Get("content"),
NID: latestPos,
Timestamp: ev.Get("origin_server_ts").Uint(),
Sender: ev.Get("sender").Str,
TransactionID: ev.Get("unsigned.transaction_id").Str,
}
}

View File

@ -42,7 +42,8 @@ type ConnState struct {
// roomID -> latest load pos
loadPositions map[string]int64
live *connStateLive
txnIDWaiter *TxnIDWaiter
live *connStateLive
globalCache *caches.GlobalCache
userCache *caches.UserCache
@ -52,13 +53,14 @@ type ConnState struct {
joinChecker JoinChecker
extensionsHandler extensions.HandlerInterface
setupHistogramVec *prometheus.HistogramVec
processHistogramVec *prometheus.HistogramVec
}
func NewConnState(
userID, deviceID string, userCache *caches.UserCache, globalCache *caches.GlobalCache,
ex extensions.HandlerInterface, joinChecker JoinChecker, histVec *prometheus.HistogramVec,
maxPendingEventUpdates int,
ex extensions.HandlerInterface, joinChecker JoinChecker, setupHistVec *prometheus.HistogramVec, histVec *prometheus.HistogramVec,
maxPendingEventUpdates int, maxTransactionIDDelay time.Duration,
) *ConnState {
cs := &ConnState{
globalCache: globalCache,
@ -72,12 +74,20 @@ func NewConnState(
extensionsHandler: ex,
joinChecker: joinChecker,
lazyCache: NewLazyCache(),
setupHistogramVec: setupHistVec,
processHistogramVec: histVec,
}
cs.live = &connStateLive{
ConnState: cs,
updates: make(chan caches.Update, maxPendingEventUpdates),
}
cs.txnIDWaiter = NewTxnIDWaiter(
userID,
maxTransactionIDDelay,
func(delayed bool, update caches.Update) {
cs.live.onUpdate(update)
},
)
// subscribe for updates before loading. We risk seeing dupes but that's fine as load positions
// will stop us double-processing.
cs.userCacheID = cs.userCache.Subsribe(cs)
@ -160,7 +170,7 @@ func (s *ConnState) load(ctx context.Context, req *sync3.Request) error {
}
// OnIncomingRequest is guaranteed to be called sequentially (it's protected by a mutex in conn.go)
func (s *ConnState) OnIncomingRequest(ctx context.Context, cid sync3.ConnID, req *sync3.Request, isInitial bool) (*sync3.Response, error) {
func (s *ConnState) OnIncomingRequest(ctx context.Context, cid sync3.ConnID, req *sync3.Request, isInitial bool, start time.Time) (*sync3.Response, error) {
if s.anchorLoadPosition <= 0 {
// load() needs no ctx so drop it
_, region := internal.StartSpan(ctx, "load")
@ -172,45 +182,50 @@ func (s *ConnState) OnIncomingRequest(ctx context.Context, cid sync3.ConnID, req
}
region.End()
}
setupTime := time.Since(start)
s.trackSetupDuration(setupTime, isInitial)
return s.onIncomingRequest(ctx, req, isInitial)
}
// onIncomingRequest is a callback which fires when the client makes a request to the server. Whilst each request may
// be on their own goroutine, the requests are linearised for us by Conn so it is safe to modify ConnState without
// additional locking mechanisms.
func (s *ConnState) onIncomingRequest(ctx context.Context, req *sync3.Request, isInitial bool) (*sync3.Response, error) {
func (s *ConnState) onIncomingRequest(reqCtx context.Context, req *sync3.Request, isInitial bool) (*sync3.Response, error) {
start := time.Now()
// ApplyDelta works fine if s.muxedReq is nil
var delta *sync3.RequestDelta
s.muxedReq, delta = s.muxedReq.ApplyDelta(req)
internal.Logf(ctx, "connstate", "new subs=%v unsubs=%v num_lists=%v", len(delta.Subs), len(delta.Unsubs), len(delta.Lists))
internal.Logf(reqCtx, "connstate", "new subs=%v unsubs=%v num_lists=%v", len(delta.Subs), len(delta.Unsubs), len(delta.Lists))
for key, l := range delta.Lists {
listData := ""
if l.Curr != nil {
listDataBytes, _ := json.Marshal(l.Curr)
listData = string(listDataBytes)
}
internal.Logf(ctx, "connstate", "list[%v] prev_empty=%v curr=%v", key, l.Prev == nil, listData)
internal.Logf(reqCtx, "connstate", "list[%v] prev_empty=%v curr=%v", key, l.Prev == nil, listData)
}
for roomID, sub := range s.muxedReq.RoomSubscriptions {
internal.Logf(reqCtx, "connstate", "room sub[%v] %v", roomID, sub)
}
// work out which rooms we'll return data for and add their relevant subscriptions to the builder
// for it to mix together
builder := NewRoomsBuilder()
// works out which rooms are subscribed to but doesn't pull room data
s.buildRoomSubscriptions(ctx, builder, delta.Subs, delta.Unsubs)
s.buildRoomSubscriptions(reqCtx, builder, delta.Subs, delta.Unsubs)
// works out how rooms get moved about but doesn't pull room data
respLists := s.buildListSubscriptions(ctx, builder, delta.Lists)
respLists := s.buildListSubscriptions(reqCtx, builder, delta.Lists)
// pull room data and set changes on the response
response := &sync3.Response{
Rooms: s.buildRooms(ctx, builder.BuildSubscriptions()), // pull room data
Rooms: s.buildRooms(reqCtx, builder.BuildSubscriptions()), // pull room data
Lists: respLists,
}
// Handle extensions AFTER processing lists as extensions may need to know which rooms the client
// is being notified about (e.g. for room account data)
ctx, region := internal.StartSpan(ctx, "extensions")
response.Extensions = s.extensionsHandler.Handle(ctx, s.muxedReq.Extensions, extensions.Context{
extCtx, region := internal.StartSpan(reqCtx, "extensions")
response.Extensions = s.extensionsHandler.Handle(extCtx, s.muxedReq.Extensions, extensions.Context{
UserID: s.userID,
DeviceID: s.deviceID,
RoomIDToTimeline: response.RoomIDsToTimelineEventIDs(),
@ -231,8 +246,8 @@ func (s *ConnState) onIncomingRequest(ctx context.Context, req *sync3.Request, i
}
// do live tracking if we have nothing to tell the client yet
ctx, region = internal.StartSpan(ctx, "liveUpdate")
s.live.liveUpdate(ctx, req, s.muxedReq.Extensions, isInitial, response)
updateCtx, region := internal.StartSpan(reqCtx, "liveUpdate")
s.live.liveUpdate(updateCtx, req, s.muxedReq.Extensions, isInitial, response)
region.End()
// counts are AFTER events are applied, hence after liveUpdate
@ -245,7 +260,7 @@ func (s *ConnState) onIncomingRequest(ctx context.Context, req *sync3.Request, i
// Add membership events for users sending typing notifications. We do this after live update
// and initial room loading code so we LL room members in all cases.
if response.Extensions.Typing != nil && response.Extensions.Typing.HasData(isInitial) {
s.lazyLoadTypingMembers(ctx, response)
s.lazyLoadTypingMembers(reqCtx, response)
}
return response, nil
}
@ -453,6 +468,12 @@ func (s *ConnState) buildRooms(ctx context.Context, builtSubs []BuiltSubscriptio
ctx, span := internal.StartSpan(ctx, "buildRooms")
defer span.End()
result := make(map[string]sync3.Room)
var bumpEventTypes []string
for _, x := range s.muxedReq.Lists {
bumpEventTypes = append(bumpEventTypes, x.BumpEventTypes...)
}
for _, bs := range builtSubs {
roomIDs := bs.RoomIDs
if bs.RoomSubscription.IncludeOldRooms != nil {
@ -477,14 +498,23 @@ func (s *ConnState) buildRooms(ctx context.Context, builtSubs []BuiltSubscriptio
}
}
}
// old rooms use a different subscription
oldRooms := s.getInitialRoomData(ctx, *bs.RoomSubscription.IncludeOldRooms, oldRoomIDs...)
for oldRoomID, oldRoom := range oldRooms {
result[oldRoomID] = oldRoom
// If we have old rooms to fetch, do so.
if len(oldRoomIDs) > 0 {
// old rooms use a different subscription
oldRooms := s.getInitialRoomData(ctx, *bs.RoomSubscription.IncludeOldRooms, bumpEventTypes, oldRoomIDs...)
for oldRoomID, oldRoom := range oldRooms {
result[oldRoomID] = oldRoom
}
}
}
rooms := s.getInitialRoomData(ctx, bs.RoomSubscription, roomIDs...)
// There won't be anything to fetch, try the next subscription.
if len(roomIDs) == 0 {
continue
}
rooms := s.getInitialRoomData(ctx, bs.RoomSubscription, bumpEventTypes, roomIDs...)
for roomID, room := range rooms {
result[roomID] = room
}
@ -521,7 +551,7 @@ func (s *ConnState) lazyLoadTypingMembers(ctx context.Context, response *sync3.R
}
}
func (s *ConnState) getInitialRoomData(ctx context.Context, roomSub sync3.RoomSubscription, roomIDs ...string) map[string]sync3.Room {
func (s *ConnState) getInitialRoomData(ctx context.Context, roomSub sync3.RoomSubscription, bumpEventTypes []string, roomIDs ...string) map[string]sync3.Room {
ctx, span := internal.StartSpan(ctx, "getInitialRoomData")
defer span.End()
rooms := make(map[string]sync3.Room, len(roomIDs))
@ -590,8 +620,40 @@ func (s *ConnState) getInitialRoomData(ctx context.Context, roomSub sync3.RoomSu
requiredState = make([]json.RawMessage, 0)
}
}
// Get the highest timestamp, determined by bumpEventTypes,
// for this room
roomListsMeta := s.lists.ReadOnlyRoom(roomID)
var maxTs uint64
for _, t := range bumpEventTypes {
if roomListsMeta == nil {
break
}
evMeta := roomListsMeta.LatestEventsByType[t]
if evMeta.Timestamp > maxTs {
maxTs = evMeta.Timestamp
}
}
// If we didn't find any events which would update the timestamp
// use the join event timestamp instead. Also don't leak
// timestamp from before we joined.
if maxTs == 0 || maxTs < roomListsMeta.JoinTiming.Timestamp {
if roomListsMeta != nil {
maxTs = roomListsMeta.JoinTiming.Timestamp
// If no bumpEventTypes are specified, use the
// LastMessageTimestamp so clients are still able
// to correctly sort on it.
if len(bumpEventTypes) == 0 {
maxTs = roomListsMeta.LastMessageTimestamp
}
}
}
rooms[roomID] = sync3.Room{
Name: internal.CalculateRoomName(metadata, 5), // TODO: customisable?
AvatarChange: sync3.NewAvatarChange(internal.CalculateAvatar(metadata)),
NotificationCount: int64(userRoomData.NotificationCount),
HighlightCount: int64(userRoomData.HighlightCount),
Timeline: roomToTimeline[roomID],
@ -600,8 +662,9 @@ func (s *ConnState) getInitialRoomData(ctx context.Context, roomSub sync3.RoomSu
Initial: true,
IsDM: userRoomData.IsDM,
JoinedCount: metadata.JoinCount,
InvitedCount: metadata.InviteCount,
InvitedCount: &metadata.InviteCount,
PrevBatch: userRoomData.RequestedLatestEvents.PrevBatch,
Timestamp: maxTs,
}
}
@ -613,6 +676,17 @@ func (s *ConnState) getInitialRoomData(ctx context.Context, roomSub sync3.RoomSu
return rooms
}
func (s *ConnState) trackSetupDuration(dur time.Duration, isInitial bool) {
if s.setupHistogramVec == nil {
return
}
val := "0"
if isInitial {
val = "1"
}
s.setupHistogramVec.WithLabelValues(val).Observe(float64(dur.Seconds()))
}
func (s *ConnState) trackProcessDuration(dur time.Duration, isInitial bool) {
if s.processHistogramVec == nil {
return
@ -638,7 +712,8 @@ func (s *ConnState) UserID() string {
}
func (s *ConnState) OnUpdate(ctx context.Context, up caches.Update) {
s.live.onUpdate(up)
// will eventually call s.live.onUpdate
s.txnIDWaiter.Ingest(up)
}
// Called by the user cache when updates arrive
@ -654,15 +729,19 @@ func (s *ConnState) OnRoomUpdate(ctx context.Context, up caches.RoomUpdate) {
}
internal.AssertWithContext(ctx, "missing global room metadata", update.GlobalRoomMetadata() != nil)
internal.Logf(ctx, "connstate", "queued update %d", update.EventData.NID)
s.live.onUpdate(update)
s.OnUpdate(ctx, update)
case caches.RoomUpdate:
internal.AssertWithContext(ctx, "missing global room metadata", update.GlobalRoomMetadata() != nil)
s.live.onUpdate(update)
s.OnUpdate(ctx, update)
default:
logger.Warn().Str("room_id", up.RoomID()).Msg("OnRoomUpdate unknown update type")
}
}
func (s *ConnState) PublishEventsUpTo(roomID string, nid int64) {
s.txnIDWaiter.PublishUpToNID(roomID, nid)
}
// clampSliceRangeToListSize helps us to send client-friendly SYNC and INVALIDATE ranges.
//
// Suppose the client asks for a window on positions [10, 19]. If the list

View File

@ -37,7 +37,7 @@ func (s *connStateLive) onUpdate(up caches.Update) {
select {
case s.updates <- up:
case <-time.After(BufferWaitTime):
logger.Warn().Interface("update", up).Str("user", s.userID).Msg(
logger.Warn().Interface("update", up).Str("user", s.userID).Str("device", s.deviceID).Msg(
"cannot send update to connection, buffer exceeded. Destroying connection.",
)
s.bufferFull = true
@ -57,6 +57,7 @@ func (s *connStateLive) liveUpdate(
if req.TimeoutMSecs() < 100 {
req.SetTimeoutMSecs(100)
}
startBufferSize := len(s.updates)
// block until we get a new event, with appropriate timeout
startTime := time.Now()
hasLiveStreamed := false
@ -81,43 +82,75 @@ func (s *connStateLive) liveUpdate(
return
case update := <-s.updates:
internal.Logf(ctx, "liveUpdate", "process live update")
s.processLiveUpdate(ctx, update, response)
// pass event to extensions AFTER processing
extCtx := extensions.Context{
IsInitial: false,
RoomIDToTimeline: response.RoomIDsToTimelineEventIDs(),
UserID: s.userID,
DeviceID: s.deviceID,
RoomIDsToLists: s.lists.ListsByVisibleRoomIDs(s.muxedReq.Lists),
AllSubscribedRooms: s.muxedReq.SubscribedRoomIDs(),
AllLists: s.muxedReq.ListKeys(),
}
s.extensionsHandler.HandleLiveUpdate(ctx, update, ex, &response.Extensions, extCtx)
s.processUpdate(ctx, update, response, ex)
// if there's more updates and we don't have lots stacked up already, go ahead and process another
for len(s.updates) > 0 && response.ListOps() < 50 {
update = <-s.updates
s.processLiveUpdate(ctx, update, response)
s.extensionsHandler.HandleLiveUpdate(ctx, update, ex, &response.Extensions, extCtx)
s.processUpdate(ctx, update, response, ex)
}
}
}
// If a client constantly changes their request params in every request they make, we will never consume from
// the update channel as the response will always have data already. In an effort to prevent starvation of new
// data, we will process some updates even though we have data already, but only if A) we didn't live stream
// due to natural circumstances, B) it isn't an initial request and C) there is in fact some data there.
numQueuedUpdates := len(s.updates)
if !hasLiveStreamed && !isInitial && numQueuedUpdates > 0 {
for i := 0; i < numQueuedUpdates; i++ {
update := <-s.updates
s.processUpdate(ctx, update, response, ex)
}
log.Debug().Int("num_queued", numQueuedUpdates).Msg("liveUpdate: caught up")
internal.Logf(ctx, "connstate", "liveUpdate caught up %d updates", numQueuedUpdates)
}
log.Trace().Bool("live_streamed", hasLiveStreamed).Msg("liveUpdate: returning")
internal.SetConnBufferInfo(ctx, startBufferSize, len(s.updates), cap(s.updates))
// TODO: op consolidation
}
func (s *connStateLive) processUpdate(ctx context.Context, update caches.Update, response *sync3.Response, ex extensions.Request) {
internal.Logf(ctx, "liveUpdate", "process live update %s", update.Type())
s.processLiveUpdate(ctx, update, response)
// pass event to extensions AFTER processing
roomIDsToLists := s.lists.ListsByVisibleRoomIDs(s.muxedReq.Lists)
s.extensionsHandler.HandleLiveUpdate(ctx, update, ex, &response.Extensions, extensions.Context{
IsInitial: false,
RoomIDToTimeline: response.RoomIDsToTimelineEventIDs(),
UserID: s.userID,
DeviceID: s.deviceID,
RoomIDsToLists: roomIDsToLists,
AllSubscribedRooms: s.muxedReq.SubscribedRoomIDs(),
AllLists: s.muxedReq.ListKeys(),
})
}
func (s *connStateLive) processLiveUpdate(ctx context.Context, up caches.Update, response *sync3.Response) bool {
internal.AssertWithContext(ctx, "processLiveUpdate: response list length != internal list length", s.lists.Len() == len(response.Lists))
internal.AssertWithContext(ctx, "processLiveUpdate: request list length != internal list length", s.lists.Len() == len(s.muxedReq.Lists))
roomUpdate, _ := up.(caches.RoomUpdate)
roomEventUpdate, _ := up.(*caches.RoomEventUpdate)
// if this is a room event update we may not want to process this if the event nid is < loadPos,
// as that means we have already taken it into account
if roomEventUpdate != nil && !roomEventUpdate.EventData.AlwaysProcess {
// check if we should skip this update. Do we know of this room (lp > 0) and if so, is this event
// behind what we've processed before?
lp := s.loadPositions[roomEventUpdate.RoomID()]
if lp > 0 && roomEventUpdate.EventData.NID < lp {
if roomEventUpdate != nil {
// if this is a room event update we may not want to process this event, for a few reasons.
if !roomEventUpdate.EventData.AlwaysProcess {
// check if we should skip this update. Do we know of this room (lp > 0) and if so, is this event
// behind what we've processed before?
lp := s.loadPositions[roomEventUpdate.RoomID()]
if lp > 0 && roomEventUpdate.EventData.NID < lp {
return false
}
}
// Skip message events from ignored users.
if roomEventUpdate.EventData.StateKey == nil && s.userCache.ShouldIgnore(roomEventUpdate.EventData.Sender) {
logger.Trace().
Str("user", s.userID).
Str("type", roomEventUpdate.EventData.EventType).
Str("sender", roomEventUpdate.EventData.Sender).
Msg("ignoring event update")
return false
}
}
@ -157,16 +190,42 @@ func (s *connStateLive) processLiveUpdate(ctx context.Context, up caches.Update,
// include this update in the rooms response TODO: filters on event type?
userRoomData := roomUpdate.UserRoomMetadata()
r := response.Rooms[roomUpdate.RoomID()]
// Get the highest timestamp, determined by bumpEventTypes,
// for this room
roomListsMeta := s.lists.ReadOnlyRoom(roomUpdate.RoomID())
var bumpEventTypes []string
for _, list := range s.muxedReq.Lists {
bumpEventTypes = append(bumpEventTypes, list.BumpEventTypes...)
}
for _, t := range bumpEventTypes {
evMeta := roomListsMeta.LatestEventsByType[t]
if evMeta.Timestamp > r.Timestamp {
r.Timestamp = evMeta.Timestamp
}
}
// If there are no bumpEventTypes defined, use the last message timestamp
if r.Timestamp == 0 && len(bumpEventTypes) == 0 {
r.Timestamp = roomUpdate.GlobalRoomMetadata().LastMessageTimestamp
}
// Make sure we don't leak a timestamp from before we joined
if r.Timestamp < roomListsMeta.JoinTiming.Timestamp {
r.Timestamp = roomListsMeta.JoinTiming.Timestamp
}
r.HighlightCount = int64(userRoomData.HighlightCount)
r.NotificationCount = int64(userRoomData.NotificationCount)
if roomEventUpdate != nil && roomEventUpdate.EventData.Event != nil {
r.NumLive++
advancedPastEvent := false
if roomEventUpdate.EventData.NID <= s.loadPositions[roomEventUpdate.RoomID()] {
// this update has been accounted for by the initial:true room snapshot
advancedPastEvent = true
if !roomEventUpdate.EventData.AlwaysProcess {
if roomEventUpdate.EventData.NID <= s.loadPositions[roomEventUpdate.RoomID()] {
// this update has been accounted for by the initial:true room snapshot
advancedPastEvent = true
}
s.loadPositions[roomEventUpdate.RoomID()] = roomEventUpdate.EventData.NID
}
s.loadPositions[roomEventUpdate.RoomID()] = roomEventUpdate.EventData.NID
// we only append to the timeline if we haven't already got this event. This can happen when:
// - 2 live events for a room mid-connection
// - next request bumps a room from outside to inside the window
@ -203,8 +262,13 @@ func (s *connStateLive) processLiveUpdate(ctx context.Context, up caches.Update,
metadata.RemoveHero(s.userID)
thisRoom.Name = internal.CalculateRoomName(metadata, 5) // TODO: customisable?
}
if delta.RoomAvatarChanged {
metadata := roomUpdate.GlobalRoomMetadata()
metadata.RemoveHero(s.userID)
thisRoom.AvatarChange = sync3.NewAvatarChange(internal.CalculateAvatar(metadata))
}
if delta.InviteCountChanged {
thisRoom.InvitedCount = roomUpdate.GlobalRoomMetadata().InviteCount
thisRoom.InvitedCount = &roomUpdate.GlobalRoomMetadata().InviteCount
}
if delta.JoinCountChanged {
thisRoom.JoinedCount = roomUpdate.GlobalRoomMetadata().JoinCount
@ -275,8 +339,17 @@ func (s *connStateLive) processGlobalUpdates(ctx context.Context, builder *Rooms
}
}
metadata := rup.GlobalRoomMetadata().CopyHeroes()
metadata.RemoveHero(s.userID)
// TODO: if we change a room from being a DM to not being a DM, we should call
// SetRoom and recalculate avatars. To do that we'd need to
// - listen to m.direct global account data events
// - compute the symmetric difference between old and new
// - call SetRooms for each room in the difference.
// I'm assuming this happens so rarely that we can ignore this for now. PRs
// welcome if you a strong opinion to the contrary.
delta = s.lists.SetRoom(sync3.RoomConnMetadata{
RoomMetadata: *rup.GlobalRoomMetadata(),
RoomMetadata: *metadata,
UserRoomData: *rup.UserRoomMetadata(),
LastInterestedEventTimestamps: bumpTimestampInList,
})

View File

@ -107,7 +107,7 @@ func TestConnStateInitial(t *testing.T) {
}
return result
}
cs := NewConnState(userID, deviceID, userCache, globalCache, &NopExtensionHandler{}, &NopJoinTracker{}, nil, 1000)
cs := NewConnState(userID, deviceID, userCache, globalCache, &NopExtensionHandler{}, &NopJoinTracker{}, nil, nil, 1000, 0)
if userID != cs.UserID() {
t.Fatalf("UserID returned wrong value, got %v want %v", cs.UserID(), userID)
}
@ -118,7 +118,7 @@ func TestConnStateInitial(t *testing.T) {
{0, 9},
}),
}},
}, false)
}, false, time.Now())
if err != nil {
t.Fatalf("OnIncomingRequest returned error : %s", err)
}
@ -168,7 +168,7 @@ func TestConnStateInitial(t *testing.T) {
{0, 9},
}),
}},
}, false)
}, false, time.Now())
if err != nil {
t.Fatalf("OnIncomingRequest returned error : %s", err)
}
@ -206,7 +206,7 @@ func TestConnStateInitial(t *testing.T) {
{0, 9},
}),
}},
}, false)
}, false, time.Now())
if err != nil {
t.Fatalf("OnIncomingRequest returned error : %s", err)
}
@ -272,7 +272,7 @@ func TestConnStateMultipleRanges(t *testing.T) {
userCache.LazyRoomDataOverride = mockLazyRoomOverride
dispatcher.Register(context.Background(), userCache.UserID, userCache)
dispatcher.Register(context.Background(), sync3.DispatcherAllUsers, globalCache)
cs := NewConnState(userID, deviceID, userCache, globalCache, &NopExtensionHandler{}, &NopJoinTracker{}, nil, 1000)
cs := NewConnState(userID, deviceID, userCache, globalCache, &NopExtensionHandler{}, &NopJoinTracker{}, nil, nil, 1000, 0)
// request first page
res, err := cs.OnIncomingRequest(context.Background(), ConnID, &sync3.Request{
@ -282,7 +282,7 @@ func TestConnStateMultipleRanges(t *testing.T) {
{0, 2},
}),
}},
}, false)
}, false, time.Now())
if err != nil {
t.Fatalf("OnIncomingRequest returned error : %s", err)
}
@ -308,7 +308,7 @@ func TestConnStateMultipleRanges(t *testing.T) {
{0, 2}, {4, 6},
}),
}},
}, false)
}, false, time.Now())
if err != nil {
t.Fatalf("OnIncomingRequest returned error : %s", err)
}
@ -343,7 +343,7 @@ func TestConnStateMultipleRanges(t *testing.T) {
{0, 2}, {4, 6},
}),
}},
}, false)
}, false, time.Now())
if err != nil {
t.Fatalf("OnIncomingRequest returned error : %s", err)
}
@ -383,7 +383,7 @@ func TestConnStateMultipleRanges(t *testing.T) {
{0, 2}, {4, 6},
}),
}},
}, false)
}, false, time.Now())
if err != nil {
t.Fatalf("OnIncomingRequest returned error : %s", err)
}
@ -451,7 +451,7 @@ func TestBumpToOutsideRange(t *testing.T) {
userCache.LazyRoomDataOverride = mockLazyRoomOverride
dispatcher.Register(context.Background(), userCache.UserID, userCache)
dispatcher.Register(context.Background(), sync3.DispatcherAllUsers, globalCache)
cs := NewConnState(userID, deviceID, userCache, globalCache, &NopExtensionHandler{}, &NopJoinTracker{}, nil, 1000)
cs := NewConnState(userID, deviceID, userCache, globalCache, &NopExtensionHandler{}, &NopJoinTracker{}, nil, nil, 1000, 0)
// Ask for A,B
res, err := cs.OnIncomingRequest(context.Background(), ConnID, &sync3.Request{
Lists: map[string]sync3.RequestList{"a": {
@ -460,7 +460,7 @@ func TestBumpToOutsideRange(t *testing.T) {
{0, 1},
}),
}},
}, false)
}, false, time.Now())
if err != nil {
t.Fatalf("OnIncomingRequest returned error : %s", err)
}
@ -495,7 +495,7 @@ func TestBumpToOutsideRange(t *testing.T) {
{0, 1},
}),
}},
}, false)
}, false, time.Now())
if err != nil {
t.Fatalf("OnIncomingRequest returned error : %s", err)
}
@ -562,7 +562,7 @@ func TestConnStateRoomSubscriptions(t *testing.T) {
}
dispatcher.Register(context.Background(), userCache.UserID, userCache)
dispatcher.Register(context.Background(), sync3.DispatcherAllUsers, globalCache)
cs := NewConnState(userID, deviceID, userCache, globalCache, &NopExtensionHandler{}, &NopJoinTracker{}, nil, 1000)
cs := NewConnState(userID, deviceID, userCache, globalCache, &NopExtensionHandler{}, &NopJoinTracker{}, nil, nil, 1000, 0)
// subscribe to room D
res, err := cs.OnIncomingRequest(context.Background(), ConnID, &sync3.Request{
RoomSubscriptions: map[string]sync3.RoomSubscription{
@ -576,7 +576,7 @@ func TestConnStateRoomSubscriptions(t *testing.T) {
{0, 1},
}),
}},
}, false)
}, false, time.Now())
if err != nil {
t.Fatalf("OnIncomingRequest returned error : %s", err)
}
@ -630,7 +630,7 @@ func TestConnStateRoomSubscriptions(t *testing.T) {
{0, 1},
}),
}},
}, false)
}, false, time.Now())
if err != nil {
t.Fatalf("OnIncomingRequest returned error : %s", err)
}
@ -664,7 +664,7 @@ func TestConnStateRoomSubscriptions(t *testing.T) {
{0, 1},
}),
}},
}, false)
}, false, time.Now())
if err != nil {
t.Fatalf("OnIncomingRequest returned error : %s", err)
}

View File

@ -1,9 +1,13 @@
package handler
import (
"github.com/matrix-org/sliding-sync/sync2"
"context"
"sync"
"github.com/matrix-org/sliding-sync/internal"
"github.com/matrix-org/sliding-sync/sync2"
"github.com/prometheus/client_golang/prometheus"
"github.com/matrix-org/sliding-sync/pubsub"
)
@ -28,23 +32,38 @@ type EnsurePoller struct {
// pendingPolls tracks the status of pollers that we are waiting to start.
pendingPolls map[sync2.PollerID]pendingInfo
notifier pubsub.Notifier
// the total number of outstanding ensurepolling requests.
numPendingEnsurePolling prometheus.Gauge
}
func NewEnsurePoller(notifier pubsub.Notifier) *EnsurePoller {
return &EnsurePoller{
func NewEnsurePoller(notifier pubsub.Notifier, enablePrometheus bool) *EnsurePoller {
p := &EnsurePoller{
chanName: pubsub.ChanV3,
mu: &sync.Mutex{},
pendingPolls: make(map[sync2.PollerID]pendingInfo),
notifier: notifier,
}
if enablePrometheus {
p.numPendingEnsurePolling = prometheus.NewGauge(prometheus.GaugeOpts{
Namespace: "sliding_sync",
Subsystem: "api",
Name: "num_devices_pending_ensure_polling",
Help: "Number of devices blocked on EnsurePolling returning.",
})
prometheus.MustRegister(p.numPendingEnsurePolling)
}
return p
}
// EnsurePolling blocks until the V2InitialSyncComplete response is received for this device. It is
// the caller's responsibility to call OnInitialSyncComplete when new events arrive.
func (p *EnsurePoller) EnsurePolling(pid sync2.PollerID, tokenHash string) {
func (p *EnsurePoller) EnsurePolling(ctx context.Context, pid sync2.PollerID, tokenHash string) {
ctx, region := internal.StartSpan(ctx, "EnsurePolling")
defer region.End()
p.mu.Lock()
// do we need to wait?
if p.pendingPolls[pid].done {
internal.Logf(ctx, "EnsurePolling", "user %s device %s already done", pid.UserID, pid.DeviceID)
p.mu.Unlock()
return
}
@ -56,7 +75,10 @@ func (p *EnsurePoller) EnsurePolling(pid sync2.PollerID, tokenHash string) {
// TODO: several times there have been problems getting the response back from the poller
// we should time out here after 100s and return an error or something to kick conns into
// trying again
internal.Logf(ctx, "EnsurePolling", "user %s device %s channel exits, listening for channel close", pid.UserID, pid.DeviceID)
_, r2 := internal.StartSpan(ctx, "waitForExistingChannelClose")
<-ch
r2.End()
return
}
// Make a channel to wait until we have done an initial sync
@ -65,6 +87,7 @@ func (p *EnsurePoller) EnsurePolling(pid sync2.PollerID, tokenHash string) {
done: false,
ch: ch,
}
p.calculateNumOutstanding() // increment total
p.mu.Unlock()
// ask the pollers to poll for this device
p.notifier.Notify(p.chanName, &pubsub.V3EnsurePolling{
@ -74,10 +97,15 @@ func (p *EnsurePoller) EnsurePolling(pid sync2.PollerID, tokenHash string) {
})
// if by some miracle the notify AND sync completes before we receive on ch then this is
// still fine as recv on a closed channel will return immediately.
internal.Logf(ctx, "EnsurePolling", "user %s device %s just made channel, listening for channel close", pid.UserID, pid.DeviceID)
_, r2 := internal.StartSpan(ctx, "waitForNewChannelClose")
<-ch
r2.End()
}
func (p *EnsurePoller) OnInitialSyncComplete(payload *pubsub.V2InitialSyncComplete) {
log := logger.With().Str("user", payload.UserID).Str("device", payload.DeviceID).Logger()
log.Trace().Msg("OnInitialSyncComplete: got payload")
pid := sync2.PollerID{UserID: payload.UserID, DeviceID: payload.DeviceID}
p.mu.Lock()
defer p.mu.Unlock()
@ -86,12 +114,14 @@ func (p *EnsurePoller) OnInitialSyncComplete(payload *pubsub.V2InitialSyncComple
if !ok {
// This can happen when the v2 poller spontaneously starts polling even without us asking it to
// e.g from the database
log.Trace().Msg("OnInitialSyncComplete: we weren't waiting for this")
p.pendingPolls[pid] = pendingInfo{
done: true,
}
return
}
if pending.done {
log.Trace().Msg("OnInitialSyncComplete: already done")
// nothing to do, we just got OnInitialSyncComplete called twice
return
}
@ -101,6 +131,8 @@ func (p *EnsurePoller) OnInitialSyncComplete(payload *pubsub.V2InitialSyncComple
pending.done = true
pending.ch = nil
p.pendingPolls[pid] = pending
p.calculateNumOutstanding() // decrement total
log.Trace().Msg("OnInitialSyncComplete: closing channel")
close(ch)
}
@ -121,4 +153,21 @@ func (p *EnsurePoller) OnExpiredToken(payload *pubsub.V2ExpiredToken) {
func (p *EnsurePoller) Teardown() {
p.notifier.Close()
if p.numPendingEnsurePolling != nil {
prometheus.Unregister(p.numPendingEnsurePolling)
}
}
// must hold p.mu
func (p *EnsurePoller) calculateNumOutstanding() {
if p.numPendingEnsurePolling == nil {
return
}
var total int
for _, pi := range p.pendingPolls {
if !pi.done {
total++
}
}
p.numPendingEnsurePolling.Set(float64(total))
}

View File

@ -1,6 +1,5 @@
package handler
import "C"
import (
"context"
"database/sql"
@ -59,25 +58,29 @@ type SyncLiveHandler struct {
GlobalCache *caches.GlobalCache
maxPendingEventUpdates int
maxTransactionIDDelay time.Duration
numConns prometheus.Gauge
histVec *prometheus.HistogramVec
setupHistVec *prometheus.HistogramVec
histVec *prometheus.HistogramVec
slowReqs prometheus.Counter
}
func NewSync3Handler(
store *state.Storage, storev2 *sync2.Storage, v2Client sync2.Client, postgresDBURI, secret string,
store *state.Storage, storev2 *sync2.Storage, v2Client sync2.Client, secret string,
pub pubsub.Notifier, sub pubsub.Listener, enablePrometheus bool, maxPendingEventUpdates int,
maxTransactionIDDelay time.Duration,
) (*SyncLiveHandler, error) {
logger.Info().Msg("creating handler")
sh := &SyncLiveHandler{
V2: v2Client,
Storage: store,
V2Store: storev2,
ConnMap: sync3.NewConnMap(),
ConnMap: sync3.NewConnMap(enablePrometheus),
userCaches: &sync.Map{},
Dispatcher: sync3.NewDispatcher(),
GlobalCache: caches.NewGlobalCache(store),
maxPendingEventUpdates: maxPendingEventUpdates,
maxTransactionIDDelay: maxTransactionIDDelay,
}
sh.Extensions = &extensions.Handler{
Store: store,
@ -91,7 +94,7 @@ func NewSync3Handler(
}
// set up pubsub mechanism to start from this point
sh.EnsurePoller = NewEnsurePoller(pub)
sh.EnsurePoller = NewEnsurePoller(pub, enablePrometheus)
sh.V2Sub = pubsub.NewV2Sub(sub, sh)
return sh, nil
@ -127,37 +130,41 @@ func (h *SyncLiveHandler) Teardown() {
h.V2Sub.Teardown()
h.EnsurePoller.Teardown()
h.ConnMap.Teardown()
if h.numConns != nil {
prometheus.Unregister(h.numConns)
if h.setupHistVec != nil {
prometheus.Unregister(h.setupHistVec)
}
if h.histVec != nil {
prometheus.Unregister(h.histVec)
}
}
func (h *SyncLiveHandler) updateMetrics() {
if h.numConns == nil {
return
if h.slowReqs != nil {
prometheus.Unregister(h.slowReqs)
}
h.numConns.Set(float64(h.ConnMap.Len()))
}
func (h *SyncLiveHandler) addPrometheusMetrics() {
h.numConns = prometheus.NewGauge(prometheus.GaugeOpts{
h.setupHistVec = prometheus.NewHistogramVec(prometheus.HistogramOpts{
Namespace: "sliding_sync",
Subsystem: "api",
Name: "num_active_conns",
Help: "Number of active sliding sync connections.",
})
Name: "setup_duration_secs",
Help: "Time taken in seconds after receiving a request before we start calculating a sliding sync response.",
Buckets: []float64{0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5, 10},
}, []string{"initial"})
h.histVec = prometheus.NewHistogramVec(prometheus.HistogramOpts{
Namespace: "sliding_sync",
Subsystem: "api",
Name: "process_duration_secs",
Help: "Time taken in seconds for the sliding sync response to calculated, excludes long polling",
Help: "Time taken in seconds for the sliding sync response to be calculated, excludes long polling",
Buckets: []float64{0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5, 10},
}, []string{"initial"})
prometheus.MustRegister(h.numConns)
h.slowReqs = prometheus.NewCounter(prometheus.CounterOpts{
Namespace: "sliding_sync",
Subsystem: "api",
Name: "slow_requests",
Help: "Counter of slow (>=50s) requests, initial or otherwise.",
})
prometheus.MustRegister(h.setupHistVec)
prometheus.MustRegister(h.histVec)
prometheus.MustRegister(h.slowReqs)
}
func (h *SyncLiveHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
@ -174,9 +181,13 @@ func (h *SyncLiveHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
Err: err,
}
}
// artificially wait a bit before sending back the error
// this guards against tightlooping when the client hammers the server with invalid requests
time.Sleep(time.Second)
if herr.ErrCode != "M_UNKNOWN_POS" {
// artificially wait a bit before sending back the error
// this guards against tightlooping when the client hammers the server with invalid requests,
// but not for M_UNKNOWN_POS which we expect to send back after expiring a client's connection.
// We want to recover rapidly in that scenario, hence not sleeping.
time.Sleep(time.Second)
}
w.WriteHeader(herr.StatusCode)
w.Write(herr.JSON())
}
@ -184,6 +195,16 @@ func (h *SyncLiveHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
// Entry point for sync v3
func (h *SyncLiveHandler) serve(w http.ResponseWriter, req *http.Request) error {
start := time.Now()
defer func() {
dur := time.Since(start)
if dur > 50*time.Second {
if h.slowReqs != nil {
h.slowReqs.Add(1.0)
}
internal.DecorateLogger(req.Context(), log.Warn()).Dur("duration", dur).Msg("slow request")
}
}()
var requestBody sync3.Request
if req.ContentLength != 0 {
defer req.Body.Close()
@ -233,7 +254,6 @@ func (h *SyncLiveHandler) serve(w http.ResponseWriter, req *http.Request) error
return herr
}
requestBody.SetPos(cpos)
internal.SetRequestContextUserID(req.Context(), conn.UserID)
log := hlog.FromRequest(req).With().Str("user", conn.UserID).Int64("pos", cpos).Logger()
var timeout int
@ -250,7 +270,7 @@ func (h *SyncLiveHandler) serve(w http.ResponseWriter, req *http.Request) error
requestBody.SetTimeoutMSecs(timeout)
log.Trace().Int("timeout", timeout).Msg("recv")
resp, herr := conn.OnIncomingRequest(req.Context(), &requestBody)
resp, herr := conn.OnIncomingRequest(req.Context(), &requestBody, start)
if herr != nil {
logErrorOrWarning("failed to OnIncomingRequest", herr)
return herr
@ -271,7 +291,7 @@ func (h *SyncLiveHandler) serve(w http.ResponseWriter, req *http.Request) error
}
internal.SetRequestContextResponseInfo(
req.Context(), cpos, resp.PosInt(), len(resp.Rooms), requestBody.TxnID, numToDeviceEvents, numGlobalAccountData,
numChangedDevices, numLeftDevices,
numChangedDevices, numLeftDevices, requestBody.ConnID, len(requestBody.Lists), len(requestBody.RoomSubscriptions), len(requestBody.UnsubscribeRooms),
)
w.Header().Set("Content-Type", "application/json")
@ -300,8 +320,9 @@ func (h *SyncLiveHandler) serve(w http.ResponseWriter, req *http.Request) error
// setupConnection associates this request with an existing connection or makes a new connection.
// It also sets a v2 sync poll loop going if one didn't exist already for this user.
// When this function returns, the connection is alive and active.
func (h *SyncLiveHandler) setupConnection(req *http.Request, syncReq *sync3.Request, containsPos bool) (*sync3.Conn, *internal.HandlerError) {
taskCtx, task := internal.StartTask(req.Context(), "setupConnection")
defer task.End()
var conn *sync3.Conn
// Extract an access token
accessToken, err := internal.ExtractAccessToken(req)
@ -333,6 +354,8 @@ func (h *SyncLiveHandler) setupConnection(req *http.Request, syncReq *sync3.Requ
}
}
log := hlog.FromRequest(req).With().Str("user", token.UserID).Str("device", token.DeviceID).Logger()
internal.SetRequestContextUserID(req.Context(), token.UserID, token.DeviceID)
internal.Logf(taskCtx, "setupConnection", "identified access token as user=%s device=%s", token.UserID, token.DeviceID)
// Record the fact that we've recieved a request from this token
err = h.V2Store.TokensTable.MaybeUpdateLastSeen(token, time.Now())
@ -358,9 +381,9 @@ func (h *SyncLiveHandler) setupConnection(req *http.Request, syncReq *sync3.Requ
return nil, internal.ExpiredSessionError()
}
log.Trace().Msg("checking poller exists and is running")
pid := sync2.PollerID{UserID: token.UserID, DeviceID: token.DeviceID}
h.EnsurePoller.EnsurePolling(pid, token.AccessTokenHash)
log.Trace().Any("pid", pid).Msg("checking poller exists and is running")
h.EnsurePoller.EnsurePolling(req.Context(), pid, token.AccessTokenHash)
log.Trace().Msg("poller exists and is running")
// this may take a while so if the client has given up (e.g timed out) by this point, just stop.
// We'll be quicker next time as the poller will already exist.
@ -382,7 +405,7 @@ func (h *SyncLiveHandler) setupConnection(req *http.Request, syncReq *sync3.Requ
}
// once we have the conn, make sure our metrics are correct
defer h.updateMetrics()
defer h.ConnMap.UpdateMetrics()
// Now the v2 side of things are running, we can make a v3 live sync conn
// NB: this isn't inherently racey (we did the check for an existing conn before EnsurePolling)
@ -390,7 +413,7 @@ func (h *SyncLiveHandler) setupConnection(req *http.Request, syncReq *sync3.Requ
// to check for an existing connection though, as it's possible for the client to call /sync
// twice for a new connection.
conn, created := h.ConnMap.CreateConn(connID, func() sync3.ConnHandler {
return NewConnState(token.UserID, token.DeviceID, userCache, h.GlobalCache, h.Extensions, h.Dispatcher, h.histVec, h.maxPendingEventUpdates)
return NewConnState(token.UserID, token.DeviceID, userCache, h.GlobalCache, h.Extensions, h.Dispatcher, h.setupHistVec, h.histVec, h.maxPendingEventUpdates, h.maxTransactionIDDelay)
})
if created {
log.Info().Msg("created new connection")
@ -408,6 +431,7 @@ func (h *SyncLiveHandler) identifyUnknownAccessToken(accessToken string, logger
return nil, &internal.HandlerError{
StatusCode: 401,
Err: fmt.Errorf("/whoami returned HTTP 401"),
ErrCode: "M_UNKNOWN_TOKEN",
}
}
log.Warn().Err(err).Msg("failed to get user ID from device ID")
@ -420,14 +444,14 @@ func (h *SyncLiveHandler) identifyUnknownAccessToken(accessToken string, logger
var token *sync2.Token
err = sqlutil.WithTransaction(h.V2Store.DB, func(txn *sqlx.Tx) error {
// Create a brand-new row for this token.
token, err = h.V2Store.TokensTable.Insert(accessToken, userID, deviceID, time.Now())
token, err = h.V2Store.TokensTable.Insert(txn, accessToken, userID, deviceID, time.Now())
if err != nil {
logger.Warn().Err(err).Str("user", userID).Str("device", deviceID).Msg("failed to insert v2 token")
return err
}
// Ensure we have a device row for this token.
err = h.V2Store.DevicesTable.InsertDevice(userID, deviceID)
err = h.V2Store.DevicesTable.InsertDevice(txn, userID, deviceID)
if err != nil {
log.Warn().Err(err).Str("user", userID).Str("device", deviceID).Msg("failed to insert v2 device")
return err
@ -484,6 +508,15 @@ func (h *SyncLiveHandler) userCache(userID string) (*caches.UserCache, error) {
uc.OnAccountData(context.Background(), []state.AccountData{directEvent[0]})
}
// select the ignored users account data event and set ignored user list
ignoreEvent, err := h.Storage.AccountData(userID, sync2.AccountDataGlobalRoom, []string{"m.ignored_user_list"})
if err != nil {
return nil, fmt.Errorf("failed to load ignored user list for user %s: %w", userID, err)
}
if len(ignoreEvent) == 1 {
uc.OnAccountData(context.Background(), []state.AccountData{ignoreEvent[0]})
}
// select all room tag account data and set it
tagEvents, err := h.Storage.RoomAccountDatasWithType(userID, "m.tag")
if err != nil {
@ -601,7 +634,11 @@ func (h *SyncLiveHandler) Accumulate(p *pubsub.V2Accumulate) {
func (h *SyncLiveHandler) OnTransactionID(p *pubsub.V2TransactionID) {
_, task := internal.StartTask(context.Background(), "TransactionID")
defer task.End()
// TODO implement me
// There is some event E for which we now have a transaction ID, or else now know
// that we will never get a transaction ID. In either case, tell the sender's
// connections to unblock that event in the transaction ID waiter.
h.ConnMap.ClearUpdateQueues(p.UserID, p.RoomID, p.NID)
}
// Called from the v2 poller, implements V2DataReceiver
@ -632,9 +669,14 @@ func (h *SyncLiveHandler) OnUnreadCounts(p *pubsub.V2UnreadCounts) {
func (h *SyncLiveHandler) OnDeviceData(p *pubsub.V2DeviceData) {
ctx, task := internal.StartTask(context.Background(), "OnDeviceData")
defer task.End()
conns := h.ConnMap.Conns(p.UserID, p.DeviceID)
for _, conn := range conns {
conn.OnUpdate(ctx, caches.DeviceDataUpdate{})
internal.Logf(ctx, "device_data", fmt.Sprintf("%v users to notify", len(p.UserIDToDeviceIDs)))
for userID, deviceIDs := range p.UserIDToDeviceIDs {
for _, deviceID := range deviceIDs {
conns := h.ConnMap.Conns(userID, deviceID)
for _, conn := range conns {
conn.OnUpdate(ctx, caches.DeviceDataUpdate{})
}
}
}
}
@ -671,7 +713,7 @@ func (h *SyncLiveHandler) OnLeftRoom(p *pubsub.V2LeaveRoom) {
if !ok {
return
}
userCache.(*caches.UserCache).OnLeftRoom(ctx, p.RoomID)
userCache.(*caches.UserCache).OnLeftRoom(ctx, p.RoomID, p.LeaveEvent)
}
func (h *SyncLiveHandler) OnReceipt(p *pubsub.V2Receipt) {

View File

@ -0,0 +1,94 @@
package handler
import (
"github.com/matrix-org/sliding-sync/sync3/caches"
"sync"
"time"
)
type TxnIDWaiter struct {
userID string
publish func(delayed bool, update caches.Update)
// mu guards the queues map.
mu sync.Mutex
queues map[string][]*caches.RoomEventUpdate
maxDelay time.Duration
}
func NewTxnIDWaiter(userID string, maxDelay time.Duration, publish func(bool, caches.Update)) *TxnIDWaiter {
return &TxnIDWaiter{
userID: userID,
publish: publish,
mu: sync.Mutex{},
queues: make(map[string][]*caches.RoomEventUpdate),
maxDelay: maxDelay,
// TODO: metric that tracks how long events were queued for.
}
}
func (t *TxnIDWaiter) Ingest(up caches.Update) {
if t.maxDelay <= 0 {
t.publish(false, up)
return
}
eventUpdate, isEventUpdate := up.(*caches.RoomEventUpdate)
if !isEventUpdate {
t.publish(false, up)
return
}
ed := eventUpdate.EventData
// An event should be queued if
// - it's a state event that our user sent, lacking a txn_id; OR
// - the room already has queued events.
t.mu.Lock()
defer t.mu.Unlock()
_, roomQueued := t.queues[ed.RoomID]
missingTxnID := ed.StateKey == nil && ed.Sender == t.userID && ed.TransactionID == ""
if !(missingTxnID || roomQueued) {
t.publish(false, up)
return
}
// We've decided to queue the event.
queue, exists := t.queues[ed.RoomID]
if !exists {
queue = make([]*caches.RoomEventUpdate, 0, 10)
}
// TODO: bound the queue size?
t.queues[ed.RoomID] = append(queue, eventUpdate)
time.AfterFunc(t.maxDelay, func() { t.PublishUpToNID(ed.RoomID, ed.NID) })
}
func (t *TxnIDWaiter) PublishUpToNID(roomID string, publishNID int64) {
t.mu.Lock()
defer t.mu.Unlock()
queue, exists := t.queues[roomID]
if !exists {
return
}
var i int
for i = 0; i < len(queue); i++ {
// Scan forwards through the queue until we find an event with nid > publishNID.
if queue[i].EventData.NID > publishNID {
break
}
}
// Now queue[:i] has events with nid <= publishNID, and queue[i:] has nids > publishNID.
// strip off the first i events from the slice and publish them.
toPublish, queue := queue[:i], queue[i:]
if len(queue) == 0 {
delete(t.queues, roomID)
} else {
t.queues[roomID] = queue
}
for _, eventUpdate := range toPublish {
t.publish(true, eventUpdate)
}
}

View File

@ -0,0 +1,390 @@
package handler
import (
"github.com/matrix-org/sliding-sync/sync3/caches"
"github.com/tidwall/gjson"
"testing"
"time"
)
type publishArg struct {
delayed bool
update caches.Update
}
// Test that
// - events are (reported as being) delayed when we expect them to be
// - delayed events are automatically published after the maximum delay period
func TestTxnIDWaiter_QueuingLogic(t *testing.T) {
const alice = "alice"
const bob = "bob"
const room1 = "!theroom"
const room2 = "!daszimmer"
testCases := []struct {
Name string
Ingest []caches.Update
WaitForUpdate int
ExpectDelayed bool
}{
{
Name: "empty queue, non-event update",
Ingest: []caches.Update{&caches.AccountDataUpdate{}},
WaitForUpdate: 0,
ExpectDelayed: false,
},
{
Name: "empty queue, event update, another sender",
Ingest: []caches.Update{
&caches.RoomEventUpdate{
EventData: &caches.EventData{
RoomID: room1,
Sender: bob,
},
}},
WaitForUpdate: 0,
ExpectDelayed: false,
},
{
Name: "empty queue, event update, has txn_id",
Ingest: []caches.Update{
&caches.RoomEventUpdate{
EventData: &caches.EventData{
RoomID: room1,
Sender: alice,
TransactionID: "txntxntxn",
},
}},
WaitForUpdate: 0,
ExpectDelayed: false,
},
{
Name: "empty queue, event update, no txn_id",
Ingest: []caches.Update{
&caches.RoomEventUpdate{
EventData: &caches.EventData{
RoomID: room1,
Sender: alice,
TransactionID: "",
},
}},
WaitForUpdate: 0,
ExpectDelayed: true,
},
{
Name: "nonempty queue, non-event update",
Ingest: []caches.Update{
&caches.RoomEventUpdate{
EventData: &caches.EventData{
RoomID: room1,
Sender: alice,
TransactionID: "",
NID: 1,
},
},
&caches.AccountDataUpdate{},
},
WaitForUpdate: 1,
ExpectDelayed: false, // not a room event, no need to queued behind alice's event
},
{
Name: "empty queue, join event for sender",
Ingest: []caches.Update{
&caches.RoomEventUpdate{
EventData: &caches.EventData{
RoomID: room1,
Sender: alice,
TransactionID: "",
NID: 1,
EventType: "m.room.member",
StateKey: ptr(alice),
Content: gjson.Parse(`{"membership": "join"}`),
},
},
},
WaitForUpdate: 0,
ExpectDelayed: false,
},
{
Name: "nonempty queue, join event for sender",
Ingest: []caches.Update{
&caches.RoomEventUpdate{
EventData: &caches.EventData{
RoomID: room1,
Sender: alice,
TransactionID: "",
NID: 1,
},
},
&caches.RoomEventUpdate{
EventData: &caches.EventData{
RoomID: room1,
Sender: alice,
TransactionID: "",
NID: 2,
EventType: "m.room.member",
StateKey: ptr(alice),
Content: gjson.Parse(`{"membership": "join"}`),
},
},
},
WaitForUpdate: 1,
ExpectDelayed: true,
},
{
Name: "nonempty queue, event update, different sender",
Ingest: []caches.Update{
&caches.RoomEventUpdate{
EventData: &caches.EventData{
RoomID: room1,
Sender: alice,
TransactionID: "",
NID: 1,
},
},
&caches.RoomEventUpdate{
EventData: &caches.EventData{
RoomID: room1,
Sender: bob,
NID: 2,
},
},
},
WaitForUpdate: 1,
ExpectDelayed: true, // should be queued behind alice's event
},
{
Name: "nonempty queue, event update, has txn_id",
Ingest: []caches.Update{
&caches.RoomEventUpdate{
EventData: &caches.EventData{
RoomID: room1,
Sender: alice,
TransactionID: "",
NID: 1,
},
},
&caches.RoomEventUpdate{
EventData: &caches.EventData{
RoomID: room1,
Sender: alice,
NID: 2,
TransactionID: "I have a txn",
},
},
},
WaitForUpdate: 1,
ExpectDelayed: true, // should still be queued behind alice's first event
},
{
Name: "existence of queue only matters per-room",
Ingest: []caches.Update{
&caches.RoomEventUpdate{
EventData: &caches.EventData{
RoomID: room1,
Sender: alice,
TransactionID: "",
NID: 1,
},
},
&caches.RoomEventUpdate{
EventData: &caches.EventData{
RoomID: room2,
Sender: alice,
NID: 2,
TransactionID: "I have a txn",
},
},
},
WaitForUpdate: 1,
ExpectDelayed: false, // queue only tracks room1
},
}
for _, tc := range testCases {
t.Run(tc.Name, func(t *testing.T) {
updates := make(chan publishArg, 100)
publish := func(delayed bool, update caches.Update) {
updates <- publishArg{delayed, update}
}
w := NewTxnIDWaiter(alice, time.Millisecond, publish)
for _, up := range tc.Ingest {
w.Ingest(up)
}
wantedUpdate := tc.Ingest[tc.WaitForUpdate]
var got publishArg
WaitForSelectedUpdate:
for {
select {
case got = <-updates:
t.Logf("Got update %v", got.update)
if got.update == wantedUpdate {
break WaitForSelectedUpdate
}
case <-time.After(5 * time.Millisecond):
t.Fatalf("Did not see update %v published", wantedUpdate)
}
}
if got.delayed != tc.ExpectDelayed {
t.Errorf("Got delayed=%t want delayed=%t", got.delayed, tc.ExpectDelayed)
}
})
}
}
// Test that PublishUpToNID
// - correctly pops off the start of the queue
// - is idempotent
// - deletes map entry if queue is empty (so that roomQueued is set correctly)
func TestTxnIDWaiter_PublishUpToNID(t *testing.T) {
const alice = "@alice:example.com"
const room = "!unimportant"
var published []publishArg
publish := func(delayed bool, update caches.Update) {
published = append(published, publishArg{delayed, update})
}
// Use an hour's expiry to effectively disable expiry.
w := NewTxnIDWaiter(alice, time.Hour, publish)
// Ingest 5 events, each of which would be queued by themselves.
for i := int64(2); i <= 6; i++ {
w.Ingest(&caches.RoomEventUpdate{
EventData: &caches.EventData{
RoomID: room,
Sender: alice,
TransactionID: "",
NID: i,
},
})
}
t.Log("Queue has nids [2,3,4,5,6]")
t.Log("Publishing up to 1 should do nothing")
w.PublishUpToNID(room, 1)
assertNIDs(t, published, nil)
t.Log("Publishing up to 3 should yield nids [2, 3] in that order")
w.PublishUpToNID(room, 3)
assertNIDs(t, published, []int64{2, 3})
assertDelayed(t, published[:2])
t.Log("Publishing up to 3 a second time should do nothing")
w.PublishUpToNID(room, 3)
assertNIDs(t, published, []int64{2, 3})
t.Log("Publishing up to 2 at this point should do nothing.")
w.PublishUpToNID(room, 2)
assertNIDs(t, published, []int64{2, 3})
t.Log("Publishing up to 6 should yield nids [4, 5, 6] in that order")
w.PublishUpToNID(room, 6)
assertNIDs(t, published, []int64{2, 3, 4, 5, 6})
assertDelayed(t, published[2:5])
t.Log("Publishing up to 6 a second time should do nothing")
w.PublishUpToNID(room, 6)
assertNIDs(t, published, []int64{2, 3, 4, 5, 6})
t.Log("Ingesting another event that doesn't need to be queueing should be published immediately")
w.Ingest(&caches.RoomEventUpdate{
EventData: &caches.EventData{
RoomID: room,
Sender: "@notalice:example.com",
TransactionID: "",
NID: 7,
},
})
assertNIDs(t, published, []int64{2, 3, 4, 5, 6, 7})
if published[len(published)-1].delayed {
t.Errorf("Final event was delayed, but should have been published immediately")
}
}
// Test that PublishUpToNID only publishes in the given room
func TestTxnIDWaiter_PublishUpToNID_MultipleRooms(t *testing.T) {
const alice = "@alice:example.com"
var published []publishArg
publish := func(delayed bool, update caches.Update) {
published = append(published, publishArg{delayed, update})
}
// Use an hour's expiry to effectively disable expiry.
w := NewTxnIDWaiter(alice, time.Hour, publish)
// Ingest four queueable events across two rooms.
w.Ingest(&caches.RoomEventUpdate{
EventData: &caches.EventData{
RoomID: "!room1",
Sender: alice,
TransactionID: "",
NID: 1,
},
})
w.Ingest(&caches.RoomEventUpdate{
EventData: &caches.EventData{
RoomID: "!room2",
Sender: alice,
TransactionID: "",
NID: 2,
},
})
w.Ingest(&caches.RoomEventUpdate{
EventData: &caches.EventData{
RoomID: "!room2",
Sender: alice,
TransactionID: "",
NID: 3,
},
})
w.Ingest(&caches.RoomEventUpdate{
EventData: &caches.EventData{
RoomID: "!room1",
Sender: alice,
TransactionID: "",
NID: 4,
},
})
t.Log("Queues are [1, 4] and [2, 3]")
t.Log("Publish up to NID 4 in room 1 should yield nids [1, 4]")
w.PublishUpToNID("!room1", 4)
assertNIDs(t, published, []int64{1, 4})
assertDelayed(t, published)
t.Log("Queues are [1, 4] and [2, 3]")
t.Log("Publish up to NID 3 in room 2 should yield nids [2, 3]")
w.PublishUpToNID("!room2", 3)
assertNIDs(t, published, []int64{1, 4, 2, 3})
assertDelayed(t, published)
}
func assertDelayed(t *testing.T, published []publishArg) {
t.Helper()
for _, p := range published {
if !p.delayed {
t.Errorf("published arg with NID %d was not delayed, but we expected it to be", p.update.(*caches.RoomEventUpdate).EventData.NID)
}
}
}
func assertNIDs(t *testing.T, published []publishArg, expectedNIDs []int64) {
t.Helper()
if len(published) != len(expectedNIDs) {
t.Errorf("Got %d nids, but expected %d", len(published), len(expectedNIDs))
}
for i := range published {
rup, ok := published[i].update.(*caches.RoomEventUpdate)
if !ok {
t.Errorf("Update %d (%v) was not a RoomEventUpdate", i, published[i].update)
}
if rup.EventData.NID != expectedNIDs[i] {
t.Errorf("Update %d (%v) got nid %d, expected %d", i, *rup, rup.EventData.NID, expectedNIDs[i])
}
}
}
func ptr(s string) *string {
return &s
}

View File

@ -33,6 +33,7 @@ type RoomListDelta struct {
type RoomDelta struct {
RoomNameChanged bool
RoomAvatarChanged bool
JoinCountChanged bool
InviteCountChanged bool
NotificationCountChanged bool
@ -73,8 +74,20 @@ func (s *InternalRequestLists) SetRoom(r RoomConnMetadata) (delta RoomDelta) {
strings.Trim(internal.CalculateRoomName(&r.RoomMetadata, 5), "#!():_@"),
)
} else {
// XXX: during TestConnectionTimeoutNotReset there is some situation where
// r.CanonicalisedName is the empty string. Looking at the SetRoom
// call in connstate_live.go, this is because the UserRoomMetadata on
// the RoomUpdate has an empty CanonicalisedName. Either
// a) that is expected, in which case we should _always_ write to
// r.CanonicalisedName here; or
// b) that is not expected, in which case... erm, I don't know what
// to conclude.
r.CanonicalisedName = existing.CanonicalisedName
}
delta.RoomAvatarChanged = !existing.SameRoomAvatar(&r.RoomMetadata)
if delta.RoomAvatarChanged {
r.ResolvedAvatarURL = internal.CalculateAvatar(&r.RoomMetadata)
}
// Interpret the timestamp map on r as the changes we should apply atop the
// existing timestamps.
@ -99,6 +112,7 @@ func (s *InternalRequestLists) SetRoom(r RoomConnMetadata) (delta RoomDelta) {
r.CanonicalisedName = strings.ToLower(
strings.Trim(internal.CalculateRoomName(&r.RoomMetadata, 5), "#!():_@"),
)
r.ResolvedAvatarURL = internal.CalculateAvatar(&r.RoomMetadata)
// We'll automatically use the LastInterestedEventTimestamps provided by the
// caller, so that recency sorts work.
}

View File

@ -21,9 +21,8 @@ type Response struct {
Rooms map[string]Room `json:"rooms"`
Extensions extensions.Response `json:"extensions"`
Pos string `json:"pos"`
TxnID string `json:"txn_id,omitempty"`
Session string `json:"session_id,omitempty"`
Pos string `json:"pos"`
TxnID string `json:"txn_id,omitempty"`
}
type ResponseList struct {
@ -68,9 +67,8 @@ func (r *Response) UnmarshalJSON(b []byte) error {
} `json:"lists"`
Extensions extensions.Response `json:"extensions"`
Pos string `json:"pos"`
TxnID string `json:"txn_id,omitempty"`
Session string `json:"session_id,omitempty"`
Pos string `json:"pos"`
TxnID string `json:"txn_id,omitempty"`
}{}
if err := json.Unmarshal(b, &temporary); err != nil {
return err
@ -78,7 +76,6 @@ func (r *Response) UnmarshalJSON(b []byte) error {
r.Rooms = temporary.Rooms
r.Pos = temporary.Pos
r.TxnID = temporary.TxnID
r.Session = temporary.Session
r.Extensions = temporary.Extensions
r.Lists = make(map[string]ResponseList, len(temporary.Lists))

View File

@ -2,6 +2,7 @@ package sync3
import (
"encoding/json"
"github.com/matrix-org/sliding-sync/internal"
"github.com/matrix-org/sliding-sync/sync3/caches"
@ -9,6 +10,7 @@ import (
type Room struct {
Name string `json:"name,omitempty"`
AvatarChange AvatarChange `json:"avatar,omitempty"`
RequiredState []json.RawMessage `json:"required_state,omitempty"`
Timeline []json.RawMessage `json:"timeline,omitempty"`
InviteState []json.RawMessage `json:"invite_state,omitempty"`
@ -17,9 +19,10 @@ type Room struct {
Initial bool `json:"initial,omitempty"`
IsDM bool `json:"is_dm,omitempty"`
JoinedCount int `json:"joined_count,omitempty"`
InvitedCount int `json:"invited_count,omitempty"`
InvitedCount *int `json:"invited_count,omitempty"`
PrevBatch string `json:"prev_batch,omitempty"`
NumLive int `json:"num_live,omitempty"`
Timestamp uint64 `json:"timestamp,omitempty"`
}
// RoomConnMetadata represents a room as seen by one specific connection (hence one

74
sync3/room_test.go Normal file
View File

@ -0,0 +1,74 @@
package sync3
import (
"encoding/json"
"fmt"
"github.com/tidwall/gjson"
"reflect"
"testing"
)
func TestAvatarChangeMarshalling(t *testing.T) {
var url = "mxc://..."
testCases := []struct {
Name string
AvatarChange AvatarChange
Check func(avatar gjson.Result) error
}{
{
Name: "Avatar exists",
AvatarChange: NewAvatarChange(url),
Check: func(avatar gjson.Result) error {
if !(avatar.Exists() && avatar.Type == gjson.String && avatar.Str == url) {
return fmt.Errorf("unexpected marshalled avatar: got %#v want %s", avatar, url)
}
return nil
},
},
{
Name: "Avatar doesn't exist",
AvatarChange: DeletedAvatar,
Check: func(avatar gjson.Result) error {
if !(avatar.Exists() && avatar.Type == gjson.Null) {
return fmt.Errorf("unexpected marshalled Avatar: got %#v want null", avatar)
}
return nil
},
},
{
Name: "Avatar unchanged",
AvatarChange: UnchangedAvatar,
Check: func(avatar gjson.Result) error {
if avatar.Exists() {
return fmt.Errorf("unexpected marshalled Avatar: got %#v want omitted", avatar)
}
return nil
},
},
}
for _, tc := range testCases {
t.Run(tc.Name, func(t *testing.T) {
room := Room{AvatarChange: tc.AvatarChange}
marshalled, err := json.Marshal(room)
t.Logf("Marshalled to %s", string(marshalled))
if err != nil {
t.Fatal(err)
}
avatar := gjson.GetBytes(marshalled, "avatar")
if err = tc.Check(avatar); err != nil {
t.Fatal(err)
}
var unmarshalled Room
err = json.Unmarshal(marshalled, &unmarshalled)
if err != nil {
t.Fatal(err)
}
t.Logf("Unmarshalled to %#v", unmarshalled.AvatarChange)
if !reflect.DeepEqual(unmarshalled, room) {
t.Fatalf("Unmarshalled struct is different from original")
}
})
}
}

View File

@ -2,11 +2,13 @@ package syncv3_test
import (
"encoding/json"
"testing"
"time"
"github.com/matrix-org/sliding-sync/sync3"
"github.com/matrix-org/sliding-sync/sync3/extensions"
"github.com/matrix-org/sliding-sync/testutils"
"github.com/matrix-org/sliding-sync/testutils/m"
"testing"
)
func TestAccountDataRespectsExtensionScope(t *testing.T) {
@ -99,14 +101,14 @@ func TestAccountDataRespectsExtensionScope(t *testing.T) {
alice,
room1,
"com.example.room",
map[string]interface{}{"room": 1, "version": 1},
map[string]interface{}{"room": 1, "version": 11},
)
room2AccountDataEvent := putRoomAccountData(
t,
alice,
room2,
"com.example.room",
map[string]interface{}{"room": 2, "version": 2},
map[string]interface{}{"room": 2, "version": 22},
)
t.Log("Alice syncs until she sees the account data for room 2. She shouldn't see account data for room 1")
@ -122,7 +124,90 @@ func TestAccountDataRespectsExtensionScope(t *testing.T) {
return m.MatchAccountData(nil, map[string][]json.RawMessage{room2: {room2AccountDataEvent}})(response)
},
)
}
// Regression test for https://github.com/matrix-org/sliding-sync/issues/189
func TestAccountDataDoesntDupe(t *testing.T) {
alice := registerNewUser(t)
alice2 := *alice
alice2.Login(t, "password", "device2")
// send some initial account data
putGlobalAccountData(t, alice, "initial", map[string]interface{}{"foo": "bar"})
// no devices are polling.
// syncing with both devices => only shows 1 copy of this event per connection
for _, client := range []*CSAPI{alice, &alice2} {
res := client.SlidingSync(t, sync3.Request{
Extensions: extensions.Request{
AccountData: &extensions.AccountDataRequest{
Core: extensions.Core{
Enabled: &boolTrue,
},
},
},
})
m.MatchResponse(t, res, MatchGlobalAccountData([]Event{
{
Type: "m.push_rules",
},
{
Type: "initial",
Content: map[string]interface{}{"foo": "bar"},
},
}))
}
// now both devices are polling, we're going to do the same thing to make sure we only see only 1 copy still.
putGlobalAccountData(t, alice, "initial2", map[string]interface{}{"foo2": "bar2"})
time.Sleep(time.Second) // TODO: we need to make sure the pollers have seen this and explciitly don't want to use SlidingSyncUntil...
var responses []*sync3.Response
for _, client := range []*CSAPI{alice, &alice2} {
res := client.SlidingSync(t, sync3.Request{
Extensions: extensions.Request{
AccountData: &extensions.AccountDataRequest{
Core: extensions.Core{
Enabled: &boolTrue,
},
},
},
})
m.MatchResponse(t, res, MatchGlobalAccountData([]Event{
{
Type: "m.push_rules",
},
{
Type: "initial",
Content: map[string]interface{}{"foo": "bar"},
},
{
Type: "initial2",
Content: map[string]interface{}{"foo2": "bar2"},
},
}))
responses = append(responses, res) // we need the pos values later
}
// now we're going to do an incremental sync with account data to make sure we don't see dupes either.
putGlobalAccountData(t, alice, "incremental", map[string]interface{}{"foo3": "bar3"})
time.Sleep(time.Second) // TODO: we need to make sure the pollers have seen this and explciitly don't want to use SlidingSyncUntil...
for i, client := range []*CSAPI{alice, &alice2} {
res := client.SlidingSync(t, sync3.Request{
Extensions: extensions.Request{
AccountData: &extensions.AccountDataRequest{
Core: extensions.Core{
Enabled: &boolTrue,
},
},
},
}, WithPos(responses[i].Pos))
m.MatchResponse(t, res, MatchGlobalAccountData([]Event{
{
Type: "incremental",
Content: map[string]interface{}{"foo3": "bar3"},
},
}))
}
}
// putAccountData is a wrapper around SetGlobalAccountData. It returns the account data

View File

@ -134,6 +134,7 @@ type CSAPI struct {
Localpart string
AccessToken string
DeviceID string
AvatarURL string
BaseURL string
Client *http.Client
// how long are we willing to wait for MustSyncUntil.... calls
@ -159,6 +160,16 @@ func (c *CSAPI) UploadContent(t *testing.T, fileBody []byte, fileName string, co
return GetJSONFieldStr(t, body, "content_uri")
}
// Use an empty string to remove your avatar.
func (c *CSAPI) SetAvatar(t *testing.T, avatarURL string) {
t.Helper()
reqBody := map[string]interface{}{
"avatar_url": avatarURL,
}
c.MustDoFunc(t, "PUT", []string{"_matrix", "client", "v3", "profile", c.UserID, "avatar_url"}, WithJSONBody(t, reqBody))
c.AvatarURL = avatarURL
}
// DownloadContent downloads media from the server, returning the raw bytes and the Content-Type. Fails the test on error.
func (c *CSAPI) DownloadContent(t *testing.T, mxcUri string) ([]byte, string) {
t.Helper()
@ -678,16 +689,32 @@ func (c *CSAPI) SlidingSyncUntilMembership(t *testing.T, pos string, roomID stri
})
}
return c.SlidingSyncUntilEvent(t, pos, sync3.Request{
return c.SlidingSyncUntil(t, pos, sync3.Request{
RoomSubscriptions: map[string]sync3.RoomSubscription{
roomID: {
TimelineLimit: 10,
},
},
}, roomID, Event{
Type: "m.room.member",
StateKey: &target.UserID,
Content: content,
}, func(r *sync3.Response) error {
room, ok := r.Rooms[roomID]
if !ok {
return fmt.Errorf("missing room %s", roomID)
}
for _, got := range room.Timeline {
wantEvent := Event{
Type: "m.room.member",
StateKey: &target.UserID,
}
if err := eventsEqual([]Event{wantEvent}, []json.RawMessage{got}); err == nil {
gotMembership := gjson.GetBytes(got, "content.membership")
if gotMembership.Exists() && gotMembership.Type == gjson.String && gotMembership.Str == membership {
return nil
}
} else {
t.Log(err)
}
}
return fmt.Errorf("found room %s but missing event", roomID)
})
}

View File

@ -1,13 +1,86 @@
package syncv3_test
import (
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"testing"
"time"
"github.com/matrix-org/sliding-sync/sync3"
"github.com/matrix-org/sliding-sync/testutils/m"
)
func TestInvalidTokenReturnsMUnknownTokenError(t *testing.T) {
alice := registerNewUser(t)
roomID := alice.CreateRoom(t, map[string]interface{}{})
// normal sliding sync
alice.SlidingSync(t, sync3.Request{
ConnID: "A",
RoomSubscriptions: map[string]sync3.RoomSubscription{
roomID: {
TimelineLimit: 1,
},
},
})
// invalidate the access token
alice.MustDoFunc(t, "POST", []string{"_matrix", "client", "v3", "logout"})
// let the proxy realise the token is expired and tell downstream
time.Sleep(time.Second)
var invalidResponses []*http.Response
// using the same token now returns a 401 with M_UNKNOWN_TOKEN
httpRes := alice.DoFunc(t, "POST", []string{"_matrix", "client", "unstable", "org.matrix.msc3575", "sync"}, WithQueries(url.Values{
"timeout": []string{"500"},
}), WithJSONBody(t, sync3.Request{
ConnID: "A",
RoomSubscriptions: map[string]sync3.RoomSubscription{
roomID: {
TimelineLimit: 1,
},
},
}))
invalidResponses = append(invalidResponses, httpRes)
// using a bogus access token returns a 401 with M_UNKNOWN_TOKEN
alice.AccessToken = "flibble_wibble"
httpRes = alice.DoFunc(t, "POST", []string{"_matrix", "client", "unstable", "org.matrix.msc3575", "sync"}, WithQueries(url.Values{
"timeout": []string{"500"},
}), WithJSONBody(t, sync3.Request{
ConnID: "A",
RoomSubscriptions: map[string]sync3.RoomSubscription{
roomID: {
TimelineLimit: 1,
},
},
}))
invalidResponses = append(invalidResponses, httpRes)
for i, httpRes := range invalidResponses {
body, err := io.ReadAll(httpRes.Body)
if err != nil {
t.Fatalf("[%d] failed to read response body: %v", i, err)
}
if httpRes.StatusCode != 401 {
t.Errorf("[%d] got HTTP %v want 401: %v", i, httpRes.StatusCode, string(body))
}
var jsonError struct {
Err string `json:"error"`
ErrCode string `json:"errcode"`
}
if err := json.Unmarshal(body, &jsonError); err != nil {
t.Fatalf("[%d] failed to unmarshal error response into JSON: %v", i, string(body))
}
wantErrCode := "M_UNKNOWN_TOKEN"
if jsonError.ErrCode != wantErrCode {
t.Errorf("[%d] errcode: got %v want %v", i, jsonError.ErrCode, wantErrCode)
}
}
}
// Test that you can have multiple connections with the same device, and they work independently.
func TestMultipleConns(t *testing.T) {
alice := registerNewUser(t)

View File

@ -0,0 +1,92 @@
package syncv3_test
import (
"fmt"
"github.com/matrix-org/sliding-sync/sync3"
"github.com/matrix-org/sliding-sync/testutils/m"
"testing"
)
func TestInvitesFromIgnoredUsersOmitted(t *testing.T) {
alice := registerNamedUser(t, "alice")
bob := registerNamedUser(t, "bob")
nigel := registerNamedUser(t, "nigel")
t.Log("Nigel create two public rooms. Bob joins both.")
room1 := nigel.CreateRoom(t, map[string]any{"preset": "public_chat", "name": "room 1"})
room2 := nigel.CreateRoom(t, map[string]any{"preset": "public_chat", "name": "room 2"})
bob.JoinRoom(t, room1, nil)
bob.JoinRoom(t, room2, nil)
t.Log("Alice makes a room for dumping sentinel messages.")
aliceRoom := alice.CreateRoom(t, map[string]any{"preset": "private_chat"})
t.Log("Alice ignores Nigel.")
alice.SetGlobalAccountData(t, "m.ignored_user_list", map[string]any{
"ignored_users": map[string]any{
nigel.UserID: map[string]any{},
},
})
t.Log("Nigel invites Alice to room 1.")
nigel.InviteRoom(t, room1, alice.UserID)
t.Log("Bob sliding syncs until he sees that invite.")
bob.SlidingSyncUntilMembership(t, "", room1, alice, "invite")
t.Log("Alice sends a sentinel message in her private room.")
sentinel := alice.SendEventSynced(t, aliceRoom, Event{
Type: "m.room.message",
Content: map[string]any{
"body": "Hello, world!",
"msgtype": "m.text",
},
})
t.Log("Alice does an initial sliding sync.")
res := alice.SlidingSync(t, sync3.Request{
Lists: map[string]sync3.RequestList{
"a": {
RoomSubscription: sync3.RoomSubscription{
TimelineLimit: 20,
},
Ranges: sync3.SliceRanges{{0, 20}},
},
},
})
t.Log("Alice should see her sentinel, but not Nigel's invite.")
m.MatchResponse(t, res, m.MatchRoomSubscriptionsStrict(
map[string][]m.RoomMatcher{
aliceRoom: {MatchRoomTimelineMostRecent(1, []Event{{ID: sentinel}})},
},
))
t.Log("Nigel invites Alice to room 2.")
nigel.InviteRoom(t, room2, alice.UserID)
t.Log("Bob sliding syncs until he sees that invite.")
bob.SlidingSyncUntilMembership(t, "", room1, alice, "invite")
t.Log("Alice sends a sentinel message in her private room.")
sentinel = alice.SendEventSynced(t, aliceRoom, Event{
Type: "m.room.message",
Content: map[string]any{
"body": "Hello, world, again",
"msgtype": "m.text",
},
})
t.Log("Alice does an incremental sliding sync.")
res = alice.SlidingSyncUntil(t, res.Pos, sync3.Request{}, func(response *sync3.Response) error {
if m.MatchRoomSubscription(room2)(response) == nil {
err := fmt.Errorf("unexpectedly got subscription for room 2 (%s)", room2)
t.Error(err)
return err
}
gotSentinel := m.MatchRoomSubscription(aliceRoom, MatchRoomTimelineMostRecent(1, []Event{{ID: sentinel}}))
return gotSentinel(response)
})
}

View File

@ -1297,3 +1297,391 @@ func TestRangeOutsideTotalRooms(t *testing.T) {
),
)
}
// Nicked from Synapse's tests, see
// https://github.com/matrix-org/synapse/blob/2cacd0849a02d43f88b6c15ee862398159ab827c/tests/test_utils/__init__.py#L154-L161
// Resolution: 1×1, MIME type: image/png, Extension: png, Size: 67 B
var smallPNG = []byte(
"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01\x08\x06\x00\x00\x00\x1f\x15\xc4\x89\x00\x00\x00\nIDATx\x9cc\x00\x01\x00\x00\x05\x00\x01\r\n-\xb4\x00\x00\x00\x00IEND\xaeB`\x82",
)
func TestAvatarFieldInRoomResponse(t *testing.T) {
alice := registerNamedUser(t, "alice")
bob := registerNamedUser(t, "bob")
chris := registerNamedUser(t, "chris")
avatarURLs := map[string]struct{}{}
uploadAvatar := func(client *CSAPI, filename string) string {
avatar := alice.UploadContent(t, smallPNG, filename, "image/png")
if _, exists := avatarURLs[avatar]; exists {
t.Fatalf("New avatar %s has already been uploaded", avatar)
}
t.Logf("%s is uploaded as %s", filename, avatar)
avatarURLs[avatar] = struct{}{}
return avatar
}
t.Log("Alice, Bob and Chris upload and set an avatar.")
aliceAvatar := uploadAvatar(alice, "alice.png")
bobAvatar := uploadAvatar(bob, "bob.png")
chrisAvatar := uploadAvatar(chris, "chris.png")
alice.SetAvatar(t, aliceAvatar)
bob.SetAvatar(t, bobAvatar)
chris.SetAvatar(t, chrisAvatar)
t.Log("Alice makes a public room, a DM with herself, a DM with Bob, a DM with Chris, and a group-DM with Bob and Chris.")
public := alice.CreateRoom(t, map[string]interface{}{"preset": "public_chat"})
// TODO: you can create a DM with yourself e.g. as below. It probably ought to have
// your own face as an avatar.
// dmAlice := alice.CreateRoom(t, map[string]interface{}{
// "preset": "trusted_private_chat",
// "is_direct": true,
// })
dmBob := alice.CreateRoom(t, map[string]interface{}{
"preset": "trusted_private_chat",
"is_direct": true,
"invite": []string{bob.UserID},
})
dmChris := alice.CreateRoom(t, map[string]interface{}{
"preset": "trusted_private_chat",
"is_direct": true,
"invite": []string{chris.UserID},
})
dmBobChris := alice.CreateRoom(t, map[string]interface{}{
"preset": "trusted_private_chat",
"is_direct": true,
"invite": []string{bob.UserID, chris.UserID},
})
t.Logf("Rooms:\npublic=%s\ndmBob=%s\ndmChris=%s\ndmBobChris=%s", public, dmBob, dmChris, dmBobChris)
t.Log("Bob accepts his invites. Chris accepts none.")
bob.JoinRoom(t, dmBob, nil)
bob.JoinRoom(t, dmBobChris, nil)
t.Log("Alice makes an initial sliding sync.")
res := alice.SlidingSync(t, sync3.Request{
Lists: map[string]sync3.RequestList{
"rooms": {
Ranges: sync3.SliceRanges{{0, 4}},
},
},
})
t.Log("Alice should see each room in the sync response with an appropriate avatar")
m.MatchResponse(
t,
res,
m.MatchRoomSubscription(public, m.MatchRoomUnsetAvatar()),
m.MatchRoomSubscription(dmBob, m.MatchRoomAvatar(bob.AvatarURL)),
m.MatchRoomSubscription(dmChris, m.MatchRoomAvatar(chris.AvatarURL)),
m.MatchRoomSubscription(dmBobChris, m.MatchRoomUnsetAvatar()),
)
t.Run("Avatar not resent on message", func(t *testing.T) {
t.Log("Bob sends a sentinel message.")
sentinel := bob.SendEventSynced(t, dmBob, Event{
Type: "m.room.message",
Content: map[string]interface{}{
"body": "Hello world",
"msgtype": "m.text",
},
})
t.Log("Alice syncs until she sees the sentinel. She should not see the DM avatar change.")
res = alice.SlidingSyncUntil(t, res.Pos, sync3.Request{}, func(response *sync3.Response) error {
matchNoAvatarChange := m.MatchRoomSubscription(dmBob, m.MatchRoomUnchangedAvatar())
if err := matchNoAvatarChange(response); err != nil {
t.Fatalf("Saw DM avatar change: %s", err)
}
matchSentinel := m.MatchRoomSubscription(dmBob, MatchRoomTimelineMostRecent(1, []Event{{ID: sentinel}}))
return matchSentinel(response)
})
})
t.Run("DM declined", func(t *testing.T) {
t.Log("Chris leaves his DM with Alice.")
chris.LeaveRoom(t, dmChris)
t.Log("Alice syncs until she sees Chris's leave.")
res = alice.SlidingSyncUntilMembership(t, res.Pos, dmChris, chris, "leave")
t.Log("Alice sees Chris's avatar vanish.")
m.MatchResponse(t, res, m.MatchRoomSubscription(dmChris, m.MatchRoomUnsetAvatar()))
})
t.Run("Group DM declined", func(t *testing.T) {
t.Log("Chris leaves his group DM with Alice and Bob.")
chris.LeaveRoom(t, dmBobChris)
t.Log("Alice syncs until she sees Chris's leave.")
res = alice.SlidingSyncUntilMembership(t, res.Pos, dmBobChris, chris, "leave")
t.Log("Alice sees the room's avatar change to Bob's avatar.")
// Because this is now a DM room with exactly one other (joined|invited) member.
m.MatchResponse(t, res, m.MatchRoomSubscription(dmBobChris, m.MatchRoomAvatar(bob.AvatarURL)))
})
t.Run("Bob's avatar change propagates", func(t *testing.T) {
t.Log("Bob changes his avatar.")
bobAvatar2 := uploadAvatar(bob, "bob2.png")
bob.SetAvatar(t, bobAvatar2)
avatarChangeInDM := false
avatarChangeInGroupDM := false
t.Log("Alice syncs until she sees Bob's new avatar.")
res = alice.SlidingSyncUntil(
t,
res.Pos,
sync3.Request{},
func(response *sync3.Response) error {
if !avatarChangeInDM {
err := m.MatchRoomSubscription(dmBob, m.MatchRoomAvatar(bob.AvatarURL))(response)
if err == nil {
avatarChangeInDM = true
}
}
if !avatarChangeInGroupDM {
err := m.MatchRoomSubscription(dmBobChris, m.MatchRoomAvatar(bob.AvatarURL))(response)
if err == nil {
avatarChangeInGroupDM = true
}
}
if avatarChangeInDM && avatarChangeInGroupDM {
return nil
}
return fmt.Errorf("still waiting: avatarChangeInDM=%t avatarChangeInGroupDM=%t", avatarChangeInDM, avatarChangeInGroupDM)
},
)
t.Log("Bob removes his avatar.")
bob.SetAvatar(t, "")
avatarChangeInDM = false
avatarChangeInGroupDM = false
t.Log("Alice syncs until she sees Bob's avatars vanish.")
res = alice.SlidingSyncUntil(
t,
res.Pos,
sync3.Request{},
func(response *sync3.Response) error {
if !avatarChangeInDM {
err := m.MatchRoomSubscription(dmBob, m.MatchRoomUnsetAvatar())(response)
if err == nil {
avatarChangeInDM = true
} else {
t.Log(err)
}
}
if !avatarChangeInGroupDM {
err := m.MatchRoomSubscription(dmBobChris, m.MatchRoomUnsetAvatar())(response)
if err == nil {
avatarChangeInGroupDM = true
} else {
t.Log(err)
}
}
if avatarChangeInDM && avatarChangeInGroupDM {
return nil
}
return fmt.Errorf("still waiting: avatarChangeInDM=%t avatarChangeInGroupDM=%t", avatarChangeInDM, avatarChangeInGroupDM)
},
)
})
t.Run("Explicit avatar propagates in non-DM room", func(t *testing.T) {
t.Log("Alice sets an avatar for the public room.")
publicAvatar := uploadAvatar(alice, "public.png")
alice.SetState(t, public, "m.room.avatar", "", map[string]interface{}{
"url": publicAvatar,
})
t.Log("Alice syncs until she sees that avatar.")
res = alice.SlidingSyncUntil(
t,
res.Pos,
sync3.Request{},
m.MatchRoomSubscriptions(map[string][]m.RoomMatcher{
public: {m.MatchRoomAvatar(publicAvatar)},
}),
)
t.Log("Alice changes the avatar for the public room.")
publicAvatar2 := uploadAvatar(alice, "public2.png")
alice.SetState(t, public, "m.room.avatar", "", map[string]interface{}{
"url": publicAvatar2,
})
t.Log("Alice syncs until she sees that avatar.")
res = alice.SlidingSyncUntil(
t,
res.Pos,
sync3.Request{},
m.MatchRoomSubscriptions(map[string][]m.RoomMatcher{
public: {m.MatchRoomAvatar(publicAvatar2)},
}),
)
t.Log("Alice removes the avatar for the public room.")
alice.SetState(t, public, "m.room.avatar", "", map[string]interface{}{})
t.Log("Alice syncs until she sees that avatar vanish.")
res = alice.SlidingSyncUntil(
t,
res.Pos,
sync3.Request{},
m.MatchRoomSubscriptions(map[string][]m.RoomMatcher{
public: {m.MatchRoomUnsetAvatar()},
}),
)
})
t.Run("Explicit avatar propagates in DM room", func(t *testing.T) {
t.Log("Alice re-invites Chris to their DM.")
alice.InviteRoom(t, dmChris, chris.UserID)
t.Log("Alice syncs until she sees her invitation to Chris.")
res = alice.SlidingSyncUntilMembership(t, res.Pos, dmChris, chris, "invite")
t.Log("Alice should see the DM with Chris's avatar.")
m.MatchResponse(t, res, m.MatchRoomSubscription(dmChris, m.MatchRoomAvatar(chris.AvatarURL)))
t.Log("Chris joins the room.")
chris.JoinRoom(t, dmChris, nil)
t.Log("Alice syncs until she sees Chris's join.")
res = alice.SlidingSyncUntilMembership(t, res.Pos, dmChris, chris, "join")
t.Log("Alice shouldn't see the DM's avatar change..")
m.MatchResponse(t, res, m.MatchRoomSubscription(dmChris, m.MatchRoomUnchangedAvatar()))
t.Log("Chris gives their DM a bespoke avatar.")
dmAvatar := uploadAvatar(chris, "dm.png")
chris.SetState(t, dmChris, "m.room.avatar", "", map[string]interface{}{
"url": dmAvatar,
})
t.Log("Alice syncs until she sees that avatar.")
alice.SlidingSyncUntil(t, res.Pos, sync3.Request{}, m.MatchRoomSubscription(dmChris, m.MatchRoomAvatar(dmAvatar)))
t.Log("Chris changes his global avatar, which adds a join event to the room.")
chrisAvatar2 := uploadAvatar(chris, "chris2.png")
chris.SetAvatar(t, chrisAvatar2)
t.Log("Alice syncs until she sees that join event.")
res = alice.SlidingSyncUntilMembership(t, res.Pos, dmChris, chris, "join")
t.Log("Her response should have either no avatar change, or the same bespoke avatar.")
// No change, ideally, but repeating the same avatar isn't _wrong_
m.MatchResponse(t, res, m.MatchRoomSubscription(dmChris, func(r sync3.Room) error {
noChangeErr := m.MatchRoomUnchangedAvatar()(r)
sameBespokeAvatarErr := m.MatchRoomAvatar(dmAvatar)(r)
if noChangeErr == nil || sameBespokeAvatarErr == nil {
return nil
}
return fmt.Errorf("expected no change or the same bespoke avatar (%s), got '%s'", dmAvatar, r.AvatarChange)
}))
t.Log("Chris updates the DM's avatar.")
dmAvatar2 := uploadAvatar(chris, "dm2.png")
chris.SetState(t, dmChris, "m.room.avatar", "", map[string]interface{}{
"url": dmAvatar2,
})
t.Log("Alice syncs until she sees that avatar.")
res = alice.SlidingSyncUntil(t, res.Pos, sync3.Request{}, m.MatchRoomSubscription(dmChris, m.MatchRoomAvatar(dmAvatar2)))
t.Log("Chris removes the DM's avatar.")
chris.SetState(t, dmChris, "m.room.avatar", "", map[string]interface{}{})
t.Log("Alice syncs until the DM avatar returns to Chris's most recent avatar.")
res = alice.SlidingSyncUntil(t, res.Pos, sync3.Request{}, m.MatchRoomSubscription(dmChris, m.MatchRoomAvatar(chris.AvatarURL)))
})
t.Run("Changing DM flag", func(t *testing.T) {
t.Skip("TODO: unimplemented")
t.Log("Alice clears the DM flag on Bob's room.")
alice.SetGlobalAccountData(t, "m.direct", map[string]interface{}{
"content": map[string][]string{
bob.UserID: {}, // no dmBob here
chris.UserID: {dmChris, dmBobChris},
},
})
t.Log("Alice syncs until she sees a new set of account data.")
res = alice.SlidingSyncUntil(t, res.Pos, sync3.Request{
Extensions: extensions.Request{
AccountData: &extensions.AccountDataRequest{
extensions.Core{Enabled: &boolTrue},
},
},
}, func(response *sync3.Response) error {
if response.Extensions.AccountData == nil {
return fmt.Errorf("no account data yet")
}
if len(response.Extensions.AccountData.Global) == 0 {
return fmt.Errorf("no global account data yet")
}
return nil
})
t.Log("The DM with Bob should no longer be a DM and should no longer have an avatar.")
m.MatchResponse(t, res, m.MatchRoomSubscription(dmBob, func(r sync3.Room) error {
if r.IsDM {
return fmt.Errorf("dmBob is still a DM")
}
return m.MatchRoomUnsetAvatar()(r)
}))
t.Log("Alice sets the DM flag on Bob's room.")
alice.SetGlobalAccountData(t, "m.direct", map[string]interface{}{
"content": map[string][]string{
bob.UserID: {dmBob}, // dmBob reinstated
chris.UserID: {dmChris, dmBobChris},
},
})
t.Log("Alice syncs until she sees a new set of account data.")
res = alice.SlidingSyncUntil(t, res.Pos, sync3.Request{
Extensions: extensions.Request{
AccountData: &extensions.AccountDataRequest{
extensions.Core{Enabled: &boolTrue},
},
},
}, func(response *sync3.Response) error {
if response.Extensions.AccountData == nil {
return fmt.Errorf("no account data yet")
}
if len(response.Extensions.AccountData.Global) == 0 {
return fmt.Errorf("no global account data yet")
}
return nil
})
t.Log("The room should have Bob's avatar again.")
m.MatchResponse(t, res, m.MatchRoomSubscription(dmBob, func(r sync3.Room) error {
if !r.IsDM {
return fmt.Errorf("dmBob is still not a DM")
}
return m.MatchRoomAvatar(bob.AvatarURL)(r)
}))
})
t.Run("See avatar when invited", func(t *testing.T) {
t.Log("Chris invites Alice to a DM.")
dmInvited := chris.CreateRoom(t, map[string]interface{}{
"preset": "trusted_private_chat",
"is_direct": true,
"invite": []string{alice.UserID},
})
t.Log("Alice syncs until she sees the invite.")
res = alice.SlidingSyncUntilMembership(t, res.Pos, dmInvited, alice, "invite")
t.Log("The new room should use Chris's avatar.")
m.MatchResponse(t, res, m.MatchRoomSubscription(dmInvited, m.MatchRoomAvatar(chris.AvatarURL)))
})
}

View File

@ -6,6 +6,7 @@ import (
"net/url"
"os"
"reflect"
"sort"
"strings"
"sync/atomic"
"testing"
@ -13,6 +14,7 @@ import (
"github.com/matrix-org/sliding-sync/sync3"
"github.com/matrix-org/sliding-sync/testutils/m"
"github.com/tidwall/gjson"
)
var (
@ -159,13 +161,41 @@ func MatchRoomInviteState(events []Event, partial bool) m.RoomMatcher {
}
}
if !found {
return fmt.Errorf("MatchRoomInviteState: want event %+v but it does not exist", want)
return fmt.Errorf("MatchRoomInviteState: want event %+v but it does not exist or failed to pass equality checks", want)
}
}
return nil
}
}
// MatchGlobalAccountData builds a matcher which asserts that the account data in a sync
// response matches the given `globals`, with any ordering.
// If there is no account data extension in the response, the match fails.
func MatchGlobalAccountData(globals []Event) m.RespMatcher {
// sort want list by type
sort.Slice(globals, func(i, j int) bool {
return globals[i].Type < globals[j].Type
})
return func(res *sync3.Response) error {
if res.Extensions.AccountData == nil {
return fmt.Errorf("MatchGlobalAccountData: no account_data extension")
}
if len(globals) != len(res.Extensions.AccountData.Global) {
return fmt.Errorf("MatchGlobalAccountData: got %v global account data, want %v", len(res.Extensions.AccountData.Global), len(globals))
}
// sort the got list by type
got := res.Extensions.AccountData.Global
sort.Slice(got, func(i, j int) bool {
return gjson.GetBytes(got[i], "type").Str < gjson.GetBytes(got[j], "type").Str
})
if err := eventsEqual(globals, got); err != nil {
return fmt.Errorf("MatchGlobalAccountData: %s", err)
}
return nil
}
}
func registerNewUser(t *testing.T) *CSAPI {
return registerNamedUser(t, "user")
}

View File

@ -64,7 +64,22 @@ func TestRoomStateTransitions(t *testing.T) {
m.MatchRoomHighlightCount(1),
m.MatchRoomInitial(true),
m.MatchRoomRequiredState(nil),
// TODO m.MatchRoomInviteState(inviteStrippedState.InviteState.Events),
m.MatchInviteCount(1),
m.MatchJoinCount(1),
MatchRoomInviteState([]Event{
{
Type: "m.room.create",
StateKey: ptr(""),
// no content as it includes the room version which we don't want to guess/hardcode
},
{
Type: "m.room.join_rules",
StateKey: ptr(""),
Content: map[string]interface{}{
"join_rule": "public",
},
},
}, true),
},
joinRoomID: {},
}),
@ -105,6 +120,8 @@ func TestRoomStateTransitions(t *testing.T) {
},
}),
m.MatchRoomInitial(true),
m.MatchJoinCount(2),
m.MatchInviteCount(0),
m.MatchRoomHighlightCount(0),
))
}
@ -216,17 +233,18 @@ func TestInviteRejection(t *testing.T) {
}
func TestInviteAcceptance(t *testing.T) {
alice := registerNewUser(t)
bob := registerNewUser(t)
alice := registerNamedUser(t, "alice")
bob := registerNamedUser(t, "bob")
// ensure that invite state correctly propagates. One room will already be in 'invite' state
// prior to the first proxy sync, whereas the 2nd will transition.
t.Logf("Alice creates two rooms and invites Bob to the first.")
firstInviteRoomID := alice.CreateRoom(t, map[string]interface{}{"preset": "private_chat", "name": "First"})
alice.InviteRoom(t, firstInviteRoomID, bob.UserID)
secondInviteRoomID := alice.CreateRoom(t, map[string]interface{}{"preset": "private_chat", "name": "Second"})
t.Logf("first %s second %s", firstInviteRoomID, secondInviteRoomID)
// sync as bob, we should see 1 invite
t.Log("Sync as Bob, requesting invites only. He should see 1 invite")
res := bob.SlidingSync(t, sync3.Request{
Lists: map[string]sync3.RequestList{
"a": {
@ -256,10 +274,12 @@ func TestInviteAcceptance(t *testing.T) {
},
}))
// now invite bob
t.Log("Alice invites bob to room 2.")
alice.InviteRoom(t, secondInviteRoomID, bob.UserID)
t.Log("Alice syncs until she sees Bob's invite.")
alice.SlidingSyncUntilMembership(t, "", secondInviteRoomID, bob, "invite")
t.Log("Bob syncs. He should see the invite to room 2 as well.")
res = bob.SlidingSync(t, sync3.Request{
Lists: map[string]sync3.RequestList{
"a": {
@ -287,13 +307,16 @@ func TestInviteAcceptance(t *testing.T) {
},
}))
// now accept the invites
t.Log("Bob accept the invites.")
bob.JoinRoom(t, firstInviteRoomID, nil)
bob.JoinRoom(t, secondInviteRoomID, nil)
t.Log("Alice syncs until she sees Bob join room 1.")
alice.SlidingSyncUntilMembership(t, "", firstInviteRoomID, bob, "join")
t.Log("Alice syncs until she sees Bob join room 2.")
alice.SlidingSyncUntilMembership(t, "", secondInviteRoomID, bob, "join")
// the list should be purged
t.Log("Bob does an incremental sync")
res = bob.SlidingSync(t, sync3.Request{
Lists: map[string]sync3.RequestList{
"a": {
@ -301,12 +324,13 @@ func TestInviteAcceptance(t *testing.T) {
},
},
}, WithPos(res.Pos))
t.Log("Both of his invites should be purged.")
m.MatchResponse(t, res, m.MatchList("a", m.MatchV3Count(0), m.MatchV3Ops(
m.MatchV3DeleteOp(1),
m.MatchV3DeleteOp(0),
)))
// fresh sync -> no invites
t.Log("Bob makes a fresh sliding sync request.")
res = bob.SlidingSync(t, sync3.Request{
Lists: map[string]sync3.RequestList{
"a": {
@ -317,6 +341,7 @@ func TestInviteAcceptance(t *testing.T) {
},
},
})
t.Log("He should see no invites.")
m.MatchResponse(t, res, m.MatchNoV3Ops(), m.MatchRoomSubscriptionsStrict(nil), m.MatchList("a", m.MatchV3Count(0)))
}
@ -467,7 +492,7 @@ func TestMemberCounts(t *testing.T) {
m.MatchResponse(t, res, m.MatchRoomSubscriptionsStrict(map[string][]m.RoomMatcher{
secondRoomID: {
m.MatchRoomInitial(false),
m.MatchInviteCount(0),
m.MatchNoInviteCount(),
m.MatchJoinCount(0), // omitempty
},
}))
@ -486,7 +511,7 @@ func TestMemberCounts(t *testing.T) {
m.MatchResponse(t, res, m.MatchRoomSubscriptionsStrict(map[string][]m.RoomMatcher{
secondRoomID: {
m.MatchRoomInitial(false),
m.MatchInviteCount(0),
m.MatchNoInviteCount(),
m.MatchJoinCount(2),
},
}))

View File

@ -1,10 +1,13 @@
package syncv3_test
import (
"fmt"
"testing"
"time"
"github.com/matrix-org/sliding-sync/sync3"
"github.com/matrix-org/sliding-sync/testutils/m"
"github.com/tidwall/gjson"
)
func TestNumLive(t *testing.T) {
@ -126,3 +129,70 @@ func TestNumLive(t *testing.T) {
},
}))
}
// Test that if you constantly change req params, we still see live traffic. It does this by:
// - Creating 11 rooms.
// - Hitting /sync with a range [0,1] then [0,2] then [0,3]. Each time this causes a new room to be returned.
// - Interleaving each /sync request with genuine events sent into a room.
// - ensuring we see the genuine events by the time we finish.
func TestReqParamStarvation(t *testing.T) {
alice := registerNewUser(t)
bob := registerNewUser(t)
roomID := alice.CreateRoom(t, map[string]interface{}{
"preset": "public_chat",
})
numOtherRooms := 10
for i := 0; i < numOtherRooms; i++ {
bob.CreateRoom(t, map[string]interface{}{
"preset": "public_chat",
})
}
bob.JoinRoom(t, roomID, nil)
res := bob.SlidingSyncUntilMembership(t, "", roomID, bob, "join")
wantEventIDs := make(map[string]bool)
for i := 0; i < numOtherRooms; i++ {
res = bob.SlidingSync(t, sync3.Request{
Lists: map[string]sync3.RequestList{
"a": {
Ranges: sync3.SliceRanges{{0, int64(i)}}, // [0,0], [0,1], ... [0,9]
},
},
}, WithPos(res.Pos))
// mark off any event we see in wantEventIDs
for _, r := range res.Rooms {
for _, ev := range r.Timeline {
gotEventID := gjson.GetBytes(ev, "event_id").Str
wantEventIDs[gotEventID] = false
}
}
// send an event in the first few syncs to add to wantEventIDs
// We do this for the first few /syncs and don't dictate which response they should arrive
// in, as we do not know and cannot force the proxy to deliver the event in a particular response.
if i < 3 {
eventID := alice.SendEventSynced(t, roomID, Event{
Type: "m.room.message",
Content: map[string]interface{}{
"msgtype": "m.text",
"body": fmt.Sprintf("msg %d", i),
},
})
wantEventIDs[eventID] = true
}
// it's possible the proxy won't see this event before the next /sync
// and that is the reason why we don't send it, as opposed to starvation.
// To try to counter this, sleep a bit. This is why we sleep on every cycle and
// why we send the events early on.
time.Sleep(50 * time.Millisecond)
}
// at this point wantEventIDs should all have false values if we got the events
for evID, unseen := range wantEventIDs {
if unseen {
t.Errorf("failed to see event %v", evID)
}
}
}

View File

@ -100,15 +100,16 @@ func TestSecurityLiveStreamEventLeftLeak(t *testing.T) {
})
// check Alice sees both events
assertEventsEqual(t, []Event{
{
Type: "m.room.member",
StateKey: ptr(eve.UserID),
Content: map[string]interface{}{
"membership": "leave",
},
Sender: alice.UserID,
kickEvent := Event{
Type: "m.room.member",
StateKey: ptr(eve.UserID),
Content: map[string]interface{}{
"membership": "leave",
},
Sender: alice.UserID,
}
assertEventsEqual(t, []Event{
kickEvent,
{
Type: "m.room.name",
StateKey: ptr(""),
@ -120,7 +121,6 @@ func TestSecurityLiveStreamEventLeftLeak(t *testing.T) {
},
}, timeline)
kickEvent := timeline[0]
// Ensure Eve doesn't see this message in the timeline, name calc or required_state
eveRes = eve.SlidingSync(t, sync3.Request{
Lists: map[string]sync3.RequestList{
@ -140,7 +140,7 @@ func TestSecurityLiveStreamEventLeftLeak(t *testing.T) {
}, WithPos(eveRes.Pos))
// the room is deleted from eve's point of view and she sees up to and including her kick event
m.MatchResponse(t, eveRes, m.MatchList("a", m.MatchV3Count(0), m.MatchV3Ops(m.MatchV3DeleteOp(0))), m.MatchRoomSubscription(
roomID, m.MatchRoomName(""), m.MatchRoomRequiredState(nil), m.MatchRoomTimelineMostRecent(1, []json.RawMessage{kickEvent}),
roomID, m.MatchRoomName(""), m.MatchRoomRequiredState(nil), MatchRoomTimelineMostRecent(1, []Event{kickEvent}),
))
}

180
tests-e2e/timestamp_test.go Normal file
View File

@ -0,0 +1,180 @@
package syncv3_test
import (
"testing"
"time"
"github.com/matrix-org/sliding-sync/sync3"
"github.com/matrix-org/sliding-sync/testutils/m"
)
func TestTimestamp(t *testing.T) {
alice := registerNewUser(t)
bob := registerNewUser(t)
charlie := registerNewUser(t)
roomID := alice.CreateRoom(t, map[string]interface{}{
"preset": "public_chat",
})
var gotTs, expectedTs uint64
lists := map[string]sync3.RequestList{
"myFirstList": {
Ranges: [][2]int64{{0, 1}},
RoomSubscription: sync3.RoomSubscription{
TimelineLimit: 10,
},
BumpEventTypes: []string{"m.room.message"}, // only messages bump the timestamp
},
"mySecondList": {
Ranges: [][2]int64{{0, 1}},
RoomSubscription: sync3.RoomSubscription{
TimelineLimit: 10,
},
BumpEventTypes: []string{"m.reaction"}, // only reactions bump the timestamp
},
}
// Init sync to get the latest timestamp
resAlice := alice.SlidingSync(t, sync3.Request{
RoomSubscriptions: map[string]sync3.RoomSubscription{
roomID: {
TimelineLimit: 10,
},
},
})
m.MatchResponse(t, resAlice, m.MatchRoomSubscription(roomID, m.MatchRoomInitial(true)))
timestampBeforeBobJoined := resAlice.Rooms[roomID].Timestamp
bob.JoinRoom(t, roomID, nil)
resAlice = alice.SlidingSyncUntilMembership(t, resAlice.Pos, roomID, bob, "join")
resBob := bob.SlidingSync(t, sync3.Request{
Lists: lists,
})
// Bob should see a different timestamp than alice, as he just joined
gotTs = resBob.Rooms[roomID].Timestamp
expectedTs = resAlice.Rooms[roomID].Timestamp
if gotTs != expectedTs {
t.Fatalf("expected timestamp to be equal, but got: %v vs %v", gotTs, expectedTs)
}
// ... the timestamp should still differ from what Alice received before the join
if gotTs == timestampBeforeBobJoined {
t.Fatalf("expected timestamp to differ, but got: %v vs %v", gotTs, timestampBeforeBobJoined)
}
// Send an event which should NOT bump Bobs timestamp, because it is not listed it
// any BumpEventTypes
emptyStateKey := ""
eventID := alice.SendEventSynced(t, roomID, Event{
Type: "m.room.topic",
StateKey: &emptyStateKey,
Content: map[string]interface{}{
"topic": "random topic",
},
})
time.Sleep(time.Millisecond)
resBob = bob.SlidingSyncUntilEventID(t, resBob.Pos, roomID, eventID)
gotTs = resBob.Rooms[roomID].Timestamp
expectedTs = resAlice.Rooms[roomID].Timestamp
if gotTs != expectedTs {
t.Fatalf("expected timestamps to be the same, but they aren't: %v vs %v", gotTs, expectedTs)
}
expectedTs = gotTs
// Now send a message which bumps the timestamp in myFirstList
eventID = alice.SendEventSynced(t, roomID, Event{
Type: "m.room.message",
Content: map[string]interface{}{
"msgtype": "m.text",
"body": "Hello, world!",
},
})
time.Sleep(time.Millisecond)
resBob = bob.SlidingSyncUntilEventID(t, resBob.Pos, roomID, eventID)
gotTs = resBob.Rooms[roomID].Timestamp
if expectedTs == gotTs {
t.Fatalf("expected timestamps to be different, but they aren't: %v vs %v", gotTs, expectedTs)
}
expectedTs = gotTs
// Now send a message which bumps the timestamp in mySecondList
eventID = alice.SendEventSynced(t, roomID, Event{
Type: "m.reaction",
Content: map[string]interface{}{
"m.relates.to": map[string]interface{}{
"event_id": eventID,
"key": "✅",
"rel_type": "m.annotation",
},
},
})
time.Sleep(time.Millisecond)
resBob = bob.SlidingSyncUntilEventID(t, resBob.Pos, roomID, eventID)
bobTimestampReaction := resBob.Rooms[roomID].Timestamp
if bobTimestampReaction == expectedTs {
t.Fatalf("expected timestamps to be different, but they aren't: %v vs %v", expectedTs, bobTimestampReaction)
}
expectedTs = bobTimestampReaction
// Send another event which should NOT bump Bobs timestamp
eventID = alice.SendEventSynced(t, roomID, Event{
Type: "m.room.name",
StateKey: &emptyStateKey,
Content: map[string]interface{}{
"name": "random name",
},
})
time.Sleep(time.Millisecond)
resBob = bob.SlidingSyncUntilEventID(t, resBob.Pos, roomID, eventID)
gotTs = resBob.Rooms[roomID].Timestamp
if gotTs != expectedTs {
t.Fatalf("expected timestamps to be the same, but they aren't: %v, expected %v", gotTs, expectedTs)
}
// Bob makes an initial sync again, he should still see the m.reaction timestamp
resBob = bob.SlidingSync(t, sync3.Request{
Lists: lists,
})
gotTs = resBob.Rooms[roomID].Timestamp
expectedTs = bobTimestampReaction
if gotTs != expectedTs {
t.Fatalf("initial sync contains wrong timestamp: %d, expected %d", gotTs, expectedTs)
}
// Charlie joins the room
charlie.JoinRoom(t, roomID, nil)
resAlice = alice.SlidingSyncUntilMembership(t, resAlice.Pos, roomID, charlie, "join")
resCharlie := charlie.SlidingSync(t, sync3.Request{
Lists: lists,
})
// Charlie just joined so should see the same timestamp as Alice, even if
// Charlie has the same bumpEvents as Bob, we don't leak those timestamps.
gotTs = resCharlie.Rooms[roomID].Timestamp
expectedTs = resAlice.Rooms[roomID].Timestamp
if gotTs != expectedTs {
t.Fatalf("Charlie should see the timestamp they joined, but didn't: %d, expected %d", gotTs, expectedTs)
}
// Initial sync without bump types should see the most recent timestamp
resAlice = alice.SlidingSync(t, sync3.Request{
RoomSubscriptions: map[string]sync3.RoomSubscription{
roomID: {
TimelineLimit: 10,
},
},
})
// expected TS stays the same, so the join of Charlie
gotTs = resAlice.Rooms[roomID].Timestamp
if gotTs != expectedTs {
t.Fatalf("Alice should see the timestamp of Charlie joining, but didn't: %d, expected %d", gotTs, expectedTs)
}
}

View File

@ -32,33 +32,10 @@ func TestTransactionIDsAppear(t *testing.T) {
// we cannot use MatchTimeline here because the Unsigned section contains 'age' which is not
// deterministic and MatchTimeline does not do partial matches.
matchTransactionID := func(eventID, txnID string) m.RoomMatcher {
return func(r sync3.Room) error {
for _, ev := range r.Timeline {
var got Event
if err := json.Unmarshal(ev, &got); err != nil {
return fmt.Errorf("failed to unmarshal event: %s", err)
}
if got.ID != eventID {
continue
}
tx, ok := got.Unsigned["transaction_id"]
if !ok {
return fmt.Errorf("unsigned block for %s has no transaction_id", eventID)
}
gotTxnID := tx.(string)
if gotTxnID != txnID {
return fmt.Errorf("wrong transaction_id, got %s want %s", gotTxnID, txnID)
}
t.Logf("%s has txn ID %s", eventID, gotTxnID)
return nil
}
return fmt.Errorf("not found event %s", eventID)
}
}
m.MatchResponse(t, res, m.MatchRoomSubscriptionsStrict(map[string][]m.RoomMatcher{
roomID: {
matchTransactionID(eventID, "foobar"),
matchTransactionID(t, eventID, "foobar"),
},
}))
@ -74,8 +51,85 @@ func TestTransactionIDsAppear(t *testing.T) {
res = client.SlidingSyncUntilEvent(t, res.Pos, sync3.Request{}, roomID, Event{ID: eventID})
m.MatchResponse(t, res, m.MatchRoomSubscriptionsStrict(map[string][]m.RoomMatcher{
roomID: {
matchTransactionID(eventID, "foobar2"),
matchTransactionID(t, eventID, "foobar2"),
},
}))
}
// This test has 1 poller expecting a txn ID and 10 others that won't see one.
// We test that sending device sees a txnID. Without the TxnIDWaiter logic in place,
// this test is likely (but not guaranteed) to fail.
func TestTransactionIDsAppearWithMultiplePollers(t *testing.T) {
alice := registerNamedUser(t, "alice")
t.Log("Alice creates a room and syncs until she sees it.")
roomID := alice.CreateRoom(t, map[string]interface{}{})
res := alice.SlidingSync(t, sync3.Request{
Lists: map[string]sync3.RequestList{
"a": {
RoomSubscription: sync3.RoomSubscription{
TimelineLimit: 10,
},
Ranges: sync3.SliceRanges{{0, 20}},
},
},
})
m.MatchResponse(t, res, m.MatchRoomSubscription(roomID))
t.Log("Alice makes other devices and starts them syncing.")
for i := 0; i < 10; i++ {
device := *alice
device.Login(t, "password", fmt.Sprintf("device_%d", i))
device.SlidingSync(t, sync3.Request{
Lists: map[string]sync3.RequestList{
"a": {
RoomSubscription: sync3.RoomSubscription{
TimelineLimit: 10,
},
Ranges: sync3.SliceRanges{{0, 20}},
},
},
})
}
t.Log("Alice sends a message with a transaction ID.")
const txnID = "foobar"
sendRes := alice.MustDoFunc(t, "PUT", []string{"_matrix", "client", "v3", "rooms", roomID, "send", "m.room.message", txnID},
WithJSONBody(t, map[string]interface{}{
"msgtype": "m.text",
"body": "Hello, world!",
}))
body := ParseJSON(t, sendRes)
eventID := GetJSONFieldStr(t, body, "event_id")
t.Log("Alice syncs on her main devices until she sees her message.")
res = alice.SlidingSyncUntilEventID(t, res.Pos, roomID, eventID)
m.MatchResponse(t, res, m.MatchRoomSubscription(roomID, matchTransactionID(t, eventID, txnID)))
}
func matchTransactionID(t *testing.T, eventID, txnID string) m.RoomMatcher {
return func(r sync3.Room) error {
for _, ev := range r.Timeline {
var got Event
if err := json.Unmarshal(ev, &got); err != nil {
return fmt.Errorf("failed to unmarshal event: %s", err)
}
if got.ID != eventID {
continue
}
tx, ok := got.Unsigned["transaction_id"]
if !ok {
return fmt.Errorf("unsigned block for %s has no transaction_id", eventID)
}
gotTxnID := tx.(string)
if gotTxnID != txnID {
return fmt.Errorf("wrong transaction_id, got %s want %s", gotTxnID, txnID)
}
t.Logf("%s has txn ID %s", eventID, gotTxnID)
return nil
}
return fmt.Errorf("not found event %s", eventID)
}
}

View File

@ -622,6 +622,20 @@ func TestSessionExpiryOnBufferFill(t *testing.T) {
if gjson.ParseBytes(body).Get("errcode").Str != "M_UNKNOWN_POS" {
t.Errorf("got %v want errcode=M_UNKNOWN_POS", string(body))
}
// make sure we can sync from fresh (regression for when we deadlocked after this point)
res = v3.mustDoV3Request(t, aliceToken, sync3.Request{
RoomSubscriptions: map[string]sync3.RoomSubscription{
roomID: {
TimelineLimit: 1,
},
},
})
m.MatchResponse(t, res, m.MatchRoomSubscriptionsStrict(map[string][]m.RoomMatcher{
roomID: {
m.MatchJoinCount(1),
},
}))
}
func TestExpiredAccessToken(t *testing.T) {

View File

@ -0,0 +1,103 @@
package syncv3
import (
"encoding/json"
"fmt"
"sync"
"testing"
"time"
syncv3 "github.com/matrix-org/sliding-sync"
"github.com/matrix-org/sliding-sync/sync2"
"github.com/matrix-org/sliding-sync/sync3"
"github.com/matrix-org/sliding-sync/testutils"
"github.com/matrix-org/sliding-sync/testutils/m"
)
// Test that the proxy works fine with low max conns. Low max conns can be a problem
// if a request A needs 2 conns to respond and that blocks forward progress on the server,
// and the request can only obtain 1 conn.
func TestMaxDBConns(t *testing.T) {
pqString := testutils.PrepareDBConnectionString()
// setup code
v2 := runTestV2Server(t)
opts := syncv3.Opts{
DBMaxConns: 1,
}
v3 := runTestServer(t, v2, pqString, opts)
defer v2.close()
defer v3.close()
testMaxDBConns := func() {
// make N users and drip feed some events, make sure they are all seen
numUsers := 5
var wg sync.WaitGroup
wg.Add(numUsers)
for i := 0; i < numUsers; i++ {
go func(n int) {
defer wg.Done()
userID := fmt.Sprintf("@maxconns_%d:localhost", n)
token := fmt.Sprintf("maxconns_%d", n)
roomID := fmt.Sprintf("!maxconns_%d", n)
v2.addAccount(t, userID, token)
v2.queueResponse(userID, sync2.SyncResponse{
Rooms: sync2.SyncRoomsResponse{
Join: v2JoinTimeline(roomEvents{
roomID: roomID,
state: createRoomState(t, userID, time.Now()),
}),
},
})
// initial sync
res := v3.mustDoV3Request(t, token, sync3.Request{
Lists: map[string]sync3.RequestList{"a": {
Ranges: sync3.SliceRanges{
[2]int64{0, 1},
},
RoomSubscription: sync3.RoomSubscription{
TimelineLimit: 1,
},
}},
})
t.Logf("user %s has done an initial /sync OK", userID)
m.MatchResponse(t, res, m.MatchList("a", m.MatchV3Count(1), m.MatchV3Ops(
m.MatchV3SyncOp(0, 0, []string{roomID}),
)), m.MatchRoomSubscriptionsStrict(map[string][]m.RoomMatcher{
roomID: {
m.MatchJoinCount(1),
},
}))
// drip feed and get update
dripMsg := testutils.NewEvent(t, "m.room.message", userID, map[string]interface{}{
"msgtype": "m.text",
"body": "drip drip",
})
v2.queueResponse(userID, sync2.SyncResponse{
Rooms: sync2.SyncRoomsResponse{
Join: v2JoinTimeline(roomEvents{
roomID: roomID,
events: []json.RawMessage{
dripMsg,
},
}),
},
})
t.Logf("user %s has queued the drip", userID)
v2.waitUntilEmpty(t, userID)
t.Logf("user %s poller has received the drip", userID)
res = v3.mustDoV3RequestWithPos(t, token, res.Pos, sync3.Request{})
m.MatchResponse(t, res, m.MatchRoomSubscriptionsStrict(map[string][]m.RoomMatcher{
roomID: {
m.MatchRoomTimelineMostRecent(1, []json.RawMessage{dripMsg}),
},
}))
t.Logf("user %s has received the drip", userID)
}(i)
}
wg.Wait()
}
testMaxDBConns()
v3.restart(t, v2, pqString, opts)
testMaxDBConns()
}

View File

@ -604,3 +604,171 @@ func TestExtensionLateEnable(t *testing.T) {
},
})
}
func TestTypingMultiplePoller(t *testing.T) {
pqString := testutils.PrepareDBConnectionString()
// setup code
v2 := runTestV2Server(t)
v3 := runTestServer(t, v2, pqString)
defer v2.close()
defer v3.close()
roomA := "!a:localhost"
v2.addAccountWithDeviceID(alice, "first", aliceToken)
v2.addAccountWithDeviceID(bob, "second", bobToken)
// Create the room state and join with Bob
roomState := createRoomState(t, alice, time.Now())
joinEv := testutils.NewStateEvent(t, "m.room.member", bob, alice, map[string]interface{}{
"membership": "join",
})
// Queue the response with Alice typing
v2.queueResponse(aliceToken, sync2.SyncResponse{
Rooms: sync2.SyncRoomsResponse{
Join: map[string]sync2.SyncV2JoinResponse{
roomA: {
State: sync2.EventsResponse{
Events: roomState,
},
Timeline: sync2.TimelineResponse{
Events: []json.RawMessage{joinEv},
},
Ephemeral: sync2.EventsResponse{
Events: []json.RawMessage{json.RawMessage(`{"type":"m.typing","content":{"user_ids":["@alice:localhost"]}}`)},
},
},
},
},
})
// Queue another response for Bob with Bob typing.
v2.queueResponse(bobToken, sync2.SyncResponse{
Rooms: sync2.SyncRoomsResponse{
Join: map[string]sync2.SyncV2JoinResponse{
roomA: {
State: sync2.EventsResponse{
Events: roomState,
},
Timeline: sync2.TimelineResponse{
Events: []json.RawMessage{joinEv},
},
Ephemeral: sync2.EventsResponse{
Events: []json.RawMessage{json.RawMessage(`{"type":"m.typing","content":{"user_ids":["@bob:localhost"]}}`)}},
},
},
},
})
// Start the pollers. Since Alice's poller is started first, the poller is in
// charge of handling typing notifications for roomA.
aliceRes := v3.mustDoV3Request(t, aliceToken, sync3.Request{})
bobRes := v3.mustDoV3Request(t, bobToken, sync3.Request{})
// Get the response from v3
for _, token := range []string{aliceToken, bobToken} {
pos := aliceRes.Pos
if token == bobToken {
pos = bobRes.Pos
}
res := v3.mustDoV3RequestWithPos(t, token, pos, sync3.Request{
Extensions: extensions.Request{
Typing: &extensions.TypingRequest{
Core: extensions.Core{Enabled: &boolTrue},
},
},
Lists: map[string]sync3.RequestList{"a": {
Ranges: sync3.SliceRanges{
[2]int64{0, 1},
},
Sort: []string{sync3.SortByRecency},
RoomSubscription: sync3.RoomSubscription{
TimelineLimit: 0,
},
}},
})
// We expect only Alice typing, as only Alice Poller is "allowed"
// to update typing notifications.
m.MatchResponse(t, res, m.MatchTyping(roomA, []string{alice}))
if token == bobToken {
bobRes = res
}
if token == aliceToken {
aliceRes = res
}
}
// Queue the response with Bob typing
v2.queueResponse(aliceToken, sync2.SyncResponse{
Rooms: sync2.SyncRoomsResponse{
Join: map[string]sync2.SyncV2JoinResponse{
roomA: {
State: sync2.EventsResponse{
Events: roomState,
},
Timeline: sync2.TimelineResponse{
Events: []json.RawMessage{joinEv},
},
Ephemeral: sync2.EventsResponse{
Events: []json.RawMessage{json.RawMessage(`{"type":"m.typing","content":{"user_ids":["@bob:localhost"]}}`)},
},
},
},
},
})
// Queue another response for Bob with Charlie typing.
// Since Alice's poller is in charge of handling typing notifications, this shouldn't
// show up on future responses.
v2.queueResponse(bobToken, sync2.SyncResponse{
Rooms: sync2.SyncRoomsResponse{
Join: map[string]sync2.SyncV2JoinResponse{
roomA: {
State: sync2.EventsResponse{
Events: roomState,
},
Timeline: sync2.TimelineResponse{
Events: []json.RawMessage{joinEv},
},
Ephemeral: sync2.EventsResponse{
Events: []json.RawMessage{json.RawMessage(`{"type":"m.typing","content":{"user_ids":["@charlie:localhost"]}}`)},
},
},
},
},
})
// Wait for the queued responses to be processed.
v2.waitUntilEmpty(t, aliceToken)
v2.waitUntilEmpty(t, bobToken)
// Check that only Bob is typing and not Charlie.
for _, token := range []string{aliceToken, bobToken} {
pos := aliceRes.Pos
if token == bobToken {
pos = bobRes.Pos
}
res := v3.mustDoV3RequestWithPos(t, token, pos, sync3.Request{
Extensions: extensions.Request{
Typing: &extensions.TypingRequest{
Core: extensions.Core{Enabled: &boolTrue},
},
},
Lists: map[string]sync3.RequestList{"a": {
Ranges: sync3.SliceRanges{
[2]int64{0, 1},
},
Sort: []string{sync3.SortByRecency},
RoomSubscription: sync3.RoomSubscription{
TimelineLimit: 0,
},
}},
})
// We expect only Bob typing, as only Alice Poller is "allowed"
// to update typing notifications.
m.MatchResponse(t, res, m.MatchTyping(roomA, []string{bob}))
}
}

View File

@ -0,0 +1,188 @@
package syncv3
import (
"encoding/json"
"github.com/matrix-org/sliding-sync/sync2"
"github.com/matrix-org/sliding-sync/sync3"
"github.com/matrix-org/sliding-sync/testutils"
"github.com/matrix-org/sliding-sync/testutils/m"
"testing"
"time"
)
// Test that messages from ignored users are not sent to clients, even if they appear
// on someone else's poller first.
func TestIgnoredUsersDuringLiveUpdate(t *testing.T) {
pqString := testutils.PrepareDBConnectionString()
v2 := runTestV2Server(t)
v3 := runTestServer(t, v2, pqString)
defer v2.close()
defer v3.close()
const nigel = "@nigel:localhost"
roomID := "!unimportant"
v2.addAccount(t, alice, aliceToken)
v2.addAccount(t, bob, bobToken)
// Bob creates a room. Nigel and Alice join.
state := createRoomState(t, bob, time.Now())
state = append(state, testutils.NewStateEvent(t, "m.room.member", nigel, nigel, map[string]interface{}{
"membership": "join",
}))
aliceJoin := testutils.NewStateEvent(t, "m.room.member", alice, alice, map[string]interface{}{
"membership": "join",
})
t.Log("Alice and Bob's pollers sees Alice's join.")
v2.queueResponse(alice, sync2.SyncResponse{
Rooms: sync2.SyncRoomsResponse{
Join: v2JoinTimeline(roomEvents{
roomID: roomID,
state: state,
events: []json.RawMessage{aliceJoin},
}),
},
NextBatch: "alice_sync_1",
})
v2.queueResponse(bob, sync2.SyncResponse{
Rooms: sync2.SyncRoomsResponse{
Join: v2JoinTimeline(roomEvents{
roomID: roomID,
state: state,
events: []json.RawMessage{aliceJoin},
}),
},
NextBatch: "bob_sync_1",
})
t.Log("Alice sliding syncs.")
aliceRes := v3.mustDoV3Request(t, aliceToken, sync3.Request{
Lists: map[string]sync3.RequestList{
"a": {
Ranges: sync3.SliceRanges{{0, 10}},
RoomSubscription: sync3.RoomSubscription{
TimelineLimit: 20,
},
},
},
})
t.Log("She should see her join.")
m.MatchResponse(t, aliceRes, m.MatchRoomSubscription(roomID, m.MatchRoomTimeline([]json.RawMessage{aliceJoin})))
t.Log("Bob sliding syncs.")
bobRes := v3.mustDoV3Request(t, bobToken, sync3.Request{
Lists: map[string]sync3.RequestList{
"a": {
Ranges: sync3.SliceRanges{{0, 10}},
RoomSubscription: sync3.RoomSubscription{
TimelineLimit: 20,
},
},
},
})
t.Log("He should see Alice's join.")
m.MatchResponse(t, bobRes, m.MatchRoomSubscription(roomID, m.MatchRoomTimeline([]json.RawMessage{aliceJoin})))
t.Log("Alice ignores Nigel.")
v2.queueResponse(alice, sync2.SyncResponse{
AccountData: sync2.EventsResponse{
Events: []json.RawMessage{
testutils.NewAccountData(t, "m.ignored_user_list", map[string]any{
"ignored_users": map[string]any{
nigel: map[string]any{},
},
}),
},
},
NextBatch: "alice_sync_2",
})
v2.waitUntilEmpty(t, alice)
t.Log("Bob's poller sees a message from Nigel, then a message from Alice.")
nigelMsg := testutils.NewMessageEvent(t, nigel, "naughty nigel")
aliceMsg := testutils.NewMessageEvent(t, alice, "angelic alice")
v2.queueResponse(bob, sync2.SyncResponse{
Rooms: sync2.SyncRoomsResponse{
Join: v2JoinTimeline(roomEvents{
roomID: roomID,
events: []json.RawMessage{nigelMsg, aliceMsg},
}),
},
NextBatch: "bob_sync_2",
})
v2.waitUntilEmpty(t, bob)
t.Log("Bob syncs. He should see both messages.")
bobRes = v3.mustDoV3RequestWithPos(t, bobToken, bobRes.Pos, sync3.Request{})
m.MatchResponse(t, bobRes, m.MatchRoomSubscription(roomID, m.MatchRoomTimeline([]json.RawMessage{nigelMsg, aliceMsg})))
t.Log("Alice syncs. She should see her message, but not Nigel's.")
aliceRes = v3.mustDoV3RequestWithPos(t, aliceToken, aliceRes.Pos, sync3.Request{})
m.MatchResponse(t, aliceRes, m.MatchRoomSubscription(roomID, m.MatchRoomTimeline([]json.RawMessage{aliceMsg})))
t.Log("Bob's poller sees Nigel set a custom state event")
nigelState := testutils.NewStateEvent(t, "com.example.fruit", "banana", nigel, map[string]any{})
v2.queueResponse(bob, sync2.SyncResponse{
Rooms: sync2.SyncRoomsResponse{
Join: v2JoinTimeline(roomEvents{
roomID: roomID,
events: []json.RawMessage{nigelState},
}),
},
NextBatch: "bob_sync_3",
})
v2.waitUntilEmpty(t, bob)
t.Log("Alice syncs. She should see Nigel's state event.")
aliceRes = v3.mustDoV3RequestWithPos(t, aliceToken, aliceRes.Pos, sync3.Request{})
m.MatchResponse(t, aliceRes, m.MatchRoomSubscription(roomID, m.MatchRoomTimeline([]json.RawMessage{nigelState})))
t.Log("Bob's poller sees Alice send a message.")
aliceMsg2 := testutils.NewMessageEvent(t, alice, "angelic alice 2")
v2.queueResponse(bob, sync2.SyncResponse{
Rooms: sync2.SyncRoomsResponse{
Join: v2JoinTimeline(roomEvents{
roomID: roomID,
events: []json.RawMessage{aliceMsg2},
}),
},
NextBatch: "bob_sync_4",
})
v2.waitUntilEmpty(t, bob)
t.Log("Alice syncs, making a new conn with a direct room subscription.")
aliceRes = v3.mustDoV3Request(t, aliceToken, sync3.Request{
Lists: map[string]sync3.RequestList{},
RoomSubscriptions: map[string]sync3.RoomSubscription{
roomID: {
TimelineLimit: 20,
},
},
})
t.Log("Alice sees her join, her messages and nigel's state in the timeline.")
m.MatchResponse(t, aliceRes, m.MatchRoomSubscription(roomID, m.MatchRoomTimeline([]json.RawMessage{
aliceJoin, aliceMsg, nigelState, aliceMsg2,
})))
t.Log("Bob's poller sees Nigel and Alice send a message.")
nigelMsg2 := testutils.NewMessageEvent(t, nigel, "naughty nigel 3")
aliceMsg3 := testutils.NewMessageEvent(t, alice, "angelic alice 3")
v2.queueResponse(bob, sync2.SyncResponse{
Rooms: sync2.SyncRoomsResponse{
Join: v2JoinTimeline(roomEvents{
roomID: roomID,
events: []json.RawMessage{nigelMsg2, aliceMsg3},
}),
},
NextBatch: "bob_sync_5",
})
v2.waitUntilEmpty(t, bob)
t.Log("Alice syncs. She should only see her message.")
aliceRes = v3.mustDoV3RequestWithPos(t, aliceToken, aliceRes.Pos, sync3.Request{})
m.MatchResponse(t, aliceRes, m.MatchRoomSubscription(roomID, m.MatchRoomTimeline([]json.RawMessage{aliceMsg3})))
}

View File

@ -0,0 +1,230 @@
package syncv3
import (
"encoding/json"
"testing"
"time"
"github.com/matrix-org/sliding-sync/sync2"
"github.com/matrix-org/sliding-sync/sync3"
"github.com/matrix-org/sliding-sync/testutils"
"github.com/matrix-org/sliding-sync/testutils/m"
)
// catch all file for any kind of regression test which doesn't fall into a unique category
// Regression test for https://github.com/matrix-org/sliding-sync/issues/192
// - Bob on his server invites Alice to a room.
// - Alice joins the room first over federation. Proxy does the right thing and sets her membership to join. There is no timeline though due to not having backfilled.
// - Alice's client backfills in the room which pulls in the invite event, but the SS proxy doesn't see it as it's backfill, not /sync.
// - Charlie joins the same room via SS, which makes the SS proxy see 50 timeline events, which includes the invite.
// As the proxy has never seen this invite event before, it assumes it is newer than the join event and inserts it, corrupting state.
//
// Manually confirmed this can happen with 3x Element clients. We need to make sure we drop those earlier events.
// The first join over federation presents itself as a single join event in the timeline, with the create event, etc in state.
func TestBackfillInviteDoesntCorruptState(t *testing.T) {
pqString := testutils.PrepareDBConnectionString()
// setup code
v2 := runTestV2Server(t)
v3 := runTestServer(t, v2, pqString)
defer v2.close()
defer v3.close()
fedBob := "@bob:over_federation"
charlie := "@charlie:localhost"
charlieToken := "CHARLIE_TOKEN"
joinEvent := testutils.NewJoinEvent(t, alice)
room := roomEvents{
roomID: "!TestBackfillInviteDoesntCorruptState:localhost",
events: []json.RawMessage{
joinEvent,
},
state: createRoomState(t, fedBob, time.Now()),
}
v2.addAccount(t, alice, aliceToken)
v2.queueResponse(alice, sync2.SyncResponse{
Rooms: sync2.SyncRoomsResponse{
Join: v2JoinTimeline(room),
},
})
// alice syncs and should see the room.
aliceRes := v3.mustDoV3Request(t, aliceToken, sync3.Request{
Lists: map[string]sync3.RequestList{
"a": {
Ranges: sync3.SliceRanges{{0, 20}},
RoomSubscription: sync3.RoomSubscription{
TimelineLimit: 5,
},
},
},
})
m.MatchResponse(t, aliceRes, m.MatchList("a", m.MatchV3Count(1), m.MatchV3Ops(m.MatchV3SyncOp(0, 0, []string{room.roomID}))))
// Alice's client "backfills" new data in, meaning the next user who joins is going to see a different set of timeline events
dummyMsg := testutils.NewMessageEvent(t, fedBob, "you didn't see this before joining")
charlieJoinEvent := testutils.NewJoinEvent(t, charlie)
backfilledTimelineEvents := append(
room.state, []json.RawMessage{
dummyMsg,
testutils.NewStateEvent(t, "m.room.member", alice, fedBob, map[string]interface{}{
"membership": "invite",
}),
joinEvent,
charlieJoinEvent,
}...,
)
// now charlie also joins the room, causing a different response from /sync v2
v2.addAccount(t, charlie, charlieToken)
v2.queueResponse(charlie, sync2.SyncResponse{
Rooms: sync2.SyncRoomsResponse{
Join: v2JoinTimeline(roomEvents{
roomID: room.roomID,
events: backfilledTimelineEvents,
}),
},
})
// and now charlie hits SS, which might corrupt membership state for alice.
charlieRes := v3.mustDoV3Request(t, charlieToken, sync3.Request{
Lists: map[string]sync3.RequestList{
"a": {
Ranges: sync3.SliceRanges{{0, 20}},
},
},
})
m.MatchResponse(t, charlieRes, m.MatchList("a", m.MatchV3Count(1), m.MatchV3Ops(m.MatchV3SyncOp(0, 0, []string{room.roomID}))))
// alice should not see dummyMsg or the invite
aliceRes = v3.mustDoV3RequestWithPos(t, aliceToken, aliceRes.Pos, sync3.Request{})
m.MatchResponse(t, aliceRes, m.MatchNoV3Ops(), m.LogResponse(t), m.MatchRoomSubscriptionsStrict(
map[string][]m.RoomMatcher{
room.roomID: {
m.MatchJoinCount(3), // alice, bob, charlie,
m.MatchNoInviteCount(),
m.MatchNumLive(1),
m.MatchRoomTimeline([]json.RawMessage{charlieJoinEvent}),
},
},
))
}
func TestMalformedEventsTimeline(t *testing.T) {
pqString := testutils.PrepareDBConnectionString()
// setup code
v2 := runTestV2Server(t)
v3 := runTestServer(t, v2, pqString)
defer v2.close()
defer v3.close()
// unusual events ARE VALID EVENTS and should be sent to the client, but are unusual for some reason.
unusualEvents := []json.RawMessage{
testutils.NewStateEvent(t, "", "", alice, map[string]interface{}{
"empty string": "for event type",
}),
}
// malformed events are INVALID and should be ignored by the proxy.
malformedEvents := []json.RawMessage{
json.RawMessage(`{}`), // empty object
json.RawMessage(`{"type":5}`), // type is an integer
json.RawMessage(`{"type":"foo","content":{},"event_id":""}`), // 0-length string as event ID
json.RawMessage(`{"type":"foo","content":{}}`), // missing event ID
}
room := roomEvents{
roomID: "!TestMalformedEventsTimeline:localhost",
// append malformed after unusual. All malformed events should be dropped,
// leaving only unusualEvents.
events: append(unusualEvents, malformedEvents...),
state: createRoomState(t, alice, time.Now()),
}
v2.addAccount(t, alice, aliceToken)
v2.queueResponse(alice, sync2.SyncResponse{
Rooms: sync2.SyncRoomsResponse{
Join: v2JoinTimeline(room),
},
})
// alice syncs and should see the room.
aliceRes := v3.mustDoV3Request(t, aliceToken, sync3.Request{
Lists: map[string]sync3.RequestList{
"a": {
Ranges: sync3.SliceRanges{{0, 20}},
RoomSubscription: sync3.RoomSubscription{
TimelineLimit: int64(len(unusualEvents)),
},
},
},
})
m.MatchResponse(t, aliceRes, m.MatchList("a", m.MatchV3Count(1), m.MatchV3Ops(m.MatchV3SyncOp(0, 0, []string{room.roomID}))),
m.MatchRoomSubscriptionsStrict(map[string][]m.RoomMatcher{
room.roomID: {
m.MatchJoinCount(1),
m.MatchRoomTimeline(unusualEvents),
},
}))
}
func TestMalformedEventsState(t *testing.T) {
pqString := testutils.PrepareDBConnectionString()
// setup code
v2 := runTestV2Server(t)
v3 := runTestServer(t, v2, pqString)
defer v2.close()
defer v3.close()
// unusual events ARE VALID EVENTS and should be sent to the client, but are unusual for some reason.
unusualEvents := []json.RawMessage{
testutils.NewStateEvent(t, "", "", alice, map[string]interface{}{
"empty string": "for event type",
}),
}
// malformed events are INVALID and should be ignored by the proxy.
malformedEvents := []json.RawMessage{
json.RawMessage(`{}`), // empty object
json.RawMessage(`{"type":5,"content":{},"event_id":"f","state_key":""}`), // type is an integer
json.RawMessage(`{"type":"foo","content":{},"event_id":"","state_key":""}`), // 0-length string as event ID
json.RawMessage(`{"type":"foo","content":{},"state_key":""}`), // missing event ID
}
latestEvent := testutils.NewEvent(t, "m.room.message", alice, map[string]interface{}{"body": "hi"})
room := roomEvents{
roomID: "!TestMalformedEventsState:localhost",
events: []json.RawMessage{latestEvent},
// append malformed after unusual. All malformed events should be dropped,
// leaving only unusualEvents.
state: append(createRoomState(t, alice, time.Now()), append(unusualEvents, malformedEvents...)...),
}
v2.addAccount(t, alice, aliceToken)
v2.queueResponse(alice, sync2.SyncResponse{
Rooms: sync2.SyncRoomsResponse{
Join: v2JoinTimeline(room),
},
})
// alice syncs and should see the room.
aliceRes := v3.mustDoV3Request(t, aliceToken, sync3.Request{
Lists: map[string]sync3.RequestList{
"a": {
Ranges: sync3.SliceRanges{{0, 20}},
RoomSubscription: sync3.RoomSubscription{
TimelineLimit: int64(len(unusualEvents)),
RequiredState: [][2]string{{"", ""}},
},
},
},
})
m.MatchResponse(t, aliceRes, m.MatchList("a", m.MatchV3Count(1), m.MatchV3Ops(m.MatchV3SyncOp(0, 0, []string{room.roomID}))),
m.MatchRoomSubscriptionsStrict(map[string][]m.RoomMatcher{
room.roomID: {
m.MatchJoinCount(1),
m.MatchRoomTimeline([]json.RawMessage{latestEvent}),
m.MatchRoomRequiredState([]json.RawMessage{
unusualEvents[0],
}),
},
}))
}

View File

@ -137,12 +137,9 @@ func TestRoomSubscriptionMisorderedTimeline(t *testing.T) {
})
m.MatchResponse(t, res, m.MatchRoomSubscriptionsStrict(map[string][]m.RoomMatcher{
room.roomID: {
// TODO: this is the correct result, but due to how timeline loading works currently
// it will be returning the last 5 events BEFORE D,E, which isn't ideal but also isn't
// incorrect per se due to the fact that clients don't know when D,E have been processed
// on the server.
// m.MatchRoomTimeline(append(abcInitialEvents, deLiveEvents...)),
m.MatchRoomTimeline(append(roomState[len(roomState)-2:], abcInitialEvents...)),
// we append live events AFTER processing the new timeline limit, so 7 events not 5.
// TODO: ideally we'd just return abcde here.
m.MatchRoomTimeline(append(roomState[len(roomState)-2:], append(abcInitialEvents, deLiveEvents...)...)),
},
}), m.LogResponse(t))

View File

@ -3,6 +3,7 @@ package syncv3
import (
"encoding/json"
"fmt"
slidingsync "github.com/matrix-org/sliding-sync"
"testing"
"time"
@ -689,6 +690,247 @@ func TestTimelineTxnID(t *testing.T) {
))
}
// TestTimelineTxnID checks that Alice sees her transaction_id if
// - Bob's poller sees Alice's event,
// - Alice's poller sees Alice's event with txn_id, and
// - Alice syncs.
//
// This test is similar but not identical. It checks that Alice sees her transaction_id if
// - Bob's poller sees Alice's event,
// - Alice does an incremental sync, which should omit her event,
// - Alice's poller sees Alice's event with txn_id, and
// - Alice syncs, seeing her event with txn_id.
func TestTimelineTxnIDBuffersForTxnID(t *testing.T) {
pqString := testutils.PrepareDBConnectionString()
// setup code
v2 := runTestV2Server(t)
v3 := runTestServer(t, v2, pqString, slidingsync.Opts{
// This needs to be greater than the request timeout, which is hardcoded to a
// minimum of 100ms in connStateLive.liveUpdate. This ensures that the
// liveUpdate call finishes before the TxnIDWaiter publishes the update,
// meaning that Alice doesn't see her event before the txn ID is known.
MaxTransactionIDDelay: 200 * time.Millisecond,
})
defer v2.close()
defer v3.close()
roomID := "!a:localhost"
latestTimestamp := time.Now()
t.Log("Alice and Bob are in the same room")
room := roomEvents{
roomID: roomID,
events: append(
createRoomState(t, alice, latestTimestamp),
testutils.NewJoinEvent(t, bob),
),
}
v2.addAccount(t, alice, aliceToken)
v2.addAccount(t, bob, bobToken)
v2.queueResponse(alice, sync2.SyncResponse{
Rooms: sync2.SyncRoomsResponse{
Join: v2JoinTimeline(room),
},
NextBatch: "alice_after_initial_poll",
})
v2.queueResponse(bob, sync2.SyncResponse{
Rooms: sync2.SyncRoomsResponse{
Join: v2JoinTimeline(room),
},
NextBatch: "bob_after_initial_poll",
})
t.Log("Alice and Bob make initial sliding syncs.")
aliceRes := v3.mustDoV3Request(t, aliceToken, sync3.Request{
Lists: map[string]sync3.RequestList{"a": {
Ranges: sync3.SliceRanges{
[2]int64{0, 10},
},
RoomSubscription: sync3.RoomSubscription{
TimelineLimit: 2,
},
},
},
})
bobRes := v3.mustDoV3Request(t, bobToken, sync3.Request{
Lists: map[string]sync3.RequestList{"a": {
Ranges: sync3.SliceRanges{
[2]int64{0, 10},
},
RoomSubscription: sync3.RoomSubscription{
TimelineLimit: 2,
},
},
},
})
t.Log("Alice has sent a message... but it arrives down Bob's poller first, without a transaction_id")
txnID := "m1234567890"
newEvent := testutils.NewEvent(t, "m.room.message", alice, map[string]interface{}{"body": "hi"}, testutils.WithUnsigned(map[string]interface{}{
"transaction_id": txnID,
}))
newEventNoUnsigned, err := sjson.DeleteBytes(newEvent, "unsigned")
if err != nil {
t.Fatalf("failed to delete bytes: %s", err)
}
v2.queueResponse(bob, sync2.SyncResponse{
Rooms: sync2.SyncRoomsResponse{
Join: v2JoinTimeline(roomEvents{
roomID: roomID,
events: []json.RawMessage{newEventNoUnsigned},
}),
},
})
t.Log("Bob's poller sees the message.")
v2.waitUntilEmpty(t, bob)
t.Log("Bob makes an incremental sliding sync")
bobRes = v3.mustDoV3RequestWithPos(t, bobToken, bobRes.Pos, sync3.Request{})
t.Log("Bob should see the message without a transaction_id")
m.MatchResponse(t, bobRes, m.MatchList("a", m.MatchV3Count(1)), m.MatchNoV3Ops(), m.MatchRoomSubscription(
roomID, m.MatchRoomTimelineMostRecent(1, []json.RawMessage{newEventNoUnsigned}),
))
t.Log("Alice requests an incremental sliding sync with no request changes.")
aliceRes = v3.mustDoV3RequestWithPos(t, aliceToken, aliceRes.Pos, sync3.Request{})
t.Log("Alice should see no messages.")
m.MatchResponse(t, aliceRes, m.MatchRoomSubscriptionsStrict(nil))
// Now the message arrives down Alice's poller.
v2.queueResponse(alice, sync2.SyncResponse{
Rooms: sync2.SyncRoomsResponse{
Join: v2JoinTimeline(roomEvents{
roomID: roomID,
events: []json.RawMessage{newEvent},
}),
},
})
t.Log("Alice's poller sees the message with transaction_id.")
v2.waitUntilEmpty(t, alice)
t.Log("Alice makes another incremental sync request.")
aliceRes = v3.mustDoV3RequestWithPos(t, aliceToken, aliceRes.Pos, sync3.Request{})
t.Log("Alice's sync response includes the message with the txn ID.")
m.MatchResponse(t, aliceRes, m.MatchList("a", m.MatchV3Count(1)), m.MatchNoV3Ops(), m.MatchRoomSubscription(
roomID, m.MatchRoomTimelineMostRecent(1, []json.RawMessage{newEvent}),
))
}
// Similar to TestTimelineTxnIDBuffersForTxnID, this test checks:
// - Bob's poller sees Alice's event,
// - Alice does an incremental sync, which should omit her event,
// - Alice's poller sees Alice's event without a txn_id, and
// - Alice syncs, seeing her event without txn_id.
// I.e. we're checking that the "all clear" empties out the buffer of events.
func TestTimelineTxnIDRespectsAllClear(t *testing.T) {
pqString := testutils.PrepareDBConnectionString()
// setup code
v2 := runTestV2Server(t)
v3 := runTestServer(t, v2, pqString, slidingsync.Opts{
// This needs to be greater than the request timeout, which is hardcoded to a
// minimum of 100ms in connStateLive.liveUpdate. This ensures that the
// liveUpdate call finishes before the TxnIDWaiter publishes the update,
// meaning that Alice doesn't see her event before the txn ID is known.
MaxTransactionIDDelay: 200 * time.Millisecond,
})
defer v2.close()
defer v3.close()
roomID := "!a:localhost"
latestTimestamp := time.Now()
t.Log("Alice and Bob are in the same room")
room := roomEvents{
roomID: roomID,
events: append(
createRoomState(t, alice, latestTimestamp),
testutils.NewJoinEvent(t, bob),
),
}
v2.addAccount(t, alice, aliceToken)
v2.addAccount(t, bob, bobToken)
v2.queueResponse(alice, sync2.SyncResponse{
Rooms: sync2.SyncRoomsResponse{
Join: v2JoinTimeline(room),
},
NextBatch: "alice_after_initial_poll",
})
v2.queueResponse(bob, sync2.SyncResponse{
Rooms: sync2.SyncRoomsResponse{
Join: v2JoinTimeline(room),
},
NextBatch: "bob_after_initial_poll",
})
t.Log("Alice and Bob make initial sliding syncs.")
aliceRes := v3.mustDoV3Request(t, aliceToken, sync3.Request{
Lists: map[string]sync3.RequestList{"a": {
Ranges: sync3.SliceRanges{
[2]int64{0, 10},
},
RoomSubscription: sync3.RoomSubscription{
TimelineLimit: 2,
},
},
},
})
bobRes := v3.mustDoV3Request(t, bobToken, sync3.Request{
Lists: map[string]sync3.RequestList{"a": {
Ranges: sync3.SliceRanges{
[2]int64{0, 10},
},
RoomSubscription: sync3.RoomSubscription{
TimelineLimit: 2,
},
},
},
})
t.Log("Alice has sent a message... but it arrives down Bob's poller first, without a transaction_id")
newEventNoTxn := testutils.NewEvent(t, "m.room.message", alice, map[string]interface{}{"body": "hi"})
v2.queueResponse(bob, sync2.SyncResponse{
Rooms: sync2.SyncRoomsResponse{
Join: v2JoinTimeline(roomEvents{
roomID: roomID,
events: []json.RawMessage{newEventNoTxn},
}),
},
})
t.Log("Bob's poller sees the message.")
v2.waitUntilEmpty(t, bob)
t.Log("Bob makes an incremental sliding sync")
bobRes = v3.mustDoV3RequestWithPos(t, bobToken, bobRes.Pos, sync3.Request{})
t.Log("Bob should see the message without a transaction_id")
m.MatchResponse(t, bobRes, m.MatchList("a", m.MatchV3Count(1)), m.MatchNoV3Ops(), m.MatchRoomSubscription(
roomID, m.MatchRoomTimelineMostRecent(1, []json.RawMessage{newEventNoTxn}),
))
t.Log("Alice requests an incremental sliding sync with no request changes.")
aliceRes = v3.mustDoV3RequestWithPos(t, aliceToken, aliceRes.Pos, sync3.Request{})
t.Log("Alice should see no messages.")
m.MatchResponse(t, aliceRes, m.MatchRoomSubscriptionsStrict(nil))
// Now the message arrives down Alice's poller.
v2.queueResponse(alice, sync2.SyncResponse{
Rooms: sync2.SyncRoomsResponse{
Join: v2JoinTimeline(roomEvents{
roomID: roomID,
events: []json.RawMessage{newEventNoTxn},
}),
},
})
t.Log("Alice's poller sees the message without transaction_id.")
v2.waitUntilEmpty(t, alice)
t.Log("Alice makes another incremental sync request.")
aliceRes = v3.mustDoV3RequestWithPos(t, aliceToken, aliceRes.Pos, sync3.Request{})
t.Log("Alice's sync response includes the event without a txn ID.")
m.MatchResponse(t, aliceRes, m.MatchList("a", m.MatchV3Count(1)), m.MatchNoV3Ops(), m.MatchRoomSubscription(
roomID, m.MatchRoomTimelineMostRecent(1, []json.RawMessage{newEventNoTxn}),
))
}
// Executes a sync v3 request without a ?pos and asserts that the count, rooms and timeline events m.Match the inputs given.
func testTimelineLoadInitialEvents(v3 *testV3Server, token string, count int, wantRooms []roomEvents, numTimelineEventsPerRoom int) func(t *testing.T) {
return func(t *testing.T) {

View File

@ -3,11 +3,12 @@ package syncv3
import (
"context"
"encoding/json"
"github.com/tidwall/gjson"
"net/http"
"testing"
"time"
"github.com/tidwall/gjson"
"github.com/matrix-org/sliding-sync/sync2"
"github.com/matrix-org/sliding-sync/sync3"
"github.com/matrix-org/sliding-sync/testutils"

View File

@ -291,11 +291,11 @@ func (s *testV3Server) close() {
s.h2.Teardown()
}
func (s *testV3Server) restart(t *testing.T, v2 *testV2Server, pq string) {
func (s *testV3Server) restart(t *testing.T, v2 *testV2Server, pq string, opts ...syncv3.Opts) {
t.Helper()
log.Printf("restarting server")
s.close()
ss := runTestServer(t, v2, pq)
ss := runTestServer(t, v2, pq, opts...)
// replace all the fields which will be close()d to ensure we don't leak
s.srv = ss.srv
s.h2 = ss.h2
@ -366,20 +366,24 @@ func runTestServer(t testutils.TestBenchInterface, v2Server *testV2Server, postg
//tests often repeat requests. To ensure tests remain fast, reduce the spam protection limits.
sync3.SpamProtectionInterval = time.Millisecond
metricsEnabled := false
maxPendingEventUpdates := 200
combinedOpts := syncv3.Opts{
TestingSynchronousPubsub: true, // critical to avoid flakey tests
AddPrometheusMetrics: false,
MaxPendingEventUpdates: 200,
MaxTransactionIDDelay: 0, // disable the txnID buffering to avoid flakey tests
}
if len(opts) > 0 {
metricsEnabled = opts[0].AddPrometheusMetrics
if opts[0].MaxPendingEventUpdates > 0 {
maxPendingEventUpdates = opts[0].MaxPendingEventUpdates
opt := opts[0]
combinedOpts.AddPrometheusMetrics = opt.AddPrometheusMetrics
combinedOpts.DBConnMaxIdleTime = opt.DBConnMaxIdleTime
combinedOpts.DBMaxConns = opt.DBMaxConns
combinedOpts.MaxTransactionIDDelay = opt.MaxTransactionIDDelay
if opt.MaxPendingEventUpdates > 0 {
combinedOpts.MaxPendingEventUpdates = opt.MaxPendingEventUpdates
handler.BufferWaitTime = 5 * time.Millisecond
}
}
h2, h3 := syncv3.Setup(v2Server.url(), postgresConnectionString, os.Getenv("SYNCV3_SECRET"), syncv3.Opts{
TestingSynchronousPubsub: true, // critical to avoid flakey tests
MaxPendingEventUpdates: maxPendingEventUpdates,
AddPrometheusMetrics: metricsEnabled,
})
h2, h3 := syncv3.Setup(v2Server.url(), postgresConnectionString, os.Getenv("SYNCV3_SECRET"), combinedOpts)
// for ease of use we don't start v2 pollers at startup in tests
r := mux.NewRouter()
r.Use(hlog.NewHandler(logger))

View File

@ -39,6 +39,39 @@ func MatchRoomName(name string) RoomMatcher {
}
}
// MatchRoomAvatar builds a RoomMatcher which checks that the given room response has
// set the room's avatar to the given value.
func MatchRoomAvatar(wantAvatar string) RoomMatcher {
return func(r sync3.Room) error {
if string(r.AvatarChange) != wantAvatar {
return fmt.Errorf("MatchRoomAvatar: got \"%s\" want \"%s\"", r.AvatarChange, wantAvatar)
}
return nil
}
}
// MatchRoomUnsetAvatar builds a RoomMatcher which checks that the given room has no
// avatar, or has had its avatar deleted.
func MatchRoomUnsetAvatar() RoomMatcher {
return func(r sync3.Room) error {
if r.AvatarChange != sync3.DeletedAvatar {
return fmt.Errorf("MatchRoomAvatar: got \"%s\" want \"%s\"", r.AvatarChange, sync3.DeletedAvatar)
}
return nil
}
}
// MatchRoomUnchangedAvatar builds a RoomMatcher which checks that the given room has no
// change to its avatar, or has had its avatar deleted.
func MatchRoomUnchangedAvatar() RoomMatcher {
return func(r sync3.Room) error {
if r.AvatarChange != sync3.UnchangedAvatar {
return fmt.Errorf("MatchRoomAvatar: got \"%s\" want \"%s\"", r.AvatarChange, sync3.UnchangedAvatar)
}
return nil
}
}
func MatchJoinCount(count int) RoomMatcher {
return func(r sync3.Room) error {
if r.JoinedCount != count {
@ -48,10 +81,22 @@ func MatchJoinCount(count int) RoomMatcher {
}
}
func MatchNoInviteCount() RoomMatcher {
return func(r sync3.Room) error {
if r.InvitedCount != nil {
return fmt.Errorf("MatchInviteCount: invited_count is present when it should be missing: val=%v", *r.InvitedCount)
}
return nil
}
}
func MatchInviteCount(count int) RoomMatcher {
return func(r sync3.Room) error {
if r.InvitedCount != count {
return fmt.Errorf("MatchInviteCount: got %v want %v", r.InvitedCount, count)
if r.InvitedCount == nil {
return fmt.Errorf("MatchInviteCount: invited_count is missing")
}
if *r.InvitedCount != count {
return fmt.Errorf("MatchInviteCount: got %v want %v", *r.InvitedCount, count)
}
return nil
}
@ -210,11 +255,11 @@ func MatchRoomSubscription(roomID string, matchers ...RoomMatcher) RespMatcher {
return func(res *sync3.Response) error {
room, ok := res.Rooms[roomID]
if !ok {
return fmt.Errorf("MatchRoomSubscription: want sub for %s but it was missing", roomID)
return fmt.Errorf("MatchRoomSubscription[%s]: want sub but it was missing", roomID)
}
for _, m := range matchers {
if err := m(room); err != nil {
return fmt.Errorf("MatchRoomSubscription: %s", err)
return fmt.Errorf("MatchRoomSubscription[%s]: %s", roomID, err)
}
}
return nil
@ -633,6 +678,15 @@ func LogResponse(t *testing.T) RespMatcher {
}
}
// LogRooms is like LogResponse, but only logs the rooms section of the response.
func LogRooms(t *testing.T) RespMatcher {
return func(res *sync3.Response) error {
dump, _ := json.MarshalIndent(res.Rooms, "", " ")
t.Logf("Response rooms were: %s", dump)
return nil
}
}
func CheckList(listKey string, res sync3.ResponseList, matchers ...ListMatcher) error {
for _, m := range matchers {
if err := m(res); err != nil {
@ -675,13 +729,16 @@ func MatchLists(matchers map[string][]ListMatcher) RespMatcher {
}
}
const AnsiRedForeground = "\x1b[31m"
const AnsiResetForeground = "\x1b[39m"
func MatchResponse(t *testing.T, res *sync3.Response, matchers ...RespMatcher) {
t.Helper()
for _, m := range matchers {
err := m(res)
if err != nil {
b, _ := json.Marshal(res)
t.Errorf("MatchResponse: %s\n%+v", err, string(b))
b, _ := json.MarshalIndent(res, "", " ")
t.Errorf("%vMatchResponse: %s\n%s%v", AnsiRedForeground, err, string(b), AnsiResetForeground)
}
}
}

51
v3.go
View File

@ -1,6 +1,7 @@
package slidingsync
import (
"embed"
"encoding/json"
"fmt"
"net/http"
@ -9,18 +10,24 @@ import (
"time"
"github.com/getsentry/sentry-go"
"github.com/gorilla/mux"
"github.com/jmoiron/sqlx"
"github.com/matrix-org/sliding-sync/internal"
"github.com/matrix-org/sliding-sync/pubsub"
"github.com/matrix-org/sliding-sync/state"
"github.com/matrix-org/sliding-sync/sync2"
"github.com/matrix-org/sliding-sync/sync2/handler2"
"github.com/matrix-org/sliding-sync/sync3/handler"
"github.com/pressly/goose/v3"
"github.com/rs/zerolog"
"github.com/rs/zerolog/hlog"
_ "github.com/matrix-org/sliding-sync/state/migrations"
)
//go:embed state/migrations/*
var EmbedMigrations embed.FS
var logger = zerolog.New(os.Stdout).With().Timestamp().Logger().Output(zerolog.ConsoleWriter{
Out: os.Stderr,
TimeFormat: "15:04:05",
@ -36,6 +43,13 @@ type Opts struct {
// if true, publishing messages will block until the consumer has consumed it.
// Assumes a single producer and a single consumer.
TestingSynchronousPubsub bool
// MaxTransactionIDDelay is the longest amount of time that we will wait for
// confirmation of an event's transaction_id before sending it to its sender.
// Set to 0 to disable this delay mechanism entirely.
MaxTransactionIDDelay time.Duration
DBMaxConns int
DBConnMaxIdleTime time.Duration
}
type server struct {
@ -73,11 +87,38 @@ func Setup(destHomeserver, postgresURI, secret string, opts Opts) (*handler2.Han
},
DestinationServer: destHomeserver,
}
store := state.NewStorage(postgresURI)
storev2 := sync2.NewStore(postgresURI, secret)
db, err := sqlx.Open("postgres", postgresURI)
if err != nil {
sentry.CaptureException(err)
// TODO: if we panic(), will sentry have a chance to flush the event?
logger.Panic().Err(err).Str("uri", postgresURI).Msg("failed to open SQL DB")
}
if opts.DBMaxConns > 0 {
// https://github.com/go-sql-driver/mysql#important-settings
// "db.SetMaxIdleConns() is recommended to be set same to db.SetMaxOpenConns(). When it is smaller
// than SetMaxOpenConns(), connections can be opened and closed much more frequently than you expect."
db.SetMaxOpenConns(opts.DBMaxConns)
db.SetMaxIdleConns(opts.DBMaxConns)
}
if opts.DBConnMaxIdleTime > 0 {
db.SetConnMaxIdleTime(opts.DBConnMaxIdleTime)
}
store := state.NewStorageWithDB(db)
storev2 := sync2.NewStoreWithDB(db, secret)
// Automatically execute migrations
goose.SetBaseFS(EmbedMigrations)
err = goose.Up(db.DB, "state/migrations", goose.WithAllowMissing())
if err != nil {
logger.Panic().Err(err).Msg("failed to execute migrations")
}
bufferSize := 50
deviceDataUpdateFrequency := time.Second
if opts.TestingSynchronousPubsub {
bufferSize = 0
deviceDataUpdateFrequency = 0 // don't batch
}
if opts.MaxPendingEventUpdates == 0 {
opts.MaxPendingEventUpdates = 2000
@ -86,14 +127,14 @@ func Setup(destHomeserver, postgresURI, secret string, opts Opts) (*handler2.Han
pMap := sync2.NewPollerMap(v2Client, opts.AddPrometheusMetrics)
// create v2 handler
h2, err := handler2.NewHandler(postgresURI, pMap, storev2, store, v2Client, pubSub, pubSub, opts.AddPrometheusMetrics)
h2, err := handler2.NewHandler(pMap, storev2, store, pubSub, pubSub, opts.AddPrometheusMetrics, deviceDataUpdateFrequency)
if err != nil {
panic(err)
}
pMap.SetCallbacks(h2)
// create v3 handler
h3, err := handler.NewSync3Handler(store, storev2, v2Client, postgresURI, secret, pubSub, pubSub, opts.AddPrometheusMetrics, opts.MaxPendingEventUpdates)
h3, err := handler.NewSync3Handler(store, storev2, v2Client, secret, pubSub, pubSub, opts.AddPrometheusMetrics, opts.MaxPendingEventUpdates, opts.MaxTransactionIDDelay)
if err != nil {
panic(err)
}