mirror of
https://github.com/matrix-org/sliding-sync.git
synced 2025-03-10 13:37:11 +00:00
Merge remote-tracking branch 'origin/main' into dmr/extension-scoping-fix
This commit is contained in:
commit
a51230d852
2
.github/workflows/docker.yml
vendored
2
.github/workflows/docker.yml
vendored
@ -26,7 +26,7 @@ jobs:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v3
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v1
|
||||
uses: docker/setup-qemu-action@v2
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v2
|
||||
- name: Login to GitHub Containers
|
||||
|
137
.github/workflows/tests.yml
vendored
137
.github/workflows/tests.yml
vendored
@ -59,7 +59,7 @@ jobs:
|
||||
- name: Test
|
||||
run: |
|
||||
set -euo pipefail
|
||||
go test -count=1 -covermode=atomic -coverpkg ./... -p 1 -v -json $(go list ./... | grep -v tests-e2e) -coverprofile synccoverage.out 2>&1 | tee ./test-integration.log | gotestfmt
|
||||
go test -count=1 -covermode=atomic -coverpkg ./... -p 1 -v -json $(go list ./... | grep -v tests-e2e) -coverprofile synccoverage.out 2>&1 | tee ./test-integration.log | gotestfmt -hide all
|
||||
shell: bash
|
||||
env:
|
||||
POSTGRES_HOST: localhost
|
||||
@ -88,6 +88,10 @@ jobs:
|
||||
if-no-files-found: error
|
||||
end_to_end:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
# test with unlimited + 1 + 2 max db conns. If we end up double transacting in the tests anywhere, conn=1 tests will fail.
|
||||
max_db_conns: [0,1,2]
|
||||
services:
|
||||
synapse:
|
||||
# Custom image built from https://github.com/matrix-org/synapse/tree/v1.72.0/docker/complement with a dummy /complement/ca set
|
||||
@ -97,6 +101,11 @@ jobs:
|
||||
SERVER_NAME: synapse
|
||||
ports:
|
||||
- 8008:8008
|
||||
# Set health checks to wait until synapse has started
|
||||
options: >-
|
||||
--health-interval 10s
|
||||
--health-timeout 5s
|
||||
--health-retries 5
|
||||
# Label used to access the service container
|
||||
postgres:
|
||||
# Docker Hub image
|
||||
@ -135,13 +144,14 @@ jobs:
|
||||
- name: Run end-to-end tests
|
||||
run: |
|
||||
set -euo pipefail
|
||||
./run-tests.sh -count=1 -v -json . 2>&1 | tee test-e2e-runner.log | gotestfmt
|
||||
./run-tests.sh -count=1 -v -json . 2>&1 | tee test-e2e-runner.log | gotestfmt -hide all
|
||||
working-directory: tests-e2e
|
||||
shell: bash
|
||||
env:
|
||||
SYNCV3_DB: user=postgres dbname=syncv3 sslmode=disable password=postgres host=localhost
|
||||
SYNCV3_SERVER: http://localhost:8008
|
||||
SYNCV3_SECRET: itsasecret
|
||||
SYNCV3_MAX_DB_CONN: ${{ matrix.max_db_conns }}
|
||||
E2E_TEST_SERVER_STDOUT: test-e2e-server.log
|
||||
|
||||
- name: Upload test log
|
||||
@ -164,11 +174,14 @@ jobs:
|
||||
- uses: actions/checkout@v3
|
||||
with:
|
||||
repository: matrix-org/matrix-react-sdk
|
||||
ref: "v3.71.0" # later versions break the SS E2E tests which need to be fixed :(
|
||||
- uses: actions/setup-node@v3
|
||||
with:
|
||||
cache: 'yarn'
|
||||
- name: Fetch layered build
|
||||
run: scripts/ci/layered.sh
|
||||
env:
|
||||
JS_SDK_GITHUB_BASE_REF: "v25.0.0-rc.1"
|
||||
- name: Copy config
|
||||
run: cp element.io/develop/config.json config.json
|
||||
working-directory: ./element-web
|
||||
@ -195,3 +208,123 @@ jobs:
|
||||
path: |
|
||||
cypress/screenshots
|
||||
cypress/videos
|
||||
|
||||
|
||||
upgrade-test:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
env:
|
||||
PREV_VERSION: "v0.99.4"
|
||||
|
||||
# Service containers to run with `container-job`
|
||||
services:
|
||||
synapse:
|
||||
# Custom image built from https://github.com/matrix-org/synapse/tree/v1.72.0/docker/complement with a dummy /complement/ca set
|
||||
image: ghcr.io/matrix-org/synapse-service:v1.72.0
|
||||
env:
|
||||
SYNAPSE_COMPLEMENT_DATABASE: sqlite
|
||||
SERVER_NAME: synapse
|
||||
ports:
|
||||
- 8008:8008
|
||||
# Set health checks to wait until synapse has started
|
||||
options: >-
|
||||
--health-interval 10s
|
||||
--health-timeout 5s
|
||||
--health-retries 5
|
||||
# Label used to access the service container
|
||||
postgres:
|
||||
# Docker Hub image
|
||||
image: postgres:13-alpine
|
||||
# Provide the password for postgres
|
||||
env:
|
||||
POSTGRES_USER: postgres
|
||||
POSTGRES_PASSWORD: postgres
|
||||
POSTGRES_DB: syncv3
|
||||
ports:
|
||||
# Maps tcp port 5432 on service container to the host
|
||||
- 5432:5432
|
||||
# Set health checks to wait until postgres has started
|
||||
options: >-
|
||||
--health-cmd pg_isready
|
||||
--health-interval 10s
|
||||
--health-timeout 5s
|
||||
--health-retries 5
|
||||
|
||||
steps:
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: 1.19
|
||||
|
||||
# E2E tests with ${{env.PREV_VERSION}}
|
||||
- uses: actions/checkout@v3
|
||||
with:
|
||||
ref: ${{env.PREV_VERSION}}
|
||||
|
||||
- name: Set up gotestfmt
|
||||
uses: GoTestTools/gotestfmt-action@v2
|
||||
with:
|
||||
# Note: constrained to `packages:read` only at the top of the file
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Build ${{env.PREV_VERSION}}
|
||||
run: go build ./cmd/syncv3
|
||||
|
||||
- name: Run end-to-end tests
|
||||
run: |
|
||||
set -euo pipefail
|
||||
./run-tests.sh -count=1 -v -json . 2>&1 | tee test-e2e-runner-${{env.PREV_VERSION}}.log | gotestfmt -hide all
|
||||
working-directory: tests-e2e
|
||||
shell: bash
|
||||
env:
|
||||
SYNCV3_DB: user=postgres dbname=syncv3 sslmode=disable password=postgres host=localhost
|
||||
SYNCV3_SERVER: http://localhost:8008
|
||||
SYNCV3_SECRET: itsasecret
|
||||
SYNCV3_MAX_DB_CONN: ${{ matrix.max_db_conns }}
|
||||
E2E_TEST_SERVER_STDOUT: test-e2e-server-${{env.PREV_VERSION}}.log
|
||||
|
||||
- name: Upload test log ${{env.PREV_VERSION}}
|
||||
uses: actions/upload-artifact@v3
|
||||
if: failure()
|
||||
with:
|
||||
name: E2E test logs upgrade ${{env.PREV_VERSION}}
|
||||
path: |
|
||||
./tests-e2e/test-e2e-runner-${{env.PREV_VERSION}}.log
|
||||
./tests-e2e/test-e2e-server-${{env.PREV_VERSION}}.log
|
||||
if-no-files-found: error
|
||||
|
||||
|
||||
# E2E tests with current commit
|
||||
- uses: actions/checkout@v3
|
||||
|
||||
- name: Set up gotestfmt
|
||||
uses: GoTestTools/gotestfmt-action@v2
|
||||
with:
|
||||
# Note: constrained to `packages:read` only at the top of the file
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Build
|
||||
run: go build ./cmd/syncv3
|
||||
|
||||
- name: Run end-to-end tests
|
||||
run: |
|
||||
set -euo pipefail
|
||||
./run-tests.sh -count=1 -v -json . 2>&1 | tee test-e2e-runner.log | gotestfmt -hide all
|
||||
working-directory: tests-e2e
|
||||
shell: bash
|
||||
env:
|
||||
SYNCV3_DB: user=postgres dbname=syncv3 sslmode=disable password=postgres host=localhost
|
||||
SYNCV3_SERVER: http://localhost:8008
|
||||
SYNCV3_SECRET: itsasecret
|
||||
SYNCV3_MAX_DB_CONN: ${{ matrix.max_db_conns }}
|
||||
E2E_TEST_SERVER_STDOUT: test-e2e-server.log
|
||||
|
||||
- name: Upload test log
|
||||
uses: actions/upload-artifact@v3
|
||||
if: failure()
|
||||
with:
|
||||
name: E2E test logs upgrade
|
||||
path: |
|
||||
./tests-e2e/test-e2e-runner.log
|
||||
./tests-e2e/test-e2e-server.log
|
||||
if-no-files-found: error
|
93
README.md
93
README.md
@ -2,7 +2,11 @@
|
||||
|
||||
Run a sliding sync proxy. An implementation of [MSC3575](https://github.com/matrix-org/matrix-doc/blob/kegan/sync-v3/proposals/3575-sync.md).
|
||||
|
||||
Proxy version to MSC API specification:
|
||||
## Proxy version to MSC API specification
|
||||
|
||||
This describes which proxy versions implement which version of the API drafted
|
||||
in MSC3575. See https://github.com/matrix-org/sliding-sync/releases for the
|
||||
changes in the proxy itself.
|
||||
|
||||
- Version 0.1.x: [2022/04/01](https://github.com/matrix-org/matrix-spec-proposals/blob/615e8f5a7bfe4da813bc2db661ed0bd00bccac20/proposals/3575-sync.md)
|
||||
- First release
|
||||
@ -21,29 +25,102 @@ Proxy version to MSC API specification:
|
||||
- Support for `errcode` when sessions expire.
|
||||
- Version 0.99.1 [2023/01/20](https://github.com/matrix-org/matrix-spec-proposals/blob/b4b4e7ff306920d2c862c6ff4d245110f6fa5bc7/proposals/3575-sync.md)
|
||||
- Preparing for major v1.x release: lists-as-keys support.
|
||||
- Version 0.99.2 [2024/07/27](https://github.com/matrix-org/matrix-spec-proposals/blob/eab643cb3ca63b03537a260fa343e1fb2d1ee284/proposals/3575-sync.md)
|
||||
- Version 0.99.2 [2023/03/31](https://github.com/matrix-org/matrix-spec-proposals/blob/eab643cb3ca63b03537a260fa343e1fb2d1ee284/proposals/3575-sync.md)
|
||||
- Experimental support for `bump_event_types` when ordering rooms by recency.
|
||||
- Support for opting in to extensions on a per-list and per-room basis.
|
||||
- Sentry support.
|
||||
- Version 0.99.3 [2023/05/23](https://github.com/matrix-org/matrix-spec-proposals/blob/4103ee768a4a3e1decee80c2987f50f4c6b3d539/proposals/3575-sync.md)
|
||||
- Support for per-list `bump_event_types`.
|
||||
- Support for [`conn_id`](https://github.com/matrix-org/matrix-spec-proposals/blob/4103ee768a4a3e1decee80c2987f50f4c6b3d539/proposals/3575-sync.md#concurrent-connections) for distinguishing multiple concurrent connections.
|
||||
- Version 0.99.4 [2023/07/12](https://github.com/matrix-org/matrix-spec-proposals/blob/4103ee768a4a3e1decee80c2987f50f4c6b3d539/proposals/3575-sync.md)
|
||||
- Support for `SYNCV3_MAX_DB_CONN`, and reduce the amount of concurrent connections required during normal operation.
|
||||
- Add more metrics and logs. Reduce log spam.
|
||||
- Improve performance when handling changed device lists.
|
||||
- Responses will consume from the live buffer even when clients change their request parameters to more speedily send new events down.
|
||||
- Bugfix: return `invited_count` correctly when it transitions to 0.
|
||||
- Bugfix: fix a data corruption bug when 2 users join a federated room where the first user was invited to said room.
|
||||
|
||||
## Usage
|
||||
|
||||
### Setup
|
||||
Requires Postgres 13+.
|
||||
|
||||
First, you must create a Postgres database and secret:
|
||||
```bash
|
||||
$ createdb syncv3
|
||||
$ echo -n "$(openssl rand -hex 32)" > .secret # this MUST remain the same throughout the lifetime of the database created above.
|
||||
```
|
||||
|
||||
Compiling from source and running:
|
||||
The Sliding Sync proxy requires some environment variables set to function. They are described when the proxy is run with missing variables.
|
||||
|
||||
Here is a short description of each, as of writing:
|
||||
```
|
||||
SYNCV3_SERVER Required. The destination homeserver to talk to (CS API HTTPS URL) e.g 'https://matrix-client.matrix.org'
|
||||
SYNCV3_DB Required. The postgres connection string: https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING
|
||||
SYNCV3_SECRET Required. A secret to use to encrypt access tokens. Must remain the same for the lifetime of the database.
|
||||
SYNCV3_BINDADDR Default: 0.0.0.0:8008. The interface and port to listen on.
|
||||
SYNCV3_TLS_CERT Default: unset. Path to a certificate file to serve to HTTPS clients. Specifying this enables TLS on the bound address.
|
||||
SYNCV3_TLS_KEY Default: unset. Path to a key file for the certificate. Must be provided along with the certificate file.
|
||||
SYNCV3_PPROF Default: unset. The bind addr for pprof debugging e.g ':6060'. If not set, does not listen.
|
||||
SYNCV3_PROM Default: unset. The bind addr for Prometheus metrics, which will be accessible at /metrics at this address.
|
||||
SYNCV3_JAEGER_URL Default: unset. The Jaeger URL to send spans to e.g http://localhost:14268/api/traces - if unset does not send OTLP traces.
|
||||
SYNCV3_SENTRY_DSN Default: unset. The Sentry DSN to report events to e.g https://sliding-sync@sentry.example.com/123 - if unset does not send sentry events.
|
||||
SYNCV3_LOG_LEVEL Default: info. The level of verbosity for messages logged. Available values are trace, debug, info, warn, error and fatal
|
||||
SYNCV3_MAX_DB_CONN Default: unset. Max database connections to use when communicating with postgres. Unset or 0 means no limit.
|
||||
```
|
||||
|
||||
It is easiest to host the proxy on a separate hostname than the Matrix server, though it is possible to use the same hostname by forwarding the used endpoints.
|
||||
|
||||
In both cases, the path `https://example.com/.well-known/matrix/client` must return a JSON with at least the following contents:
|
||||
```json
|
||||
{
|
||||
"m.server": {
|
||||
"base_url": "https://example.com"
|
||||
},
|
||||
"m.homeserver": {
|
||||
"base_url": "https://example.com"
|
||||
},
|
||||
"org.matrix.msc3575.proxy": {
|
||||
"url": "https://syncv3.example.com"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
#### Same hostname
|
||||
The following nginx configuration can be used to pass the required endpoints to the sync proxy, running on local port 8009 (so as to not conflict with Synapse):
|
||||
```nginx
|
||||
location ~ ^/(client/|_matrix/client/unstable/org.matrix.msc3575/sync) {
|
||||
proxy_pass http://localhost:8009;
|
||||
proxy_set_header X-Forwarded-For $remote_addr;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
proxy_set_header Host $host;
|
||||
}
|
||||
|
||||
location ~ ^(\/_matrix|\/_synapse\/client) {
|
||||
proxy_pass http://localhost:8008;
|
||||
proxy_set_header X-Forwarded-For $remote_addr;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
proxy_set_header Host $host;
|
||||
}
|
||||
|
||||
location /.well-known/matrix/client {
|
||||
add_header Access-Control-Allow-Origin *;
|
||||
}
|
||||
```
|
||||
|
||||
### Running
|
||||
There are two ways to run the proxy:
|
||||
- Compiling from source:
|
||||
```
|
||||
$ go build ./cmd/syncv3
|
||||
$ SYNCV3_SECRET=$(cat .secret) SYNCV3_SERVER="https://matrix-client.matrix.org" SYNCV3_DB="user=$(whoami) dbname=syncv3 sslmode=disable" SYNCV3_BINDADDR=0.0.0.0:8008 ./syncv3
|
||||
$ SYNCV3_SECRET=$(cat .secret) SYNCV3_SERVER="https://matrix-client.matrix.org" SYNCV3_DB="user=$(whoami) dbname=syncv3 sslmode=disable password='DATABASE_PASSWORD_HERE'" SYNCV3_BINDADDR=0.0.0.0:8008 ./syncv3
|
||||
```
|
||||
Using a Docker image:
|
||||
|
||||
- Using a Docker image:
|
||||
```
|
||||
docker run --rm -e "SYNCV3_SERVER=https://matrix-client.matrix.org" -e "SYNCV3_SECRET=$(cat .secret)" -e "SYNCV3_BINDADDR=:8008" -e "SYNCV3_DB=user=$(whoami) dbname=syncv3 sslmode=disable host=host.docker.internal" -p 8008:8008 ghcr.io/matrix-org/sliding-sync:latest
|
||||
docker run --rm -e "SYNCV3_SERVER=https://matrix-client.matrix.org" -e "SYNCV3_SECRET=$(cat .secret)" -e "SYNCV3_BINDADDR=:8008" -e "SYNCV3_DB=user=$(whoami) dbname=syncv3 sslmode=disable host=host.docker.internal password='DATABASE_PASSWORD_HERE'" -p 8008:8008 ghcr.io/matrix-org/sliding-sync:latest
|
||||
```
|
||||
|
||||
|
||||
Optionally also set `SYNCV3_TLS_CERT=path/to/cert.pem` and `SYNCV3_TLS_KEY=path/to/key.pem` to listen on HTTPS instead of HTTP.
|
||||
Make sure to tweak the `SYNCV3_DB` environment variable if the Postgres database isn't running on the host.
|
||||
|
||||
@ -157,4 +234,4 @@ Run end-to-end tests:
|
||||
# to ghcr and pull the image.
|
||||
docker run --rm -e "SYNAPSE_COMPLEMENT_DATABASE=sqlite" -e "SERVER_NAME=synapse" -p 8888:8008 ghcr.io/matrix-org/synapse-service:v1.72.0
|
||||
(go build ./cmd/syncv3 && dropdb syncv3_test && createdb syncv3_test && cd tests-e2e && ./run-tests.sh -count=1 .)
|
||||
```
|
||||
```
|
||||
|
@ -1,27 +1,36 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
_ "net/http/pprof"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/getsentry/sentry-go"
|
||||
sentryhttp "github.com/getsentry/sentry-go/http"
|
||||
syncv3 "github.com/matrix-org/sliding-sync"
|
||||
"github.com/matrix-org/sliding-sync/internal"
|
||||
"github.com/matrix-org/sliding-sync/sync2"
|
||||
"github.com/pressly/goose/v3"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
"github.com/rs/zerolog"
|
||||
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
|
||||
"net/http"
|
||||
_ "net/http/pprof"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
var GitCommit string
|
||||
|
||||
const version = "0.99.2"
|
||||
const version = "0.99.5"
|
||||
|
||||
var (
|
||||
flags = flag.NewFlagSet("goose", flag.ExitOnError)
|
||||
)
|
||||
|
||||
const (
|
||||
// Required fields
|
||||
@ -39,6 +48,7 @@ const (
|
||||
EnvJaeger = "SYNCV3_JAEGER_URL"
|
||||
EnvSentryDsn = "SYNCV3_SENTRY_DSN"
|
||||
EnvLogLevel = "SYNCV3_LOG_LEVEL"
|
||||
EnvMaxConns = "SYNCV3_MAX_DB_CONN"
|
||||
)
|
||||
|
||||
var helpMsg = fmt.Sprintf(`
|
||||
@ -54,7 +64,8 @@ Environment var
|
||||
%s Default: unset. The Jaeger URL to send spans to e.g http://localhost:14268/api/traces - if unset does not send OTLP traces.
|
||||
%s Default: unset. The Sentry DSN to report events to e.g https://sliding-sync@sentry.example.com/123 - if unset does not send sentry events.
|
||||
%s Default: info. The level of verbosity for messages logged. Available values are trace, debug, info, warn, error and fatal
|
||||
`, EnvServer, EnvDB, EnvSecret, EnvBindAddr, EnvTLSCert, EnvTLSKey, EnvPPROF, EnvPrometheus, EnvJaeger, EnvSentryDsn, EnvLogLevel)
|
||||
%s Default: unset. Max database connections to use when communicating with postgres. Unset or 0 means no limit.
|
||||
`, EnvServer, EnvDB, EnvSecret, EnvBindAddr, EnvTLSCert, EnvTLSKey, EnvPPROF, EnvPrometheus, EnvJaeger, EnvSentryDsn, EnvLogLevel, EnvMaxConns)
|
||||
|
||||
func defaulting(in, dft string) string {
|
||||
if in == "" {
|
||||
@ -67,6 +78,12 @@ func main() {
|
||||
fmt.Printf("Sync v3 [%s] (%s)\n", version, GitCommit)
|
||||
sync2.ProxyVersion = version
|
||||
syncv3.Version = fmt.Sprintf("%s (%s)", version, GitCommit)
|
||||
|
||||
if len(os.Args) > 1 && os.Args[1] == "migrate" {
|
||||
executeMigrations()
|
||||
return
|
||||
}
|
||||
|
||||
args := map[string]string{
|
||||
EnvServer: os.Getenv(EnvServer),
|
||||
EnvDB: os.Getenv(EnvDB),
|
||||
@ -80,6 +97,7 @@ func main() {
|
||||
EnvJaeger: os.Getenv(EnvJaeger),
|
||||
EnvSentryDsn: os.Getenv(EnvSentryDsn),
|
||||
EnvLogLevel: os.Getenv(EnvLogLevel),
|
||||
EnvMaxConns: defaulting(os.Getenv(EnvMaxConns), "0"),
|
||||
}
|
||||
requiredEnvVars := []string{EnvServer, EnvDB, EnvSecret, EnvBindAddr}
|
||||
for _, requiredEnvVar := range requiredEnvVars {
|
||||
@ -135,6 +153,8 @@ func main() {
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Printf("Debug=%v LogLevel=%v MaxConns=%v\n", args[EnvDebug] == "1", args[EnvLogLevel], args[EnvMaxConns])
|
||||
|
||||
if args[EnvDebug] == "1" {
|
||||
zerolog.SetGlobalLevel(zerolog.TraceLevel)
|
||||
} else {
|
||||
@ -161,8 +181,15 @@ func main() {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
maxConnsInt, err := strconv.Atoi(args[EnvMaxConns])
|
||||
if err != nil {
|
||||
panic("invalid value for " + EnvMaxConns + ": " + args[EnvMaxConns])
|
||||
}
|
||||
h2, h3 := syncv3.Setup(args[EnvServer], args[EnvDB], args[EnvSecret], syncv3.Opts{
|
||||
AddPrometheusMetrics: args[EnvPrometheus] != "",
|
||||
AddPrometheusMetrics: args[EnvPrometheus] != "",
|
||||
DBMaxConns: maxConnsInt,
|
||||
DBConnMaxIdleTime: time.Hour,
|
||||
MaxTransactionIDDelay: time.Second,
|
||||
})
|
||||
|
||||
go h2.StartV2Pollers()
|
||||
@ -203,3 +230,49 @@ func WaitForShutdown(sentryInUse bool) {
|
||||
|
||||
fmt.Printf("Exiting now")
|
||||
}
|
||||
|
||||
func executeMigrations() {
|
||||
envArgs := map[string]string{
|
||||
EnvDB: os.Getenv(EnvDB),
|
||||
}
|
||||
requiredEnvVars := []string{EnvDB}
|
||||
for _, requiredEnvVar := range requiredEnvVars {
|
||||
if envArgs[requiredEnvVar] == "" {
|
||||
fmt.Print(helpMsg)
|
||||
fmt.Printf("\n%s is not set", requiredEnvVar)
|
||||
fmt.Printf("\n%s must be set\n", strings.Join(requiredEnvVars, ", "))
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
flags.Parse(os.Args[1:])
|
||||
args := flags.Args()
|
||||
|
||||
if len(args) < 2 {
|
||||
flags.Usage()
|
||||
return
|
||||
}
|
||||
|
||||
command := args[1]
|
||||
|
||||
db, err := goose.OpenDBWithDriver("postgres", envArgs[EnvDB])
|
||||
if err != nil {
|
||||
log.Fatalf("goose: failed to open DB: %v\n", err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err := db.Close(); err != nil {
|
||||
log.Fatalf("goose: failed to close DB: %v\n", err)
|
||||
}
|
||||
}()
|
||||
|
||||
arguments := []string{}
|
||||
if len(args) > 2 {
|
||||
arguments = append(arguments, args[2:]...)
|
||||
}
|
||||
|
||||
goose.SetBaseFS(syncv3.EmbedMigrations)
|
||||
if err := goose.Run(command, db, "state/migrations", arguments...); err != nil {
|
||||
log.Fatalf("goose %v: %v", command, err)
|
||||
}
|
||||
}
|
||||
|
34
go.mod
34
go.mod
@ -7,44 +7,46 @@ require (
|
||||
github.com/getsentry/sentry-go v0.20.0
|
||||
github.com/go-logr/zerologr v1.2.3
|
||||
github.com/gorilla/mux v1.8.0
|
||||
github.com/hashicorp/golang-lru v0.5.4
|
||||
github.com/jmoiron/sqlx v1.3.3
|
||||
github.com/lib/pq v1.10.1
|
||||
github.com/lib/pq v1.10.9
|
||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20230105074811-965b10ae73ab
|
||||
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4
|
||||
github.com/pressly/goose/v3 v3.14.0
|
||||
github.com/prometheus/client_golang v1.13.0
|
||||
github.com/rs/zerolog v1.29.0
|
||||
github.com/tidwall/gjson v1.14.3
|
||||
github.com/tidwall/sjson v1.2.5
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.39.0
|
||||
go.opentelemetry.io/otel v1.13.0
|
||||
go.opentelemetry.io/otel/exporters/jaeger v1.13.0
|
||||
go.opentelemetry.io/otel/sdk v1.13.0
|
||||
go.opentelemetry.io/otel/trace v1.13.0
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.42.0
|
||||
go.opentelemetry.io/otel v1.16.0
|
||||
go.opentelemetry.io/otel/exporters/jaeger v1.16.0
|
||||
go.opentelemetry.io/otel/sdk v1.16.0
|
||||
go.opentelemetry.io/otel/trace v1.16.0
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/beorn7/perks v1.0.1 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.1.2 // indirect
|
||||
github.com/felixge/httpsnoop v1.0.3 // indirect
|
||||
github.com/go-logr/logr v1.2.3 // indirect
|
||||
github.com/fxamacker/cbor/v2 v2.5.0 // indirect
|
||||
github.com/go-logr/logr v1.2.4 // indirect
|
||||
github.com/go-logr/stdr v1.2.2 // indirect
|
||||
github.com/golang/protobuf v1.5.2 // indirect
|
||||
github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 // indirect
|
||||
github.com/mattn/go-colorable v0.1.13 // indirect
|
||||
github.com/mattn/go-isatty v0.0.17 // indirect
|
||||
github.com/mattn/go-isatty v0.0.19 // indirect
|
||||
github.com/matttproud/golang_protobuf_extensions v1.0.1 // indirect
|
||||
github.com/prometheus/client_model v0.2.0 // indirect
|
||||
github.com/prometheus/common v0.37.0 // indirect
|
||||
github.com/prometheus/procfs v0.8.0 // indirect
|
||||
github.com/prometheus/procfs v0.11.0 // indirect
|
||||
github.com/rs/xid v1.4.0 // indirect
|
||||
github.com/sirupsen/logrus v1.9.0 // indirect
|
||||
github.com/sirupsen/logrus v1.9.3 // indirect
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
github.com/tidwall/pretty v1.2.0 // indirect
|
||||
go.opentelemetry.io/otel/metric v0.36.0 // indirect
|
||||
golang.org/x/crypto v0.7.0 // indirect
|
||||
golang.org/x/sync v0.0.0-20220907140024-f12130a52804 // indirect
|
||||
golang.org/x/sys v0.6.0 // indirect
|
||||
golang.org/x/text v0.8.0 // indirect
|
||||
github.com/x448/float16 v0.8.4 // indirect
|
||||
go.opentelemetry.io/otel/metric v1.16.0 // indirect
|
||||
golang.org/x/crypto v0.10.0 // indirect
|
||||
golang.org/x/sync v0.2.0 // indirect
|
||||
golang.org/x/sys v0.10.0 // indirect
|
||||
golang.org/x/text v0.11.0 // indirect
|
||||
google.golang.org/protobuf v1.29.1 // indirect
|
||||
)
|
||||
|
91
go.sum
91
go.sum
@ -57,12 +57,15 @@ github.com/coreos/go-systemd/v22 v22.3.3-0.20220203105225-a9a7ef127534/go.mod h1
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
||||
github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
|
||||
github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
|
||||
github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98=
|
||||
github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c=
|
||||
github.com/felixge/httpsnoop v1.0.3 h1:s/nj+GCswXYzN5v2DpNMuMQYe+0DDwt5WVCU6CWBdXk=
|
||||
github.com/felixge/httpsnoop v1.0.3/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
|
||||
github.com/fxamacker/cbor/v2 v2.5.0 h1:oHsG0V/Q6E/wqTS2O1Cozzsy69nqCiguo5Q1a1ADivE=
|
||||
github.com/fxamacker/cbor/v2 v2.5.0/go.mod h1:TA1xS00nchWmaBnEIxPSE5oHLuJBAVvqrtAnWBwBCVo=
|
||||
github.com/getsentry/sentry-go v0.20.0 h1:bwXW98iMRIWxn+4FgPW7vMrjmbym6HblXALmhjHmQaQ=
|
||||
github.com/getsentry/sentry-go v0.20.0/go.mod h1:lc76E2QywIyW8WuBnwl8Lc4bkmQH4+w1gwTf25trprY=
|
||||
github.com/go-errors/errors v1.4.2 h1:J6MZopCL4uSllY1OfXM374weqZFFItUbrImctkmUxIA=
|
||||
@ -78,14 +81,14 @@ github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V
|
||||
github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG1KdI/P7A=
|
||||
github.com/go-logfmt/logfmt v0.5.1/go.mod h1:WYhtIu8zTZfxdn5+rREduYbwxfcBr/Vr6KEVveWlfTs=
|
||||
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
|
||||
github.com/go-logr/logr v1.2.3 h1:2DntVwHkVopvECVRSlL5PSo9eG+cAkDCuckLubN+rq0=
|
||||
github.com/go-logr/logr v1.2.3/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
|
||||
github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ=
|
||||
github.com/go-logr/logr v1.2.4/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
|
||||
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
|
||||
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
|
||||
github.com/go-logr/zerologr v1.2.3 h1:up5N9vcH9Xck3jJkXzgyOxozT14R47IyDODz8LM1KSs=
|
||||
github.com/go-logr/zerologr v1.2.3/go.mod h1:BxwGo7y5zgSHYR1BjbnHPyF/5ZjVKfKxAZANVu6E8Ho=
|
||||
github.com/go-sql-driver/mysql v1.5.0 h1:ozyZYNQW3x3HtqT1jira07DN2PArx2v7/mN66gGcHOs=
|
||||
github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg=
|
||||
github.com/go-sql-driver/mysql v1.7.1 h1:lUIinVbN1DY0xBg0eMOzmmtGoHwWBbvnWubQUrtU8EI=
|
||||
github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY=
|
||||
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
|
||||
github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ=
|
||||
@ -140,6 +143,7 @@ github.com/google/pprof v0.0.0-20200229191704-1ebb73c60ed3/go.mod h1:ZgVRPoUq/hf
|
||||
github.com/google/pprof v0.0.0-20200430221834-fc25d7d30c6d/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM=
|
||||
github.com/google/pprof v0.0.0-20200708004538-1a94d8640e99/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM=
|
||||
github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI=
|
||||
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
|
||||
github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg=
|
||||
github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk=
|
||||
github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI=
|
||||
@ -147,8 +151,6 @@ github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB7
|
||||
github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542 h1:2VTzZjLZBgl62/EtslCrtky5vbi9dd7HrQPQIx6wqiw=
|
||||
github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8=
|
||||
github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8=
|
||||
github.com/hashicorp/golang-lru v0.5.4 h1:YDjusn29QI/Das2iO9M0BHnIbxPeyuCHsjMW+lJfyTc=
|
||||
github.com/hashicorp/golang-lru v0.5.4/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4=
|
||||
github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc=
|
||||
github.com/jmoiron/sqlx v1.3.3 h1:j82X0bf7oQ27XeqxicSZsTU5suPwKElg3oyxNn43iTk=
|
||||
github.com/jmoiron/sqlx v1.3.3/go.mod h1:2BljVx/86SuTyjE+aPYlHCTNvZrnJXghYGpNiXLBMCQ=
|
||||
@ -161,6 +163,7 @@ github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1
|
||||
github.com/jstemmer/go-junit-report v0.9.1/go.mod h1:Brl9GWCQeLvo8nXZwPNNblvFj/XSXhF0NWZEnDohbsk=
|
||||
github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w=
|
||||
github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM=
|
||||
github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 h1:Z9n2FFNUXsshfwJMBgNA0RU6/i7WVaAegv3PtuIHPMs=
|
||||
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
|
||||
github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
|
||||
github.com/konsorten/go-windows-terminal-sequences v1.0.3/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
|
||||
@ -169,8 +172,8 @@ github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORN
|
||||
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
|
||||
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
||||
github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
|
||||
github.com/lib/pq v1.10.1 h1:6VXZrLU0jHBYyAqrSPa+MgPfnSvTPuMgK+k0o5kVFWo=
|
||||
github.com/lib/pq v1.10.1/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
|
||||
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
|
||||
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
|
||||
github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 h1:ZtO5uywdd5dLDCud4r0r55eP4j9FuUNpl60Gmntcop4=
|
||||
github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s=
|
||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20230105074811-965b10ae73ab h1:ChaQdT2mpxMm3GRXNOZzLDQ/wOnlKZ8o60LmZGOjdj8=
|
||||
@ -182,8 +185,8 @@ github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxec
|
||||
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
|
||||
github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94=
|
||||
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
|
||||
github.com/mattn/go-isatty v0.0.17 h1:BTarxUcIeDqL27Mc+vyvdWYSL28zpIhv3RoTdsLMPng=
|
||||
github.com/mattn/go-isatty v0.0.17/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
|
||||
github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
|
||||
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-sqlite3 v1.14.6 h1:dNPt6NO46WmLVt2DLNpwczCmdV5boIZ6g/tlDrlRUbg=
|
||||
github.com/mattn/go-sqlite3 v1.14.6/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU=
|
||||
github.com/matttproud/golang_protobuf_extensions v1.0.1 h1:4hp9jkHxhMHkqkrB3Ix0jegS5sx/RkqARlsWZ6pIwiU=
|
||||
@ -203,6 +206,8 @@ github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/pressly/goose/v3 v3.14.0 h1:gNrFLLDF+fujdq394rcdYK3WPxp3VKWifTajlZwInJM=
|
||||
github.com/pressly/goose/v3 v3.14.0/go.mod h1:uwSpREK867PbIsdE9GS6pRk1LUPB7gwMkmvk9/hbIMA=
|
||||
github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw=
|
||||
github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo=
|
||||
github.com/prometheus/client_golang v1.7.1/go.mod h1:PY5Wy2awLA44sXw4AOSfFBetzPP4j5+D6mVACh+pe2M=
|
||||
@ -226,8 +231,9 @@ github.com/prometheus/procfs v0.0.2/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsT
|
||||
github.com/prometheus/procfs v0.1.3/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4OA4YeYWdaU=
|
||||
github.com/prometheus/procfs v0.6.0/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA=
|
||||
github.com/prometheus/procfs v0.7.3/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA=
|
||||
github.com/prometheus/procfs v0.8.0 h1:ODq8ZFEaYeCaZOJlZZdJA2AbQR98dSHSM1KW/You5mo=
|
||||
github.com/prometheus/procfs v0.8.0/go.mod h1:z7EfXMXOkbkqb9IINtpCn86r/to3BnA0uaxHdg830/4=
|
||||
github.com/prometheus/procfs v0.11.0 h1:5EAgkfkMl659uZPbe9AS2N68a7Cc1TJbPEuGzFuRbyk=
|
||||
github.com/prometheus/procfs v0.11.0/go.mod h1:nwNm2aOCAYw8uTR/9bWRREkZFxAUcWzPHWJq+XBB/FM=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
|
||||
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
|
||||
github.com/rs/xid v1.4.0 h1:qd7wPTDkN6KQx2VmMBLrpHkiyQwgFXRnkOLacUiaSNY=
|
||||
github.com/rs/xid v1.4.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
|
||||
@ -236,8 +242,8 @@ github.com/rs/zerolog v1.29.0/go.mod h1:NILgTygv/Uej1ra5XxGf82ZFSLk58MFGAUS2o6us
|
||||
github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo=
|
||||
github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE=
|
||||
github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88=
|
||||
github.com/sirupsen/logrus v1.9.0 h1:trlNQbNUG3OdDrDil03MCb1H2o9nJ1x4/5LYw7byDE0=
|
||||
github.com/sirupsen/logrus v1.9.0/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
|
||||
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
|
||||
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c=
|
||||
@ -245,7 +251,7 @@ github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXf
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8=
|
||||
github.com/stretchr/testify v1.8.3 h1:RP3t2pwF7cMEbC1dqtB6poj3niw/9gnV4Cjg5oW5gtY=
|
||||
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/gjson v1.14.3 h1:9jvXn7olKEHU1S9vwoMGliaT8jq1vJ7IH/n9zD9Dnlw=
|
||||
github.com/tidwall/gjson v1.14.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
@ -255,6 +261,8 @@ github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
|
||||
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
||||
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
||||
github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=
|
||||
github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg=
|
||||
github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||
github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||
@ -264,18 +272,18 @@ go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8=
|
||||
go.opencensus.io v0.22.2/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=
|
||||
go.opencensus.io v0.22.3/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=
|
||||
go.opencensus.io v0.22.4/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.39.0 h1:vFEBG7SieZJzvnRWQ81jxpuEqe6J8Ex+hgc9CqOTzHc=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.39.0/go.mod h1:9rgTcOKdIhDOC0IcAu8a+R+FChqSUBihKpM1lVNi6T0=
|
||||
go.opentelemetry.io/otel v1.13.0 h1:1ZAKnNQKwBBxFtww/GwxNUyTf0AxkZzrukO8MeXqe4Y=
|
||||
go.opentelemetry.io/otel v1.13.0/go.mod h1:FH3RtdZCzRkJYFTCsAKDy9l/XYjMdNv6QrkFFB8DvVg=
|
||||
go.opentelemetry.io/otel/exporters/jaeger v1.13.0 h1:VAMoGujbVV8Q0JNM/cEbhzUIWWBxnEqH45HP9iBKN04=
|
||||
go.opentelemetry.io/otel/exporters/jaeger v1.13.0/go.mod h1:fHwbmle6mBFJA1p2ZIhilvffCdq/dM5UTIiCOmEjS+w=
|
||||
go.opentelemetry.io/otel/metric v0.36.0 h1:t0lgGI+L68QWt3QtOIlqM9gXoxqxWLhZ3R/e5oOAY0Q=
|
||||
go.opentelemetry.io/otel/metric v0.36.0/go.mod h1:wKVw57sd2HdSZAzyfOM9gTqqE8v7CbqWsYL6AyrH9qk=
|
||||
go.opentelemetry.io/otel/sdk v1.13.0 h1:BHib5g8MvdqS65yo2vV1s6Le42Hm6rrw08qU6yz5JaM=
|
||||
go.opentelemetry.io/otel/sdk v1.13.0/go.mod h1:YLKPx5+6Vx/o1TCUYYs+bpymtkmazOMT6zoRrC7AQ7I=
|
||||
go.opentelemetry.io/otel/trace v1.13.0 h1:CBgRZ6ntv+Amuj1jDsMhZtlAPT6gbyIRdaIzFhfBSdY=
|
||||
go.opentelemetry.io/otel/trace v1.13.0/go.mod h1:muCvmmO9KKpvuXSf3KKAXXB2ygNYHQ+ZfI5X08d3tds=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.42.0 h1:pginetY7+onl4qN1vl0xW/V/v6OBZ0vVdH+esuJgvmM=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.42.0/go.mod h1:XiYsayHc36K3EByOO6nbAXnAWbrUxdjUROCEeeROOH8=
|
||||
go.opentelemetry.io/otel v1.16.0 h1:Z7GVAX/UkAXPKsy94IU+i6thsQS4nb7LviLpnaNeW8s=
|
||||
go.opentelemetry.io/otel v1.16.0/go.mod h1:vl0h9NUa1D5s1nv3A5vZOYWn8av4K8Ml6JDeHrT/bx4=
|
||||
go.opentelemetry.io/otel/exporters/jaeger v1.16.0 h1:YhxxmXZ011C0aDZKoNw+juVWAmEfv/0W2XBOv9aHTaA=
|
||||
go.opentelemetry.io/otel/exporters/jaeger v1.16.0/go.mod h1:grYbBo/5afWlPpdPZYhyn78Bk04hnvxn2+hvxQhKIQM=
|
||||
go.opentelemetry.io/otel/metric v1.16.0 h1:RbrpwVG1Hfv85LgnZ7+txXioPDoh6EdbZHo26Q3hqOo=
|
||||
go.opentelemetry.io/otel/metric v1.16.0/go.mod h1:QE47cpOmkwipPiefDwo2wDzwJrlfxxNYodqc4xnGCo4=
|
||||
go.opentelemetry.io/otel/sdk v1.16.0 h1:Z1Ok1YsijYL0CSJpHt4cS3wDDh7p572grzNrBMiMWgE=
|
||||
go.opentelemetry.io/otel/sdk v1.16.0/go.mod h1:tMsIuKXuuIWPBAOrH+eHtvhTL+SntFtXF9QD68aP6p4=
|
||||
go.opentelemetry.io/otel/trace v1.16.0 h1:8JRpaObFoW0pxuVPapkgH8UhHQj+bJW8jJsCZEu5MQs=
|
||||
go.opentelemetry.io/otel/trace v1.16.0/go.mod h1:Yt9vYq1SdNz3xdjZZK7wcXv1qv2pwLkqr2QVwea0ef0=
|
||||
go.uber.org/goleak v1.1.10 h1:z+mqJhf6ss6BSfSM671tgKyZBFPTTJM+HLxnhPC3wu0=
|
||||
go.uber.org/goleak v1.1.10/go.mod h1:8a7PlsEVH3e/a/GLqe5IIrQx6GzcnRmZEufDUTk4A7A=
|
||||
golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
|
||||
@ -284,8 +292,8 @@ golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8U
|
||||
golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||
golang.org/x/crypto v0.7.0 h1:AvwMYaRytfdeVt3u6mLaxYtErKYjxA2OXjJ1HHq6t3A=
|
||||
golang.org/x/crypto v0.7.0/go.mod h1:pYwdfH91IfpZVANVyUOhSIPZaFoJGxTFbZhFTx+dXZU=
|
||||
golang.org/x/crypto v0.10.0 h1:LKqV2xt9+kDzSTfOhx4FrkEBcMrAgHSYgzywV9zcGmM=
|
||||
golang.org/x/crypto v0.10.0/go.mod h1:o4eNf7Ede1fv+hwOwZsTHl9EsPFO6q6ZvYR8vYfY45I=
|
||||
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8=
|
||||
@ -318,7 +326,7 @@ golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzB
|
||||
golang.org/x/mod v0.1.1-0.20191107180719-034126e5016b/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
|
||||
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
golang.org/x/mod v0.8.0 h1:LUYupSeNrTNCGzR/hVBk2NHZO4hXcVaW1k4Qx7rjPx8=
|
||||
golang.org/x/mod v0.12.0 h1:rmsUpXtvNzj340zd98LZ4KntptpfRHwpFOHG188oHXc=
|
||||
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
@ -351,7 +359,7 @@ golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwY
|
||||
golang.org/x/net v0.0.0-20210525063256-abc453219eb5/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||
golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk=
|
||||
golang.org/x/net v0.0.0-20220225172249-27dd8689420f/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk=
|
||||
golang.org/x/net v0.8.0 h1:Zrh2ngAOFYneWTAIAPethzeaQLuHwhuBkuV6ZiRnUaQ=
|
||||
golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M=
|
||||
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
||||
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
||||
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
||||
@ -370,8 +378,8 @@ golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJ
|
||||
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20220907140024-f12130a52804 h1:0SH2R3f1b1VmIMG7BXbEZCBUu2dKmHschSmjqGUrW8A=
|
||||
golang.org/x/sync v0.0.0-20220907140024-f12130a52804/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.2.0 h1:PUR+T4wwASmuSTYdKjYHI5TD22Wy5ogLU5qZCOLxBrI=
|
||||
golang.org/x/sync v0.2.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
@ -415,8 +423,9 @@ golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBc
|
||||
golang.org/x/sys v0.0.0-20220114195835-da31bd327af9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0 h1:MVltZSvRTcU2ljQOhs94SXPftV6DCNnZViHeQps87pQ=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.10.0 h1:SqMFp9UcQJZa+pmYuAKjd9xq1f0j5rLcDIk0mj4qAsA=
|
||||
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
||||
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
@ -426,8 +435,8 @@ golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
|
||||
golang.org/x/text v0.8.0 h1:57P1ETyNKtuIjB4SRd15iJxuhj8Gc416Y78H3qgMh68=
|
||||
golang.org/x/text v0.8.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
|
||||
golang.org/x/text v0.11.0 h1:LAntKIrcmeSKERyiOh0XMV39LXS8IE9UL2yP7+f5ij4=
|
||||
golang.org/x/text v0.11.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
|
||||
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||
@ -473,7 +482,7 @@ golang.org/x/tools v0.0.0-20200729194436-6467de6f59a7/go.mod h1:njjCfa9FT2d7l9Bc
|
||||
golang.org/x/tools v0.0.0-20200804011535-6c149bb5ef0d/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA=
|
||||
golang.org/x/tools v0.0.0-20200825202427-b303f430e36d/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA=
|
||||
golang.org/x/tools v0.0.0-20210112230658-8b4aab62c064/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
|
||||
golang.org/x/tools v0.6.0 h1:BOw41kyTf3PuCW1pVQf8+Cyg8pMlkYB1oo9iJ6D/lKM=
|
||||
golang.org/x/tools v0.10.0 h1:tvDr/iQoUqNdohiYm0LmmKcBk+q86lb9EprIUFhHHGg=
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
@ -577,6 +586,16 @@ honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWh
|
||||
honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg=
|
||||
honnef.co/go/tools v0.0.1-2020.1.3/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k=
|
||||
honnef.co/go/tools v0.0.1-2020.1.4/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k=
|
||||
lukechampine.com/uint128 v1.3.0 h1:cDdUVfRwDUDovz610ABgFD17nXD4/uDgVHl2sC3+sbo=
|
||||
modernc.org/cc/v3 v3.41.0 h1:QoR1Sn3YWlmA1T4vLaKZfawdVtSiGx8H+cEojbC7v1Q=
|
||||
modernc.org/ccgo/v3 v3.16.14 h1:af6KNtFgsVmnDYrWk3PQCS9XT6BXe7o3ZFJKkIKvXNQ=
|
||||
modernc.org/libc v1.24.1 h1:uvJSeCKL/AgzBo2yYIPPTy82v21KgGnizcGYfBHaNuM=
|
||||
modernc.org/mathutil v1.6.0 h1:fRe9+AmYlaej+64JsEEhoWuAYBkOtQiMEU7n/XgfYi4=
|
||||
modernc.org/memory v1.6.0 h1:i6mzavxrE9a30whzMfwf7XWVODx2r5OYXvU46cirX7o=
|
||||
modernc.org/opt v0.1.3 h1:3XOZf2yznlhC+ibLltsDGzABUGVx8J6pnFMS3E4dcq4=
|
||||
modernc.org/sqlite v1.24.0 h1:EsClRIWHGhLTCX44p+Ri/JLD+vFGo0QGjasg2/F9TlI=
|
||||
modernc.org/strutil v1.1.3 h1:fNMm+oJklMGYfU9Ylcywl0CO5O6nTfaowNsh2wpPjzY=
|
||||
modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y=
|
||||
rsc.io/binaryregexp v0.2.0/go.mod h1:qTv7/COck+e2FymRvadv62gMdZztPaShugOCi3I+8D8=
|
||||
rsc.io/quote/v3 v3.1.0/go.mod h1:yEA65RcK8LyAZtP9Kv3t0HmxON59tX3rD+tICJqUlj0=
|
||||
rsc.io/sampler v1.3.0/go.mod h1:T1hPZKmBbMNahiBKFy5HrXp6adAjACjK9JXDnKaTXpA=
|
||||
|
@ -2,6 +2,8 @@ package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/getsentry/sentry-go"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
@ -16,6 +18,9 @@ var (
|
||||
// logging metadata for a single request
|
||||
type data struct {
|
||||
userID string
|
||||
deviceID string
|
||||
bufferSummary string
|
||||
connID string
|
||||
since int64
|
||||
next int64
|
||||
numRooms int
|
||||
@ -24,6 +29,9 @@ type data struct {
|
||||
numGlobalAccountData int
|
||||
numChangedDevices int
|
||||
numLeftDevices int
|
||||
numLists int
|
||||
roomSubs int
|
||||
roomUnsubs int
|
||||
}
|
||||
|
||||
// prepare a request context so it can contain syncv3 info
|
||||
@ -37,13 +45,14 @@ func RequestContext(ctx context.Context) context.Context {
|
||||
}
|
||||
|
||||
// add the user ID to this request context. Need to have called RequestContext first.
|
||||
func SetRequestContextUserID(ctx context.Context, userID string) {
|
||||
func SetRequestContextUserID(ctx context.Context, userID, deviceID string) {
|
||||
d := ctx.Value(ctxData)
|
||||
if d == nil {
|
||||
return
|
||||
}
|
||||
da := d.(*data)
|
||||
da.userID = userID
|
||||
da.deviceID = deviceID
|
||||
if hub := sentry.GetHubFromContext(ctx); hub != nil {
|
||||
sentry.ConfigureScope(func(scope *sentry.Scope) {
|
||||
scope.SetUser(sentry.User{Username: userID})
|
||||
@ -51,9 +60,18 @@ func SetRequestContextUserID(ctx context.Context, userID string) {
|
||||
}
|
||||
}
|
||||
|
||||
func SetConnBufferInfo(ctx context.Context, bufferLen, nextLen, bufferCap int) {
|
||||
d := ctx.Value(ctxData)
|
||||
if d == nil {
|
||||
return
|
||||
}
|
||||
da := d.(*data)
|
||||
da.bufferSummary = fmt.Sprintf("%d/%d/%d", bufferLen, nextLen, bufferCap)
|
||||
}
|
||||
|
||||
func SetRequestContextResponseInfo(
|
||||
ctx context.Context, since, next int64, numRooms int, txnID string, numToDeviceEvents, numGlobalAccountData int,
|
||||
numChangedDevices, numLeftDevices int,
|
||||
numChangedDevices, numLeftDevices int, connID string, numLists int, roomSubs, roomUnsubs int,
|
||||
) {
|
||||
d := ctx.Value(ctxData)
|
||||
if d == nil {
|
||||
@ -68,6 +86,10 @@ func SetRequestContextResponseInfo(
|
||||
da.numGlobalAccountData = numGlobalAccountData
|
||||
da.numChangedDevices = numChangedDevices
|
||||
da.numLeftDevices = numLeftDevices
|
||||
da.connID = connID
|
||||
da.numLists = numLists
|
||||
da.roomSubs = roomSubs
|
||||
da.roomUnsubs = roomUnsubs
|
||||
}
|
||||
|
||||
func DecorateLogger(ctx context.Context, l *zerolog.Event) *zerolog.Event {
|
||||
@ -79,6 +101,9 @@ func DecorateLogger(ctx context.Context, l *zerolog.Event) *zerolog.Event {
|
||||
if da.userID != "" {
|
||||
l = l.Str("u", da.userID)
|
||||
}
|
||||
if da.deviceID != "" {
|
||||
l = l.Str("dev", da.deviceID)
|
||||
}
|
||||
if da.since >= 0 {
|
||||
l = l.Int64("p", da.since)
|
||||
}
|
||||
@ -103,5 +128,19 @@ func DecorateLogger(ctx context.Context, l *zerolog.Event) *zerolog.Event {
|
||||
if da.numLeftDevices > 0 {
|
||||
l = l.Int("dl-l", da.numLeftDevices)
|
||||
}
|
||||
if da.bufferSummary != "" {
|
||||
l = l.Str("b", da.bufferSummary)
|
||||
}
|
||||
if da.roomSubs > 0 {
|
||||
l = l.Int("sub", da.roomSubs)
|
||||
}
|
||||
if da.roomUnsubs > 0 {
|
||||
l = l.Int("usub", da.roomUnsubs)
|
||||
}
|
||||
if da.numLists > 0 {
|
||||
l = l.Int("l", da.numLists)
|
||||
}
|
||||
// always log the connection ID so we know when it isn't set
|
||||
l = l.Str("c", da.connID)
|
||||
return l
|
||||
}
|
||||
|
@ -23,7 +23,7 @@ func isBitSet(n int, bit int) bool {
|
||||
type DeviceData struct {
|
||||
// Contains the latest device_one_time_keys_count values.
|
||||
// Set whenever this field arrives down the v2 poller, and it replaces what was previously there.
|
||||
OTKCounts map[string]int `json:"otk"`
|
||||
OTKCounts MapStringInt `json:"otk"`
|
||||
// Contains the latest device_unused_fallback_key_types value
|
||||
// Set whenever this field arrives down the v2 poller, and it replaces what was previously there.
|
||||
FallbackKeyTypes []string `json:"fallback"`
|
||||
|
@ -1,5 +1,10 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
const (
|
||||
DeviceListChanged = 1
|
||||
DeviceListLeft = 2
|
||||
@ -7,8 +12,19 @@ const (
|
||||
|
||||
type DeviceLists struct {
|
||||
// map user_id -> DeviceList enum
|
||||
New map[string]int `json:"n"`
|
||||
Sent map[string]int `json:"s"`
|
||||
New MapStringInt `json:"n"`
|
||||
Sent MapStringInt `json:"s"`
|
||||
}
|
||||
|
||||
type MapStringInt map[string]int
|
||||
|
||||
// Value implements driver.Valuer
|
||||
func (dl MapStringInt) Value() (driver.Value, error) {
|
||||
if len(dl) == 0 {
|
||||
return "{}", nil
|
||||
}
|
||||
v, err := json.Marshal(dl)
|
||||
return v, err
|
||||
}
|
||||
|
||||
func (dl DeviceLists) Combine(newer DeviceLists) DeviceLists {
|
||||
|
@ -71,10 +71,18 @@ func ExpiredSessionError() *HandlerError {
|
||||
// Which then produces:
|
||||
//
|
||||
// assertion failed: list is not empty
|
||||
func Assert(msg string, expr bool) {
|
||||
//
|
||||
// An optional debugContext map can be provided. If it is present and sentry is configured,
|
||||
// it is added as context to the sentry events generated for failed assertions.
|
||||
func Assert(msg string, expr bool, debugContext ...map[string]interface{}) {
|
||||
assert(msg, expr)
|
||||
if !expr {
|
||||
sentry.CaptureException(fmt.Errorf("assertion failed: %s", msg))
|
||||
sentry.WithScope(func(scope *sentry.Scope) {
|
||||
if len(debugContext) > 0 {
|
||||
scope.SetContext(SentryCtxKey, debugContext[0])
|
||||
}
|
||||
sentry.CaptureException(fmt.Errorf("assertion failed: %s", msg))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
67
internal/pool.go
Normal file
67
internal/pool.go
Normal file
@ -0,0 +1,67 @@
|
||||
package internal
|
||||
|
||||
type WorkerPool struct {
|
||||
N int
|
||||
ch chan func()
|
||||
}
|
||||
|
||||
// Create a new worker pool of size N. Up to N work can be done concurrently.
|
||||
// The size of N depends on the expected frequency of work and contention for
|
||||
// shared resources. Large values of N allow more frequent work at the cost of
|
||||
// more contention for shared resources like cpu, memory and fds. Small values
|
||||
// of N allow less frequent work but control the amount of shared resource contention.
|
||||
// Ideally this value will be derived from whatever shared resource constraints you
|
||||
// are hitting up against, rather than set to a fixed value. For example, if you have
|
||||
// a database connection limit of 100, then setting N to some fraction of the limit is
|
||||
// preferred to setting this to an arbitrary number < 100. If more than N work is requested,
|
||||
// eventually WorkerPool.Queue will block until some work is done.
|
||||
//
|
||||
// The larger N is, the larger the up front memory costs are due to the implementation of WorkerPool.
|
||||
func NewWorkerPool(n int) *WorkerPool {
|
||||
return &WorkerPool{
|
||||
N: n,
|
||||
// If we have N workers, we can process N work concurrently.
|
||||
// If we have >N work, we need to apply backpressure to stop us
|
||||
// making more and more work which takes up more and more memory.
|
||||
// By setting the channel size to N, we ensure that backpressure is
|
||||
// being applied on the producer, stopping it from creating more work,
|
||||
// and hence bounding memory consumption. Work is still being produced
|
||||
// upstream on the homeserver, but we will consume it when we're ready
|
||||
// rather than gobble it all at once.
|
||||
//
|
||||
// Note: we aren't forced to set this to N, it just serves as a useful
|
||||
// metric which scales on the number of workers. The amount of in-flight
|
||||
// work is N, so it makes sense to allow up to N work to be queued up before
|
||||
// applying backpressure. If the channel buffer is < N then the channel can
|
||||
// become the bottleneck in the case where we have lots of instantaneous work
|
||||
// to do. If the channel buffer is too large, we needlessly consume memory as
|
||||
// make() will allocate a backing array of whatever size you give it up front (sad face)
|
||||
ch: make(chan func(), n),
|
||||
}
|
||||
}
|
||||
|
||||
// Start the workers. Only call this once.
|
||||
func (wp *WorkerPool) Start() {
|
||||
for i := 0; i < wp.N; i++ {
|
||||
go wp.worker()
|
||||
}
|
||||
}
|
||||
|
||||
// Stop the worker pool. Only really useful for tests as a worker pool should be started once
|
||||
// and persist for the lifetime of the process, else it causes needless goroutine churn.
|
||||
// Only call this once.
|
||||
func (wp *WorkerPool) Stop() {
|
||||
close(wp.ch)
|
||||
}
|
||||
|
||||
// Queue some work on the pool. May or may not block until some work is processed.
|
||||
func (wp *WorkerPool) Queue(fn func()) {
|
||||
wp.ch <- fn
|
||||
}
|
||||
|
||||
// worker impl
|
||||
func (wp *WorkerPool) worker() {
|
||||
for fn := range wp.ch {
|
||||
fn()
|
||||
}
|
||||
}
|
186
internal/pool_test.go
Normal file
186
internal/pool_test.go
Normal file
@ -0,0 +1,186 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Test basic functions of WorkerPool
|
||||
func TestWorkerPool(t *testing.T) {
|
||||
wp := NewWorkerPool(2)
|
||||
wp.Start()
|
||||
defer wp.Stop()
|
||||
|
||||
// we should process this concurrently as N=2 so it should take 1s not 2s
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
start := time.Now()
|
||||
wp.Queue(func() {
|
||||
time.Sleep(time.Second)
|
||||
wg.Done()
|
||||
})
|
||||
wp.Queue(func() {
|
||||
time.Sleep(time.Second)
|
||||
wg.Done()
|
||||
})
|
||||
wg.Wait()
|
||||
took := time.Since(start)
|
||||
if took > 2*time.Second {
|
||||
t.Fatalf("took %v for queued work, it should have been faster than 2s", took)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWorkerPoolDoesWorkPriorToStart(t *testing.T) {
|
||||
wp := NewWorkerPool(2)
|
||||
|
||||
// return channel to use to see when work is done
|
||||
ch := make(chan int, 2)
|
||||
wp.Queue(func() {
|
||||
ch <- 1
|
||||
})
|
||||
wp.Queue(func() {
|
||||
ch <- 2
|
||||
})
|
||||
|
||||
// the work should not be done yet
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
if len(ch) > 0 {
|
||||
t.Fatalf("Queued work was done before Start()")
|
||||
}
|
||||
|
||||
// the work should be starting now
|
||||
wp.Start()
|
||||
defer wp.Stop()
|
||||
|
||||
sum := 0
|
||||
for {
|
||||
select {
|
||||
case <-time.After(time.Second):
|
||||
t.Fatalf("timed out waiting for work to be done")
|
||||
case val := <-ch:
|
||||
sum += val
|
||||
}
|
||||
if sum == 3 { // 2 + 1
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type workerState struct {
|
||||
id int
|
||||
state int // not running, queued, running, finished
|
||||
unblock *sync.WaitGroup // decrement to unblock this worker
|
||||
}
|
||||
|
||||
func TestWorkerPoolBackpressure(t *testing.T) {
|
||||
// this test assumes backpressure starts at n*2+1 due to a chan buffer of size n, and n in-flight work.
|
||||
n := 2
|
||||
wp := NewWorkerPool(n)
|
||||
wp.Start()
|
||||
defer wp.Stop()
|
||||
|
||||
var mu sync.Mutex
|
||||
stateNotRunning := 0
|
||||
stateQueued := 1
|
||||
stateRunning := 2
|
||||
stateFinished := 3
|
||||
size := (2 * n) + 1
|
||||
running := make([]*workerState, size)
|
||||
|
||||
go func() {
|
||||
// we test backpressure by scheduling (n*2)+1 work and ensuring that we see the following running states:
|
||||
// [2,2,1,1,0] <-- 2 running, 2 queued, 1 blocked <-- THIS IS BACKPRESSURE
|
||||
// [3,2,2,1,1] <-- 1 finished, 2 running, 2 queued
|
||||
// [3,3,2,2,1] <-- 2 finished, 2 running , 1 queued
|
||||
// [3,3,3,2,2] <-- 3 finished, 2 running
|
||||
for i := 0; i < size; i++ {
|
||||
// set initial state of this piece of work
|
||||
wg := &sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
state := &workerState{
|
||||
id: i,
|
||||
state: stateNotRunning,
|
||||
unblock: wg,
|
||||
}
|
||||
mu.Lock()
|
||||
running[i] = state
|
||||
mu.Unlock()
|
||||
|
||||
// queue the work on the pool. The final piece of work will block here and remain in
|
||||
// stateNotRunning and not transition to stateQueued until the first piece of work is done.
|
||||
wp.Queue(func() {
|
||||
mu.Lock()
|
||||
if running[state.id].state != stateQueued {
|
||||
// we ran work in the worker faster than the code underneath .Queue, so let it catch up
|
||||
mu.Unlock()
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
mu.Lock()
|
||||
}
|
||||
running[state.id].state = stateRunning
|
||||
mu.Unlock()
|
||||
|
||||
running[state.id].unblock.Wait()
|
||||
mu.Lock()
|
||||
running[state.id].state = stateFinished
|
||||
mu.Unlock()
|
||||
})
|
||||
|
||||
// mark this work as queued
|
||||
mu.Lock()
|
||||
running[i].state = stateQueued
|
||||
mu.Unlock()
|
||||
}
|
||||
}()
|
||||
|
||||
// wait for the workers to be doing work and assert the states of each task
|
||||
time.Sleep(time.Second)
|
||||
|
||||
assertStates(t, &mu, running, []int{
|
||||
stateRunning, stateRunning, stateQueued, stateQueued, stateNotRunning,
|
||||
})
|
||||
|
||||
// now let the first task complete
|
||||
running[0].unblock.Done()
|
||||
// wait for the pool to grab more work
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
// assert new states
|
||||
assertStates(t, &mu, running, []int{
|
||||
stateFinished, stateRunning, stateRunning, stateQueued, stateQueued,
|
||||
})
|
||||
|
||||
// now let the second task complete
|
||||
running[1].unblock.Done()
|
||||
// wait for the pool to grab more work
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
// assert new states
|
||||
assertStates(t, &mu, running, []int{
|
||||
stateFinished, stateFinished, stateRunning, stateRunning, stateQueued,
|
||||
})
|
||||
|
||||
// now let the third task complete
|
||||
running[2].unblock.Done()
|
||||
// wait for the pool to grab more work
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
// assert new states
|
||||
assertStates(t, &mu, running, []int{
|
||||
stateFinished, stateFinished, stateFinished, stateRunning, stateRunning,
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
func assertStates(t *testing.T, mu *sync.Mutex, running []*workerState, wantStates []int) {
|
||||
t.Helper()
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
if len(running) != len(wantStates) {
|
||||
t.Fatalf("assertStates: bad wantStates length, got %d want %d", len(wantStates), len(running))
|
||||
}
|
||||
for i := range running {
|
||||
state := running[i]
|
||||
wantVal := wantStates[i]
|
||||
if state.state != wantVal {
|
||||
t.Errorf("work[%d] got state %d want %d", i, state.state, wantVal)
|
||||
}
|
||||
}
|
||||
}
|
@ -13,7 +13,12 @@ type EventMetadata struct {
|
||||
Timestamp uint64
|
||||
}
|
||||
|
||||
// RoomMetadata holds room-scoped data. It is primarily used in two places:
|
||||
// RoomMetadata holds room-scoped data.
|
||||
// TODO: This is a lie: we sometimes remove a user U from the list of heroes
|
||||
// when calculating the sync response for that user U. Grep for `RemoveHero`.
|
||||
//
|
||||
// It is primarily used in two places:
|
||||
//
|
||||
// - in the caches.GlobalCache, to hold the latest version of data that is consistent
|
||||
// between all users in the room; and
|
||||
// - in the sync3.RoomConnMetadata struct, to hold the version of data last seen by
|
||||
@ -25,6 +30,7 @@ type RoomMetadata struct {
|
||||
RoomID string
|
||||
Heroes []Hero
|
||||
NameEvent string // the content of m.room.name, NOT the calculated name
|
||||
AvatarEvent string // the content of m.room.avatar, NOT the resolved avatar
|
||||
CanonicalAlias string
|
||||
JoinCount int
|
||||
InviteCount int
|
||||
@ -54,6 +60,32 @@ func NewRoomMetadata(roomID string) *RoomMetadata {
|
||||
}
|
||||
}
|
||||
|
||||
// CopyHeroes returns a version of the current RoomMetadata whose Heroes field is
|
||||
// a brand-new copy of the original Heroes. The return value's Heroes field can be
|
||||
// safely modified by the caller, but it is NOT safe for the caller to modify any other
|
||||
// fields.
|
||||
func (m *RoomMetadata) CopyHeroes() *RoomMetadata {
|
||||
newMetadata := *m
|
||||
|
||||
// XXX: We're doing this because we end up calling RemoveHero() to omit the
|
||||
// currently-sycning user in various places. But this seems smelly. The set of
|
||||
// heroes in the room is a global, room-scoped fact: it is a property of the room
|
||||
// state and nothing else, and all users see the same set of heroes.
|
||||
//
|
||||
// I think the data model would be cleaner if we made the hero-reading functions
|
||||
// aware of the currently syncing user, in order to ignore them without having to
|
||||
// change the underlying data.
|
||||
//
|
||||
// copy the heroes or else we may modify the same slice which would be bad :(
|
||||
newMetadata.Heroes = make([]Hero, len(m.Heroes))
|
||||
copy(newMetadata.Heroes, m.Heroes)
|
||||
|
||||
// ⚠️ NB: there are other pointer fields (e.g. PredecessorRoomID *string) or
|
||||
// and pointer-backed fields (e.g. LatestEventsByType map[string]EventMetadata)
|
||||
// which are not deepcopied here.
|
||||
return &newMetadata
|
||||
}
|
||||
|
||||
// SameRoomName checks if the fields relevant for room names have changed between the two metadatas.
|
||||
// Returns true if there are no changes.
|
||||
func (m *RoomMetadata) SameRoomName(other *RoomMetadata) bool {
|
||||
@ -62,7 +94,13 @@ func (m *RoomMetadata) SameRoomName(other *RoomMetadata) bool {
|
||||
m.CanonicalAlias == other.CanonicalAlias &&
|
||||
m.JoinCount == other.JoinCount &&
|
||||
m.InviteCount == other.InviteCount &&
|
||||
sameHeroes(m.Heroes, other.Heroes))
|
||||
sameHeroNames(m.Heroes, other.Heroes))
|
||||
}
|
||||
|
||||
// SameRoomAvatar checks if the fields relevant for room avatars have changed between the two metadatas.
|
||||
// Returns true if there are no changes.
|
||||
func (m *RoomMetadata) SameRoomAvatar(other *RoomMetadata) bool {
|
||||
return m.AvatarEvent == other.AvatarEvent && sameHeroAvatars(m.Heroes, other.Heroes)
|
||||
}
|
||||
|
||||
func (m *RoomMetadata) SameJoinCount(other *RoomMetadata) bool {
|
||||
@ -73,7 +111,7 @@ func (m *RoomMetadata) SameInviteCount(other *RoomMetadata) bool {
|
||||
return m.InviteCount == other.InviteCount
|
||||
}
|
||||
|
||||
func sameHeroes(a, b []Hero) bool {
|
||||
func sameHeroNames(a, b []Hero) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
@ -88,6 +126,21 @@ func sameHeroes(a, b []Hero) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func sameHeroAvatars(a, b []Hero) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
for i := range a {
|
||||
if a[i].ID != b[i].ID {
|
||||
return false
|
||||
}
|
||||
if a[i].Avatar != b[i].Avatar {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *RoomMetadata) RemoveHero(userID string) {
|
||||
for i, h := range m.Heroes {
|
||||
if h.ID == userID {
|
||||
@ -102,8 +155,9 @@ func (m *RoomMetadata) IsSpace() bool {
|
||||
}
|
||||
|
||||
type Hero struct {
|
||||
ID string
|
||||
Name string
|
||||
ID string
|
||||
Name string
|
||||
Avatar string
|
||||
}
|
||||
|
||||
func CalculateRoomName(heroInfo *RoomMetadata, maxNumNamesPerRoom int) string {
|
||||
@ -190,3 +244,18 @@ func disambiguate(heroes []Hero) []string {
|
||||
}
|
||||
return disambiguatedNames
|
||||
}
|
||||
|
||||
const noAvatar = ""
|
||||
|
||||
// CalculateAvatar computes the avatar for the room, based on the global room metadata.
|
||||
// Assumption: metadata.RemoveHero has been called to remove the user who is syncing
|
||||
// from the list of heroes.
|
||||
func CalculateAvatar(metadata *RoomMetadata) string {
|
||||
if metadata.AvatarEvent != "" {
|
||||
return metadata.AvatarEvent
|
||||
}
|
||||
if len(metadata.Heroes) == 1 {
|
||||
return metadata.Heroes[0].Avatar
|
||||
}
|
||||
return noAvatar
|
||||
}
|
||||
|
@ -247,3 +247,45 @@ func TestCalculateRoomName(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCopyHeroes(t *testing.T) {
|
||||
const alice = "@alice:test"
|
||||
const bob = "@bob:test"
|
||||
const chris = "@chris:test"
|
||||
m1 := RoomMetadata{Heroes: []Hero{
|
||||
{ID: alice},
|
||||
{ID: bob},
|
||||
{ID: chris},
|
||||
}}
|
||||
|
||||
m2 := m1.CopyHeroes()
|
||||
// Uncomment this to see why CopyHeroes is necessary!
|
||||
//m2 := m1
|
||||
|
||||
t.Logf("Compare heroes:\n\tm1=%v\n\tm2=%v", m1.Heroes, m2.Heroes)
|
||||
|
||||
t.Log("Remove chris from m1")
|
||||
m1.RemoveHero(chris)
|
||||
t.Logf("Compare heroes:\n\tm1=%v\n\tm2=%v", m1.Heroes, m2.Heroes)
|
||||
|
||||
assertSliceIDs(t, "m1.Heroes", m1.Heroes, []string{alice, bob})
|
||||
assertSliceIDs(t, "m2.Heroes", m2.Heroes, []string{alice, bob, chris})
|
||||
|
||||
t.Log("Remove alice from m1")
|
||||
m1.RemoveHero(alice)
|
||||
t.Logf("Compare heroes:\n\tm1=%v\n\tm2=%v", m1.Heroes, m2.Heroes)
|
||||
|
||||
assertSliceIDs(t, "m1.Heroes", m1.Heroes, []string{bob})
|
||||
assertSliceIDs(t, "m2.Heroes", m2.Heroes, []string{alice, bob, chris})
|
||||
}
|
||||
|
||||
func assertSliceIDs(t *testing.T, desc string, h []Hero, ids []string) {
|
||||
if len(h) != len(ids) {
|
||||
t.Errorf("%s has length %d, expected %d", desc, len(h), len(ids))
|
||||
}
|
||||
for index, id := range ids {
|
||||
if h[index].ID != id {
|
||||
t.Errorf("%s[%d] ID is %s, expected %s", desc, index, h[index].ID, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
18
pubsub/v2.go
18
pubsub/v2.go
@ -41,12 +41,15 @@ type V2Accumulate struct {
|
||||
|
||||
func (*V2Accumulate) Type() string { return "V2Accumulate" }
|
||||
|
||||
// V2TransactionID is emitted by a poller when it sees an event with a transaction ID.
|
||||
// V2TransactionID is emitted by a poller when it sees an event with a transaction ID,
|
||||
// or when it is certain that no other poller will see a transaction ID for this event
|
||||
// (the "all-clear").
|
||||
type V2TransactionID struct {
|
||||
EventID string
|
||||
UserID string
|
||||
RoomID string
|
||||
UserID string // of the sender
|
||||
DeviceID string
|
||||
TransactionID string
|
||||
TransactionID string // Note: an empty transaction ID represents the all-clear.
|
||||
NID int64
|
||||
}
|
||||
|
||||
@ -70,8 +73,9 @@ type V2AccountData struct {
|
||||
func (*V2AccountData) Type() string { return "V2AccountData" }
|
||||
|
||||
type V2LeaveRoom struct {
|
||||
UserID string
|
||||
RoomID string
|
||||
UserID string
|
||||
RoomID string
|
||||
LeaveEvent json.RawMessage
|
||||
}
|
||||
|
||||
func (*V2LeaveRoom) Type() string { return "V2LeaveRoom" }
|
||||
@ -91,9 +95,7 @@ type V2InitialSyncComplete struct {
|
||||
func (*V2InitialSyncComplete) Type() string { return "V2InitialSyncComplete" }
|
||||
|
||||
type V2DeviceData struct {
|
||||
UserID string
|
||||
DeviceID string
|
||||
Pos int64
|
||||
UserIDToDeviceIDs map[string][]string
|
||||
}
|
||||
|
||||
func (*V2DeviceData) Type() string { return "V2DeviceData" }
|
||||
|
@ -207,8 +207,9 @@ func (a *Accumulator) Initialise(roomID string, state []json.RawMessage) (Initia
|
||||
IsState: true,
|
||||
}
|
||||
}
|
||||
if err := ensureFieldsSet(events); err != nil {
|
||||
return fmt.Errorf("events malformed: %s", err)
|
||||
events = filterAndEnsureFieldsSet(events)
|
||||
if len(events) == 0 {
|
||||
return fmt.Errorf("failed to insert events, all events were filtered out: %w", err)
|
||||
}
|
||||
eventIDToNID, err := a.eventsTable.Insert(txn, events, false)
|
||||
if err != nil {
|
||||
@ -292,35 +293,51 @@ func (a *Accumulator) Initialise(roomID string, state []json.RawMessage) (Initia
|
||||
// to exist in the database, and the sync stream is already linearised for us.
|
||||
// - Else it creates a new room state snapshot if the timeline contains state events (as this now represents the current state)
|
||||
// - It adds entries to the membership log for membership events.
|
||||
func (a *Accumulator) Accumulate(txn *sqlx.Tx, roomID string, prevBatch string, timeline []json.RawMessage) (numNew int, timelineNIDs []int64, err error) {
|
||||
// Insert the events. Check for duplicates which can happen in the real world when joining
|
||||
// Matrix HQ on Synapse.
|
||||
dedupedEvents := make([]Event, 0, len(timeline))
|
||||
seenEvents := make(map[string]struct{})
|
||||
for i := range timeline {
|
||||
e := Event{
|
||||
JSON: timeline[i],
|
||||
RoomID: roomID,
|
||||
}
|
||||
if err := e.ensureFieldsSetOnEvent(); err != nil {
|
||||
return 0, nil, fmt.Errorf("event malformed: %s", err)
|
||||
}
|
||||
if _, ok := seenEvents[e.ID]; ok {
|
||||
logger.Warn().Str("event_id", e.ID).Str("room_id", roomID).Msg(
|
||||
"Accumulator.Accumulate: seen the same event ID twice, ignoring",
|
||||
)
|
||||
continue
|
||||
}
|
||||
if i == 0 && prevBatch != "" {
|
||||
// tag the first timeline event with the prev batch token
|
||||
e.PrevBatch = sql.NullString{
|
||||
String: prevBatch,
|
||||
Valid: true,
|
||||
}
|
||||
}
|
||||
dedupedEvents = append(dedupedEvents, e)
|
||||
seenEvents[e.ID] = struct{}{}
|
||||
func (a *Accumulator) Accumulate(txn *sqlx.Tx, userID, roomID string, prevBatch string, timeline []json.RawMessage) (numNew int, timelineNIDs []int64, err error) {
|
||||
// The first stage of accumulating events is mostly around validation around what the upstream HS sends us. For accumulation to work correctly
|
||||
// we expect:
|
||||
// - there to be no duplicate events
|
||||
// - if there are new events, they are always new.
|
||||
// Both of these assumptions can be false for different reasons
|
||||
dedupedEvents, err := a.filterAndParseTimelineEvents(txn, roomID, timeline, prevBatch)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("filterTimelineEvents: %w", err)
|
||||
return
|
||||
}
|
||||
if len(dedupedEvents) == 0 {
|
||||
return 0, nil, err // nothing to do
|
||||
}
|
||||
|
||||
// Given a timeline of [E1, E2, S3, E4, S5, S6, E7] (E=message event, S=state event)
|
||||
// And a prior state snapshot of SNAP0 then the BEFORE snapshot IDs are grouped as:
|
||||
// E1,E2,S3 => SNAP0
|
||||
// E4, S5 => (SNAP0 + S3)
|
||||
// S6 => (SNAP0 + S3 + S5)
|
||||
// E7 => (SNAP0 + S3 + S5 + S6)
|
||||
// We can track this by loading the current snapshot ID (after snapshot) then rolling forward
|
||||
// the timeline until we hit a state event, at which point we make a new snapshot but critically
|
||||
// do NOT assign the new state event in the snapshot so as to represent the state before the event.
|
||||
snapID, err := a.roomsTable.CurrentAfterSnapshotID(txn, roomID)
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
|
||||
// if we have just got a leave event for the polling user, and there is no snapshot for this room already, then
|
||||
// we do NOT want to add this event to the events table, nor do we want to make a room snapshot. This is because
|
||||
// this leave event is an invite rejection, rather than a normal event. Invite rejections cannot be processed in
|
||||
// a normal way because we lack room state (no create event, PLs, etc). If we were to process the invite rejection,
|
||||
// the room state would just be a single event: this leave event, which is wrong.
|
||||
if len(dedupedEvents) == 1 &&
|
||||
dedupedEvents[0].Type == "m.room.member" &&
|
||||
(dedupedEvents[0].Membership == "leave" || dedupedEvents[0].Membership == "_leave") &&
|
||||
dedupedEvents[0].StateKey == userID &&
|
||||
snapID == 0 {
|
||||
logger.Info().Str("event_id", dedupedEvents[0].ID).Str("room_id", roomID).Str("user_id", userID).Err(err).Msg(
|
||||
"Accumulator: skipping processing of leave event, as no snapshot exists",
|
||||
)
|
||||
return 0, nil, nil
|
||||
}
|
||||
|
||||
eventIDToNID, err := a.eventsTable.Insert(txn, dedupedEvents, false)
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
@ -352,19 +369,6 @@ func (a *Accumulator) Accumulate(txn *sqlx.Tx, roomID string, prevBatch string,
|
||||
}
|
||||
}
|
||||
|
||||
// Given a timeline of [E1, E2, S3, E4, S5, S6, E7] (E=message event, S=state event)
|
||||
// And a prior state snapshot of SNAP0 then the BEFORE snapshot IDs are grouped as:
|
||||
// E1,E2,S3 => SNAP0
|
||||
// E4, S5 => (SNAP0 + S3)
|
||||
// S6 => (SNAP0 + S3 + S5)
|
||||
// E7 => (SNAP0 + S3 + S5 + S6)
|
||||
// We can track this by loading the current snapshot ID (after snapshot) then rolling forward
|
||||
// the timeline until we hit a state event, at which point we make a new snapshot but critically
|
||||
// do NOT assign the new state event in the snapshot so as to represent the state before the event.
|
||||
snapID, err := a.roomsTable.CurrentAfterSnapshotID(txn, roomID)
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
for _, ev := range newEvents {
|
||||
var replacesNID int64
|
||||
// the snapshot ID we assign to this event is unaffected by whether /this/ event is state or not,
|
||||
@ -413,24 +417,90 @@ func (a *Accumulator) Accumulate(txn *sqlx.Tx, roomID string, prevBatch string,
|
||||
return numNew, timelineNIDs, nil
|
||||
}
|
||||
|
||||
// Delta returns a list of events of at most `limit` for the room not including `lastEventNID`.
|
||||
// Returns the latest NID of the last event (most recent)
|
||||
func (a *Accumulator) Delta(roomID string, lastEventNID int64, limit int) (eventsJSON []json.RawMessage, latest int64, err error) {
|
||||
txn, err := a.db.Beginx()
|
||||
// filterAndParseTimelineEvents takes a raw timeline array from sync v2 and applies sanity to it:
|
||||
// - removes duplicate events: this is just a bug which has been seen on Synapse on matrix.org
|
||||
// - removes old events: this is an edge case when joining rooms over federation, see https://github.com/matrix-org/sliding-sync/issues/192
|
||||
// - parses it and returns Event structs.
|
||||
// - check which events are unknown. If all events are known, filter them all out.
|
||||
func (a *Accumulator) filterAndParseTimelineEvents(txn *sqlx.Tx, roomID string, timeline []json.RawMessage, prevBatch string) ([]Event, error) {
|
||||
// Check for duplicates which can happen in the real world when joining
|
||||
// Matrix HQ on Synapse, as well as when you join rooms for the first time over federation.
|
||||
dedupedEvents := make([]Event, 0, len(timeline))
|
||||
seenEvents := make(map[string]struct{})
|
||||
for i := range timeline {
|
||||
e := Event{
|
||||
JSON: timeline[i],
|
||||
RoomID: roomID,
|
||||
}
|
||||
if err := e.ensureFieldsSetOnEvent(); err != nil {
|
||||
logger.Warn().Str("event_id", e.ID).Str("room_id", roomID).Err(err).Msg(
|
||||
"Accumulator.filterAndParseTimelineEvents: failed to parse event, ignoring",
|
||||
)
|
||||
continue
|
||||
}
|
||||
if _, ok := seenEvents[e.ID]; ok {
|
||||
logger.Warn().Str("event_id", e.ID).Str("room_id", roomID).Msg(
|
||||
"Accumulator.filterAndParseTimelineEvents: seen the same event ID twice, ignoring",
|
||||
)
|
||||
continue
|
||||
}
|
||||
if i == 0 && prevBatch != "" {
|
||||
// tag the first timeline event with the prev batch token
|
||||
e.PrevBatch = sql.NullString{
|
||||
String: prevBatch,
|
||||
Valid: true,
|
||||
}
|
||||
}
|
||||
dedupedEvents = append(dedupedEvents, e)
|
||||
seenEvents[e.ID] = struct{}{}
|
||||
}
|
||||
|
||||
// if we only have a single timeline event we cannot determine if it is old or not, as we rely on already seen events
|
||||
// being after (higher index) than it.
|
||||
if len(dedupedEvents) <= 1 {
|
||||
return dedupedEvents, nil
|
||||
}
|
||||
|
||||
// Figure out which of these events are unseen and hence brand new live events.
|
||||
// In some cases, we may have unseen OLD events - see https://github.com/matrix-org/sliding-sync/issues/192
|
||||
// in which case we need to drop those events.
|
||||
dedupedEventIDs := make([]string, 0, len(seenEvents))
|
||||
for evID := range seenEvents {
|
||||
dedupedEventIDs = append(dedupedEventIDs, evID)
|
||||
}
|
||||
unknownEventIDs, err := a.eventsTable.SelectUnknownEventIDs(txn, dedupedEventIDs)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
return nil, fmt.Errorf("filterAndParseTimelineEvents: failed to SelectUnknownEventIDs: %w", err)
|
||||
}
|
||||
defer txn.Commit()
|
||||
events, err := a.eventsTable.SelectEventsBetween(txn, roomID, lastEventNID, EventsEnd, limit)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
|
||||
if len(unknownEventIDs) == 0 {
|
||||
// every event has been seen already, no work to do
|
||||
return nil, nil
|
||||
}
|
||||
if len(events) == 0 {
|
||||
return nil, lastEventNID, nil
|
||||
|
||||
// In the happy case, we expect to see timeline arrays like this: (SEEN=S, UNSEEN=U)
|
||||
// [S,S,U,U] -> want last 2
|
||||
// [U,U,U] -> want all
|
||||
// In the backfill edge case, we might see:
|
||||
// [U,S,S,S] -> want none
|
||||
// [U,S,S,U] -> want last 1
|
||||
// We should never see scenarios like:
|
||||
// [U,S,S,U,S,S] <- we should only see 1 contiguous block of seen events.
|
||||
// If we do, we'll just ignore all unseen events less than the highest seen event.
|
||||
|
||||
// The algorithm starts at the end and just looks for the first S event, returning the subslice after that S event (which may be [])
|
||||
seenIndex := -1
|
||||
for i := len(dedupedEvents) - 1; i >= 0; i-- {
|
||||
_, unseen := unknownEventIDs[dedupedEvents[i].ID]
|
||||
if !unseen {
|
||||
seenIndex = i
|
||||
break
|
||||
}
|
||||
}
|
||||
eventsJSON = make([]json.RawMessage, len(events))
|
||||
for i := range events {
|
||||
eventsJSON[i] = events[i].JSON
|
||||
}
|
||||
return eventsJSON, int64(events[len(events)-1].NID), nil
|
||||
// seenIndex can be -1 if all are unseen, or len-1 if all are seen, either way if we +1 this slices correctly:
|
||||
// no seen events s[A,B,C] => s[-1+1:] => [A,B,C]
|
||||
// C is seen event s[A,B,C] => s[2+1:] => []
|
||||
// B is seen event s[A,B,C] => s[1+1:] => [C]
|
||||
// A is seen event s[A,B,C] => s[0+1:] => [B,C]
|
||||
return dedupedEvents[seenIndex+1:], nil
|
||||
}
|
||||
|
@ -11,10 +11,13 @@ import (
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/matrix-org/sliding-sync/sqlutil"
|
||||
"github.com/matrix-org/sliding-sync/sync2"
|
||||
"github.com/matrix-org/sliding-sync/testutils"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
var (
|
||||
userID = "@me:localhost"
|
||||
)
|
||||
|
||||
func TestAccumulatorInitialise(t *testing.T) {
|
||||
roomID := "!TestAccumulatorInitialise:localhost"
|
||||
roomEvents := []json.RawMessage{
|
||||
@ -119,7 +122,7 @@ func TestAccumulatorAccumulate(t *testing.T) {
|
||||
var numNew int
|
||||
var latestNIDs []int64
|
||||
err = sqlutil.WithTransaction(accumulator.db, func(txn *sqlx.Tx) error {
|
||||
numNew, latestNIDs, err = accumulator.Accumulate(txn, roomID, "", newEvents)
|
||||
numNew, latestNIDs, err = accumulator.Accumulate(txn, userID, roomID, "", newEvents)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
@ -193,7 +196,7 @@ func TestAccumulatorAccumulate(t *testing.T) {
|
||||
|
||||
// subsequent calls do nothing and are not an error
|
||||
err = sqlutil.WithTransaction(accumulator.db, func(txn *sqlx.Tx) error {
|
||||
_, _, err = accumulator.Accumulate(txn, roomID, "", newEvents)
|
||||
_, _, err = accumulator.Accumulate(txn, userID, roomID, "", newEvents)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
@ -201,59 +204,6 @@ func TestAccumulatorAccumulate(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccumulatorDelta(t *testing.T) {
|
||||
roomID := "!TestAccumulatorDelta:localhost"
|
||||
db, close := connectToDB(t)
|
||||
defer close()
|
||||
accumulator := NewAccumulator(db)
|
||||
_, err := accumulator.Initialise(roomID, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to Initialise accumulator: %s", err)
|
||||
}
|
||||
roomEvents := []json.RawMessage{
|
||||
[]byte(`{"event_id":"aD", "type":"m.room.create", "state_key":"", "content":{"creator":"@TestAccumulatorDelta:localhost"}}`),
|
||||
[]byte(`{"event_id":"aE", "type":"m.room.member", "state_key":"@TestAccumulatorDelta:localhost", "content":{"membership":"join"}}`),
|
||||
[]byte(`{"event_id":"aF", "type":"m.room.join_rules", "state_key":"", "content":{"join_rule":"public"}}`),
|
||||
[]byte(`{"event_id":"aG", "type":"m.room.message","content":{"body":"Hello World","msgtype":"m.text"}}`),
|
||||
[]byte(`{"event_id":"aH", "type":"m.room.join_rules", "state_key":"", "content":{"join_rule":"public"}}`),
|
||||
[]byte(`{"event_id":"aI", "type":"m.room.history_visibility", "state_key":"", "content":{"visibility":"public"}}`),
|
||||
}
|
||||
err = sqlutil.WithTransaction(accumulator.db, func(txn *sqlx.Tx) error {
|
||||
_, _, err = accumulator.Accumulate(txn, roomID, "", roomEvents)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to Accumulate: %s", err)
|
||||
}
|
||||
|
||||
// Draw the create event, tests limits
|
||||
events, position, err := accumulator.Delta(roomID, EventsStart, 1)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to Delta: %s", err)
|
||||
}
|
||||
if len(events) != 1 {
|
||||
t.Fatalf("failed to get events from Delta, got %d want 1", len(events))
|
||||
}
|
||||
if gjson.GetBytes(events[0], "event_id").Str != gjson.GetBytes(roomEvents[0], "event_id").Str {
|
||||
t.Fatalf("failed to draw first event, got %s want %s", string(events[0]), string(roomEvents[0]))
|
||||
}
|
||||
if position == 0 {
|
||||
t.Errorf("Delta returned zero position")
|
||||
}
|
||||
|
||||
// Draw up to the end
|
||||
events, position, err = accumulator.Delta(roomID, position, 1000)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to Delta: %s", err)
|
||||
}
|
||||
if len(events) != len(roomEvents)-1 {
|
||||
t.Fatalf("failed to get events from Delta, got %d want %d", len(events), len(roomEvents)-1)
|
||||
}
|
||||
if position == 0 {
|
||||
t.Errorf("Delta returned zero position")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccumulatorMembershipLogs(t *testing.T) {
|
||||
roomID := "!TestAccumulatorMembershipLogs:localhost"
|
||||
db, close := connectToDB(t)
|
||||
@ -282,7 +232,7 @@ func TestAccumulatorMembershipLogs(t *testing.T) {
|
||||
[]byte(`{"event_id":"` + roomEventIDs[7] + `", "type":"m.room.member", "state_key":"@me:localhost","unsigned":{"prev_content":{"membership":"join", "displayname":"Me"}}, "content":{"membership":"leave"}}`),
|
||||
}
|
||||
err = sqlutil.WithTransaction(accumulator.db, func(txn *sqlx.Tx) error {
|
||||
_, _, err = accumulator.Accumulate(txn, roomID, "", roomEvents)
|
||||
_, _, err = accumulator.Accumulate(txn, userID, roomID, "", roomEvents)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
@ -409,7 +359,7 @@ func TestAccumulatorDupeEvents(t *testing.T) {
|
||||
}
|
||||
|
||||
err = sqlutil.WithTransaction(accumulator.db, func(txn *sqlx.Tx) error {
|
||||
_, _, err = accumulator.Accumulate(txn, roomID, "", joinRoom.Timeline.Events)
|
||||
_, _, err = accumulator.Accumulate(txn, userID, roomID, "", joinRoom.Timeline.Events)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
@ -417,86 +367,6 @@ func TestAccumulatorDupeEvents(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// Regression test for corrupt state snapshots.
|
||||
// This seems to have happened in the wild, whereby the snapshot exhibited 2 things:
|
||||
// - A message event having a event_replaces_nid. This should be impossible as messages are not state.
|
||||
// - Duplicate events in the state snapshot.
|
||||
//
|
||||
// We can reproduce a message event having a event_replaces_nid by doing the following:
|
||||
// - Create a room with initial state A,C
|
||||
// - Accumulate events D, A, B(msg). This should be impossible because we already got A initially but whatever, roll with it, blame state resets or something.
|
||||
// - This leads to A,B being processed and D ignored if you just take the newest results.
|
||||
//
|
||||
// This can then be tested by:
|
||||
// - Query the current room snapshot. This will include B(msg) when it shouldn't.
|
||||
func TestAccumulatorMisorderedGraceful(t *testing.T) {
|
||||
alice := "@alice:localhost"
|
||||
bob := "@bob:localhost"
|
||||
|
||||
eventA := testutils.NewStateEvent(t, "m.room.member", alice, alice, map[string]interface{}{"membership": "join"})
|
||||
eventC := testutils.NewStateEvent(t, "m.room.create", "", alice, map[string]interface{}{})
|
||||
eventD := testutils.NewStateEvent(
|
||||
t, "m.room.member", bob, "join", map[string]interface{}{"membership": "join"},
|
||||
)
|
||||
eventBMsg := testutils.NewEvent(
|
||||
t, "m.room.message", bob, map[string]interface{}{"body": "hello"},
|
||||
)
|
||||
t.Logf("A=member-alice, B=msg, C=create, D=member-bob")
|
||||
|
||||
db, close := connectToDB(t)
|
||||
defer close()
|
||||
accumulator := NewAccumulator(db)
|
||||
roomID := "!TestAccumulatorStateReset:localhost"
|
||||
// Create a room with initial state A,C
|
||||
_, err := accumulator.Initialise(roomID, []json.RawMessage{
|
||||
eventA, eventC,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to Initialise accumulator: %s", err)
|
||||
}
|
||||
|
||||
// Accumulate events D, A, B(msg).
|
||||
err = sqlutil.WithTransaction(accumulator.db, func(txn *sqlx.Tx) error {
|
||||
_, _, err = accumulator.Accumulate(txn, roomID, "", []json.RawMessage{eventD, eventA, eventBMsg})
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to Accumulate: %s", err)
|
||||
}
|
||||
|
||||
eventIDs := []string{
|
||||
gjson.GetBytes(eventA, "event_id").Str,
|
||||
gjson.GetBytes(eventBMsg, "event_id").Str,
|
||||
gjson.GetBytes(eventC, "event_id").Str,
|
||||
gjson.GetBytes(eventD, "event_id").Str,
|
||||
}
|
||||
t.Logf("Events A,B,C,D: %v", eventIDs)
|
||||
txn := accumulator.db.MustBeginTx(context.Background(), nil)
|
||||
idsToNIDs, err := accumulator.eventsTable.SelectNIDsByIDs(txn, eventIDs)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to SelectNIDsByIDs: %s", err)
|
||||
}
|
||||
if len(idsToNIDs) != len(eventIDs) {
|
||||
t.Errorf("SelectNIDsByIDs: asked for %v got %v", eventIDs, idsToNIDs)
|
||||
}
|
||||
t.Logf("Events: %v", idsToNIDs)
|
||||
|
||||
wantEventNIDs := []int64{
|
||||
idsToNIDs[eventIDs[0]], idsToNIDs[eventIDs[2]], idsToNIDs[eventIDs[3]],
|
||||
}
|
||||
sort.Slice(wantEventNIDs, func(i, j int) bool {
|
||||
return wantEventNIDs[i] < wantEventNIDs[j]
|
||||
})
|
||||
// Query the current room snapshot
|
||||
gotSnapshotEvents := currentSnapshotNIDs(t, accumulator.snapshotTable, roomID)
|
||||
if len(gotSnapshotEvents) != len(wantEventNIDs) { // events A,C,D
|
||||
t.Errorf("corrupt snapshot, got %v want %v", gotSnapshotEvents, wantEventNIDs)
|
||||
}
|
||||
if !reflect.DeepEqual(wantEventNIDs, gotSnapshotEvents) {
|
||||
t.Errorf("got %v want %v", gotSnapshotEvents, wantEventNIDs)
|
||||
}
|
||||
}
|
||||
|
||||
// Regression test for corrupt state snapshots.
|
||||
// This seems to have happened in the wild, whereby the snapshot exhibited 2 things:
|
||||
// - A message event having a event_replaces_nid. This should be impossible as messages are not state.
|
||||
@ -689,7 +559,7 @@ func TestAccumulatorConcurrency(t *testing.T) {
|
||||
defer wg.Done()
|
||||
subset := newEvents[:(i + 1)] // i=0 => [1], i=1 => [1,2], etc
|
||||
err := sqlutil.WithTransaction(accumulator.db, func(txn *sqlx.Tx) error {
|
||||
numNew, _, err := accumulator.Accumulate(txn, roomID, "", subset)
|
||||
numNew, _, err := accumulator.Accumulate(txn, userID, roomID, "", subset)
|
||||
totalNumNew += numNew
|
||||
return err
|
||||
})
|
||||
|
@ -1,10 +1,11 @@
|
||||
package state
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"reflect"
|
||||
|
||||
"github.com/fxamacker/cbor/v2"
|
||||
"github.com/getsentry/sentry-go"
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/matrix-org/sliding-sync/internal"
|
||||
"github.com/matrix-org/sliding-sync/sqlutil"
|
||||
@ -25,14 +26,15 @@ type DeviceDataTable struct {
|
||||
|
||||
func NewDeviceDataTable(db *sqlx.DB) *DeviceDataTable {
|
||||
db.MustExec(`
|
||||
CREATE SEQUENCE IF NOT EXISTS syncv3_device_data_seq;
|
||||
CREATE TABLE IF NOT EXISTS syncv3_device_data (
|
||||
id BIGINT PRIMARY KEY NOT NULL DEFAULT nextval('syncv3_device_data_seq'),
|
||||
user_id TEXT NOT NULL,
|
||||
device_id TEXT NOT NULL,
|
||||
data BYTEA NOT NULL,
|
||||
UNIQUE(user_id, device_id)
|
||||
);
|
||||
-- Set the fillfactor to 90%, to allow for HOT updates (e.g. we only
|
||||
-- change the data, not anything indexed like the id)
|
||||
ALTER TABLE syncv3_device_data SET (fillfactor = 90);
|
||||
`)
|
||||
return &DeviceDataTable{
|
||||
db: db,
|
||||
@ -44,7 +46,7 @@ func NewDeviceDataTable(db *sqlx.DB) *DeviceDataTable {
|
||||
func (t *DeviceDataTable) Select(userID, deviceID string, swap bool) (result *internal.DeviceData, err error) {
|
||||
err = sqlutil.WithTransaction(t.db, func(txn *sqlx.Tx) error {
|
||||
var row DeviceDataRow
|
||||
err = t.db.Get(&row, `SELECT data FROM syncv3_device_data WHERE user_id=$1 AND device_id=$2`, userID, deviceID)
|
||||
err = txn.Get(&row, `SELECT data FROM syncv3_device_data WHERE user_id=$1 AND device_id=$2`, userID, deviceID)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
// if there is no device data for this user, it's not an error.
|
||||
@ -53,7 +55,7 @@ func (t *DeviceDataTable) Select(userID, deviceID string, swap bool) (result *in
|
||||
return err
|
||||
}
|
||||
// unmarshal to swap
|
||||
if err = json.Unmarshal(row.Data, &result); err != nil {
|
||||
if err = cbor.Unmarshal(row.Data, &result); err != nil {
|
||||
return err
|
||||
}
|
||||
result.UserID = userID
|
||||
@ -67,18 +69,19 @@ func (t *DeviceDataTable) Select(userID, deviceID string, swap bool) (result *in
|
||||
writeBack.DeviceLists.New = make(map[string]int)
|
||||
writeBack.ChangedBits = 0
|
||||
|
||||
// re-marshal and write
|
||||
data, err := json.Marshal(writeBack)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if bytes.Equal(data, row.Data) {
|
||||
if reflect.DeepEqual(result, &writeBack) {
|
||||
// The update to the DB would be a no-op; don't bother with it.
|
||||
// This helps reduce write usage and the contention on the unique index for
|
||||
// the device_data table.
|
||||
return nil
|
||||
}
|
||||
_, err = t.db.Exec(`UPDATE syncv3_device_data SET data=$1 WHERE user_id=$2 AND device_id=$3`, data, userID, deviceID)
|
||||
// re-marshal and write
|
||||
data, err := cbor.Marshal(writeBack)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = txn.Exec(`UPDATE syncv3_device_data SET data=$1 WHERE user_id=$2 AND device_id=$3`, data, userID, deviceID)
|
||||
return err
|
||||
})
|
||||
return
|
||||
@ -90,18 +93,18 @@ func (t *DeviceDataTable) DeleteDevice(userID, deviceID string) error {
|
||||
}
|
||||
|
||||
// Upsert combines what is in the database for this user|device with the partial entry `dd`
|
||||
func (t *DeviceDataTable) Upsert(dd *internal.DeviceData) (pos int64, err error) {
|
||||
func (t *DeviceDataTable) Upsert(dd *internal.DeviceData) (err error) {
|
||||
err = sqlutil.WithTransaction(t.db, func(txn *sqlx.Tx) error {
|
||||
// select what already exists
|
||||
var row DeviceDataRow
|
||||
err = t.db.Get(&row, `SELECT data FROM syncv3_device_data WHERE user_id=$1 AND device_id=$2`, dd.UserID, dd.DeviceID)
|
||||
err = txn.Get(&row, `SELECT data FROM syncv3_device_data WHERE user_id=$1 AND device_id=$2`, dd.UserID, dd.DeviceID)
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
return err
|
||||
}
|
||||
// unmarshal and combine
|
||||
var tempDD internal.DeviceData
|
||||
if len(row.Data) > 0 {
|
||||
if err = json.Unmarshal(row.Data, &tempDD); err != nil {
|
||||
if err = cbor.Unmarshal(row.Data, &tempDD); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@ -115,16 +118,19 @@ func (t *DeviceDataTable) Upsert(dd *internal.DeviceData) (pos int64, err error)
|
||||
}
|
||||
tempDD.DeviceLists = tempDD.DeviceLists.Combine(dd.DeviceLists)
|
||||
|
||||
data, err := json.Marshal(tempDD)
|
||||
data, err := cbor.Marshal(tempDD)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = t.db.QueryRow(
|
||||
_, err = txn.Exec(
|
||||
`INSERT INTO syncv3_device_data(user_id, device_id, data) VALUES($1,$2,$3)
|
||||
ON CONFLICT (user_id, device_id) DO UPDATE SET data=$3, id=nextval('syncv3_device_data_seq') RETURNING id`,
|
||||
ON CONFLICT (user_id, device_id) DO UPDATE SET data=$3`,
|
||||
dd.UserID, dd.DeviceID, data,
|
||||
).Scan(&pos)
|
||||
)
|
||||
return err
|
||||
})
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
sentry.CaptureException(err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
@ -10,7 +10,7 @@ import (
|
||||
func assertVal(t *testing.T, msg string, got, want interface{}) {
|
||||
t.Helper()
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Errorf("%s: got %v want %v", msg, got, want)
|
||||
t.Errorf("%s: got\n%#v want\n%#v", msg, got, want)
|
||||
}
|
||||
}
|
||||
|
||||
@ -21,6 +21,7 @@ func assertDeviceData(t *testing.T, g, w internal.DeviceData) {
|
||||
assertVal(t, "FallbackKeyTypes", g.FallbackKeyTypes, w.FallbackKeyTypes)
|
||||
assertVal(t, "OTKCounts", g.OTKCounts, w.OTKCounts)
|
||||
assertVal(t, "ChangedBits", g.ChangedBits, w.ChangedBits)
|
||||
assertVal(t, "DeviceLists", g.DeviceLists, w.DeviceLists)
|
||||
}
|
||||
|
||||
func TestDeviceDataTableSwaps(t *testing.T) {
|
||||
@ -59,12 +60,12 @@ func TestDeviceDataTableSwaps(t *testing.T) {
|
||||
UserID: userID,
|
||||
DeviceID: deviceID,
|
||||
DeviceLists: internal.DeviceLists{
|
||||
New: internal.ToDeviceListChangesMap([]string{"bob"}, nil),
|
||||
New: internal.ToDeviceListChangesMap([]string{"💣"}, nil),
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, dd := range deltas {
|
||||
_, err := table.Upsert(&dd)
|
||||
err := table.Upsert(&dd)
|
||||
assertNoError(t, err)
|
||||
}
|
||||
|
||||
@ -76,7 +77,8 @@ func TestDeviceDataTableSwaps(t *testing.T) {
|
||||
},
|
||||
FallbackKeyTypes: []string{"foobar"},
|
||||
DeviceLists: internal.DeviceLists{
|
||||
New: internal.ToDeviceListChangesMap([]string{"alice", "bob"}, nil),
|
||||
New: internal.ToDeviceListChangesMap([]string{"alice", "💣"}, nil),
|
||||
Sent: map[string]int{},
|
||||
},
|
||||
}
|
||||
want.SetFallbackKeysChanged()
|
||||
@ -87,38 +89,43 @@ func TestDeviceDataTableSwaps(t *testing.T) {
|
||||
assertNoError(t, err)
|
||||
assertDeviceData(t, *got, want)
|
||||
}
|
||||
// now swap-er-roo
|
||||
// now swap-er-roo, at this point we still expect the "old" data,
|
||||
// as it is the first time we swap
|
||||
got, err := table.Select(userID, deviceID, true)
|
||||
assertNoError(t, err)
|
||||
want2 := want
|
||||
want2.DeviceLists = internal.DeviceLists{
|
||||
Sent: internal.ToDeviceListChangesMap([]string{"alice"}, nil),
|
||||
New: nil,
|
||||
}
|
||||
assertDeviceData(t, *got, want2)
|
||||
assertDeviceData(t, *got, want)
|
||||
|
||||
// changed bits were reset when we swapped
|
||||
want2 := want
|
||||
want2.DeviceLists = internal.DeviceLists{
|
||||
Sent: internal.ToDeviceListChangesMap([]string{"alice", "💣"}, nil),
|
||||
New: map[string]int{},
|
||||
}
|
||||
want2.ChangedBits = 0
|
||||
want.ChangedBits = 0
|
||||
|
||||
// this is permanent, read-only views show this too
|
||||
got, err = table.Select(userID, deviceID, false)
|
||||
// this is permanent, read-only views show this too.
|
||||
// Since we have swapped previously, we now expect New to be empty
|
||||
// and Sent to be set. Swap again to clear Sent.
|
||||
got, err = table.Select(userID, deviceID, true)
|
||||
assertNoError(t, err)
|
||||
assertDeviceData(t, *got, want2)
|
||||
|
||||
// another swap causes sent to be cleared out
|
||||
got, err = table.Select(userID, deviceID, true)
|
||||
// We now expect empty DeviceLists, as we swapped twice.
|
||||
got, err = table.Select(userID, deviceID, false)
|
||||
assertNoError(t, err)
|
||||
want3 := want2
|
||||
want3.DeviceLists = internal.DeviceLists{
|
||||
Sent: nil,
|
||||
New: nil,
|
||||
Sent: map[string]int{},
|
||||
New: map[string]int{},
|
||||
}
|
||||
assertDeviceData(t, *got, want3)
|
||||
|
||||
// get back the original state
|
||||
//err = table.DeleteDevice(userID, deviceID)
|
||||
assertNoError(t, err)
|
||||
for _, dd := range deltas {
|
||||
_, err = table.Upsert(&dd)
|
||||
err = table.Upsert(&dd)
|
||||
assertNoError(t, err)
|
||||
}
|
||||
want.SetFallbackKeysChanged()
|
||||
@ -128,13 +135,14 @@ func TestDeviceDataTableSwaps(t *testing.T) {
|
||||
assertDeviceData(t, *got, want)
|
||||
|
||||
// swap once then add once so both sent and new are populated
|
||||
// Moves Alice and Bob to Sent
|
||||
_, err = table.Select(userID, deviceID, true)
|
||||
assertNoError(t, err)
|
||||
_, err = table.Upsert(&internal.DeviceData{
|
||||
err = table.Upsert(&internal.DeviceData{
|
||||
UserID: userID,
|
||||
DeviceID: deviceID,
|
||||
DeviceLists: internal.DeviceLists{
|
||||
New: internal.ToDeviceListChangesMap([]string{"bob"}, []string{"charlie"}),
|
||||
New: internal.ToDeviceListChangesMap([]string{"💣"}, []string{"charlie"}),
|
||||
},
|
||||
})
|
||||
assertNoError(t, err)
|
||||
@ -143,15 +151,17 @@ func TestDeviceDataTableSwaps(t *testing.T) {
|
||||
|
||||
want4 := want
|
||||
want4.DeviceLists = internal.DeviceLists{
|
||||
Sent: internal.ToDeviceListChangesMap([]string{"alice"}, nil),
|
||||
New: internal.ToDeviceListChangesMap([]string{"bob"}, []string{"charlie"}),
|
||||
Sent: internal.ToDeviceListChangesMap([]string{"alice", "💣"}, nil),
|
||||
New: internal.ToDeviceListChangesMap([]string{"💣"}, []string{"charlie"}),
|
||||
}
|
||||
// Without swapping, we expect Alice and Bob in Sent, and Bob and Charlie in New
|
||||
got, err = table.Select(userID, deviceID, false)
|
||||
assertNoError(t, err)
|
||||
assertDeviceData(t, *got, want4)
|
||||
|
||||
// another append then consume
|
||||
_, err = table.Upsert(&internal.DeviceData{
|
||||
// This results in dave to be added to New
|
||||
err = table.Upsert(&internal.DeviceData{
|
||||
UserID: userID,
|
||||
DeviceID: deviceID,
|
||||
DeviceLists: internal.DeviceLists{
|
||||
@ -163,8 +173,18 @@ func TestDeviceDataTableSwaps(t *testing.T) {
|
||||
assertNoError(t, err)
|
||||
want5 := want4
|
||||
want5.DeviceLists = internal.DeviceLists{
|
||||
Sent: internal.ToDeviceListChangesMap([]string{"bob", "dave"}, []string{"charlie", "dave"}),
|
||||
New: nil,
|
||||
Sent: internal.ToDeviceListChangesMap([]string{"alice", "💣"}, nil),
|
||||
New: internal.ToDeviceListChangesMap([]string{"💣"}, []string{"charlie", "dave"}),
|
||||
}
|
||||
assertDeviceData(t, *got, want5)
|
||||
|
||||
// Swapping again clears New
|
||||
got, err = table.Select(userID, deviceID, true)
|
||||
assertNoError(t, err)
|
||||
want5 = want4
|
||||
want5.DeviceLists = internal.DeviceLists{
|
||||
Sent: internal.ToDeviceListChangesMap([]string{"💣"}, []string{"charlie", "dave"}),
|
||||
New: map[string]int{},
|
||||
}
|
||||
assertDeviceData(t, *got, want5)
|
||||
|
||||
@ -190,11 +210,13 @@ func TestDeviceDataTableBitset(t *testing.T) {
|
||||
"foo": 100,
|
||||
"bar": 92,
|
||||
},
|
||||
DeviceLists: internal.DeviceLists{New: map[string]int{}, Sent: map[string]int{}},
|
||||
}
|
||||
fallbakKeyUpdate := internal.DeviceData{
|
||||
UserID: userID,
|
||||
DeviceID: deviceID,
|
||||
FallbackKeyTypes: []string{"foo", "bar"},
|
||||
DeviceLists: internal.DeviceLists{New: map[string]int{}, Sent: map[string]int{}},
|
||||
}
|
||||
bothUpdate := internal.DeviceData{
|
||||
UserID: userID,
|
||||
@ -203,9 +225,10 @@ func TestDeviceDataTableBitset(t *testing.T) {
|
||||
OTKCounts: map[string]int{
|
||||
"both": 100,
|
||||
},
|
||||
DeviceLists: internal.DeviceLists{New: map[string]int{}, Sent: map[string]int{}},
|
||||
}
|
||||
|
||||
_, err := table.Upsert(&otkUpdate)
|
||||
err := table.Upsert(&otkUpdate)
|
||||
assertNoError(t, err)
|
||||
got, err := table.Select(userID, deviceID, true)
|
||||
assertNoError(t, err)
|
||||
@ -217,7 +240,7 @@ func TestDeviceDataTableBitset(t *testing.T) {
|
||||
otkUpdate.ChangedBits = 0
|
||||
assertDeviceData(t, *got, otkUpdate)
|
||||
// now same for fallback keys, but we won't swap them so it should return those diffs
|
||||
_, err = table.Upsert(&fallbakKeyUpdate)
|
||||
err = table.Upsert(&fallbakKeyUpdate)
|
||||
assertNoError(t, err)
|
||||
fallbakKeyUpdate.OTKCounts = otkUpdate.OTKCounts
|
||||
got, err = table.Select(userID, deviceID, false)
|
||||
@ -229,7 +252,7 @@ func TestDeviceDataTableBitset(t *testing.T) {
|
||||
fallbakKeyUpdate.SetFallbackKeysChanged()
|
||||
assertDeviceData(t, *got, fallbakKeyUpdate)
|
||||
// updating both works
|
||||
_, err = table.Upsert(&bothUpdate)
|
||||
err = table.Upsert(&bothUpdate)
|
||||
assertNoError(t, err)
|
||||
got, err = table.Select(userID, deviceID, true)
|
||||
assertNoError(t, err)
|
||||
|
@ -57,7 +57,7 @@ func (ev *Event) ensureFieldsSetOnEvent() error {
|
||||
}
|
||||
if ev.Type == "" {
|
||||
typeResult := evJSON.Get("type")
|
||||
if !typeResult.Exists() || typeResult.Str == "" {
|
||||
if !typeResult.Exists() || typeResult.Type != gjson.String { // empty strings for 'type' are valid apparently
|
||||
return fmt.Errorf("event JSON missing type key")
|
||||
}
|
||||
ev.Type = typeResult.Str
|
||||
@ -153,7 +153,7 @@ func (t *EventTable) SelectHighestNID() (highest int64, err error) {
|
||||
// we insert new events A and B in that order, then NID(A) < NID(B).
|
||||
func (t *EventTable) Insert(txn *sqlx.Tx, events []Event, checkFields bool) (map[string]int64, error) {
|
||||
if checkFields {
|
||||
ensureFieldsSet(events)
|
||||
events = filterAndEnsureFieldsSet(events)
|
||||
}
|
||||
result := make(map[string]int64)
|
||||
for i := range events {
|
||||
@ -317,10 +317,10 @@ func (t *EventTable) LatestEventInRooms(txn *sqlx.Tx, roomIDs []string, highestN
|
||||
return
|
||||
}
|
||||
|
||||
func (t *EventTable) LatestEventNIDInRooms(roomIDs []string, highestNID int64) (roomToNID map[string]int64, err error) {
|
||||
func (t *EventTable) LatestEventNIDInRooms(txn *sqlx.Tx, roomIDs []string, highestNID int64) (roomToNID map[string]int64, err error) {
|
||||
// the position (event nid) may be for a random different room, so we need to find the highest nid <= this position for this room
|
||||
var events []Event
|
||||
err = t.db.Select(
|
||||
err = txn.Select(
|
||||
&events,
|
||||
`SELECT event_nid, room_id FROM syncv3_events
|
||||
WHERE event_nid IN (SELECT max(event_nid) FROM syncv3_events WHERE event_nid <= $1 AND room_id = ANY($2) GROUP BY room_id)`,
|
||||
@ -336,14 +336,6 @@ func (t *EventTable) LatestEventNIDInRooms(roomIDs []string, highestNID int64) (
|
||||
return
|
||||
}
|
||||
|
||||
func (t *EventTable) SelectEventsBetween(txn *sqlx.Tx, roomID string, lowerExclusive, upperInclusive int64, limit int) ([]Event, error) {
|
||||
var events []Event
|
||||
err := txn.Select(&events, `SELECT event_nid, event FROM syncv3_events WHERE event_nid > $1 AND event_nid <= $2 AND room_id = $3 ORDER BY event_nid ASC LIMIT $4`,
|
||||
lowerExclusive, upperInclusive, roomID, limit,
|
||||
)
|
||||
return events, err
|
||||
}
|
||||
|
||||
func (t *EventTable) SelectLatestEventsBetween(txn *sqlx.Tx, roomID string, lowerExclusive, upperInclusive int64, limit int) ([]Event, error) {
|
||||
var events []Event
|
||||
// do not pull in events which were in the v2 state block
|
||||
@ -438,8 +430,8 @@ func (t *EventTable) SelectClosestPrevBatchByID(roomID string, eventID string) (
|
||||
|
||||
// Select the closest prev batch token for the provided event NID. Returns the empty string if there
|
||||
// is no closest.
|
||||
func (t *EventTable) SelectClosestPrevBatch(roomID string, eventNID int64) (prevBatch string, err error) {
|
||||
err = t.db.QueryRow(
|
||||
func (t *EventTable) SelectClosestPrevBatch(txn *sqlx.Tx, roomID string, eventNID int64) (prevBatch string, err error) {
|
||||
err = txn.QueryRow(
|
||||
`SELECT prev_batch FROM syncv3_events WHERE prev_batch IS NOT NULL AND room_id=$1 AND event_nid >= $2 LIMIT 1`, roomID, eventNID,
|
||||
).Scan(&prevBatch)
|
||||
if err == sql.ErrNoRows {
|
||||
@ -457,14 +449,18 @@ func (c EventChunker) Subslice(i, j int) sqlutil.Chunker {
|
||||
return c[i:j]
|
||||
}
|
||||
|
||||
func ensureFieldsSet(events []Event) error {
|
||||
func filterAndEnsureFieldsSet(events []Event) []Event {
|
||||
result := make([]Event, 0, len(events))
|
||||
// ensure fields are set
|
||||
for i := range events {
|
||||
ev := events[i]
|
||||
ev := &events[i]
|
||||
if err := ev.ensureFieldsSetOnEvent(); err != nil {
|
||||
return err
|
||||
logger.Warn().Str("event_id", ev.ID).Err(err).Msg(
|
||||
"filterAndEnsureFieldsSet: failed to parse event, ignoring",
|
||||
)
|
||||
continue
|
||||
}
|
||||
events[i] = ev
|
||||
result = append(result, *ev)
|
||||
}
|
||||
return nil
|
||||
return result
|
||||
}
|
||||
|
@ -297,125 +297,6 @@ func TestEventTableDupeInsert(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestEventTableSelectEventsBetween(t *testing.T) {
|
||||
db, close := connectToDB(t)
|
||||
defer close()
|
||||
txn, err := db.Beginx()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to start txn: %s", err)
|
||||
}
|
||||
table := NewEventTable(db)
|
||||
searchRoomID := "!0TestEventTableSelectEventsBetween:localhost"
|
||||
eventIDs := []string{
|
||||
"100TestEventTableSelectEventsBetween",
|
||||
"101TestEventTableSelectEventsBetween",
|
||||
"102TestEventTableSelectEventsBetween",
|
||||
"103TestEventTableSelectEventsBetween",
|
||||
"104TestEventTableSelectEventsBetween",
|
||||
}
|
||||
events := []Event{
|
||||
{
|
||||
JSON: []byte(`{"event_id":"` + eventIDs[0] + `","type": "T1", "state_key":"S1", "room_id":"` + searchRoomID + `"}`),
|
||||
},
|
||||
{
|
||||
JSON: []byte(`{"event_id":"` + eventIDs[1] + `","type": "T2", "state_key":"S2", "room_id":"` + searchRoomID + `"}`),
|
||||
},
|
||||
{
|
||||
JSON: []byte(`{"event_id":"` + eventIDs[2] + `","type": "T3", "state_key":"", "room_id":"` + searchRoomID + `"}`),
|
||||
},
|
||||
{
|
||||
// different room
|
||||
JSON: []byte(`{"event_id":"` + eventIDs[3] + `","type": "T4", "state_key":"", "room_id":"!1TestEventTableSelectEventsBetween:localhost"}`),
|
||||
},
|
||||
{
|
||||
JSON: []byte(`{"event_id":"` + eventIDs[4] + `","type": "T5", "state_key":"", "room_id":"` + searchRoomID + `"}`),
|
||||
},
|
||||
}
|
||||
idToNID, err := table.Insert(txn, events, true)
|
||||
if err != nil {
|
||||
t.Fatalf("Insert failed: %s", err)
|
||||
}
|
||||
if len(idToNID) != len(events) {
|
||||
t.Fatalf("failed to insert events: got %d want %d", len(idToNID), len(events))
|
||||
}
|
||||
txn.Commit()
|
||||
|
||||
t.Run("subgroup", func(t *testing.T) {
|
||||
t.Run("selecting multiple events known lower bound", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
txn2, err := db.Beginx()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to start txn: %s", err)
|
||||
}
|
||||
defer txn2.Rollback()
|
||||
events, err := table.SelectByIDs(txn2, true, []string{eventIDs[0]})
|
||||
if err != nil || len(events) == 0 {
|
||||
t.Fatalf("failed to extract event for lower bound: %s", err)
|
||||
}
|
||||
events, err = table.SelectEventsBetween(txn2, searchRoomID, int64(events[0].NID), EventsEnd, 1000)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to SelectEventsBetween: %s", err)
|
||||
}
|
||||
// 3 as 1 is from a different room
|
||||
if len(events) != 3 {
|
||||
t.Fatalf("wanted 3 events, got %d", len(events))
|
||||
}
|
||||
})
|
||||
t.Run("selecting multiple events known lower and upper bound", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
txn3, err := db.Beginx()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to start txn: %s", err)
|
||||
}
|
||||
defer txn3.Rollback()
|
||||
events, err := table.SelectByIDs(txn3, true, []string{eventIDs[0], eventIDs[2]})
|
||||
if err != nil || len(events) == 0 {
|
||||
t.Fatalf("failed to extract event for lower/upper bound: %s", err)
|
||||
}
|
||||
events, err = table.SelectEventsBetween(txn3, searchRoomID, int64(events[0].NID), int64(events[1].NID), 1000)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to SelectEventsBetween: %s", err)
|
||||
}
|
||||
// eventIDs[1] and eventIDs[2]
|
||||
if len(events) != 2 {
|
||||
t.Fatalf("wanted 2 events, got %d", len(events))
|
||||
}
|
||||
})
|
||||
t.Run("selecting multiple events unknown bounds (all events)", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
txn4, err := db.Beginx()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to start txn: %s", err)
|
||||
}
|
||||
defer txn4.Rollback()
|
||||
gotEvents, err := table.SelectEventsBetween(txn4, searchRoomID, EventsStart, EventsEnd, 1000)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to SelectEventsBetween: %s", err)
|
||||
}
|
||||
// one less as one event is for a different room
|
||||
if len(gotEvents) != (len(events) - 1) {
|
||||
t.Fatalf("wanted %d events, got %d", len(events)-1, len(gotEvents))
|
||||
}
|
||||
})
|
||||
t.Run("selecting multiple events hitting the limit", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
txn5, err := db.Beginx()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to start txn: %s", err)
|
||||
}
|
||||
defer txn5.Rollback()
|
||||
limit := 2
|
||||
gotEvents, err := table.SelectEventsBetween(txn5, searchRoomID, EventsStart, EventsEnd, limit)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to SelectEventsBetween: %s", err)
|
||||
}
|
||||
if len(gotEvents) != limit {
|
||||
t.Fatalf("wanted %d events, got %d", limit, len(gotEvents))
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestEventTableMembershipDetection(t *testing.T) {
|
||||
db, close := connectToDB(t)
|
||||
defer close()
|
||||
@ -778,10 +659,14 @@ func TestEventTablePrevBatch(t *testing.T) {
|
||||
}
|
||||
|
||||
assertPrevBatch := func(roomID string, index int, wantPrevBatch string) {
|
||||
gotPrevBatch, err := table.SelectClosestPrevBatch(roomID, int64(idToNID[events[index].ID]))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to SelectClosestPrevBatch: %s", err)
|
||||
}
|
||||
var gotPrevBatch string
|
||||
_ = sqlutil.WithTransaction(db, func(txn *sqlx.Tx) error {
|
||||
gotPrevBatch, err = table.SelectClosestPrevBatch(txn, roomID, int64(idToNID[events[index].ID]))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to SelectClosestPrevBatch: %s", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if wantPrevBatch != "" {
|
||||
if gotPrevBatch == "" || gotPrevBatch != wantPrevBatch {
|
||||
t.Fatalf("SelectClosestPrevBatch: got %v want %v", gotPrevBatch, wantPrevBatch)
|
||||
@ -947,7 +832,11 @@ func TestLatestEventNIDInRooms(t *testing.T) {
|
||||
},
|
||||
}
|
||||
for _, tc := range testCases {
|
||||
gotRoomToNID, err := table.LatestEventNIDInRooms(tc.roomIDs, int64(tc.highestNID))
|
||||
var gotRoomToNID map[string]int64
|
||||
err = sqlutil.WithTransaction(table.db, func(txn *sqlx.Tx) error {
|
||||
gotRoomToNID, err = table.LatestEventNIDInRooms(txn, tc.roomIDs, int64(tc.highestNID))
|
||||
return err
|
||||
})
|
||||
assertNoError(t, err)
|
||||
want := make(map[string]int64) // map event IDs to nids
|
||||
for roomID, eventID := range tc.wantMap {
|
||||
|
11
state/migrations/20230728114555_device_data_drop_id.sql
Normal file
11
state/migrations/20230728114555_device_data_drop_id.sql
Normal file
@ -0,0 +1,11 @@
|
||||
-- +goose Up
|
||||
-- +goose StatementBegin
|
||||
ALTER TABLE IF EXISTS syncv3_device_data DROP COLUMN IF EXISTS id;
|
||||
DROP SEQUENCE IF EXISTS syncv3_device_data_seq;
|
||||
-- +goose StatementEnd
|
||||
|
||||
-- +goose Down
|
||||
-- +goose StatementBegin
|
||||
CREATE SEQUENCE IF NOT EXISTS syncv3_device_data_seq;
|
||||
ALTER TABLE syncv3_device_data ADD COLUMN IF NOT EXISTS id BIGINT PRIMARY KEY NOT NULL DEFAULT nextval('syncv3_device_data_seq') ;
|
||||
-- +goose StatementEnd
|
97
state/migrations/20230802121023_device_data_jsonb.go
Normal file
97
state/migrations/20230802121023_device_data_jsonb.go
Normal file
@ -0,0 +1,97 @@
|
||||
package migrations
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
"github.com/matrix-org/sliding-sync/sync2"
|
||||
"github.com/pressly/goose/v3"
|
||||
)
|
||||
|
||||
func init() {
|
||||
goose.AddMigrationContext(upJSONB, downJSONB)
|
||||
}
|
||||
|
||||
func upJSONB(ctx context.Context, tx *sql.Tx) error {
|
||||
// check if we even need to do anything
|
||||
var dataType string
|
||||
err := tx.QueryRow("select data_type from information_schema.columns where table_name = 'syncv3_device_data' AND column_name = 'data'").Scan(&dataType)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
// The table/column doesn't exist in is likely going to be created soon with the
|
||||
// correct schema
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
if strings.ToLower(dataType) == "jsonb" {
|
||||
return nil
|
||||
}
|
||||
|
||||
_, err = tx.ExecContext(ctx, "ALTER TABLE IF EXISTS syncv3_device_data ADD COLUMN IF NOT EXISTS dataj JSONB;")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rows, err := tx.Query("SELECT user_id, device_id, data FROM syncv3_device_data")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
// abusing PollerID here
|
||||
var deviceData sync2.PollerID
|
||||
var data []byte
|
||||
|
||||
// map from PollerID -> deviceData
|
||||
deviceDatas := make(map[sync2.PollerID][]byte)
|
||||
for rows.Next() {
|
||||
if err = rows.Scan(&deviceData.UserID, &deviceData.DeviceID, &data); err != nil {
|
||||
return err
|
||||
}
|
||||
deviceDatas[deviceData] = data
|
||||
}
|
||||
|
||||
for dd, d := range deviceDatas {
|
||||
_, err = tx.ExecContext(ctx, "UPDATE syncv3_device_data SET dataj = $1 WHERE user_id = $2 AND device_id = $3;", d, dd.UserID, dd.DeviceID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if rows.Err() != nil {
|
||||
return rows.Err()
|
||||
}
|
||||
|
||||
_, err = tx.ExecContext(ctx, "ALTER TABLE IF EXISTS syncv3_device_data DROP COLUMN IF EXISTS data;")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = tx.ExecContext(ctx, "ALTER TABLE IF EXISTS syncv3_device_data RENAME COLUMN dataj TO data;")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func downJSONB(ctx context.Context, tx *sql.Tx) error {
|
||||
_, err := tx.ExecContext(ctx, "ALTER TABLE IF EXISTS syncv3_device_data ADD COLUMN IF NOT EXISTS datab BYTEA;")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = tx.ExecContext(ctx, "UPDATE syncv3_device_data SET datab = (data::TEXT)::BYTEA;")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = tx.ExecContext(ctx, "ALTER TABLE IF EXISTS syncv3_device_data DROP COLUMN IF EXISTS data;")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = tx.ExecContext(ctx, "ALTER TABLE IF EXISTS syncv3_device_data RENAME COLUMN datab TO data;")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
80
state/migrations/20230802121023_device_data_jsonb_test.go
Normal file
80
state/migrations/20230802121023_device_data_jsonb_test.go
Normal file
@ -0,0 +1,80 @@
|
||||
package migrations
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/jmoiron/sqlx"
|
||||
_ "github.com/lib/pq"
|
||||
"github.com/matrix-org/sliding-sync/internal"
|
||||
"github.com/matrix-org/sliding-sync/testutils"
|
||||
)
|
||||
|
||||
var postgresConnectionString = "user=xxxxx dbname=syncv3_test sslmode=disable"
|
||||
|
||||
func connectToDB(t *testing.T) (*sqlx.DB, func()) {
|
||||
postgresConnectionString = testutils.PrepareDBConnectionString()
|
||||
db, err := sqlx.Open("postgres", postgresConnectionString)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to open SQL db: %s", err)
|
||||
}
|
||||
return db, func() {
|
||||
db.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func TestJSONBMigration(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
db, close := connectToDB(t)
|
||||
defer close()
|
||||
|
||||
// Create the table in the old format (data = BYTEA instead of JSONB)
|
||||
_, err := db.Exec(`CREATE TABLE syncv3_device_data (
|
||||
user_id TEXT NOT NULL,
|
||||
device_id TEXT NOT NULL,
|
||||
data BYTEA NOT NULL,
|
||||
UNIQUE(user_id, device_id)
|
||||
);`)
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer tx.Commit()
|
||||
|
||||
// insert some "invalid" data
|
||||
dd := internal.DeviceData{
|
||||
DeviceLists: internal.DeviceLists{
|
||||
New: map[string]int{"@💣:localhost": 1},
|
||||
Sent: map[string]int{},
|
||||
},
|
||||
OTKCounts: map[string]int{},
|
||||
FallbackKeyTypes: []string{},
|
||||
}
|
||||
data, err := json.Marshal(dd)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = tx.ExecContext(ctx, `INSERT INTO syncv3_device_data (user_id, device_id, data) VALUES ($1, $2, $3)`, "bob", "bobDev", data)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// validate that invalid data can be migrated upwards
|
||||
err = upJSONB(ctx, tx)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// and downgrade again
|
||||
err = downJSONB(ctx, tx)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
142
state/migrations/20230814183302_cbor_device_data.go
Normal file
142
state/migrations/20230814183302_cbor_device_data.go
Normal file
@ -0,0 +1,142 @@
|
||||
package migrations
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/fxamacker/cbor/v2"
|
||||
"github.com/matrix-org/sliding-sync/internal"
|
||||
"github.com/matrix-org/sliding-sync/sync2"
|
||||
"github.com/pressly/goose/v3"
|
||||
)
|
||||
|
||||
func init() {
|
||||
goose.AddMigrationContext(upCborDeviceData, downCborDeviceData)
|
||||
}
|
||||
|
||||
func upCborDeviceData(ctx context.Context, tx *sql.Tx) error {
|
||||
// check if we even need to do anything
|
||||
var dataType string
|
||||
err := tx.QueryRow("select data_type from information_schema.columns where table_name = 'syncv3_device_data' AND column_name = 'data'").Scan(&dataType)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
// The table/column doesn't exist in is likely going to be created soon with the
|
||||
// correct schema
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
if strings.ToLower(dataType) == "bytea" {
|
||||
return nil
|
||||
}
|
||||
|
||||
_, err = tx.ExecContext(ctx, "ALTER TABLE IF EXISTS syncv3_device_data ADD COLUMN IF NOT EXISTS datab BYTEA;")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rows, err := tx.Query("SELECT user_id, device_id, data FROM syncv3_device_data")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
// abusing PollerID here
|
||||
var deviceData sync2.PollerID
|
||||
var data []byte
|
||||
|
||||
// map from PollerID -> deviceData
|
||||
deviceDatas := make(map[sync2.PollerID][]byte)
|
||||
for rows.Next() {
|
||||
if err = rows.Scan(&deviceData.UserID, &deviceData.DeviceID, &data); err != nil {
|
||||
return err
|
||||
}
|
||||
deviceDatas[deviceData] = data
|
||||
}
|
||||
|
||||
for dd, jsonBytes := range deviceDatas {
|
||||
var data internal.DeviceData
|
||||
if err := json.Unmarshal(jsonBytes, &data); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal JSON: %v -> %v", string(jsonBytes), err)
|
||||
}
|
||||
cborBytes, err := cbor.Marshal(data)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal as CBOR: %v", err)
|
||||
}
|
||||
|
||||
_, err = tx.ExecContext(ctx, "UPDATE syncv3_device_data SET datab = $1 WHERE user_id = $2 AND device_id = $3;", cborBytes, dd.UserID, dd.DeviceID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if rows.Err() != nil {
|
||||
return rows.Err()
|
||||
}
|
||||
|
||||
_, err = tx.ExecContext(ctx, "ALTER TABLE IF EXISTS syncv3_device_data DROP COLUMN IF EXISTS data;")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = tx.ExecContext(ctx, "ALTER TABLE IF EXISTS syncv3_device_data RENAME COLUMN datab TO data;")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func downCborDeviceData(ctx context.Context, tx *sql.Tx) error {
|
||||
_, err := tx.ExecContext(ctx, "ALTER TABLE IF EXISTS syncv3_device_data ADD COLUMN IF NOT EXISTS dataj JSONB;")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
rows, err := tx.Query("SELECT user_id, device_id, data FROM syncv3_device_data")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
// abusing PollerID here
|
||||
var deviceData sync2.PollerID
|
||||
var data []byte
|
||||
|
||||
// map from PollerID -> deviceData
|
||||
deviceDatas := make(map[sync2.PollerID][]byte)
|
||||
for rows.Next() {
|
||||
if err = rows.Scan(&deviceData.UserID, &deviceData.DeviceID, &data); err != nil {
|
||||
return err
|
||||
}
|
||||
deviceDatas[deviceData] = data
|
||||
}
|
||||
|
||||
for dd, cborBytes := range deviceDatas {
|
||||
var data internal.DeviceData
|
||||
if err := cbor.Unmarshal(cborBytes, &data); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal CBOR: %v", err)
|
||||
}
|
||||
jsonBytes, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal as JSON: %v", err)
|
||||
}
|
||||
|
||||
_, err = tx.ExecContext(ctx, "UPDATE syncv3_device_data SET dataj = $1 WHERE user_id = $2 AND device_id = $3;", jsonBytes, dd.UserID, dd.DeviceID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
_, err = tx.ExecContext(ctx, "ALTER TABLE IF EXISTS syncv3_device_data DROP COLUMN IF EXISTS data;")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = tx.ExecContext(ctx, "ALTER TABLE IF EXISTS syncv3_device_data RENAME COLUMN dataj TO data;")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
121
state/migrations/20230814183302_cbor_device_data_test.go
Normal file
121
state/migrations/20230814183302_cbor_device_data_test.go
Normal file
@ -0,0 +1,121 @@
|
||||
package migrations
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
_ "github.com/lib/pq"
|
||||
"github.com/matrix-org/sliding-sync/internal"
|
||||
"github.com/matrix-org/sliding-sync/state"
|
||||
)
|
||||
|
||||
func TestCBORBMigration(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
db, close := connectToDB(t)
|
||||
defer close()
|
||||
|
||||
// Create the table in the old format (data = JSONB instead of BYTEA)
|
||||
// and insert some data: we'll make sure that this data is preserved
|
||||
// after migrating.
|
||||
_, err := db.Exec(`CREATE TABLE syncv3_device_data (
|
||||
user_id TEXT NOT NULL,
|
||||
device_id TEXT NOT NULL,
|
||||
data JSONB NOT NULL,
|
||||
UNIQUE(user_id, device_id)
|
||||
);`)
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
rowData := []internal.DeviceData{
|
||||
{
|
||||
DeviceLists: internal.DeviceLists{
|
||||
New: map[string]int{"@bob:localhost": 2},
|
||||
Sent: map[string]int{},
|
||||
},
|
||||
ChangedBits: 2,
|
||||
OTKCounts: map[string]int{"bar": 42},
|
||||
FallbackKeyTypes: []string{"narp"},
|
||||
DeviceID: "ALICE",
|
||||
UserID: "@alice:localhost",
|
||||
},
|
||||
{
|
||||
DeviceLists: internal.DeviceLists{
|
||||
New: map[string]int{"@💣:localhost": 1, "@bomb:localhost": 2},
|
||||
Sent: map[string]int{"@sent:localhost": 1},
|
||||
},
|
||||
OTKCounts: map[string]int{"foo": 100},
|
||||
FallbackKeyTypes: []string{"yep"},
|
||||
DeviceID: "BOB",
|
||||
UserID: "@bob:localhost",
|
||||
},
|
||||
}
|
||||
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
for _, dd := range rowData {
|
||||
data, err := json.Marshal(dd)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, err = tx.ExecContext(ctx, `INSERT INTO syncv3_device_data (user_id, device_id, data) VALUES ($1, $2, $3)`, dd.UserID, dd.DeviceID, data)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
// validate that invalid data can be migrated upwards
|
||||
err = upCborDeviceData(ctx, tx)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
tx.Commit()
|
||||
|
||||
// ensure we can now select it
|
||||
table := state.NewDeviceDataTable(db)
|
||||
for _, want := range rowData {
|
||||
got, err := table.Select(want.UserID, want.DeviceID, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !reflect.DeepEqual(*got, want) {
|
||||
t.Fatalf("got %+v\nwant %+v", *got, want)
|
||||
}
|
||||
}
|
||||
|
||||
// and downgrade again
|
||||
tx, err = db.Begin()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = downCborDeviceData(ctx, tx)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// ensure it is what we originally inserted
|
||||
for _, want := range rowData {
|
||||
var got internal.DeviceData
|
||||
var gotBytes []byte
|
||||
err = tx.QueryRow(`SELECT data FROM syncv3_device_data WHERE user_id=$1 AND device_id=$2`, want.UserID, want.DeviceID).Scan(&gotBytes)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err = json.Unmarshal(gotBytes, &got); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
got.DeviceID = want.DeviceID
|
||||
got.UserID = want.UserID
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Fatalf("got %+v\nwant %+v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
tx.Commit()
|
||||
}
|
56
state/migrations/MIGRATIONS.md
Normal file
56
state/migrations/MIGRATIONS.md
Normal file
@ -0,0 +1,56 @@
|
||||
# Database migrations in the Sliding Sync Proxy
|
||||
|
||||
Database migrations are using https://github.com/pressly/goose, with an integrated `migrate` command in the `syncv3` binary.
|
||||
|
||||
All commands below require the `SYNCV3_DB` environment variable to determine which database to use:
|
||||
|
||||
```bash
|
||||
$ export SYNCV3_DB="user=postgres dbname=syncv3 sslmode=disable password=yourpassword"
|
||||
```
|
||||
|
||||
## Upgrading
|
||||
|
||||
It is sufficient to run the proxy itself, upgrading is done automatically. If you still have the need to upgrade manually, you can
|
||||
use one of the following commands to upgrade:
|
||||
|
||||
```bash
|
||||
# Check which versions have been applied
|
||||
$ ./syncv3 migrate status
|
||||
|
||||
# Execute all existing migrations
|
||||
$ ./syncv3 migrate up
|
||||
|
||||
# Upgrade to 20230802121023 (which is 20230802121023_device_data_jsonb.go - migrating from bytea to jsonb)
|
||||
$ ./sync3 migrate up-to 20230728114555
|
||||
|
||||
# Upgrade by one
|
||||
$ ./syncv3 migrate up-by-one
|
||||
```
|
||||
|
||||
## Downgrading
|
||||
|
||||
If you wish to downgrade, executing one of the following commands. If you downgrade, make sure you also start
|
||||
the required older version of the proxy, as otherwise the schema will automatically be upgraded again:
|
||||
|
||||
```bash
|
||||
# Check which versions have been applied
|
||||
$ ./syncv3 migrate status
|
||||
|
||||
# Undo the latest migration
|
||||
$ ./syncv3 migrate down
|
||||
|
||||
# Downgrade to 20230728114555 (which is 20230728114555_device_data_drop_id.sql - dropping the id column)
|
||||
$ ./sync3 migrate down-to 20230728114555
|
||||
```
|
||||
|
||||
## Creating new migrations
|
||||
|
||||
Migrations can either be created as plain SQL or as Go functions.
|
||||
|
||||
```bash
|
||||
# Create a new SQL migration with the name "mymigration"
|
||||
$ ./syncv3 migrate create mymigration sql
|
||||
|
||||
# Same as above, but as Go functions
|
||||
$ ./syncv3 migrate create mymigration go
|
||||
```
|
245
state/storage.go
245
state/storage.go
@ -6,7 +6,6 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/getsentry/sentry-go"
|
||||
|
||||
@ -38,6 +37,22 @@ type LatestEvents struct {
|
||||
LatestNID int64
|
||||
}
|
||||
|
||||
// DiscardIgnoredMessages modifies the struct in-place, replacing the Timeline with
|
||||
// a copy that has all ignored events omitted. The order of timelines is preserved.
|
||||
func (e *LatestEvents) DiscardIgnoredMessages(shouldIgnore func(sender string) bool) {
|
||||
// A little bit sad to be effectively doing a copy here---most of the time there
|
||||
// won't be any messages to ignore (and the timeline is likely short). But that copy
|
||||
// is unlikely to be a bottleneck.
|
||||
newTimeline := make([]json.RawMessage, 0, len(e.Timeline))
|
||||
for _, ev := range e.Timeline {
|
||||
parsed := gjson.ParseBytes(ev)
|
||||
if parsed.Get("state_key").Exists() || !shouldIgnore(parsed.Get("sender").Str) {
|
||||
newTimeline = append(newTimeline, ev)
|
||||
}
|
||||
}
|
||||
e.Timeline = newTimeline
|
||||
}
|
||||
|
||||
type Storage struct {
|
||||
Accumulator *Accumulator
|
||||
EventsTable *EventTable
|
||||
@ -58,9 +73,10 @@ func NewStorage(postgresURI string) *Storage {
|
||||
// TODO: if we panic(), will sentry have a chance to flush the event?
|
||||
logger.Panic().Err(err).Str("uri", postgresURI).Msg("failed to open SQL DB")
|
||||
}
|
||||
db.SetMaxOpenConns(100)
|
||||
db.SetMaxIdleConns(80)
|
||||
db.SetConnMaxLifetime(time.Hour)
|
||||
return NewStorageWithDB(db)
|
||||
}
|
||||
|
||||
func NewStorageWithDB(db *sqlx.DB) *Storage {
|
||||
acc := &Accumulator{
|
||||
db: db,
|
||||
roomsTable: NewRoomsTable(db),
|
||||
@ -178,26 +194,12 @@ func (s *Storage) GlobalSnapshot() (ss StartupSnapshot, err error) {
|
||||
|
||||
// Extract hero info for all rooms. Requires a prepared snapshot in order to be called.
|
||||
func (s *Storage) MetadataForAllRooms(txn *sqlx.Tx, tempTableName string, result map[string]internal.RoomMetadata) error {
|
||||
// Select the invited member counts
|
||||
rows, err := txn.Query(`
|
||||
SELECT room_id, count(state_key) FROM syncv3_events INNER JOIN ` + tempTableName + ` ON membership_nid=event_nid
|
||||
WHERE (membership='_invite' OR membership = 'invite') AND event_type='m.room.member' GROUP BY room_id`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
var roomID string
|
||||
var inviteCount int
|
||||
if err := rows.Scan(&roomID, &inviteCount); err != nil {
|
||||
return err
|
||||
}
|
||||
loadMetadata := func(roomID string) internal.RoomMetadata {
|
||||
metadata, ok := result[roomID]
|
||||
if !ok {
|
||||
metadata = *internal.NewRoomMetadata(roomID)
|
||||
}
|
||||
metadata.InviteCount = inviteCount
|
||||
result[roomID] = metadata
|
||||
return metadata
|
||||
}
|
||||
|
||||
// work out latest timestamps
|
||||
@ -206,10 +208,8 @@ func (s *Storage) MetadataForAllRooms(txn *sqlx.Tx, tempTableName string, result
|
||||
return err
|
||||
}
|
||||
for _, ev := range events {
|
||||
metadata, ok := result[ev.RoomID]
|
||||
if !ok {
|
||||
metadata = *internal.NewRoomMetadata(ev.RoomID)
|
||||
}
|
||||
metadata := loadMetadata(ev.RoomID)
|
||||
|
||||
// For a given room, we'll see many events (one for each event type in the
|
||||
// room's state). We need to pick the largest of these events' timestamps here.
|
||||
ts := gjson.ParseBytes(ev.JSON).Get("origin_server_ts").Uint()
|
||||
@ -232,68 +232,32 @@ func (s *Storage) MetadataForAllRooms(txn *sqlx.Tx, tempTableName string, result
|
||||
|
||||
// Select the name / canonical alias for all rooms
|
||||
roomIDToStateEvents, err := s.currentNotMembershipStateEventsInAllRooms(txn, []string{
|
||||
"m.room.name", "m.room.canonical_alias",
|
||||
"m.room.name", "m.room.canonical_alias", "m.room.avatar",
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load state events for all rooms: %s", err)
|
||||
}
|
||||
for roomID, stateEvents := range roomIDToStateEvents {
|
||||
metadata := result[roomID]
|
||||
metadata := loadMetadata(roomID)
|
||||
for _, ev := range stateEvents {
|
||||
if ev.Type == "m.room.name" && ev.StateKey == "" {
|
||||
metadata.NameEvent = gjson.ParseBytes(ev.JSON).Get("content.name").Str
|
||||
} else if ev.Type == "m.room.canonical_alias" && ev.StateKey == "" {
|
||||
metadata.CanonicalAlias = gjson.ParseBytes(ev.JSON).Get("content.alias").Str
|
||||
} else if ev.Type == "m.room.avatar" && ev.StateKey == "" {
|
||||
metadata.AvatarEvent = gjson.ParseBytes(ev.JSON).Get("content.url").Str
|
||||
}
|
||||
}
|
||||
result[roomID] = metadata
|
||||
}
|
||||
|
||||
// Select the most recent members for each room to serve as Heroes. The spec is ambiguous here:
|
||||
// "This should be the first 5 members of the room, ordered by stream ordering, which are joined or invited."
|
||||
// Unclear if this is the first 5 *most recent* (backwards) or forwards. For now we'll use the most recent
|
||||
// ones, and select 6 of them so we can always use 5 no matter who is requesting the room name.
|
||||
rows, err = txn.Query(`
|
||||
SELECT rf.* FROM (
|
||||
SELECT room_id, event, rank() OVER (
|
||||
PARTITION BY room_id ORDER BY event_nid DESC
|
||||
) FROM syncv3_events INNER JOIN ` + tempTableName + ` ON membership_nid=event_nid WHERE (
|
||||
membership='join' OR membership='invite' OR membership='_join'
|
||||
) AND event_type='m.room.member'
|
||||
) rf WHERE rank <= 6;`)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to query heroes: %s", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
seen := map[string]bool{}
|
||||
for rows.Next() {
|
||||
var roomID string
|
||||
var event json.RawMessage
|
||||
var rank int
|
||||
if err := rows.Scan(&roomID, &event, &rank); err != nil {
|
||||
return err
|
||||
}
|
||||
ev := gjson.ParseBytes(event)
|
||||
targetUser := ev.Get("state_key").Str
|
||||
key := roomID + " " + targetUser
|
||||
if seen[key] {
|
||||
continue
|
||||
}
|
||||
seen[key] = true
|
||||
metadata := result[roomID]
|
||||
metadata.Heroes = append(metadata.Heroes, internal.Hero{
|
||||
ID: targetUser,
|
||||
Name: ev.Get("content.displayname").Str,
|
||||
})
|
||||
result[roomID] = metadata
|
||||
}
|
||||
roomInfos, err := s.Accumulator.roomsTable.SelectRoomInfos(txn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to select room infos: %s", err)
|
||||
}
|
||||
var spaceRoomIDs []string
|
||||
for _, info := range roomInfos {
|
||||
metadata := result[info.ID]
|
||||
metadata := loadMetadata(info.ID)
|
||||
metadata.Encrypted = info.IsEncrypted
|
||||
metadata.UpgradedRoomID = info.UpgradedRoomID
|
||||
metadata.PredecessorRoomID = info.PredecessorRoomID
|
||||
@ -310,7 +274,13 @@ func (s *Storage) MetadataForAllRooms(txn *sqlx.Tx, tempTableName string, result
|
||||
return fmt.Errorf("failed to select space children: %s", err)
|
||||
}
|
||||
for roomID, relations := range spaceRoomToRelations {
|
||||
metadata := result[roomID]
|
||||
if _, exists := result[roomID]; !exists {
|
||||
// this can happen when you join a space (so it populates the spaces table) then leave the space,
|
||||
// so there are no joined members in the space so result doesn't include the room. In this case,
|
||||
// we don't want to have a stub metadata with just the space children, so skip it.
|
||||
continue
|
||||
}
|
||||
metadata := loadMetadata(roomID)
|
||||
metadata.ChildSpaceRooms = make(map[string]struct{}, len(relations))
|
||||
for _, r := range relations {
|
||||
// For now we only honour child state events, but we store all the mappings just in case.
|
||||
@ -353,12 +323,12 @@ func (s *Storage) currentNotMembershipStateEventsInAllRooms(txn *sqlx.Tx, eventT
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *Storage) Accumulate(roomID, prevBatch string, timeline []json.RawMessage) (numNew int, timelineNIDs []int64, err error) {
|
||||
func (s *Storage) Accumulate(userID, roomID, prevBatch string, timeline []json.RawMessage) (numNew int, timelineNIDs []int64, err error) {
|
||||
if len(timeline) == 0 {
|
||||
return 0, nil, nil
|
||||
}
|
||||
err = sqlutil.WithTransaction(s.Accumulator.db, func(txn *sqlx.Tx) error {
|
||||
numNew, timelineNIDs, err = s.Accumulator.Accumulate(txn, roomID, prevBatch, timeline)
|
||||
numNew, timelineNIDs, err = s.Accumulator.Accumulate(txn, userID, roomID, prevBatch, timeline)
|
||||
return err
|
||||
})
|
||||
return
|
||||
@ -545,7 +515,7 @@ func (s *Storage) RoomStateAfterEventPosition(ctx context.Context, roomIDs []str
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to form sql query: %s", err)
|
||||
}
|
||||
rows, err := s.Accumulator.db.Query(s.Accumulator.db.Rebind(query), args...)
|
||||
rows, err := txn.Query(txn.Rebind(query), args...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to execute query: %s", err)
|
||||
}
|
||||
@ -630,7 +600,7 @@ func (s *Storage) LatestEventsInRooms(userID string, roomIDs []string, to int64,
|
||||
}
|
||||
if earliestEventNID != 0 {
|
||||
// the oldest event needs a prev batch token, so find one now
|
||||
prevBatch, err := s.EventsTable.SelectClosestPrevBatch(roomID, earliestEventNID)
|
||||
prevBatch, err := s.EventsTable.SelectClosestPrevBatch(txn, roomID, earliestEventNID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to select prev_batch for room %s : %s", roomID, err)
|
||||
}
|
||||
@ -810,32 +780,121 @@ func (s *Storage) RoomMembershipDelta(roomID string, from, to int64, limit int)
|
||||
}
|
||||
|
||||
// Extract all rooms with joined members, and include the joined user list. Requires a prepared snapshot in order to be called.
|
||||
func (s *Storage) AllJoinedMembers(txn *sqlx.Tx, tempTableName string) (result map[string][]string, metadata map[string]internal.RoomMetadata, err error) {
|
||||
// Populates the join/invite count and heroes for the returned metadata.
|
||||
func (s *Storage) AllJoinedMembers(txn *sqlx.Tx, tempTableName string) (joinedMembers map[string][]string, metadata map[string]internal.RoomMetadata, err error) {
|
||||
// Select the most recent members for each room to serve as Heroes. The spec is ambiguous here:
|
||||
// "This should be the first 5 members of the room, ordered by stream ordering, which are joined or invited."
|
||||
// Unclear if this is the first 5 *most recent* (backwards) or forwards. For now we'll use the most recent
|
||||
// ones, and select 6 of them so we can always use 5 no matter who is requesting the room name.
|
||||
rows, err := txn.Query(
|
||||
`SELECT room_id, state_key from ` + tempTableName + ` INNER JOIN syncv3_events on membership_nid = event_nid WHERE membership='join' OR membership='_join' ORDER BY event_nid ASC`,
|
||||
`SELECT membership_nid, room_id, state_key, membership from ` + tempTableName + ` INNER JOIN syncv3_events
|
||||
on membership_nid = event_nid WHERE membership='join' OR membership='_join' OR membership='invite' OR membership='_invite' ORDER BY event_nid ASC`,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
result = make(map[string][]string)
|
||||
joinedMembers = make(map[string][]string)
|
||||
inviteCounts := make(map[string]int)
|
||||
heroNIDs := make(map[string]*circularSlice)
|
||||
var stateKey string
|
||||
var membership string
|
||||
var roomID string
|
||||
var joinedUserID string
|
||||
var nid int64
|
||||
for rows.Next() {
|
||||
if err := rows.Scan(&roomID, &joinedUserID); err != nil {
|
||||
if err := rows.Scan(&nid, &roomID, &stateKey, &membership); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
users := result[roomID]
|
||||
users = append(users, joinedUserID)
|
||||
result[roomID] = users
|
||||
heroes := heroNIDs[roomID]
|
||||
if heroes == nil {
|
||||
heroes = &circularSlice{max: 6}
|
||||
heroNIDs[roomID] = heroes
|
||||
}
|
||||
switch membership {
|
||||
case "join":
|
||||
fallthrough
|
||||
case "_join":
|
||||
users := joinedMembers[roomID]
|
||||
users = append(users, stateKey)
|
||||
joinedMembers[roomID] = users
|
||||
heroes.append(nid)
|
||||
case "invite":
|
||||
fallthrough
|
||||
case "_invite":
|
||||
inviteCounts[roomID] = inviteCounts[roomID] + 1
|
||||
heroes.append(nid)
|
||||
}
|
||||
}
|
||||
|
||||
// now select the membership events for the heroes
|
||||
var allHeroNIDs []int64
|
||||
for _, nids := range heroNIDs {
|
||||
allHeroNIDs = append(allHeroNIDs, nids.vals...)
|
||||
}
|
||||
heroEvents, err := s.EventsTable.SelectByNIDs(txn, true, allHeroNIDs)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
heroes := make(map[string][]internal.Hero)
|
||||
// loop backwards so the most recent hero is first in the hero list
|
||||
for i := len(heroEvents) - 1; i >= 0; i-- {
|
||||
ev := heroEvents[i]
|
||||
evJSON := gjson.ParseBytes(ev.JSON)
|
||||
roomHeroes := heroes[ev.RoomID]
|
||||
roomHeroes = append(roomHeroes, internal.Hero{
|
||||
ID: ev.StateKey,
|
||||
Name: evJSON.Get("content.displayname").Str,
|
||||
Avatar: evJSON.Get("content.avatar_url").Str,
|
||||
})
|
||||
heroes[ev.RoomID] = roomHeroes
|
||||
}
|
||||
|
||||
metadata = make(map[string]internal.RoomMetadata)
|
||||
for roomID, joinedMembers := range result {
|
||||
for roomID, members := range joinedMembers {
|
||||
m := internal.NewRoomMetadata(roomID)
|
||||
m.JoinCount = len(joinedMembers)
|
||||
m.JoinCount = len(members)
|
||||
m.InviteCount = inviteCounts[roomID]
|
||||
m.Heroes = heroes[roomID]
|
||||
metadata[roomID] = *m
|
||||
}
|
||||
return result, metadata, nil
|
||||
return joinedMembers, metadata, nil
|
||||
}
|
||||
|
||||
func (s *Storage) LatestEventNIDInRooms(roomIDs []string, highestNID int64) (roomToNID map[string]int64, err error) {
|
||||
roomToNID = make(map[string]int64)
|
||||
err = sqlutil.WithTransaction(s.Accumulator.db, func(txn *sqlx.Tx) error {
|
||||
// Pull out the latest nids for all the rooms. If they are < highestNID then use them, else we need to query the
|
||||
// events table (slow) for the latest nid in this room which is < highestNID.
|
||||
fastRoomToLatestNIDs, err := s.Accumulator.roomsTable.LatestNIDs(txn, roomIDs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var slowRooms []string
|
||||
for _, roomID := range roomIDs {
|
||||
nid := fastRoomToLatestNIDs[roomID]
|
||||
if nid > 0 && nid <= highestNID {
|
||||
roomToNID[roomID] = nid
|
||||
} else {
|
||||
// we need to do a slow query for this
|
||||
slowRooms = append(slowRooms, roomID)
|
||||
}
|
||||
}
|
||||
|
||||
if len(slowRooms) == 0 {
|
||||
return nil // no work to do
|
||||
}
|
||||
logger.Warn().Int("slow_rooms", len(slowRooms)).Msg("LatestEventNIDInRooms: pos value provided is far behind the database copy, performance degraded")
|
||||
|
||||
slowRoomToLatestNIDs, err := s.EventsTable.LatestEventNIDInRooms(txn, slowRooms, highestNID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for roomID, nid := range slowRoomToLatestNIDs {
|
||||
roomToNID[roomID] = nid
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return roomToNID, err
|
||||
}
|
||||
|
||||
// Returns a map from joined room IDs to EventMetadata, which is nil iff a non-nil error
|
||||
@ -895,3 +954,27 @@ func (s *Storage) Teardown() {
|
||||
panic("Storage.Teardown: " + err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// circularSlice is a slice which can be appended to which will wraparound at `max`.
|
||||
// Mostly useful for lazily calculating heroes. The values returned aren't sorted.
|
||||
type circularSlice struct {
|
||||
i int
|
||||
vals []int64
|
||||
max int
|
||||
}
|
||||
|
||||
func (s *circularSlice) append(val int64) {
|
||||
if len(s.vals) < s.max {
|
||||
// populate up to max
|
||||
s.vals = append(s.vals, val)
|
||||
s.i++
|
||||
return
|
||||
}
|
||||
// wraparound
|
||||
if s.i == s.max {
|
||||
s.i = 0
|
||||
}
|
||||
// replace this entry
|
||||
s.vals[s.i] = val
|
||||
s.i++
|
||||
}
|
||||
|
@ -4,6 +4,7 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"sort"
|
||||
"testing"
|
||||
@ -30,7 +31,7 @@ func TestStorageRoomStateBeforeAndAfterEventPosition(t *testing.T) {
|
||||
testutils.NewStateEvent(t, "m.room.join_rules", "", alice, map[string]interface{}{"join_rule": "invite"}),
|
||||
testutils.NewStateEvent(t, "m.room.member", bob, alice, map[string]interface{}{"membership": "invite"}),
|
||||
}
|
||||
_, latestNIDs, err := store.Accumulate(roomID, "", events)
|
||||
_, latestNIDs, err := store.Accumulate(userID, roomID, "", events)
|
||||
if err != nil {
|
||||
t.Fatalf("Accumulate returned error: %s", err)
|
||||
}
|
||||
@ -160,7 +161,7 @@ func TestStorageJoinedRoomsAfterPosition(t *testing.T) {
|
||||
var latestNIDs []int64
|
||||
var err error
|
||||
for roomID, eventMap := range roomIDToEventMap {
|
||||
_, latestNIDs, err = store.Accumulate(roomID, "", eventMap)
|
||||
_, latestNIDs, err = store.Accumulate(userID, roomID, "", eventMap)
|
||||
if err != nil {
|
||||
t.Fatalf("Accumulate on %s failed: %s", roomID, err)
|
||||
}
|
||||
@ -210,18 +211,19 @@ func TestStorageJoinedRoomsAfterPosition(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
newMetadata := func(roomID string, joinCount int) internal.RoomMetadata {
|
||||
newMetadata := func(roomID string, joinCount, inviteCount int) internal.RoomMetadata {
|
||||
m := internal.NewRoomMetadata(roomID)
|
||||
m.JoinCount = joinCount
|
||||
m.InviteCount = inviteCount
|
||||
return *m
|
||||
}
|
||||
|
||||
// also test MetadataForAllRooms
|
||||
roomIDToMetadata := map[string]internal.RoomMetadata{
|
||||
joinedRoomID: newMetadata(joinedRoomID, 1),
|
||||
invitedRoomID: newMetadata(invitedRoomID, 1),
|
||||
banRoomID: newMetadata(banRoomID, 1),
|
||||
bobJoinedRoomID: newMetadata(bobJoinedRoomID, 2),
|
||||
joinedRoomID: newMetadata(joinedRoomID, 1, 0),
|
||||
invitedRoomID: newMetadata(invitedRoomID, 1, 1),
|
||||
banRoomID: newMetadata(banRoomID, 1, 0),
|
||||
bobJoinedRoomID: newMetadata(bobJoinedRoomID, 2, 0),
|
||||
}
|
||||
|
||||
tempTableName, err := store.PrepareSnapshot(txn)
|
||||
@ -349,7 +351,7 @@ func TestVisibleEventNIDsBetween(t *testing.T) {
|
||||
},
|
||||
}
|
||||
for _, tl := range timelineInjections {
|
||||
numNew, _, err := store.Accumulate(tl.RoomID, "", tl.Events)
|
||||
numNew, _, err := store.Accumulate(userID, tl.RoomID, "", tl.Events)
|
||||
if err != nil {
|
||||
t.Fatalf("Accumulate on %s failed: %s", tl.RoomID, err)
|
||||
}
|
||||
@ -452,7 +454,7 @@ func TestVisibleEventNIDsBetween(t *testing.T) {
|
||||
t.Fatalf("LatestEventNID: %s", err)
|
||||
}
|
||||
for _, tl := range timelineInjections {
|
||||
numNew, _, err := store.Accumulate(tl.RoomID, "", tl.Events)
|
||||
numNew, _, err := store.Accumulate(userID, tl.RoomID, "", tl.Events)
|
||||
if err != nil {
|
||||
t.Fatalf("Accumulate on %s failed: %s", tl.RoomID, err)
|
||||
}
|
||||
@ -532,7 +534,7 @@ func TestStorageLatestEventsInRoomsPrevBatch(t *testing.T) {
|
||||
}
|
||||
eventIDs := []string{}
|
||||
for _, timeline := range timelines {
|
||||
_, _, err = store.Accumulate(roomID, timeline.prevBatch, timeline.timeline)
|
||||
_, _, err = store.Accumulate(userID, roomID, timeline.prevBatch, timeline.timeline)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to accumulate: %s", err)
|
||||
}
|
||||
@ -566,10 +568,15 @@ func TestStorageLatestEventsInRoomsPrevBatch(t *testing.T) {
|
||||
wantPrevBatch := wantPrevBatches[i]
|
||||
eventNID := idsToNIDs[eventIDs[i]]
|
||||
// closest batch to the last event in the chunk (latest nid) is always the next prev batch token
|
||||
pb, err := store.EventsTable.SelectClosestPrevBatch(roomID, eventNID)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to SelectClosestPrevBatch: %s", err)
|
||||
}
|
||||
var pb string
|
||||
_ = sqlutil.WithTransaction(store.DB, func(txn *sqlx.Tx) (err error) {
|
||||
pb, err = store.EventsTable.SelectClosestPrevBatch(txn, roomID, eventNID)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to SelectClosestPrevBatch: %s", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
if pb != wantPrevBatch {
|
||||
t.Fatalf("SelectClosestPrevBatch: got %v want %v", pb, wantPrevBatch)
|
||||
}
|
||||
@ -682,6 +689,189 @@ func TestGlobalSnapshot(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAllJoinedMembers(t *testing.T) {
|
||||
assertNoError(t, cleanDB(t))
|
||||
store := NewStorage(postgresConnectionString)
|
||||
defer store.Teardown()
|
||||
|
||||
alice := "@alice:localhost"
|
||||
bob := "@bob:localhost"
|
||||
charlie := "@charlie:localhost"
|
||||
doris := "@doris:localhost"
|
||||
eve := "@eve:localhost"
|
||||
frank := "@frank:localhost"
|
||||
|
||||
// Alice is always the creator and the inviter for simplicity's sake
|
||||
testCases := []struct {
|
||||
Name string
|
||||
InitMemberships [][2]string
|
||||
AccumulateMemberships [][2]string
|
||||
RoomID string // tests set this dynamically
|
||||
WantJoined []string
|
||||
WantInvited []string
|
||||
}{
|
||||
{
|
||||
Name: "basic joined users",
|
||||
InitMemberships: [][2]string{{alice, "join"}},
|
||||
AccumulateMemberships: [][2]string{{bob, "join"}},
|
||||
WantJoined: []string{alice, bob},
|
||||
},
|
||||
{
|
||||
Name: "basic invited users",
|
||||
InitMemberships: [][2]string{{alice, "join"}, {charlie, "invite"}},
|
||||
AccumulateMemberships: [][2]string{{bob, "invite"}},
|
||||
WantJoined: []string{alice},
|
||||
WantInvited: []string{bob, charlie},
|
||||
},
|
||||
{
|
||||
Name: "many join/leaves, use latest",
|
||||
InitMemberships: [][2]string{{alice, "join"}, {charlie, "join"}, {frank, "join"}},
|
||||
AccumulateMemberships: [][2]string{{bob, "join"}, {charlie, "leave"}, {frank, "leave"}, {charlie, "join"}, {eve, "join"}},
|
||||
WantJoined: []string{alice, bob, charlie, eve},
|
||||
},
|
||||
{
|
||||
Name: "many invites, use latest",
|
||||
InitMemberships: [][2]string{{alice, "join"}, {doris, "join"}},
|
||||
AccumulateMemberships: [][2]string{{doris, "leave"}, {charlie, "invite"}, {doris, "invite"}},
|
||||
WantJoined: []string{alice},
|
||||
WantInvited: []string{charlie, doris},
|
||||
},
|
||||
{
|
||||
Name: "invite and rejection in accumulate",
|
||||
InitMemberships: [][2]string{{alice, "join"}},
|
||||
AccumulateMemberships: [][2]string{{frank, "invite"}, {frank, "leave"}},
|
||||
WantJoined: []string{alice},
|
||||
},
|
||||
{
|
||||
Name: "invite in initial, rejection in accumulate",
|
||||
InitMemberships: [][2]string{{alice, "join"}, {frank, "invite"}},
|
||||
AccumulateMemberships: [][2]string{{frank, "leave"}},
|
||||
WantJoined: []string{alice},
|
||||
},
|
||||
}
|
||||
|
||||
serialise := func(memberships [][2]string) []json.RawMessage {
|
||||
var result []json.RawMessage
|
||||
for _, userWithMembership := range memberships {
|
||||
target := userWithMembership[0]
|
||||
sender := userWithMembership[0]
|
||||
membership := userWithMembership[1]
|
||||
if membership == "invite" {
|
||||
// Alice is always the inviter
|
||||
sender = alice
|
||||
}
|
||||
result = append(result, testutils.NewStateEvent(t, "m.room.member", target, sender, map[string]interface{}{
|
||||
"membership": membership,
|
||||
}))
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
for i, tc := range testCases {
|
||||
roomID := fmt.Sprintf("!TestAllJoinedMembers_%d:localhost", i)
|
||||
_, err := store.Initialise(roomID, append([]json.RawMessage{
|
||||
testutils.NewStateEvent(t, "m.room.create", "", alice, map[string]interface{}{
|
||||
"creator": alice, // alice is always the creator
|
||||
}),
|
||||
}, serialise(tc.InitMemberships)...))
|
||||
assertNoError(t, err)
|
||||
|
||||
_, _, err = store.Accumulate(userID, roomID, "foo", serialise(tc.AccumulateMemberships))
|
||||
assertNoError(t, err)
|
||||
testCases[i].RoomID = roomID // remember this for later
|
||||
}
|
||||
|
||||
// should get all joined members correctly
|
||||
var joinedMembers map[string][]string
|
||||
// should set join/invite counts correctly
|
||||
var roomMetadatas map[string]internal.RoomMetadata
|
||||
err := sqlutil.WithTransaction(store.DB, func(txn *sqlx.Tx) error {
|
||||
tableName, err := store.PrepareSnapshot(txn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
joinedMembers, roomMetadatas, err = store.AllJoinedMembers(txn, tableName)
|
||||
return err
|
||||
})
|
||||
assertNoError(t, err)
|
||||
|
||||
for _, tc := range testCases {
|
||||
roomID := tc.RoomID
|
||||
if roomID == "" {
|
||||
t.Fatalf("test case has no room id set: %+v", tc)
|
||||
}
|
||||
// make sure joined members match
|
||||
sort.Strings(joinedMembers[roomID])
|
||||
sort.Strings(tc.WantJoined)
|
||||
if !reflect.DeepEqual(joinedMembers[roomID], tc.WantJoined) {
|
||||
t.Errorf("%v: got joined members %v want %v", tc.Name, joinedMembers[roomID], tc.WantJoined)
|
||||
}
|
||||
// make sure join/invite counts match
|
||||
wantJoined := len(tc.WantJoined)
|
||||
wantInvited := len(tc.WantInvited)
|
||||
metadata, ok := roomMetadatas[roomID]
|
||||
if !ok {
|
||||
t.Fatalf("no room metadata for room %v", roomID)
|
||||
}
|
||||
if metadata.InviteCount != wantInvited {
|
||||
t.Errorf("%v: got invite count %d want %d", tc.Name, metadata.InviteCount, wantInvited)
|
||||
}
|
||||
if metadata.JoinCount != wantJoined {
|
||||
t.Errorf("%v: got join count %d want %d", tc.Name, metadata.JoinCount, wantJoined)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCircularSlice(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
max int
|
||||
appends []int64
|
||||
want []int64 // these get sorted in the test
|
||||
}{
|
||||
{
|
||||
name: "wraparound",
|
||||
max: 5,
|
||||
appends: []int64{9, 8, 7, 6, 5, 4, 3, 2},
|
||||
want: []int64{2, 3, 4, 5, 6},
|
||||
},
|
||||
{
|
||||
name: "exact",
|
||||
max: 5,
|
||||
appends: []int64{9, 8, 7, 6, 5},
|
||||
want: []int64{5, 6, 7, 8, 9},
|
||||
},
|
||||
{
|
||||
name: "unfilled",
|
||||
max: 5,
|
||||
appends: []int64{9, 8, 7},
|
||||
want: []int64{7, 8, 9},
|
||||
},
|
||||
{
|
||||
name: "wraparound x2",
|
||||
max: 5,
|
||||
appends: []int64{9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 10},
|
||||
want: []int64{0, 1, 2, 3, 10},
|
||||
},
|
||||
}
|
||||
for _, tc := range testCases {
|
||||
cs := &circularSlice{
|
||||
max: tc.max,
|
||||
}
|
||||
for _, val := range tc.appends {
|
||||
cs.append(val)
|
||||
}
|
||||
sort.Slice(cs.vals, func(i, j int) bool {
|
||||
return cs.vals[i] < cs.vals[j]
|
||||
})
|
||||
if !reflect.DeepEqual(cs.vals, tc.want) {
|
||||
t.Errorf("%s: got %v want %v", tc.name, cs.vals, tc.want)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func cleanDB(t *testing.T) error {
|
||||
// make a fresh DB which is unpolluted from other tests
|
||||
db, close := connectToDB(t)
|
||||
|
@ -7,6 +7,7 @@ import (
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/tidwall/gjson"
|
||||
@ -69,7 +70,15 @@ func (v *HTTPClient) DoSyncV2(ctx context.Context, accessToken, since string, is
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("DoSyncV2: NewRequest failed: %w", err)
|
||||
}
|
||||
res, err := v.Client.Do(req)
|
||||
var res *http.Response
|
||||
if isFirst {
|
||||
longTimeoutClient := &http.Client{
|
||||
Timeout: 30 * time.Minute,
|
||||
}
|
||||
res, err = longTimeoutClient.Do(req)
|
||||
} else {
|
||||
res, err = v.Client.Do(req)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("DoSyncV2: request failed: %w", err)
|
||||
}
|
||||
|
90
sync2/device_data_ticker.go
Normal file
90
sync2/device_data_ticker.go
Normal file
@ -0,0 +1,90 @@
|
||||
package sync2
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/sliding-sync/pubsub"
|
||||
)
|
||||
|
||||
// This struct remembers user+device IDs to notify for then periodically
|
||||
// emits them all to the caller. Use to rate limit the frequency of device list
|
||||
// updates.
|
||||
type DeviceDataTicker struct {
|
||||
// data structures to periodically notify downstream about device data updates
|
||||
// The ticker controls the frequency of updates. The done channel is used to stop ticking
|
||||
// and clean up the goroutine. The notify map contains the values to notify for.
|
||||
ticker *time.Ticker
|
||||
done chan struct{}
|
||||
notifyMap *sync.Map // map of PollerID to bools, unwrapped when notifying
|
||||
fn func(payload *pubsub.V2DeviceData)
|
||||
}
|
||||
|
||||
// Create a new device data ticker, which batches calls to Remember and invokes a callback every
|
||||
// d duration. If d is 0, no batching is performed and the callback is invoked synchronously, which
|
||||
// is useful for testing.
|
||||
func NewDeviceDataTicker(d time.Duration) *DeviceDataTicker {
|
||||
ddt := &DeviceDataTicker{
|
||||
done: make(chan struct{}),
|
||||
notifyMap: &sync.Map{},
|
||||
}
|
||||
if d != 0 {
|
||||
ddt.ticker = time.NewTicker(d)
|
||||
}
|
||||
return ddt
|
||||
}
|
||||
|
||||
// Stop ticking.
|
||||
func (t *DeviceDataTicker) Stop() {
|
||||
if t.ticker != nil {
|
||||
t.ticker.Stop()
|
||||
}
|
||||
close(t.done)
|
||||
}
|
||||
|
||||
// Set the function which should be called when the tick happens.
|
||||
func (t *DeviceDataTicker) SetCallback(fn func(payload *pubsub.V2DeviceData)) {
|
||||
t.fn = fn
|
||||
}
|
||||
|
||||
// Remember this user/device ID, and emit it later on.
|
||||
func (t *DeviceDataTicker) Remember(pid PollerID) {
|
||||
t.notifyMap.Store(pid, true)
|
||||
if t.ticker == nil {
|
||||
t.emitUpdate()
|
||||
}
|
||||
}
|
||||
|
||||
func (t *DeviceDataTicker) emitUpdate() {
|
||||
var p pubsub.V2DeviceData
|
||||
p.UserIDToDeviceIDs = make(map[string][]string)
|
||||
// populate the pubsub payload
|
||||
t.notifyMap.Range(func(key, value any) bool {
|
||||
pid := key.(PollerID)
|
||||
devices := p.UserIDToDeviceIDs[pid.UserID]
|
||||
devices = append(devices, pid.DeviceID)
|
||||
p.UserIDToDeviceIDs[pid.UserID] = devices
|
||||
// clear the map of this value
|
||||
t.notifyMap.Delete(key)
|
||||
return true // keep enumerating
|
||||
})
|
||||
// notify if we have entries
|
||||
if len(p.UserIDToDeviceIDs) > 0 {
|
||||
t.fn(&p)
|
||||
}
|
||||
}
|
||||
|
||||
// Blocks forever, ticking until Stop() is called.
|
||||
func (t *DeviceDataTicker) Run() {
|
||||
if t.ticker == nil {
|
||||
return
|
||||
}
|
||||
for {
|
||||
select {
|
||||
case <-t.done:
|
||||
return
|
||||
case <-t.ticker.C:
|
||||
t.emitUpdate()
|
||||
}
|
||||
}
|
||||
}
|
125
sync2/device_data_ticker_test.go
Normal file
125
sync2/device_data_ticker_test.go
Normal file
@ -0,0 +1,125 @@
|
||||
package sync2
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"sort"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/sliding-sync/pubsub"
|
||||
)
|
||||
|
||||
func TestDeviceTickerBasic(t *testing.T) {
|
||||
duration := time.Millisecond
|
||||
ticker := NewDeviceDataTicker(duration)
|
||||
var payloads []*pubsub.V2DeviceData
|
||||
ticker.SetCallback(func(payload *pubsub.V2DeviceData) {
|
||||
payloads = append(payloads, payload)
|
||||
})
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
t.Log("starting the ticker")
|
||||
ticker.Run()
|
||||
wg.Done()
|
||||
}()
|
||||
time.Sleep(duration * 2) // wait until the ticker is consuming
|
||||
t.Log("remembering a poller")
|
||||
ticker.Remember(PollerID{
|
||||
UserID: "a",
|
||||
DeviceID: "b",
|
||||
})
|
||||
time.Sleep(duration * 2)
|
||||
if len(payloads) != 1 {
|
||||
t.Fatalf("expected 1 callback, got %d", len(payloads))
|
||||
}
|
||||
want := map[string][]string{
|
||||
"a": {"b"},
|
||||
}
|
||||
assertPayloadEqual(t, payloads[0].UserIDToDeviceIDs, want)
|
||||
// check stopping works
|
||||
payloads = []*pubsub.V2DeviceData{}
|
||||
ticker.Stop()
|
||||
wg.Wait()
|
||||
time.Sleep(duration * 2)
|
||||
if len(payloads) != 0 {
|
||||
t.Fatalf("got extra payloads: %+v", payloads)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeviceTickerBatchesCorrectly(t *testing.T) {
|
||||
duration := 100 * time.Millisecond
|
||||
ticker := NewDeviceDataTicker(duration)
|
||||
var payloads []*pubsub.V2DeviceData
|
||||
ticker.SetCallback(func(payload *pubsub.V2DeviceData) {
|
||||
payloads = append(payloads, payload)
|
||||
})
|
||||
go ticker.Run()
|
||||
defer ticker.Stop()
|
||||
ticker.Remember(PollerID{
|
||||
UserID: "a",
|
||||
DeviceID: "b",
|
||||
})
|
||||
ticker.Remember(PollerID{
|
||||
UserID: "a",
|
||||
DeviceID: "bb", // different device, same user
|
||||
})
|
||||
ticker.Remember(PollerID{
|
||||
UserID: "a",
|
||||
DeviceID: "b", // dupe poller ID
|
||||
})
|
||||
ticker.Remember(PollerID{
|
||||
UserID: "x",
|
||||
DeviceID: "y", // new device and user
|
||||
})
|
||||
time.Sleep(duration * 2)
|
||||
if len(payloads) != 1 {
|
||||
t.Fatalf("expected 1 callback, got %d", len(payloads))
|
||||
}
|
||||
want := map[string][]string{
|
||||
"a": {"b", "bb"},
|
||||
"x": {"y"},
|
||||
}
|
||||
assertPayloadEqual(t, payloads[0].UserIDToDeviceIDs, want)
|
||||
}
|
||||
|
||||
func TestDeviceTickerForgetsAfterEmitting(t *testing.T) {
|
||||
duration := time.Millisecond
|
||||
ticker := NewDeviceDataTicker(duration)
|
||||
var payloads []*pubsub.V2DeviceData
|
||||
|
||||
ticker.SetCallback(func(payload *pubsub.V2DeviceData) {
|
||||
payloads = append(payloads, payload)
|
||||
})
|
||||
ticker.Remember(PollerID{
|
||||
UserID: "a",
|
||||
DeviceID: "b",
|
||||
})
|
||||
|
||||
go ticker.Run()
|
||||
defer ticker.Stop()
|
||||
ticker.Remember(PollerID{
|
||||
UserID: "a",
|
||||
DeviceID: "b",
|
||||
})
|
||||
time.Sleep(10 * duration)
|
||||
if len(payloads) != 1 {
|
||||
t.Fatalf("got %d payloads, want 1", len(payloads))
|
||||
}
|
||||
}
|
||||
|
||||
func assertPayloadEqual(t *testing.T, got, want map[string][]string) {
|
||||
t.Helper()
|
||||
if len(got) != len(want) {
|
||||
t.Fatalf("got %+v\nwant %+v\n", got, want)
|
||||
}
|
||||
for userID, wantDeviceIDs := range want {
|
||||
gotDeviceIDs := got[userID]
|
||||
sort.Strings(wantDeviceIDs)
|
||||
sort.Strings(gotDeviceIDs)
|
||||
if !reflect.DeepEqual(gotDeviceIDs, wantDeviceIDs) {
|
||||
t.Errorf("user %v got devices %v want %v", userID, gotDeviceIDs, wantDeviceIDs)
|
||||
}
|
||||
}
|
||||
}
|
@ -32,8 +32,8 @@ func NewDevicesTable(db *sqlx.DB) *DevicesTable {
|
||||
|
||||
// InsertDevice creates a new devices row with a blank since token if no such row
|
||||
// exists. Otherwise, it does nothing.
|
||||
func (t *DevicesTable) InsertDevice(userID, deviceID string) error {
|
||||
_, err := t.db.Exec(
|
||||
func (t *DevicesTable) InsertDevice(txn *sqlx.Tx, userID, deviceID string) error {
|
||||
_, err := txn.Exec(
|
||||
` INSERT INTO syncv3_sync2_devices(user_id, device_id, since) VALUES($1,$2,$3)
|
||||
ON CONFLICT (user_id, device_id) DO NOTHING`,
|
||||
userID, deviceID, "",
|
||||
|
@ -2,6 +2,7 @@ package sync2
|
||||
|
||||
import (
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/matrix-org/sliding-sync/sqlutil"
|
||||
"os"
|
||||
"sort"
|
||||
"testing"
|
||||
@ -41,18 +42,25 @@ func TestDevicesTableSinceColumn(t *testing.T) {
|
||||
aliceSecret1 := "mysecret1"
|
||||
aliceSecret2 := "mysecret2"
|
||||
|
||||
t.Log("Insert two tokens for Alice.")
|
||||
aliceToken, err := tokens.Insert(aliceSecret1, alice, aliceDevice, time.Now())
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to Insert token: %s", err)
|
||||
}
|
||||
aliceToken2, err := tokens.Insert(aliceSecret2, alice, aliceDevice, time.Now())
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to Insert token: %s", err)
|
||||
}
|
||||
var aliceToken, aliceToken2 *Token
|
||||
_ = sqlutil.WithTransaction(db, func(txn *sqlx.Tx) (err error) {
|
||||
t.Log("Insert two tokens for Alice.")
|
||||
aliceToken, err = tokens.Insert(txn, aliceSecret1, alice, aliceDevice, time.Now())
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to Insert token: %s", err)
|
||||
}
|
||||
aliceToken2, err = tokens.Insert(txn, aliceSecret2, alice, aliceDevice, time.Now())
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to Insert token: %s", err)
|
||||
}
|
||||
|
||||
t.Log("Add a devices row for Alice")
|
||||
err = devices.InsertDevice(alice, aliceDevice)
|
||||
t.Log("Add a devices row for Alice")
|
||||
err = devices.InsertDevice(txn, alice, aliceDevice)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to Insert device: %s", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
t.Log("Pretend we're about to start a poller. Fetch Alice's token along with the since value tracked by the devices table.")
|
||||
accessToken, since, err := tokens.GetTokenAndSince(alice, aliceDevice, aliceToken.AccessTokenHash)
|
||||
@ -104,39 +112,49 @@ func TestTokenForEachDevice(t *testing.T) {
|
||||
chris := "chris"
|
||||
chrisDevice := "chris_desktop"
|
||||
|
||||
t.Log("Add a device for Alice, Bob and Chris.")
|
||||
err := devices.InsertDevice(alice, aliceDevice)
|
||||
if err != nil {
|
||||
t.Fatalf("InsertDevice returned error: %s", err)
|
||||
}
|
||||
err = devices.InsertDevice(bob, bobDevice)
|
||||
if err != nil {
|
||||
t.Fatalf("InsertDevice returned error: %s", err)
|
||||
}
|
||||
err = devices.InsertDevice(chris, chrisDevice)
|
||||
if err != nil {
|
||||
t.Fatalf("InsertDevice returned error: %s", err)
|
||||
}
|
||||
_ = sqlutil.WithTransaction(db, func(txn *sqlx.Tx) error {
|
||||
t.Log("Add a device for Alice, Bob and Chris.")
|
||||
err := devices.InsertDevice(txn, alice, aliceDevice)
|
||||
if err != nil {
|
||||
t.Fatalf("InsertDevice returned error: %s", err)
|
||||
}
|
||||
err = devices.InsertDevice(txn, bob, bobDevice)
|
||||
if err != nil {
|
||||
t.Fatalf("InsertDevice returned error: %s", err)
|
||||
}
|
||||
err = devices.InsertDevice(txn, chris, chrisDevice)
|
||||
if err != nil {
|
||||
t.Fatalf("InsertDevice returned error: %s", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
t.Log("Mark Alice's device with a since token.")
|
||||
sinceValue := "s-1-2-3-4"
|
||||
devices.UpdateDeviceSince(alice, aliceDevice, sinceValue)
|
||||
err := devices.UpdateDeviceSince(alice, aliceDevice, sinceValue)
|
||||
if err != nil {
|
||||
t.Fatalf("UpdateDeviceSince returned error: %s", err)
|
||||
}
|
||||
|
||||
t.Log("Insert 2 tokens for Alice, one for Bob and none for Chris.")
|
||||
aliceLastSeen1 := time.Now()
|
||||
_, err = tokens.Insert("alice_secret", alice, aliceDevice, aliceLastSeen1)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to Insert token: %s", err)
|
||||
}
|
||||
aliceLastSeen2 := aliceLastSeen1.Add(1 * time.Minute)
|
||||
aliceToken2, err := tokens.Insert("alice_secret2", alice, aliceDevice, aliceLastSeen2)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to Insert token: %s", err)
|
||||
}
|
||||
bobToken, err := tokens.Insert("bob_secret", bob, bobDevice, time.Time{})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to Insert token: %s", err)
|
||||
}
|
||||
var aliceToken2, bobToken *Token
|
||||
_ = sqlutil.WithTransaction(db, func(txn *sqlx.Tx) error {
|
||||
t.Log("Insert 2 tokens for Alice, one for Bob and none for Chris.")
|
||||
aliceLastSeen1 := time.Now()
|
||||
_, err = tokens.Insert(txn, "alice_secret", alice, aliceDevice, aliceLastSeen1)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to Insert token: %s", err)
|
||||
}
|
||||
aliceLastSeen2 := aliceLastSeen1.Add(1 * time.Minute)
|
||||
aliceToken2, err = tokens.Insert(txn, "alice_secret2", alice, aliceDevice, aliceLastSeen2)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to Insert token: %s", err)
|
||||
}
|
||||
bobToken, err = tokens.Insert(txn, "bob_secret", bob, bobDevice, time.Time{})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to Insert token: %s", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
t.Log("Fetch a token for every device")
|
||||
gotTokens, err := tokens.TokenForEachDevice(nil)
|
||||
|
@ -4,11 +4,13 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/matrix-org/sliding-sync/sqlutil"
|
||||
"hash/fnv"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/matrix-org/sliding-sync/sqlutil"
|
||||
|
||||
"github.com/getsentry/sentry-go"
|
||||
|
||||
@ -30,38 +32,48 @@ var logger = zerolog.New(os.Stdout).With().Timestamp().Logger().Output(zerolog.C
|
||||
// processing v2 data (as a sync2.V2DataReceiver) and publishing updates (pubsub.Payload to V2Listeners);
|
||||
// and receiving and processing EnsurePolling events.
|
||||
type Handler struct {
|
||||
pMap *sync2.PollerMap
|
||||
v2Store *sync2.Storage
|
||||
Store *state.Storage
|
||||
v2Pub pubsub.Notifier
|
||||
v3Sub *pubsub.V3Sub
|
||||
client sync2.Client
|
||||
unreadMap map[string]struct {
|
||||
pMap sync2.IPollerMap
|
||||
v2Store *sync2.Storage
|
||||
Store *state.Storage
|
||||
v2Pub pubsub.Notifier
|
||||
v3Sub *pubsub.V3Sub
|
||||
// user_id|room_id|event_type => fnv_hash(last_event_bytes)
|
||||
accountDataMap *sync.Map
|
||||
unreadMap map[string]struct {
|
||||
Highlight int
|
||||
Notif int
|
||||
}
|
||||
// room_id => fnv_hash([typing user ids])
|
||||
typingMap map[string]uint64
|
||||
// room_id -> PollerID, stores which Poller is allowed to update typing notifications
|
||||
typingHandler map[string]sync2.PollerID
|
||||
typingMu *sync.Mutex
|
||||
PendingTxnIDs *sync2.PendingTransactionIDs
|
||||
|
||||
deviceDataTicker *sync2.DeviceDataTicker
|
||||
e2eeWorkerPool *internal.WorkerPool
|
||||
|
||||
numPollers prometheus.Gauge
|
||||
subSystem string
|
||||
}
|
||||
|
||||
func NewHandler(
|
||||
connStr string, pMap *sync2.PollerMap, v2Store *sync2.Storage, store *state.Storage, client sync2.Client,
|
||||
pub pubsub.Notifier, sub pubsub.Listener, enablePrometheus bool,
|
||||
pMap sync2.IPollerMap, v2Store *sync2.Storage, store *state.Storage,
|
||||
pub pubsub.Notifier, sub pubsub.Listener, enablePrometheus bool, deviceDataUpdateDuration time.Duration,
|
||||
) (*Handler, error) {
|
||||
h := &Handler{
|
||||
pMap: pMap,
|
||||
v2Store: v2Store,
|
||||
client: client,
|
||||
Store: store,
|
||||
subSystem: "poller",
|
||||
unreadMap: make(map[string]struct {
|
||||
Highlight int
|
||||
Notif int
|
||||
}),
|
||||
typingMap: make(map[string]uint64),
|
||||
accountDataMap: &sync.Map{},
|
||||
typingMu: &sync.Mutex{},
|
||||
typingHandler: make(map[string]sync2.PollerID),
|
||||
PendingTxnIDs: sync2.NewPendingTransactionIDs(pMap.DeviceIDs),
|
||||
deviceDataTicker: sync2.NewDeviceDataTicker(deviceDataUpdateDuration),
|
||||
e2eeWorkerPool: internal.NewWorkerPool(500), // TODO: assign as fraction of db max conns, not hardcoded
|
||||
}
|
||||
|
||||
if enablePrometheus {
|
||||
@ -86,6 +98,9 @@ func (h *Handler) Listen() {
|
||||
sentry.CaptureException(err)
|
||||
}
|
||||
}()
|
||||
h.e2eeWorkerPool.Start()
|
||||
h.deviceDataTicker.SetCallback(h.OnBulkDeviceDataUpdate)
|
||||
go h.deviceDataTicker.Run()
|
||||
}
|
||||
|
||||
func (h *Handler) Teardown() {
|
||||
@ -95,6 +110,7 @@ func (h *Handler) Teardown() {
|
||||
h.Store.Teardown()
|
||||
h.v2Store.Teardown()
|
||||
h.pMap.Terminate()
|
||||
h.deviceDataTicker.Stop()
|
||||
if h.numPollers != nil {
|
||||
prometheus.Unregister(h.numPollers)
|
||||
}
|
||||
@ -156,7 +172,15 @@ func (h *Handler) updateMetrics() {
|
||||
h.numPollers.Set(float64(h.pMap.NumPollers()))
|
||||
}
|
||||
|
||||
func (h *Handler) OnTerminated(ctx context.Context, userID, deviceID string) {
|
||||
func (h *Handler) OnTerminated(ctx context.Context, pollerID sync2.PollerID) {
|
||||
// Check if this device is handling any typing notifications, of so, remove it
|
||||
h.typingMu.Lock()
|
||||
defer h.typingMu.Unlock()
|
||||
for roomID, devID := range h.typingHandler {
|
||||
if devID == pollerID {
|
||||
delete(h.typingHandler, roomID)
|
||||
}
|
||||
}
|
||||
h.updateMetrics()
|
||||
}
|
||||
|
||||
@ -193,42 +217,62 @@ func (h *Handler) UpdateDeviceSince(ctx context.Context, userID, deviceID, since
|
||||
}
|
||||
|
||||
func (h *Handler) OnE2EEData(ctx context.Context, userID, deviceID string, otkCounts map[string]int, fallbackKeyTypes []string, deviceListChanges map[string]int) {
|
||||
// some of these fields may be set
|
||||
partialDD := internal.DeviceData{
|
||||
UserID: userID,
|
||||
DeviceID: deviceID,
|
||||
OTKCounts: otkCounts,
|
||||
FallbackKeyTypes: fallbackKeyTypes,
|
||||
DeviceLists: internal.DeviceLists{
|
||||
New: deviceListChanges,
|
||||
},
|
||||
}
|
||||
nextPos, err := h.Store.DeviceDataTable.Upsert(&partialDD)
|
||||
if err != nil {
|
||||
logger.Err(err).Str("user", userID).Msg("failed to upsert device data")
|
||||
internal.GetSentryHubFromContextOrDefault(ctx).CaptureException(err)
|
||||
return
|
||||
}
|
||||
h.v2Pub.Notify(pubsub.ChanV2, &pubsub.V2DeviceData{
|
||||
UserID: userID,
|
||||
DeviceID: deviceID,
|
||||
Pos: nextPos,
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
h.e2eeWorkerPool.Queue(func() {
|
||||
defer wg.Done()
|
||||
// some of these fields may be set
|
||||
partialDD := internal.DeviceData{
|
||||
UserID: userID,
|
||||
DeviceID: deviceID,
|
||||
OTKCounts: otkCounts,
|
||||
FallbackKeyTypes: fallbackKeyTypes,
|
||||
DeviceLists: internal.DeviceLists{
|
||||
New: deviceListChanges,
|
||||
},
|
||||
}
|
||||
err := h.Store.DeviceDataTable.Upsert(&partialDD)
|
||||
if err != nil {
|
||||
logger.Err(err).Str("user", userID).Msg("failed to upsert device data")
|
||||
internal.GetSentryHubFromContextOrDefault(ctx).CaptureException(err)
|
||||
return
|
||||
}
|
||||
// remember this to notify on pubsub later
|
||||
h.deviceDataTicker.Remember(sync2.PollerID{
|
||||
UserID: userID,
|
||||
DeviceID: deviceID,
|
||||
})
|
||||
})
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// Called periodically by deviceDataTicker, contains many updates
|
||||
func (h *Handler) OnBulkDeviceDataUpdate(payload *pubsub.V2DeviceData) {
|
||||
h.v2Pub.Notify(pubsub.ChanV2, payload)
|
||||
}
|
||||
|
||||
func (h *Handler) Accumulate(ctx context.Context, userID, deviceID, roomID, prevBatch string, timeline []json.RawMessage) {
|
||||
// Remember any transaction IDs that may be unique to this user
|
||||
eventIDsWithTxns := make([]string, 0, len(timeline)) // in timeline order
|
||||
eventIDToTxnID := make(map[string]string, len(timeline)) // event_id -> txn_id
|
||||
// Also remember events which were sent by this user but lack a transaction ID.
|
||||
eventIDsLackingTxns := make([]string, 0, len(timeline))
|
||||
|
||||
for _, e := range timeline {
|
||||
txnID := gjson.GetBytes(e, "unsigned.transaction_id")
|
||||
if !txnID.Exists() {
|
||||
parsed := gjson.ParseBytes(e)
|
||||
eventID := parsed.Get("event_id").Str
|
||||
|
||||
if txnID := parsed.Get("unsigned.transaction_id"); txnID.Exists() {
|
||||
eventIDsWithTxns = append(eventIDsWithTxns, eventID)
|
||||
eventIDToTxnID[eventID] = txnID.Str
|
||||
continue
|
||||
}
|
||||
eventID := gjson.GetBytes(e, "event_id").Str
|
||||
eventIDsWithTxns = append(eventIDsWithTxns, eventID)
|
||||
eventIDToTxnID[eventID] = txnID.Str
|
||||
|
||||
if sender := parsed.Get("sender"); sender.Str == userID {
|
||||
eventIDsLackingTxns = append(eventIDsLackingTxns, eventID)
|
||||
}
|
||||
}
|
||||
|
||||
if len(eventIDToTxnID) > 0 {
|
||||
// persist the txn IDs
|
||||
err := h.Store.TransactionsTable.Insert(userID, deviceID, eventIDToTxnID)
|
||||
@ -239,62 +283,69 @@ func (h *Handler) Accumulate(ctx context.Context, userID, deviceID, roomID, prev
|
||||
}
|
||||
|
||||
// Insert new events
|
||||
numNew, latestNIDs, err := h.Store.Accumulate(roomID, prevBatch, timeline)
|
||||
numNew, latestNIDs, err := h.Store.Accumulate(userID, roomID, prevBatch, timeline)
|
||||
if err != nil {
|
||||
logger.Err(err).Int("timeline", len(timeline)).Str("room", roomID).Msg("V2: failed to accumulate room")
|
||||
internal.GetSentryHubFromContextOrDefault(ctx).CaptureException(err)
|
||||
return
|
||||
}
|
||||
if numNew == 0 {
|
||||
// no new events
|
||||
return
|
||||
}
|
||||
h.v2Pub.Notify(pubsub.ChanV2, &pubsub.V2Accumulate{
|
||||
RoomID: roomID,
|
||||
PrevBatch: prevBatch,
|
||||
EventNIDs: latestNIDs,
|
||||
})
|
||||
|
||||
if len(eventIDToTxnID) > 0 {
|
||||
// We've updated the database. Now tell any pubsub listeners what we learned.
|
||||
if numNew != 0 {
|
||||
h.v2Pub.Notify(pubsub.ChanV2, &pubsub.V2Accumulate{
|
||||
RoomID: roomID,
|
||||
PrevBatch: prevBatch,
|
||||
EventNIDs: latestNIDs,
|
||||
})
|
||||
}
|
||||
|
||||
if len(eventIDToTxnID) > 0 || len(eventIDsLackingTxns) > 0 {
|
||||
// The call to h.Store.Accumulate above only tells us about new events' NIDS;
|
||||
// for existing events we need to requery the database to fetch them.
|
||||
// Rather than try to reuse work, keep things simple and just fetch NIDs for
|
||||
// all events with txnIDs.
|
||||
var nidsByIDs map[string]int64
|
||||
eventIDsToFetch := append(eventIDsWithTxns, eventIDsLackingTxns...)
|
||||
err = sqlutil.WithTransaction(h.Store.DB, func(txn *sqlx.Tx) error {
|
||||
nidsByIDs, err = h.Store.EventsTable.SelectNIDsByIDs(txn, eventIDsWithTxns)
|
||||
nidsByIDs, err = h.Store.EventsTable.SelectNIDsByIDs(txn, eventIDsToFetch)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
logger.Err(err).
|
||||
Int("timeline", len(timeline)).
|
||||
Int("num_transaction_ids", len(eventIDsWithTxns)).
|
||||
Int("num_missing_transaction_ids", len(eventIDsLackingTxns)).
|
||||
Str("room", roomID).
|
||||
Msg("V2: failed to fetch nids for events with transaction_ids")
|
||||
Msg("V2: failed to fetch nids for event transaction_id handling")
|
||||
internal.GetSentryHubFromContextOrDefault(ctx).CaptureException(err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, eventID := range eventIDsWithTxns {
|
||||
for eventID, nid := range nidsByIDs {
|
||||
txnID, ok := eventIDToTxnID[eventID]
|
||||
if !ok {
|
||||
continue
|
||||
if ok {
|
||||
h.PendingTxnIDs.SeenTxnID(eventID)
|
||||
h.v2Pub.Notify(pubsub.ChanV2, &pubsub.V2TransactionID{
|
||||
EventID: eventID,
|
||||
RoomID: roomID,
|
||||
UserID: userID,
|
||||
DeviceID: deviceID,
|
||||
TransactionID: txnID,
|
||||
NID: nid,
|
||||
})
|
||||
} else {
|
||||
allClear, _ := h.PendingTxnIDs.MissingTxnID(eventID, userID, deviceID)
|
||||
if allClear {
|
||||
h.v2Pub.Notify(pubsub.ChanV2, &pubsub.V2TransactionID{
|
||||
EventID: eventID,
|
||||
RoomID: roomID,
|
||||
UserID: userID,
|
||||
DeviceID: deviceID,
|
||||
TransactionID: "",
|
||||
NID: nid,
|
||||
})
|
||||
}
|
||||
}
|
||||
nid, ok := nidsByIDs[eventID]
|
||||
if !ok {
|
||||
errMsg := "V2: failed to fetch NID for txnID"
|
||||
logger.Error().Str("user", userID).Str("device", deviceID).Msg(errMsg)
|
||||
internal.GetSentryHubFromContextOrDefault(ctx).CaptureException(fmt.Errorf("errMsg"))
|
||||
continue
|
||||
}
|
||||
|
||||
h.v2Pub.Notify(pubsub.ChanV2, &pubsub.V2TransactionID{
|
||||
EventID: eventID,
|
||||
UserID: userID,
|
||||
DeviceID: deviceID,
|
||||
TransactionID: txnID,
|
||||
NID: nid,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -315,13 +366,20 @@ func (h *Handler) Initialise(ctx context.Context, roomID string, state []json.Ra
|
||||
return res.PrependTimelineEvents
|
||||
}
|
||||
|
||||
func (h *Handler) SetTyping(ctx context.Context, roomID string, ephEvent json.RawMessage) {
|
||||
next := typingHash(ephEvent)
|
||||
existing := h.typingMap[roomID]
|
||||
if existing == next {
|
||||
func (h *Handler) SetTyping(ctx context.Context, pollerID sync2.PollerID, roomID string, ephEvent json.RawMessage) {
|
||||
h.typingMu.Lock()
|
||||
defer h.typingMu.Unlock()
|
||||
|
||||
existingDevice := h.typingHandler[roomID]
|
||||
isPollerAssigned := existingDevice.DeviceID != "" && existingDevice.UserID != ""
|
||||
if isPollerAssigned && existingDevice != pollerID {
|
||||
// A different device is already handling typing notifications for this room
|
||||
return
|
||||
} else if !isPollerAssigned {
|
||||
// We're the first to call SetTyping, assign our pollerID
|
||||
h.typingHandler[roomID] = pollerID
|
||||
}
|
||||
h.typingMap[roomID] = next
|
||||
|
||||
// we don't persist this for long term storage as typing notifs are inherently ephemeral.
|
||||
// So rather than maintaining them forever, they will naturally expire when we terminate.
|
||||
h.v2Pub.Notify(pubsub.ChanV2, &pubsub.V2Typing{
|
||||
@ -398,7 +456,28 @@ func (h *Handler) UpdateUnreadCounts(ctx context.Context, roomID, userID string,
|
||||
}
|
||||
|
||||
func (h *Handler) OnAccountData(ctx context.Context, userID, roomID string, events []json.RawMessage) {
|
||||
data, err := h.Store.InsertAccountData(userID, roomID, events)
|
||||
// duplicate suppression for multiple devices on the same account.
|
||||
// We suppress by remembering the last bytes for a given account data, and if they match we ignore.
|
||||
dedupedEvents := make([]json.RawMessage, 0, len(events))
|
||||
for i := range events {
|
||||
evType := gjson.GetBytes(events[i], "type").Str
|
||||
key := fmt.Sprintf("%s|%s|%s", userID, roomID, evType)
|
||||
thisHash := fnvHash(events[i])
|
||||
last, _ := h.accountDataMap.Load(key)
|
||||
if last != nil {
|
||||
lastHash := last.(uint64)
|
||||
if lastHash == thisHash {
|
||||
continue // skip this event
|
||||
}
|
||||
}
|
||||
dedupedEvents = append(dedupedEvents, events[i])
|
||||
h.accountDataMap.Store(key, thisHash)
|
||||
}
|
||||
if len(dedupedEvents) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
data, err := h.Store.InsertAccountData(userID, roomID, dedupedEvents)
|
||||
if err != nil {
|
||||
logger.Err(err).Str("user", userID).Str("room", roomID).Msg("failed to update account data")
|
||||
sentry.CaptureException(err)
|
||||
@ -428,16 +507,24 @@ func (h *Handler) OnInvite(ctx context.Context, userID, roomID string, inviteSta
|
||||
})
|
||||
}
|
||||
|
||||
func (h *Handler) OnLeftRoom(ctx context.Context, userID, roomID string) {
|
||||
func (h *Handler) OnLeftRoom(ctx context.Context, userID, roomID string, leaveEv json.RawMessage) {
|
||||
// remove any invites for this user if they are rejecting an invite
|
||||
err := h.Store.InvitesTable.RemoveInvite(userID, roomID)
|
||||
if err != nil {
|
||||
logger.Err(err).Str("user", userID).Str("room", roomID).Msg("failed to retire invite")
|
||||
internal.GetSentryHubFromContextOrDefault(ctx).CaptureException(err)
|
||||
}
|
||||
|
||||
// Remove room from the typing deviceHandler map, this ensures we always
|
||||
// have a device handling typing notifications for a given room.
|
||||
h.typingMu.Lock()
|
||||
defer h.typingMu.Unlock()
|
||||
delete(h.typingHandler, roomID)
|
||||
|
||||
h.v2Pub.Notify(pubsub.ChanV2, &pubsub.V2LeaveRoom{
|
||||
UserID: userID,
|
||||
RoomID: roomID,
|
||||
UserID: userID,
|
||||
RoomID: roomID,
|
||||
LeaveEvent: leaveEv,
|
||||
})
|
||||
}
|
||||
|
||||
@ -471,10 +558,8 @@ func (h *Handler) EnsurePolling(p *pubsub.V3EnsurePolling) {
|
||||
}()
|
||||
}
|
||||
|
||||
func typingHash(ephEvent json.RawMessage) uint64 {
|
||||
func fnvHash(event json.RawMessage) uint64 {
|
||||
h := fnv.New64a()
|
||||
for _, userID := range gjson.ParseBytes(ephEvent).Get("content.user_ids").Array() {
|
||||
h.Write([]byte(userID.Str))
|
||||
}
|
||||
h.Write(event)
|
||||
return h.Sum64()
|
||||
}
|
||||
|
224
sync2/handler2/handler_test.go
Normal file
224
sync2/handler2/handler_test.go
Normal file
@ -0,0 +1,224 @@
|
||||
package handler2_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"os"
|
||||
"reflect"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/matrix-org/sliding-sync/sqlutil"
|
||||
|
||||
"github.com/matrix-org/sliding-sync/pubsub"
|
||||
"github.com/matrix-org/sliding-sync/state"
|
||||
"github.com/matrix-org/sliding-sync/sync2"
|
||||
"github.com/matrix-org/sliding-sync/sync2/handler2"
|
||||
"github.com/matrix-org/sliding-sync/testutils"
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
var postgresURI string
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
postgresURI = testutils.PrepareDBConnectionString()
|
||||
exitCode := m.Run()
|
||||
os.Exit(exitCode)
|
||||
}
|
||||
|
||||
type pollInfo struct {
|
||||
pid sync2.PollerID
|
||||
accessToken string
|
||||
v2since string
|
||||
isStartup bool
|
||||
}
|
||||
|
||||
type mockPollerMap struct {
|
||||
calls []pollInfo
|
||||
}
|
||||
|
||||
func (p *mockPollerMap) NumPollers() int {
|
||||
return 0
|
||||
}
|
||||
func (p *mockPollerMap) Terminate() {}
|
||||
|
||||
func (p *mockPollerMap) DeviceIDs(userID string) []string {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *mockPollerMap) EnsurePolling(pid sync2.PollerID, accessToken, v2since string, isStartup bool, logger zerolog.Logger) {
|
||||
p.calls = append(p.calls, pollInfo{
|
||||
pid: pid,
|
||||
accessToken: accessToken,
|
||||
v2since: v2since,
|
||||
isStartup: isStartup,
|
||||
})
|
||||
}
|
||||
|
||||
func (p *mockPollerMap) assertCallExists(t *testing.T, pi pollInfo) {
|
||||
for _, c := range p.calls {
|
||||
if reflect.DeepEqual(pi, c) {
|
||||
return
|
||||
}
|
||||
}
|
||||
t.Fatalf("assertCallExists: did not find %+v", pi)
|
||||
}
|
||||
|
||||
type mockPub struct {
|
||||
calls []pubsub.Payload
|
||||
mu *sync.Mutex
|
||||
waiters map[string][]chan struct{}
|
||||
}
|
||||
|
||||
func newMockPub() *mockPub {
|
||||
return &mockPub{
|
||||
mu: &sync.Mutex{},
|
||||
waiters: make(map[string][]chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Notify chanName that there is a new payload p. Return an error if we failed to send the notification.
|
||||
func (p *mockPub) Notify(chanName string, payload pubsub.Payload) error {
|
||||
p.calls = append(p.calls, payload)
|
||||
p.mu.Lock()
|
||||
for _, ch := range p.waiters[payload.Type()] {
|
||||
close(ch)
|
||||
}
|
||||
p.waiters[payload.Type()] = nil // don't re-notify for 2nd+ payload
|
||||
p.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *mockPub) WaitForPayloadType(t string) chan struct{} {
|
||||
ch := make(chan struct{})
|
||||
p.mu.Lock()
|
||||
p.waiters[t] = append(p.waiters[t], ch)
|
||||
p.mu.Unlock()
|
||||
return ch
|
||||
}
|
||||
|
||||
func (p *mockPub) DoWait(t *testing.T, errMsg string, ch chan struct{}, wantTimeOut bool) {
|
||||
select {
|
||||
case <-ch:
|
||||
if wantTimeOut {
|
||||
t.Fatalf("expected to timeout, but received on channel")
|
||||
}
|
||||
return
|
||||
case <-time.After(time.Second):
|
||||
if !wantTimeOut {
|
||||
t.Fatalf("DoWait: timed out waiting: %s", errMsg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close is called when we should stop listening.
|
||||
func (p *mockPub) Close() error { return nil }
|
||||
|
||||
type mockSub struct{}
|
||||
|
||||
// Begin listening on this channel with this callback starting from this position. Blocks until Close() is called.
|
||||
func (s *mockSub) Listen(chanName string, fn func(p pubsub.Payload)) error { return nil }
|
||||
|
||||
// Close the listener. No more callbacks should fire.
|
||||
func (s *mockSub) Close() error { return nil }
|
||||
|
||||
func assertNoError(t *testing.T, err error) {
|
||||
t.Helper()
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
t.Fatalf("assertNoError: %v", err)
|
||||
}
|
||||
|
||||
// Test that if you call EnsurePolling you get back V2InitialSyncComplete down pubsub and the poller
|
||||
// map is called correctly
|
||||
func TestHandlerFreshEnsurePolling(t *testing.T) {
|
||||
store := state.NewStorage(postgresURI)
|
||||
v2Store := sync2.NewStore(postgresURI, "secret")
|
||||
pMap := &mockPollerMap{}
|
||||
pub := newMockPub()
|
||||
sub := &mockSub{}
|
||||
h, err := handler2.NewHandler(pMap, v2Store, store, pub, sub, false, time.Minute)
|
||||
assertNoError(t, err)
|
||||
alice := "@alice:localhost"
|
||||
deviceID := "ALICE"
|
||||
token := "aliceToken"
|
||||
|
||||
var tok *sync2.Token
|
||||
sqlutil.WithTransaction(v2Store.DB, func(txn *sqlx.Tx) error {
|
||||
// the device and token needs to already exist prior to EnsurePolling
|
||||
err = v2Store.DevicesTable.InsertDevice(txn, alice, deviceID)
|
||||
assertNoError(t, err)
|
||||
tok, err = v2Store.TokensTable.Insert(txn, token, alice, deviceID, time.Now())
|
||||
assertNoError(t, err)
|
||||
return nil
|
||||
})
|
||||
|
||||
payloadInitialSyncComplete := pubsub.V2InitialSyncComplete{
|
||||
UserID: alice,
|
||||
DeviceID: deviceID,
|
||||
}
|
||||
ch := pub.WaitForPayloadType(payloadInitialSyncComplete.Type())
|
||||
// ask the handler to start polling
|
||||
h.EnsurePolling(&pubsub.V3EnsurePolling{
|
||||
UserID: alice,
|
||||
DeviceID: deviceID,
|
||||
AccessTokenHash: tok.AccessTokenHash,
|
||||
})
|
||||
pub.DoWait(t, "didn't see V2InitialSyncComplete", ch, false)
|
||||
|
||||
// make sure we polled with the token i.e it did a db hit
|
||||
pMap.assertCallExists(t, pollInfo{
|
||||
pid: sync2.PollerID{
|
||||
UserID: alice,
|
||||
DeviceID: deviceID,
|
||||
},
|
||||
accessToken: token,
|
||||
v2since: "",
|
||||
isStartup: false,
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
func TestSetTypingConcurrently(t *testing.T) {
|
||||
store := state.NewStorage(postgresURI)
|
||||
v2Store := sync2.NewStore(postgresURI, "secret")
|
||||
pMap := &mockPollerMap{}
|
||||
pub := newMockPub()
|
||||
sub := &mockSub{}
|
||||
h, err := handler2.NewHandler(pMap, v2Store, store, pub, sub, false, time.Minute)
|
||||
assertNoError(t, err)
|
||||
ctx := context.Background()
|
||||
|
||||
roomID := "!typing:localhost"
|
||||
|
||||
typingType := pubsub.V2Typing{}
|
||||
|
||||
// startSignal is used to synchronize calling SetTyping
|
||||
startSignal := make(chan struct{})
|
||||
// Call SetTyping twice, this may happen with pollers for the same user
|
||||
go func() {
|
||||
<-startSignal
|
||||
h.SetTyping(ctx, sync2.PollerID{UserID: "@alice", DeviceID: "aliceDevice"}, roomID, json.RawMessage(`{"content":{"user_ids":["@alice:localhost"]}}`))
|
||||
}()
|
||||
go func() {
|
||||
<-startSignal
|
||||
h.SetTyping(ctx, sync2.PollerID{UserID: "@bob", DeviceID: "bobDevice"}, roomID, json.RawMessage(`{"content":{"user_ids":["@alice:localhost"]}}`))
|
||||
}()
|
||||
|
||||
close(startSignal)
|
||||
|
||||
// Wait for the event to be published
|
||||
ch := pub.WaitForPayloadType(typingType.Type())
|
||||
pub.DoWait(t, "didn't see V2Typing", ch, false)
|
||||
ch = pub.WaitForPayloadType(typingType.Type())
|
||||
// Wait again, but this time we expect to timeout.
|
||||
pub.DoWait(t, "saw unexpected V2Typing", ch, true)
|
||||
|
||||
// We expect only one call to Notify, as the hashes should match
|
||||
if gotCalls := len(pub.calls); gotCalls != 1 {
|
||||
t.Fatalf("expected only one call to notify, got %d", gotCalls)
|
||||
}
|
||||
}
|
234
sync2/poller.go
234
sync2/poller.go
@ -4,12 +4,13 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/getsentry/sentry-go"
|
||||
"runtime/debug"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/getsentry/sentry-go"
|
||||
|
||||
"github.com/matrix-org/sliding-sync/internal"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/rs/zerolog"
|
||||
@ -21,8 +22,12 @@ type PollerID struct {
|
||||
DeviceID string
|
||||
}
|
||||
|
||||
// alias time.Sleep so tests can monkey patch it out
|
||||
// alias time.Sleep/time.Since so tests can monkey patch it out
|
||||
var timeSleep = time.Sleep
|
||||
var timeSince = time.Since
|
||||
|
||||
// log at most once every duration. Always logs before terminating.
|
||||
var logInterval = 30 * time.Second
|
||||
|
||||
// V2DataReceiver is the receiver for all the v2 sync data the poller gets
|
||||
type V2DataReceiver interface {
|
||||
@ -34,7 +39,7 @@ type V2DataReceiver interface {
|
||||
// If given a state delta from an incremental sync, returns the slice of all state events unknown to the DB.
|
||||
Initialise(ctx context.Context, roomID string, state []json.RawMessage) []json.RawMessage // snapshot ID?
|
||||
// SetTyping indicates which users are typing.
|
||||
SetTyping(ctx context.Context, roomID string, ephEvent json.RawMessage)
|
||||
SetTyping(ctx context.Context, pollerID PollerID, roomID string, ephEvent json.RawMessage)
|
||||
// Sent when there is a new receipt
|
||||
OnReceipt(ctx context.Context, userID, roomID, ephEventType string, ephEvent json.RawMessage)
|
||||
// AddToDeviceMessages adds this chunk of to_device messages. Preserve the ordering.
|
||||
@ -46,25 +51,35 @@ type V2DataReceiver interface {
|
||||
// Sent when there is a room in the `invite` section of the v2 response.
|
||||
OnInvite(ctx context.Context, userID, roomID string, inviteState []json.RawMessage) // invitestate in db
|
||||
// Sent when there is a room in the `leave` section of the v2 response.
|
||||
OnLeftRoom(ctx context.Context, userID, roomID string)
|
||||
OnLeftRoom(ctx context.Context, userID, roomID string, leaveEvent json.RawMessage)
|
||||
// Sent when there is a _change_ in E2EE data, not all the time
|
||||
OnE2EEData(ctx context.Context, userID, deviceID string, otkCounts map[string]int, fallbackKeyTypes []string, deviceListChanges map[string]int)
|
||||
// Sent when the poll loop terminates
|
||||
OnTerminated(ctx context.Context, userID, deviceID string)
|
||||
OnTerminated(ctx context.Context, pollerID PollerID)
|
||||
// Sent when the token gets a 401 response
|
||||
OnExpiredToken(ctx context.Context, accessTokenHash, userID, deviceID string)
|
||||
}
|
||||
|
||||
type IPollerMap interface {
|
||||
EnsurePolling(pid PollerID, accessToken, v2since string, isStartup bool, logger zerolog.Logger)
|
||||
NumPollers() int
|
||||
Terminate()
|
||||
DeviceIDs(userID string) []string
|
||||
}
|
||||
|
||||
// PollerMap is a map of device ID to Poller
|
||||
type PollerMap struct {
|
||||
v2Client Client
|
||||
callbacks V2DataReceiver
|
||||
pollerMu *sync.Mutex
|
||||
Pollers map[PollerID]*poller
|
||||
executor chan func()
|
||||
executorRunning bool
|
||||
processHistogramVec *prometheus.HistogramVec
|
||||
timelineSizeHistogramVec *prometheus.HistogramVec
|
||||
v2Client Client
|
||||
callbacks V2DataReceiver
|
||||
pollerMu *sync.Mutex
|
||||
Pollers map[PollerID]*poller
|
||||
executor chan func()
|
||||
executorRunning bool
|
||||
processHistogramVec *prometheus.HistogramVec
|
||||
timelineSizeHistogramVec *prometheus.HistogramVec
|
||||
gappyStateSizeVec *prometheus.HistogramVec
|
||||
numOutstandingSyncReqsGauge prometheus.Gauge
|
||||
totalNumPollsCounter prometheus.Counter
|
||||
}
|
||||
|
||||
// NewPollerMap makes a new PollerMap. Guarantees that the V2DataReceiver will be called on the same
|
||||
@ -115,7 +130,28 @@ func NewPollerMap(v2Client Client, enablePrometheus bool) *PollerMap {
|
||||
Buckets: []float64{0.0, 1.0, 2.0, 5.0, 10.0, 20.0, 50.0},
|
||||
}, []string{"limited"})
|
||||
prometheus.MustRegister(pm.timelineSizeHistogramVec)
|
||||
|
||||
pm.gappyStateSizeVec = prometheus.NewHistogramVec(prometheus.HistogramOpts{
|
||||
Namespace: "sliding_sync",
|
||||
Subsystem: "poller",
|
||||
Name: "gappy_state_size",
|
||||
Help: "Number of events in a state block during a sync v2 gappy sync",
|
||||
Buckets: []float64{1.0, 10.0, 100.0, 1000.0, 10000.0},
|
||||
}, nil)
|
||||
prometheus.MustRegister(pm.gappyStateSizeVec)
|
||||
pm.totalNumPollsCounter = prometheus.NewCounter(prometheus.CounterOpts{
|
||||
Namespace: "sliding_sync",
|
||||
Subsystem: "poller",
|
||||
Name: "total_num_polls",
|
||||
Help: "Total number of poll loops iterated.",
|
||||
})
|
||||
prometheus.MustRegister(pm.totalNumPollsCounter)
|
||||
pm.numOutstandingSyncReqsGauge = prometheus.NewGauge(prometheus.GaugeOpts{
|
||||
Namespace: "sliding_sync",
|
||||
Subsystem: "poller",
|
||||
Name: "num_outstanding_sync_v2_reqs",
|
||||
Help: "Number of sync v2 requests that have yet to return a response.",
|
||||
})
|
||||
prometheus.MustRegister(pm.numOutstandingSyncReqsGauge)
|
||||
}
|
||||
return pm
|
||||
}
|
||||
@ -137,6 +173,15 @@ func (h *PollerMap) Terminate() {
|
||||
if h.timelineSizeHistogramVec != nil {
|
||||
prometheus.Unregister(h.timelineSizeHistogramVec)
|
||||
}
|
||||
if h.gappyStateSizeVec != nil {
|
||||
prometheus.Unregister(h.gappyStateSizeVec)
|
||||
}
|
||||
if h.totalNumPollsCounter != nil {
|
||||
prometheus.Unregister(h.totalNumPollsCounter)
|
||||
}
|
||||
if h.numOutstandingSyncReqsGauge != nil {
|
||||
prometheus.Unregister(h.numOutstandingSyncReqsGauge)
|
||||
}
|
||||
close(h.executor)
|
||||
}
|
||||
|
||||
@ -151,6 +196,20 @@ func (h *PollerMap) NumPollers() (count int) {
|
||||
return
|
||||
}
|
||||
|
||||
// DeviceIDs returns the slice of all devices currently being polled for by this user.
|
||||
// The return value is brand-new and is fully owned by the caller.
|
||||
func (h *PollerMap) DeviceIDs(userID string) []string {
|
||||
h.pollerMu.Lock()
|
||||
defer h.pollerMu.Unlock()
|
||||
var devices []string
|
||||
for _, p := range h.Pollers {
|
||||
if !p.terminated.Load() && p.userID == userID {
|
||||
devices = append(devices, p.deviceID)
|
||||
}
|
||||
}
|
||||
return devices
|
||||
}
|
||||
|
||||
// EnsurePolling makes sure there is a poller for this device, making one if need be.
|
||||
// Blocks until at least 1 sync is done if and only if the poller was just created.
|
||||
// This ensures that calls to the database will return data.
|
||||
@ -196,6 +255,9 @@ func (h *PollerMap) EnsurePolling(pid PollerID, accessToken, v2since string, isS
|
||||
poller = newPoller(pid, accessToken, h.v2Client, h, logger, !needToWait && !isStartup)
|
||||
poller.processHistogramVec = h.processHistogramVec
|
||||
poller.timelineSizeVec = h.timelineSizeHistogramVec
|
||||
poller.gappyStateSizeVec = h.gappyStateSizeVec
|
||||
poller.numOutstandingSyncReqs = h.numOutstandingSyncReqsGauge
|
||||
poller.totalNumPolls = h.totalNumPollsCounter
|
||||
go poller.Poll(v2since)
|
||||
h.Pollers[pid] = poller
|
||||
|
||||
@ -235,11 +297,11 @@ func (h *PollerMap) Initialise(ctx context.Context, roomID string, state []json.
|
||||
wg.Wait()
|
||||
return
|
||||
}
|
||||
func (h *PollerMap) SetTyping(ctx context.Context, roomID string, ephEvent json.RawMessage) {
|
||||
func (h *PollerMap) SetTyping(ctx context.Context, pollerID PollerID, roomID string, ephEvent json.RawMessage) {
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
h.executor <- func() {
|
||||
h.callbacks.SetTyping(ctx, roomID, ephEvent)
|
||||
h.callbacks.SetTyping(ctx, pollerID, roomID, ephEvent)
|
||||
wg.Done()
|
||||
}
|
||||
wg.Wait()
|
||||
@ -254,11 +316,11 @@ func (h *PollerMap) OnInvite(ctx context.Context, userID, roomID string, inviteS
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func (h *PollerMap) OnLeftRoom(ctx context.Context, userID, roomID string) {
|
||||
func (h *PollerMap) OnLeftRoom(ctx context.Context, userID, roomID string, leaveEvent json.RawMessage) {
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
h.executor <- func() {
|
||||
h.callbacks.OnLeftRoom(ctx, userID, roomID)
|
||||
h.callbacks.OnLeftRoom(ctx, userID, roomID, leaveEvent)
|
||||
wg.Done()
|
||||
}
|
||||
wg.Wait()
|
||||
@ -270,8 +332,8 @@ func (h *PollerMap) AddToDeviceMessages(ctx context.Context, userID, deviceID st
|
||||
h.callbacks.AddToDeviceMessages(ctx, userID, deviceID, msgs)
|
||||
}
|
||||
|
||||
func (h *PollerMap) OnTerminated(ctx context.Context, userID, deviceID string) {
|
||||
h.callbacks.OnTerminated(ctx, userID, deviceID)
|
||||
func (h *PollerMap) OnTerminated(ctx context.Context, pollerID PollerID) {
|
||||
h.callbacks.OnTerminated(ctx, pollerID)
|
||||
}
|
||||
|
||||
func (h *PollerMap) OnExpiredToken(ctx context.Context, accessTokenHash, userID, deviceID string) {
|
||||
@ -335,9 +397,24 @@ type poller struct {
|
||||
terminated *atomic.Bool
|
||||
wg *sync.WaitGroup
|
||||
|
||||
pollHistogramVec *prometheus.HistogramVec
|
||||
processHistogramVec *prometheus.HistogramVec
|
||||
timelineSizeVec *prometheus.HistogramVec
|
||||
// stats about poll response data, for logging purposes
|
||||
lastLogged time.Time
|
||||
totalStateCalls int
|
||||
totalTimelineCalls int
|
||||
totalReceipts int
|
||||
totalTyping int
|
||||
totalInvites int
|
||||
totalDeviceEvents int
|
||||
totalAccountData int
|
||||
totalChangedDeviceLists int
|
||||
totalLeftDeviceLists int
|
||||
|
||||
pollHistogramVec *prometheus.HistogramVec
|
||||
processHistogramVec *prometheus.HistogramVec
|
||||
timelineSizeVec *prometheus.HistogramVec
|
||||
gappyStateSizeVec *prometheus.HistogramVec
|
||||
numOutstandingSyncReqs prometheus.Gauge
|
||||
totalNumPolls prometheus.Counter
|
||||
}
|
||||
|
||||
func newPoller(pid PollerID, accessToken string, client Client, receiver V2DataReceiver, logger zerolog.Logger, initialToDeviceOnly bool) *poller {
|
||||
@ -366,9 +443,10 @@ func (p *poller) Terminate() {
|
||||
}
|
||||
|
||||
type pollLoopState struct {
|
||||
firstTime bool
|
||||
failCount int
|
||||
since string
|
||||
firstTime bool
|
||||
failCount int
|
||||
since string
|
||||
lastStoredSince time.Time // The time we last stored the since token in the database
|
||||
}
|
||||
|
||||
// Poll will block forever, repeatedly calling v2 sync. Do this in a goroutine.
|
||||
@ -392,16 +470,21 @@ func (p *poller) Poll(since string) {
|
||||
defer func() {
|
||||
panicErr := recover()
|
||||
if panicErr != nil {
|
||||
logger.Error().Str("user", p.userID).Str("device", p.deviceID).Msg(string(debug.Stack()))
|
||||
logger.Error().Str("user", p.userID).Str("device", p.deviceID).Msgf("%s. Traceback:\n%s", panicErr, debug.Stack())
|
||||
internal.GetSentryHubFromContextOrDefault(ctx).RecoverWithContext(ctx, panicErr)
|
||||
}
|
||||
p.receiver.OnTerminated(ctx, p.userID, p.deviceID)
|
||||
p.receiver.OnTerminated(ctx, PollerID{
|
||||
UserID: p.userID,
|
||||
DeviceID: p.deviceID,
|
||||
})
|
||||
}()
|
||||
|
||||
state := pollLoopState{
|
||||
firstTime: true,
|
||||
failCount: 0,
|
||||
since: since,
|
||||
// Setting time.Time{} results in the first poll loop to immediately store the since token.
|
||||
lastStoredSince: time.Time{},
|
||||
}
|
||||
for !p.terminated.Load() {
|
||||
ctx, task := internal.StartTask(ctx, "Poll")
|
||||
@ -411,6 +494,7 @@ func (p *poller) Poll(since string) {
|
||||
break
|
||||
}
|
||||
}
|
||||
p.maybeLogStats(true)
|
||||
// always unblock EnsurePolling else we can end up head-of-line blocking other pollers!
|
||||
if state.firstTime {
|
||||
state.firstTime = false
|
||||
@ -422,6 +506,9 @@ func (p *poller) Poll(since string) {
|
||||
// s (which is assumed to be non-nil). Returns a non-nil error iff the poller loop
|
||||
// should halt.
|
||||
func (p *poller) poll(ctx context.Context, s *pollLoopState) error {
|
||||
if p.totalNumPolls != nil {
|
||||
p.totalNumPolls.Inc()
|
||||
}
|
||||
if s.failCount > 0 {
|
||||
// don't backoff when doing v2 syncs because the response is only in the cache for a short
|
||||
// period of time (on massive accounts on matrix.org) such that if you wait 2,4,8min between
|
||||
@ -435,15 +522,22 @@ func (p *poller) poll(ctx context.Context, s *pollLoopState) error {
|
||||
}
|
||||
start := time.Now()
|
||||
spanCtx, region := internal.StartSpan(ctx, "DoSyncV2")
|
||||
if p.numOutstandingSyncReqs != nil {
|
||||
p.numOutstandingSyncReqs.Inc()
|
||||
}
|
||||
resp, statusCode, err := p.client.DoSyncV2(spanCtx, p.accessToken, s.since, s.firstTime, p.initialToDeviceOnly)
|
||||
if p.numOutstandingSyncReqs != nil {
|
||||
p.numOutstandingSyncReqs.Dec()
|
||||
}
|
||||
region.End()
|
||||
p.trackRequestDuration(time.Since(start), s.since == "", s.firstTime)
|
||||
p.trackRequestDuration(timeSince(start), s.since == "", s.firstTime)
|
||||
if p.terminated.Load() {
|
||||
return fmt.Errorf("poller terminated")
|
||||
}
|
||||
if err != nil {
|
||||
// check if temporary
|
||||
if statusCode != 401 {
|
||||
isFatal := statusCode == 401 || statusCode == 403
|
||||
if !isFatal {
|
||||
p.logger.Warn().Int("code", statusCode).Err(err).Msg("Poller: sync v2 poll returned temporary error")
|
||||
s.failCount += 1
|
||||
return nil
|
||||
@ -472,14 +566,19 @@ func (p *poller) poll(ctx context.Context, s *pollLoopState) error {
|
||||
wasFirst := s.firstTime
|
||||
|
||||
s.since = resp.NextBatch
|
||||
// persist the since token (TODO: this could get slow if we hammer the DB too much)
|
||||
p.receiver.UpdateDeviceSince(ctx, p.userID, p.deviceID, s.since)
|
||||
// Persist the since token if it either was more than one minute ago since we
|
||||
// last stored it OR the response contains to-device messages
|
||||
if timeSince(s.lastStoredSince) > time.Minute || len(resp.ToDevice.Events) > 0 {
|
||||
p.receiver.UpdateDeviceSince(ctx, p.userID, p.deviceID, s.since)
|
||||
s.lastStoredSince = time.Now()
|
||||
}
|
||||
|
||||
if s.firstTime {
|
||||
s.firstTime = false
|
||||
p.wg.Done()
|
||||
}
|
||||
p.trackProcessDuration(time.Since(start), wasInitial, wasFirst)
|
||||
p.trackProcessDuration(timeSince(start), wasInitial, wasFirst)
|
||||
p.maybeLogStats(false)
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -518,6 +617,7 @@ func (p *poller) parseToDeviceMessages(ctx context.Context, res *SyncResponse) {
|
||||
if len(res.ToDevice.Events) == 0 {
|
||||
return
|
||||
}
|
||||
p.totalDeviceEvents += len(res.ToDevice.Events)
|
||||
p.receiver.AddToDeviceMessages(ctx, p.userID, p.deviceID, res.ToDevice.Events)
|
||||
}
|
||||
|
||||
@ -556,6 +656,8 @@ func (p *poller) parseE2EEData(ctx context.Context, res *SyncResponse) {
|
||||
deviceListChanges := internal.ToDeviceListChangesMap(res.DeviceLists.Changed, res.DeviceLists.Left)
|
||||
|
||||
if deviceListChanges != nil || changedFallbackTypes != nil || changedOTKCounts != nil {
|
||||
p.totalChangedDeviceLists += len(res.DeviceLists.Changed)
|
||||
p.totalLeftDeviceLists += len(res.DeviceLists.Left)
|
||||
p.receiver.OnE2EEData(ctx, p.userID, p.deviceID, changedOTKCounts, changedFallbackTypes, deviceListChanges)
|
||||
}
|
||||
}
|
||||
@ -566,6 +668,7 @@ func (p *poller) parseGlobalAccountData(ctx context.Context, res *SyncResponse)
|
||||
if len(res.AccountData.Events) == 0 {
|
||||
return
|
||||
}
|
||||
p.totalAccountData += len(res.AccountData.Events)
|
||||
p.receiver.OnAccountData(ctx, p.userID, AccountDataGlobalRoom, res.AccountData.Events)
|
||||
}
|
||||
|
||||
@ -596,6 +699,7 @@ func (p *poller) parseRoomsResponse(ctx context.Context, res *SyncResponse) {
|
||||
})
|
||||
hub.CaptureMessage(warnMsg)
|
||||
})
|
||||
p.trackGappyStateSize(len(prependStateEvents))
|
||||
roomData.Timeline.Events = append(prependStateEvents, roomData.Timeline.Events...)
|
||||
}
|
||||
}
|
||||
@ -605,7 +709,7 @@ func (p *poller) parseRoomsResponse(ctx context.Context, res *SyncResponse) {
|
||||
switch ephEventType {
|
||||
case "m.typing":
|
||||
typingCalls++
|
||||
p.receiver.SetTyping(ctx, roomID, ephEvent)
|
||||
p.receiver.SetTyping(ctx, PollerID{UserID: p.userID, DeviceID: p.deviceID}, roomID, ephEvent)
|
||||
case "m.receipt":
|
||||
receiptCalls++
|
||||
p.receiver.OnReceipt(ctx, p.userID, roomID, ephEventType, ephEvent)
|
||||
@ -630,27 +734,60 @@ func (p *poller) parseRoomsResponse(ctx context.Context, res *SyncResponse) {
|
||||
}
|
||||
}
|
||||
for roomID, roomData := range res.Rooms.Leave {
|
||||
// TODO: do we care about state?
|
||||
if len(roomData.Timeline.Events) > 0 {
|
||||
p.trackTimelineSize(len(roomData.Timeline.Events), roomData.Timeline.Limited)
|
||||
p.receiver.Accumulate(ctx, p.userID, p.deviceID, roomID, roomData.Timeline.PrevBatch, roomData.Timeline.Events)
|
||||
}
|
||||
p.receiver.OnLeftRoom(ctx, p.userID, roomID)
|
||||
// Pass the leave event directly to OnLeftRoom. We need to do this _in addition_ to calling Accumulate to handle
|
||||
// the case where a user rejects an invite (there will be no room state, but the user still expects to see the leave event).
|
||||
var leaveEvent json.RawMessage
|
||||
for _, ev := range roomData.Timeline.Events {
|
||||
leaveEv := gjson.ParseBytes(ev)
|
||||
if leaveEv.Get("content.membership").Str == "leave" && leaveEv.Get("state_key").Str == p.userID {
|
||||
leaveEvent = ev
|
||||
break
|
||||
}
|
||||
}
|
||||
if leaveEvent != nil {
|
||||
p.receiver.OnLeftRoom(ctx, p.userID, roomID, leaveEvent)
|
||||
}
|
||||
}
|
||||
for roomID, roomData := range res.Rooms.Invite {
|
||||
p.receiver.OnInvite(ctx, p.userID, roomID, roomData.InviteState.Events)
|
||||
}
|
||||
var l *zerolog.Event
|
||||
if len(res.Rooms.Invite) > 0 || len(res.Rooms.Join) > 0 {
|
||||
l = p.logger.Info()
|
||||
} else {
|
||||
l = p.logger.Debug()
|
||||
|
||||
p.totalReceipts += receiptCalls
|
||||
p.totalStateCalls += stateCalls
|
||||
p.totalTimelineCalls += timelineCalls
|
||||
p.totalTyping += typingCalls
|
||||
p.totalInvites += len(res.Rooms.Invite)
|
||||
}
|
||||
|
||||
func (p *poller) maybeLogStats(force bool) {
|
||||
if !force && timeSince(p.lastLogged) < logInterval {
|
||||
// only log at most once every logInterval
|
||||
return
|
||||
}
|
||||
l.Ints(
|
||||
"rooms [invite,join,leave]", []int{len(res.Rooms.Invite), len(res.Rooms.Join), len(res.Rooms.Leave)},
|
||||
p.lastLogged = time.Now()
|
||||
p.logger.Info().Ints(
|
||||
"rooms [timeline,state,typing,receipts,invites]", []int{
|
||||
p.totalTimelineCalls, p.totalStateCalls, p.totalTyping, p.totalReceipts, p.totalInvites,
|
||||
},
|
||||
).Ints(
|
||||
"storage [states,timelines,typing,receipts]", []int{stateCalls, timelineCalls, typingCalls, receiptCalls},
|
||||
).Int("to_device", len(res.ToDevice.Events)).Msg("Poller: accumulated data")
|
||||
"device [events,changed,left,account]", []int{
|
||||
p.totalDeviceEvents, p.totalChangedDeviceLists, p.totalLeftDeviceLists, p.totalAccountData,
|
||||
},
|
||||
).Msg("Poller: accumulated data")
|
||||
|
||||
p.totalAccountData = 0
|
||||
p.totalChangedDeviceLists = 0
|
||||
p.totalDeviceEvents = 0
|
||||
p.totalInvites = 0
|
||||
p.totalLeftDeviceLists = 0
|
||||
p.totalReceipts = 0
|
||||
p.totalStateCalls = 0
|
||||
p.totalTimelineCalls = 0
|
||||
p.totalTyping = 0
|
||||
}
|
||||
|
||||
func (p *poller) trackTimelineSize(size int, limited bool) {
|
||||
@ -663,3 +800,10 @@ func (p *poller) trackTimelineSize(size int, limited bool) {
|
||||
}
|
||||
p.timelineSizeVec.WithLabelValues(label).Observe(float64(size))
|
||||
}
|
||||
|
||||
func (p *poller) trackGappyStateSize(size int) {
|
||||
if p.gappyStateSizeVec == nil {
|
||||
return
|
||||
}
|
||||
p.gappyStateSizeVec.WithLabelValues().Observe(float64(size))
|
||||
}
|
||||
|
@ -277,6 +277,9 @@ func TestPollerPollFromExisting(t *testing.T) {
|
||||
json.RawMessage(`{"event":10}`),
|
||||
},
|
||||
}
|
||||
toDeviceResponses := [][]json.RawMessage{
|
||||
{}, {}, {}, {json.RawMessage(`{}`)},
|
||||
}
|
||||
hasPolledSuccessfully := make(chan struct{})
|
||||
accumulator, client := newMocks(func(authHeader, since string) (*SyncResponse, int, error) {
|
||||
if since == "" {
|
||||
@ -295,6 +298,10 @@ func TestPollerPollFromExisting(t *testing.T) {
|
||||
var joinResp SyncV2JoinResponse
|
||||
joinResp.Timeline.Events = roomTimelineResponses[sinceInt]
|
||||
return &SyncResponse{
|
||||
// Add in dummy toDevice messages, so the poller actually persists the since token. (Which
|
||||
// it only does for the first poll, after 1min (this test doesn't run that long) OR there are
|
||||
// ToDevice messages in the response)
|
||||
ToDevice: EventsResponse{Events: toDeviceResponses[sinceInt]},
|
||||
NextBatch: fmt.Sprintf("%d", sinceInt+1),
|
||||
Rooms: struct {
|
||||
Join map[string]SyncV2JoinResponse `json:"join"`
|
||||
@ -336,6 +343,121 @@ func TestPollerPollFromExisting(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// Check that the since token in the database
|
||||
// 1. is updated if it is the first iteration of poll
|
||||
// 2. is NOT updated for random events
|
||||
// 3. is updated if the syncV2 response contains ToDevice messages
|
||||
// 4. is updated if at least 1min has passed since we last stored a token
|
||||
func TestPollerPollUpdateDeviceSincePeriodically(t *testing.T) {
|
||||
pid := PollerID{UserID: "@alice:localhost", DeviceID: "FOOBAR"}
|
||||
|
||||
syncResponses := make(chan *SyncResponse, 1)
|
||||
syncCalledWithSince := make(chan string)
|
||||
accumulator, client := newMocks(func(authHeader, since string) (*SyncResponse, int, error) {
|
||||
if since != "" {
|
||||
syncCalledWithSince <- since
|
||||
}
|
||||
return <-syncResponses, 200, nil
|
||||
})
|
||||
accumulator.updateSinceCalled = make(chan struct{}, 1)
|
||||
poller := newPoller(pid, "Authorization: hello world", client, accumulator, zerolog.New(os.Stderr), false)
|
||||
defer poller.Terminate()
|
||||
go func() {
|
||||
poller.Poll("0")
|
||||
}()
|
||||
|
||||
hasPolledSuccessfully := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
poller.WaitUntilInitialSync()
|
||||
close(hasPolledSuccessfully)
|
||||
}()
|
||||
|
||||
// 1. Initial poll updates the database
|
||||
next := "1"
|
||||
syncResponses <- &SyncResponse{NextBatch: next}
|
||||
mustEqualSince(t, <-syncCalledWithSince, "0")
|
||||
|
||||
select {
|
||||
case <-hasPolledSuccessfully:
|
||||
break
|
||||
case <-time.After(time.Second):
|
||||
t.Errorf("WaitUntilInitialSync failed to fire")
|
||||
}
|
||||
// Also check that UpdateDeviceSince was called
|
||||
select {
|
||||
case <-accumulator.updateSinceCalled:
|
||||
case <-time.After(time.Millisecond * 100): // give the Poller some time to process the response
|
||||
t.Fatalf("did not receive call to UpdateDeviceSince in time")
|
||||
}
|
||||
|
||||
if got := accumulator.pollerIDToSince[pid]; got != next {
|
||||
t.Fatalf("expected since to be updated to %s, but got %s", next, got)
|
||||
}
|
||||
|
||||
// The since token used by calls to doSyncV2
|
||||
wantSinceFromSync := next
|
||||
|
||||
// 2. Second request updates the state but NOT the database
|
||||
syncResponses <- &SyncResponse{NextBatch: "2"}
|
||||
mustEqualSince(t, <-syncCalledWithSince, wantSinceFromSync)
|
||||
|
||||
select {
|
||||
case <-accumulator.updateSinceCalled:
|
||||
t.Fatalf("unexpected call to UpdateDeviceSince")
|
||||
case <-time.After(time.Millisecond * 100):
|
||||
}
|
||||
|
||||
if got := accumulator.pollerIDToSince[pid]; got != next {
|
||||
t.Fatalf("expected since to be updated to %s, but got %s", next, got)
|
||||
}
|
||||
|
||||
// 3. Sync response contains a toDevice message and should be stored in the database
|
||||
wantSinceFromSync = "2"
|
||||
next = "3"
|
||||
syncResponses <- &SyncResponse{
|
||||
NextBatch: next,
|
||||
ToDevice: EventsResponse{Events: []json.RawMessage{{}}},
|
||||
}
|
||||
mustEqualSince(t, <-syncCalledWithSince, wantSinceFromSync)
|
||||
|
||||
select {
|
||||
case <-accumulator.updateSinceCalled:
|
||||
case <-time.After(time.Millisecond * 100):
|
||||
t.Fatalf("did not receive call to UpdateDeviceSince in time")
|
||||
}
|
||||
|
||||
if got := accumulator.pollerIDToSince[pid]; got != next {
|
||||
t.Fatalf("expected since to be updated to %s, but got %s", wantSinceFromSync, got)
|
||||
}
|
||||
wantSinceFromSync = next
|
||||
|
||||
// 4. ... some time has passed, this triggers the 1min limit
|
||||
timeSince = func(d time.Time) time.Duration {
|
||||
return time.Minute * 2
|
||||
}
|
||||
next = "10"
|
||||
syncResponses <- &SyncResponse{NextBatch: next}
|
||||
mustEqualSince(t, <-syncCalledWithSince, wantSinceFromSync)
|
||||
|
||||
select {
|
||||
case <-accumulator.updateSinceCalled:
|
||||
case <-time.After(time.Millisecond * 100):
|
||||
t.Fatalf("did not receive call to UpdateDeviceSince in time")
|
||||
}
|
||||
|
||||
if got := accumulator.pollerIDToSince[pid]; got != next {
|
||||
t.Fatalf("expected since to be updated to %s, but got %s", wantSinceFromSync, got)
|
||||
}
|
||||
}
|
||||
|
||||
func mustEqualSince(t *testing.T, gotSince, expectedSince string) {
|
||||
t.Helper()
|
||||
if gotSince != expectedSince {
|
||||
t.Fatalf("client.DoSyncV2 using unexpected since token: %s, want %s", gotSince, expectedSince)
|
||||
}
|
||||
}
|
||||
|
||||
// Tests that the poller backs off in 2,4,8,etc second increments to a variety of errors
|
||||
func TestPollerBackoff(t *testing.T) {
|
||||
deviceID := "FOOBAR"
|
||||
@ -453,11 +575,12 @@ func (c *mockClient) WhoAmI(authHeader string) (string, string, error) {
|
||||
}
|
||||
|
||||
type mockDataReceiver struct {
|
||||
states map[string][]json.RawMessage
|
||||
timelines map[string][]json.RawMessage
|
||||
pollerIDToSince map[PollerID]string
|
||||
incomingProcess chan struct{}
|
||||
unblockProcess chan struct{}
|
||||
states map[string][]json.RawMessage
|
||||
timelines map[string][]json.RawMessage
|
||||
pollerIDToSince map[PollerID]string
|
||||
incomingProcess chan struct{}
|
||||
unblockProcess chan struct{}
|
||||
updateSinceCalled chan struct{}
|
||||
}
|
||||
|
||||
func (a *mockDataReceiver) Accumulate(ctx context.Context, userID, deviceID, roomID, prevBatch string, timeline []json.RawMessage) {
|
||||
@ -475,10 +598,13 @@ func (a *mockDataReceiver) Initialise(ctx context.Context, roomID string, state
|
||||
// timeline. Untested here---return nil for now.
|
||||
return nil
|
||||
}
|
||||
func (a *mockDataReceiver) SetTyping(ctx context.Context, roomID string, ephEvent json.RawMessage) {
|
||||
func (a *mockDataReceiver) SetTyping(ctx context.Context, pollerID PollerID, roomID string, ephEvent json.RawMessage) {
|
||||
}
|
||||
func (s *mockDataReceiver) UpdateDeviceSince(ctx context.Context, userID, deviceID, since string) {
|
||||
s.pollerIDToSince[PollerID{UserID: userID, DeviceID: deviceID}] = since
|
||||
if s.updateSinceCalled != nil {
|
||||
s.updateSinceCalled <- struct{}{}
|
||||
}
|
||||
}
|
||||
func (s *mockDataReceiver) AddToDeviceMessages(ctx context.Context, userID, deviceID string, msgs []json.RawMessage) {
|
||||
}
|
||||
@ -491,10 +617,11 @@ func (s *mockDataReceiver) OnReceipt(ctx context.Context, userID, roomID, ephEve
|
||||
}
|
||||
func (s *mockDataReceiver) OnInvite(ctx context.Context, userID, roomID string, inviteState []json.RawMessage) {
|
||||
}
|
||||
func (s *mockDataReceiver) OnLeftRoom(ctx context.Context, userID, roomID string) {}
|
||||
func (s *mockDataReceiver) OnLeftRoom(ctx context.Context, userID, roomID string, leaveEvent json.RawMessage) {
|
||||
}
|
||||
func (s *mockDataReceiver) OnE2EEData(ctx context.Context, userID, deviceID string, otkCounts map[string]int, fallbackKeyTypes []string, deviceListChanges map[string]int) {
|
||||
}
|
||||
func (s *mockDataReceiver) OnTerminated(ctx context.Context, userID, deviceID string) {}
|
||||
func (s *mockDataReceiver) OnTerminated(ctx context.Context, pollerID PollerID) {}
|
||||
func (s *mockDataReceiver) OnExpiredToken(ctx context.Context, accessTokenHash, userID, deviceID string) {
|
||||
}
|
||||
|
||||
|
@ -2,7 +2,6 @@ package sync2
|
||||
|
||||
import (
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/getsentry/sentry-go"
|
||||
"github.com/jmoiron/sqlx"
|
||||
@ -27,9 +26,10 @@ func NewStore(postgresURI, secret string) *Storage {
|
||||
// TODO: if we panic(), will sentry have a chance to flush the event?
|
||||
logger.Panic().Err(err).Str("uri", postgresURI).Msg("failed to open SQL DB")
|
||||
}
|
||||
db.SetMaxOpenConns(100)
|
||||
db.SetMaxIdleConns(80)
|
||||
db.SetConnMaxLifetime(time.Hour)
|
||||
return NewStoreWithDB(db, secret)
|
||||
}
|
||||
|
||||
func NewStoreWithDB(db *sqlx.DB, secret string) *Storage {
|
||||
return &Storage{
|
||||
DevicesTable: NewDevicesTable(db),
|
||||
TokensTable: NewTokensTable(db, secret),
|
||||
|
@ -171,10 +171,10 @@ func (t *TokensTable) TokenForEachDevice(txn *sqlx.Tx) (tokens []TokenForPoller,
|
||||
}
|
||||
|
||||
// Insert a new token into the table.
|
||||
func (t *TokensTable) Insert(plaintextToken, userID, deviceID string, lastSeen time.Time) (*Token, error) {
|
||||
func (t *TokensTable) Insert(txn *sqlx.Tx, plaintextToken, userID, deviceID string, lastSeen time.Time) (*Token, error) {
|
||||
hashedToken := hashToken(plaintextToken)
|
||||
encToken := t.encrypt(plaintextToken)
|
||||
_, err := t.db.Exec(
|
||||
_, err := txn.Exec(
|
||||
`INSERT INTO syncv3_sync2_tokens(token_hash, token_encrypted, user_id, device_id, last_seen)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
ON CONFLICT (token_hash) DO NOTHING;`,
|
||||
|
@ -1,6 +1,8 @@
|
||||
package sync2
|
||||
|
||||
import (
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/matrix-org/sliding-sync/sqlutil"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
@ -26,27 +28,31 @@ func TestTokensTable(t *testing.T) {
|
||||
aliceSecret1 := "mysecret1"
|
||||
aliceToken1FirstSeen := time.Now()
|
||||
|
||||
// Test a single token
|
||||
t.Log("Insert a new token from Alice.")
|
||||
aliceToken, err := tokens.Insert(aliceSecret1, alice, aliceDevice, aliceToken1FirstSeen)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to Insert token: %s", err)
|
||||
}
|
||||
var aliceToken, reinsertedToken *Token
|
||||
_ = sqlutil.WithTransaction(db, func(txn *sqlx.Tx) (err error) {
|
||||
// Test a single token
|
||||
t.Log("Insert a new token from Alice.")
|
||||
aliceToken, err = tokens.Insert(txn, aliceSecret1, alice, aliceDevice, aliceToken1FirstSeen)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to Insert token: %s", err)
|
||||
}
|
||||
|
||||
t.Log("The returned Token struct should have been populated correctly.")
|
||||
assertEqualTokens(t, tokens, aliceToken, aliceSecret1, alice, aliceDevice, aliceToken1FirstSeen)
|
||||
t.Log("The returned Token struct should have been populated correctly.")
|
||||
assertEqualTokens(t, tokens, aliceToken, aliceSecret1, alice, aliceDevice, aliceToken1FirstSeen)
|
||||
|
||||
t.Log("Reinsert the same token.")
|
||||
reinsertedToken, err := tokens.Insert(aliceSecret1, alice, aliceDevice, aliceToken1FirstSeen)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to Insert token: %s", err)
|
||||
}
|
||||
t.Log("Reinsert the same token.")
|
||||
reinsertedToken, err = tokens.Insert(txn, aliceSecret1, alice, aliceDevice, aliceToken1FirstSeen)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to Insert token: %s", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
t.Log("This should yield an equal Token struct.")
|
||||
assertEqualTokens(t, tokens, reinsertedToken, aliceSecret1, alice, aliceDevice, aliceToken1FirstSeen)
|
||||
|
||||
t.Log("Try to mark Alice's token as being used after an hour.")
|
||||
err = tokens.MaybeUpdateLastSeen(aliceToken, aliceToken1FirstSeen.Add(time.Hour))
|
||||
err := tokens.MaybeUpdateLastSeen(aliceToken, aliceToken1FirstSeen.Add(time.Hour))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to update last seen: %s", err)
|
||||
}
|
||||
@ -74,17 +80,20 @@ func TestTokensTable(t *testing.T) {
|
||||
}
|
||||
assertEqualTokens(t, tokens, fetchedToken, aliceSecret1, alice, aliceDevice, aliceToken1LastSeen)
|
||||
|
||||
// Test a second token for Alice
|
||||
t.Log("Insert a second token for Alice.")
|
||||
aliceSecret2 := "mysecret2"
|
||||
aliceToken2FirstSeen := aliceToken1LastSeen.Add(time.Minute)
|
||||
aliceToken2, err := tokens.Insert(aliceSecret2, alice, aliceDevice, aliceToken2FirstSeen)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to Insert token: %s", err)
|
||||
}
|
||||
_ = sqlutil.WithTransaction(db, func(txn *sqlx.Tx) error {
|
||||
// Test a second token for Alice
|
||||
t.Log("Insert a second token for Alice.")
|
||||
aliceSecret2 := "mysecret2"
|
||||
aliceToken2FirstSeen := aliceToken1LastSeen.Add(time.Minute)
|
||||
aliceToken2, err := tokens.Insert(txn, aliceSecret2, alice, aliceDevice, aliceToken2FirstSeen)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to Insert token: %s", err)
|
||||
}
|
||||
|
||||
t.Log("The returned Token struct should have been populated correctly.")
|
||||
assertEqualTokens(t, tokens, aliceToken2, aliceSecret2, alice, aliceDevice, aliceToken2FirstSeen)
|
||||
t.Log("The returned Token struct should have been populated correctly.")
|
||||
assertEqualTokens(t, tokens, aliceToken2, aliceSecret2, alice, aliceDevice, aliceToken2FirstSeen)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func TestDeletingTokens(t *testing.T) {
|
||||
@ -94,11 +103,15 @@ func TestDeletingTokens(t *testing.T) {
|
||||
|
||||
t.Log("Insert a new token from Alice.")
|
||||
accessToken := "mytoken"
|
||||
token, err := tokens.Insert(accessToken, "@bob:builders.com", "device", time.Time{})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to Insert token: %s", err)
|
||||
}
|
||||
|
||||
var token *Token
|
||||
err := sqlutil.WithTransaction(db, func(txn *sqlx.Tx) (err error) {
|
||||
token, err = tokens.Insert(txn, accessToken, "@bob:builders.com", "device", time.Time{})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to Insert token: %s", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
t.Log("We should be able to fetch this token without error.")
|
||||
_, err = tokens.Token(accessToken)
|
||||
if err != nil {
|
||||
|
117
sync2/txnid.go
117
sync2/txnid.go
@ -1,38 +1,121 @@
|
||||
package sync2
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/ReneKroon/ttlcache/v2"
|
||||
)
|
||||
|
||||
type TransactionIDCache struct {
|
||||
cache *ttlcache.Cache
|
||||
type loaderFunc func(userID string) (deviceIDs []string)
|
||||
|
||||
// PendingTransactionIDs is (conceptually) a map from event IDs to a list of device IDs.
|
||||
// Its keys are the IDs of event we've seen which a) lack a transaction ID, and b) were
|
||||
// sent by one of the users we are polling for. The values are the list of the sender's
|
||||
// devices whose pollers are yet to see a transaction ID.
|
||||
//
|
||||
// If another poller sees the same event
|
||||
//
|
||||
// - with a transaction ID, it emits a V2TransactionID payload with that ID and
|
||||
// removes the event ID from this map.
|
||||
//
|
||||
// - without a transaction ID, it removes the polling device ID from the values
|
||||
// list. If the device ID list is now empty, the poller emits an "all clear"
|
||||
// V2TransactionID payload.
|
||||
//
|
||||
// This is a best-effort affair to ensure that the rest of the proxy can wait for
|
||||
// transaction IDs to appear before transmitting an event down /sync to its sender.
|
||||
//
|
||||
// It's possible that we add an entry to this map and then the list of remaining
|
||||
// device IDs becomes out of date, either due to a new device creation or an
|
||||
// existing device expiring. We choose not to handle this case, because it is relatively
|
||||
// rare.
|
||||
//
|
||||
// To avoid the map growing without bound, we use a ttlcache and drop entries
|
||||
// after a short period of time.
|
||||
type PendingTransactionIDs struct {
|
||||
// mu guards the pending field. See MissingTxnID for rationale.
|
||||
mu sync.Mutex
|
||||
pending *ttlcache.Cache
|
||||
// loader should provide the list of device IDs
|
||||
loader loaderFunc
|
||||
}
|
||||
|
||||
func NewTransactionIDCache() *TransactionIDCache {
|
||||
func NewPendingTransactionIDs(loader loaderFunc) *PendingTransactionIDs {
|
||||
c := ttlcache.NewCache()
|
||||
c.SetTTL(5 * time.Minute) // keep transaction IDs for 5 minutes before forgetting about them
|
||||
c.SkipTTLExtensionOnHit(true) // we don't care how many times they ask for the item, 5min is the limit.
|
||||
return &TransactionIDCache{
|
||||
cache: c,
|
||||
return &PendingTransactionIDs{
|
||||
mu: sync.Mutex{},
|
||||
pending: c,
|
||||
loader: loader,
|
||||
}
|
||||
}
|
||||
|
||||
// Store a new transaction ID received via v2 /sync
|
||||
func (c *TransactionIDCache) Store(userID, eventID, txnID string) {
|
||||
c.cache.Set(cacheKey(userID, eventID), txnID)
|
||||
}
|
||||
// MissingTxnID should be called to report that this device ID did not see a
|
||||
// transaction ID for this event ID. Returns true if this is the first time we know
|
||||
// for sure that we'll never see a txn ID for this event.
|
||||
func (c *PendingTransactionIDs) MissingTxnID(eventID, userID, myDeviceID string) (bool, error) {
|
||||
// While ttlcache is threadsafe, it does not provide a way to atomically update
|
||||
// (get+set) a value, which means we are still open to races. For example:
|
||||
//
|
||||
// - We have three pollers A, B, C.
|
||||
// - Poller A sees an event without txn id and calls MissingTxnID.
|
||||
// - `c.pending.Get()` fails, so we load up all device IDs: [A, B, C].
|
||||
// - Then `c.pending.Set()` with [B, C].
|
||||
// - Poller B sees the same event, also missing txn ID and calls MissingTxnID.
|
||||
// - Poller C does the same concurrently.
|
||||
//
|
||||
// If the Get+Set isn't atomic, then we might do e.g.
|
||||
// - B gets [B, C] and prepares to write [C].
|
||||
// - C gets [B, C] and prepares to write [B].
|
||||
// - Last writer wins. Either way, we never write [] and so never return true
|
||||
// (the all-clear signal.)
|
||||
//
|
||||
// This wouldn't be the end of the world (the API process has a maximum delay, and
|
||||
// the ttlcache will expire the entry), but it would still be nice to avoid it.
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
// Get a transaction ID previously stored.
|
||||
func (c *TransactionIDCache) Get(userID, eventID string) string {
|
||||
val, _ := c.cache.Get(cacheKey(userID, eventID))
|
||||
if val != nil {
|
||||
return val.(string)
|
||||
data, err := c.pending.Get(eventID)
|
||||
if err == ttlcache.ErrNotFound {
|
||||
data = c.loader(userID)
|
||||
} else if err != nil {
|
||||
return false, fmt.Errorf("PendingTransactionIDs: failed to get device ids: %w", err)
|
||||
}
|
||||
return ""
|
||||
|
||||
deviceIDs, ok := data.([]string)
|
||||
if !ok {
|
||||
return false, fmt.Errorf("PendingTransactionIDs: failed to cast device IDs")
|
||||
}
|
||||
|
||||
deviceIDs, changed := removeDevice(myDeviceID, deviceIDs)
|
||||
if changed {
|
||||
err = c.pending.Set(eventID, deviceIDs)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("PendingTransactionIDs: failed to set device IDs: %w", err)
|
||||
}
|
||||
}
|
||||
return changed && len(deviceIDs) == 0, nil
|
||||
}
|
||||
|
||||
func cacheKey(userID, eventID string) string {
|
||||
return userID + " " + eventID
|
||||
// SeenTxnID should be called to report that this device saw a transaction ID
|
||||
// for this event.
|
||||
func (c *PendingTransactionIDs) SeenTxnID(eventID string) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return c.pending.Set(eventID, []string{})
|
||||
}
|
||||
|
||||
// removeDevice takes a device ID slice and returns a device ID slice with one
|
||||
// particular string removed. Assumes that the given slice has no duplicates.
|
||||
// Does not modify the given slice in situ.
|
||||
func removeDevice(device string, devices []string) ([]string, bool) {
|
||||
for i, otherDevice := range devices {
|
||||
if otherDevice == device {
|
||||
return append(devices[:i], devices[i+1:]...), true
|
||||
}
|
||||
}
|
||||
return devices, false
|
||||
}
|
||||
|
@ -2,54 +2,107 @@ package sync2
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestTransactionIDCache(t *testing.T) {
|
||||
alice := "@alice:localhost"
|
||||
bob := "@bob:localhost"
|
||||
eventA := "$a:localhost"
|
||||
eventB := "$b:localhost"
|
||||
eventC := "$c:localhost"
|
||||
txn1 := "1"
|
||||
txn2 := "2"
|
||||
cache := NewTransactionIDCache()
|
||||
cache.Store(alice, eventA, txn1)
|
||||
cache.Store(bob, eventB, txn1) // different users can use same txn ID
|
||||
cache.Store(alice, eventC, txn2)
|
||||
|
||||
testCases := []struct {
|
||||
eventID string
|
||||
userID string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
eventID: eventA,
|
||||
userID: alice,
|
||||
want: txn1,
|
||||
},
|
||||
{
|
||||
eventID: eventB,
|
||||
userID: bob,
|
||||
want: txn1,
|
||||
},
|
||||
{
|
||||
eventID: eventC,
|
||||
userID: alice,
|
||||
want: txn2,
|
||||
},
|
||||
{
|
||||
eventID: "$invalid",
|
||||
userID: alice,
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
eventID: eventA,
|
||||
userID: "@invalid",
|
||||
want: "",
|
||||
},
|
||||
func TestPendingTransactionIDs(t *testing.T) {
|
||||
pollingDevicesByUser := map[string][]string{
|
||||
"alice": {"A1", "A2"},
|
||||
"bob": {"B1"},
|
||||
"chris": {},
|
||||
"delia": {"D1", "D2", "D3", "D4"},
|
||||
"enid": {"E1", "E2"},
|
||||
}
|
||||
for _, tc := range testCases {
|
||||
txnID := cache.Get(tc.userID, tc.eventID)
|
||||
if txnID != tc.want {
|
||||
t.Errorf("%+v: got %v want %v", tc, txnID, tc.want)
|
||||
mockLoad := func(userID string) (deviceIDs []string) {
|
||||
devices, ok := pollingDevicesByUser[userID]
|
||||
if !ok {
|
||||
t.Fatalf("Mock didn't have devices for %s", userID)
|
||||
}
|
||||
newDevices := make([]string, len(devices))
|
||||
copy(newDevices, devices)
|
||||
return newDevices
|
||||
}
|
||||
|
||||
pending := NewPendingTransactionIDs(mockLoad)
|
||||
|
||||
// Alice.
|
||||
// We're tracking two of Alice's devices.
|
||||
allClear, err := pending.MissingTxnID("event1", "alice", "A1")
|
||||
assertNoError(t, err)
|
||||
assertAllClear(t, allClear, false) // waiting on A2
|
||||
|
||||
// If for some reason the poller sees the same event for the same device, we should
|
||||
// still be waiting for A2.
|
||||
allClear, err = pending.MissingTxnID("event1", "alice", "A1")
|
||||
assertNoError(t, err)
|
||||
assertAllClear(t, allClear, false)
|
||||
|
||||
// If for some reason Alice spun up a new device, we are still going to be waiting
|
||||
// for A2.
|
||||
allClear, err = pending.MissingTxnID("event1", "alice", "A_unknown_device")
|
||||
assertNoError(t, err)
|
||||
assertAllClear(t, allClear, false)
|
||||
|
||||
// If A2 sees the event without a txnID, we should emit the all clear signal.
|
||||
allClear, err = pending.MissingTxnID("event1", "alice", "A2")
|
||||
assertNoError(t, err)
|
||||
assertAllClear(t, allClear, true)
|
||||
|
||||
// If for some reason A2 sees the event a second time, we shouldn't re-emit the
|
||||
// all clear signal.
|
||||
allClear, err = pending.MissingTxnID("event1", "alice", "A2")
|
||||
assertNoError(t, err)
|
||||
assertAllClear(t, allClear, false)
|
||||
|
||||
// Bob.
|
||||
// We're only tracking one device for Bob
|
||||
allClear, err = pending.MissingTxnID("event2", "bob", "B1")
|
||||
assertNoError(t, err)
|
||||
assertAllClear(t, allClear, true) // not waiting on any devices
|
||||
|
||||
// Chris.
|
||||
// We're not tracking any devices for Chris. A MissingTxnID call for him shouldn't
|
||||
// cause anything to explode.
|
||||
allClear, err = pending.MissingTxnID("event3", "chris", "C_unknown_device")
|
||||
assertNoError(t, err)
|
||||
|
||||
// Delia.
|
||||
// Delia is tracking four devices.
|
||||
allClear, err = pending.MissingTxnID("event4", "delia", "D1")
|
||||
assertNoError(t, err)
|
||||
assertAllClear(t, allClear, false) // waiting on D2, D3 and D4
|
||||
|
||||
// One of Delia's devices, say D2, sees a txn ID for event 4.
|
||||
err = pending.SeenTxnID("event4")
|
||||
assertNoError(t, err)
|
||||
|
||||
// The other devices see the event. Neither should emit all clear.
|
||||
allClear, err = pending.MissingTxnID("event4", "delia", "D3")
|
||||
assertNoError(t, err)
|
||||
assertAllClear(t, allClear, false)
|
||||
|
||||
allClear, err = pending.MissingTxnID("event4", "delia", "D4")
|
||||
assertNoError(t, err)
|
||||
assertAllClear(t, allClear, false)
|
||||
|
||||
// Enid.
|
||||
// Enid has two devices. Her first poller (E1) is lucky and sees the transaction ID.
|
||||
err = pending.SeenTxnID("event5")
|
||||
assertNoError(t, err)
|
||||
|
||||
// Her second poller misses the transaction ID, but this shouldn't cause an all clear.
|
||||
allClear, err = pending.MissingTxnID("event4", "delia", "E2")
|
||||
assertNoError(t, err)
|
||||
assertAllClear(t, allClear, false)
|
||||
}
|
||||
|
||||
func assertAllClear(t *testing.T, got bool, want bool) {
|
||||
t.Helper()
|
||||
if got != want {
|
||||
t.Errorf("Expected allClear=%t, got %t", want, got)
|
||||
}
|
||||
}
|
||||
|
||||
func assertNoError(t *testing.T, err error) {
|
||||
t.Helper()
|
||||
if err != nil {
|
||||
t.Fatalf("got error: %s", err)
|
||||
}
|
||||
}
|
||||
|
42
sync3/avatar.go
Normal file
42
sync3/avatar.go
Normal file
@ -0,0 +1,42 @@
|
||||
package sync3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
// An AvatarChange represents a change to a room's avatar. There are three cases:
|
||||
// - an empty string represents no change, and should be omitted when JSON-serialised;
|
||||
// - the sentinel `<no-avatar>` represents a room that has never had an avatar,
|
||||
// or a room whose avatar has been removed. It is JSON-serialised as null.
|
||||
// - All other strings represent the current avatar of the room and JSON-serialise as
|
||||
// normal.
|
||||
type AvatarChange string
|
||||
|
||||
const DeletedAvatar = AvatarChange("<no-avatar>")
|
||||
const UnchangedAvatar AvatarChange = ""
|
||||
|
||||
// NewAvatarChange interprets an optional avatar string as an AvatarChange.
|
||||
func NewAvatarChange(avatar string) AvatarChange {
|
||||
if avatar == "" {
|
||||
return DeletedAvatar
|
||||
}
|
||||
return AvatarChange(avatar)
|
||||
}
|
||||
|
||||
func (a AvatarChange) MarshalJSON() ([]byte, error) {
|
||||
if a == DeletedAvatar {
|
||||
return []byte(`null`), nil
|
||||
} else {
|
||||
return json.Marshal(string(a))
|
||||
}
|
||||
}
|
||||
|
||||
// Note: the unmarshalling is only used in tests.
|
||||
func (a *AvatarChange) UnmarshalJSON(data []byte) error {
|
||||
if bytes.Equal(data, []byte("null")) {
|
||||
*a = DeletedAvatar
|
||||
return nil
|
||||
}
|
||||
return json.Unmarshal(data, (*string)(a))
|
||||
}
|
@ -21,6 +21,15 @@ type EventData struct {
|
||||
Content gjson.Result
|
||||
Timestamp uint64
|
||||
Sender string
|
||||
// TransactionID is the unsigned.transaction_id field in the event as stored in the
|
||||
// syncv3_events table, or the empty string if there is no such field.
|
||||
//
|
||||
// We may see the event on poller A without a transaction_id, and then later on
|
||||
// poller B with a transaction_id. If this happens, we make a temporary note of the
|
||||
// transaction_id in the syncv3_txns table, but do not edit the persisted event.
|
||||
// This means that this field is not authoritative; we only include it here as a
|
||||
// hint to avoid unnecessary waits for V2TransactionID payloads.
|
||||
TransactionID string
|
||||
|
||||
// the number of joined users in this room. Use this value and don't try to work it out as you
|
||||
// may get it wrong due to Synapse sending duplicate join events(!) This value has them de-duped
|
||||
@ -118,13 +127,7 @@ func (c *GlobalCache) copyRoom(roomID string) *internal.RoomMetadata {
|
||||
logger.Warn().Str("room", roomID).Msg("GlobalCache.LoadRoom: no metadata for this room, returning stub")
|
||||
return internal.NewRoomMetadata(roomID)
|
||||
}
|
||||
srCopy := *sr
|
||||
// copy the heroes or else we may modify the same slice which would be bad :(
|
||||
srCopy.Heroes = make([]internal.Hero, len(sr.Heroes))
|
||||
for i := range sr.Heroes {
|
||||
srCopy.Heroes[i] = sr.Heroes[i]
|
||||
}
|
||||
return &srCopy
|
||||
return sr.CopyHeroes()
|
||||
}
|
||||
|
||||
// LoadJoinedRooms loads all current joined room metadata for the user given, together
|
||||
@ -158,7 +161,7 @@ func (c *GlobalCache) LoadJoinedRooms(ctx context.Context, userID string) (
|
||||
i++
|
||||
}
|
||||
|
||||
latestNIDs, err = c.store.EventsTable.LatestEventNIDInRooms(roomIDs, initialLoadPosition)
|
||||
latestNIDs, err = c.store.LatestEventNIDInRooms(roomIDs, initialLoadPosition)
|
||||
if err != nil {
|
||||
return 0, nil, nil, nil, err
|
||||
}
|
||||
@ -239,8 +242,13 @@ func (c *GlobalCache) Startup(roomIDToMetadata map[string]internal.RoomMetadata)
|
||||
sort.Strings(roomIDs)
|
||||
for _, roomID := range roomIDs {
|
||||
metadata := roomIDToMetadata[roomID]
|
||||
internal.Assert("room ID is set", metadata.RoomID != "")
|
||||
internal.Assert("last message timestamp exists", metadata.LastMessageTimestamp > 1)
|
||||
debugContext := map[string]interface{}{
|
||||
"room_id": roomID,
|
||||
"metadata.RoomID": metadata.RoomID,
|
||||
"metadata.LastMessageTimeStamp": metadata.LastMessageTimestamp,
|
||||
}
|
||||
internal.Assert("room ID is set", metadata.RoomID != "", debugContext)
|
||||
internal.Assert("last message timestamp exists", metadata.LastMessageTimestamp > 1, debugContext)
|
||||
c.roomIDToMetadata[roomID] = &metadata
|
||||
}
|
||||
return nil
|
||||
@ -285,6 +293,10 @@ func (c *GlobalCache) OnNewEvent(
|
||||
if ed.StateKey != nil && *ed.StateKey == "" {
|
||||
metadata.NameEvent = ed.Content.Get("name").Str
|
||||
}
|
||||
case "m.room.avatar":
|
||||
if ed.StateKey != nil && *ed.StateKey == "" {
|
||||
metadata.AvatarEvent = ed.Content.Get("url").Str
|
||||
}
|
||||
case "m.room.encryption":
|
||||
if ed.StateKey != nil && *ed.StateKey == "" {
|
||||
metadata.Encrypted = true
|
||||
@ -349,14 +361,16 @@ func (c *GlobalCache) OnNewEvent(
|
||||
for i := range metadata.Heroes {
|
||||
if metadata.Heroes[i].ID == *ed.StateKey {
|
||||
metadata.Heroes[i].Name = ed.Content.Get("displayname").Str
|
||||
metadata.Heroes[i].Avatar = ed.Content.Get("avatar_url").Str
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
metadata.Heroes = append(metadata.Heroes, internal.Hero{
|
||||
ID: *ed.StateKey,
|
||||
Name: ed.Content.Get("displayname").Str,
|
||||
ID: *ed.StateKey,
|
||||
Name: ed.Content.Get("displayname").Str,
|
||||
Avatar: ed.Content.Get("avatar_url").Str,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -38,12 +38,12 @@ func TestGlobalCacheLoadState(t *testing.T) {
|
||||
testutils.NewStateEvent(t, "m.room.name", "", alice, map[string]interface{}{"name": "The Room Name"}),
|
||||
testutils.NewStateEvent(t, "m.room.name", "", alice, map[string]interface{}{"name": "The Updated Room Name"}),
|
||||
}
|
||||
_, _, err := store.Accumulate(roomID2, "", eventsRoom2)
|
||||
_, _, err := store.Accumulate(alice, roomID2, "", eventsRoom2)
|
||||
if err != nil {
|
||||
t.Fatalf("Accumulate: %s", err)
|
||||
}
|
||||
|
||||
_, latestNIDs, err := store.Accumulate(roomID, "", events)
|
||||
_, latestNIDs, err := store.Accumulate(alice, roomID, "", events)
|
||||
if err != nil {
|
||||
t.Fatalf("Accumulate: %s", err)
|
||||
}
|
||||
|
@ -38,15 +38,6 @@ func (u *InviteUpdate) Type() string {
|
||||
return fmt.Sprintf("InviteUpdate[%s]", u.RoomID())
|
||||
}
|
||||
|
||||
// LeftRoomUpdate corresponds to a key-value pair from a v2 sync's `leave` section.
|
||||
type LeftRoomUpdate struct {
|
||||
RoomUpdate
|
||||
}
|
||||
|
||||
func (u *LeftRoomUpdate) Type() string {
|
||||
return fmt.Sprintf("LeftRoomUpdate[%s]", u.RoomID())
|
||||
}
|
||||
|
||||
// TypingEdu corresponds to a typing EDU in the `ephemeral` section of a joined room's v2 sync resposne.
|
||||
type TypingUpdate struct {
|
||||
RoomUpdate
|
||||
|
@ -46,10 +46,18 @@ type UserRoomData struct {
|
||||
// The zero value of this safe to use (0 latest nid, no prev batch, no timeline).
|
||||
RequestedLatestEvents state.LatestEvents
|
||||
|
||||
// TODO: should Canonicalised really be in RoomConMetadata? It's only set in SetRoom AFAICS
|
||||
// TODO: should CanonicalisedName really be in RoomConMetadata? It's only set in SetRoom AFAICS
|
||||
CanonicalisedName string // stripped leading symbols like #, all in lower case
|
||||
// Set of spaces this room is a part of, from the perspective of this user. This is NOT global room data
|
||||
// as the set of spaces may be different for different users.
|
||||
|
||||
// ResolvedAvatarURL is the avatar that should be displayed to this user to
|
||||
// represent this room. The empty string means that this room has no avatar.
|
||||
// Avatars set in m.room.avatar take precedence; if this is missing and the room is
|
||||
// a DM with one other user joined or invited, we fall back to that user's
|
||||
// avatar (if any) as specified in their membership event in that room.
|
||||
ResolvedAvatarURL string
|
||||
|
||||
Spaces map[string]struct{}
|
||||
// Map of tag to order float.
|
||||
// See https://spec.matrix.org/latest/client-server-api/#room-tagging
|
||||
@ -73,6 +81,7 @@ type InviteData struct {
|
||||
Heroes []internal.Hero
|
||||
InviteEvent *EventData
|
||||
NameEvent string // the content of m.room.name, NOT the calculated name
|
||||
AvatarEvent string // the content of m.room.avatar, NOT the calculated avatar
|
||||
CanonicalAlias string
|
||||
LastMessageTimestamp uint64
|
||||
Encrypted bool
|
||||
@ -108,12 +117,15 @@ func NewInviteData(ctx context.Context, userID, roomID string, inviteState []jso
|
||||
id.IsDM = j.Get("is_direct").Bool()
|
||||
} else if target == j.Get("sender").Str {
|
||||
id.Heroes = append(id.Heroes, internal.Hero{
|
||||
ID: target,
|
||||
Name: j.Get("content.displayname").Str,
|
||||
ID: target,
|
||||
Name: j.Get("content.displayname").Str,
|
||||
Avatar: j.Get("content.avatar_url").Str,
|
||||
})
|
||||
}
|
||||
case "m.room.name":
|
||||
id.NameEvent = j.Get("content.name").Str
|
||||
case "m.room.avatar":
|
||||
id.AvatarEvent = j.Get("content.url").Str
|
||||
case "m.room.canonical_alias":
|
||||
id.CanonicalAlias = j.Get("content.alias").Str
|
||||
case "m.room.encryption":
|
||||
@ -147,6 +159,7 @@ func (i *InviteData) RoomMetadata() *internal.RoomMetadata {
|
||||
metadata := internal.NewRoomMetadata(i.roomID)
|
||||
metadata.Heroes = i.Heroes
|
||||
metadata.NameEvent = i.NameEvent
|
||||
metadata.AvatarEvent = i.AvatarEvent
|
||||
metadata.CanonicalAlias = i.CanonicalAlias
|
||||
metadata.InviteCount = 1
|
||||
metadata.JoinCount = 1
|
||||
@ -178,18 +191,22 @@ type UserCache struct {
|
||||
store *state.Storage
|
||||
globalCache *GlobalCache
|
||||
txnIDs TransactionIDFetcher
|
||||
ignoredUsers map[string]struct{}
|
||||
ignoredUsersMu *sync.RWMutex
|
||||
}
|
||||
|
||||
func NewUserCache(userID string, globalCache *GlobalCache, store *state.Storage, txnIDs TransactionIDFetcher) *UserCache {
|
||||
uc := &UserCache{
|
||||
UserID: userID,
|
||||
roomToDataMu: &sync.RWMutex{},
|
||||
roomToData: make(map[string]UserRoomData),
|
||||
listeners: make(map[int]UserCacheListener),
|
||||
listenersMu: &sync.RWMutex{},
|
||||
store: store,
|
||||
globalCache: globalCache,
|
||||
txnIDs: txnIDs,
|
||||
UserID: userID,
|
||||
roomToDataMu: &sync.RWMutex{},
|
||||
roomToData: make(map[string]UserRoomData),
|
||||
listeners: make(map[int]UserCacheListener),
|
||||
listenersMu: &sync.RWMutex{},
|
||||
store: store,
|
||||
globalCache: globalCache,
|
||||
txnIDs: txnIDs,
|
||||
ignoredUsers: make(map[string]struct{}),
|
||||
ignoredUsersMu: &sync.RWMutex{},
|
||||
}
|
||||
return uc
|
||||
}
|
||||
@ -212,7 +229,7 @@ func (c *UserCache) Unsubscribe(id int) {
|
||||
// OnRegistered is called after the sync3.Dispatcher has successfully registered this
|
||||
// cache to receive updates. We use this to run some final initialisation logic that
|
||||
// is sensitive to race conditions; confusingly, most of the initialisation is driven
|
||||
// externally by sync3.SyncLiveHandler.userCache. It's importatn that we don't spend too
|
||||
// externally by sync3.SyncLiveHandler.userCaches. It's important that we don't spend too
|
||||
// long inside this function, because it is called within a global lock on the
|
||||
// sync3.Dispatcher (see sync3.Dispatcher.Register).
|
||||
func (c *UserCache) OnRegistered(ctx context.Context) error {
|
||||
@ -309,6 +326,7 @@ func (c *UserCache) LazyLoadTimelines(ctx context.Context, loadPos int64, roomID
|
||||
urd = NewUserRoomData()
|
||||
}
|
||||
if latestEvents != nil {
|
||||
latestEvents.DiscardIgnoredMessages(c.ShouldIgnore)
|
||||
urd.RequestedLatestEvents = *latestEvents
|
||||
}
|
||||
result[requestedRoomID] = urd
|
||||
@ -328,7 +346,10 @@ func (c *UserCache) LoadRoomData(roomID string) UserRoomData {
|
||||
}
|
||||
|
||||
type roomUpdateCache struct {
|
||||
roomID string
|
||||
roomID string
|
||||
// globalRoomData is a snapshot of the global metadata for this room immediately
|
||||
// after this update. It is a copy, specific to the given user whose Heroes
|
||||
// field can be freely modified.
|
||||
globalRoomData *internal.RoomMetadata
|
||||
userRoomData *UserRoomData
|
||||
}
|
||||
@ -393,8 +414,16 @@ func (c *UserCache) AnnotateWithTransactionIDs(ctx context.Context, userID strin
|
||||
i int
|
||||
})
|
||||
for roomID, events := range roomIDToEvents {
|
||||
for i, ev := range events {
|
||||
evID := gjson.GetBytes(ev, "event_id").Str
|
||||
for i, evJSON := range events {
|
||||
ev := gjson.ParseBytes(evJSON)
|
||||
evID := ev.Get("event_id").Str
|
||||
sender := ev.Get("sender").Str
|
||||
if sender != userID {
|
||||
// don't ask for txn IDs for events which weren't sent by us.
|
||||
// If we do, we'll needlessly hit the database, increasing latencies when
|
||||
// catching up from the live buffer.
|
||||
continue
|
||||
}
|
||||
eventIDs = append(eventIDs, evID)
|
||||
eventIDToEvent[evID] = struct {
|
||||
roomID string
|
||||
@ -405,6 +434,10 @@ func (c *UserCache) AnnotateWithTransactionIDs(ctx context.Context, userID strin
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(eventIDs) == 0 {
|
||||
// don't do any work if we have no events
|
||||
return roomIDToEvents
|
||||
}
|
||||
eventIDToTxnID := c.txnIDs.TransactionIDForEvents(userID, deviceID, eventIDs)
|
||||
for eventID, txnID := range eventIDToTxnID {
|
||||
data, ok := eventIDToEvent[eventID]
|
||||
@ -578,7 +611,7 @@ func (c *UserCache) OnInvite(ctx context.Context, roomID string, inviteStateEven
|
||||
c.emitOnRoomUpdate(ctx, up)
|
||||
}
|
||||
|
||||
func (c *UserCache) OnLeftRoom(ctx context.Context, roomID string) {
|
||||
func (c *UserCache) OnLeftRoom(ctx context.Context, roomID string, leaveEvent json.RawMessage) {
|
||||
urd := c.LoadRoomData(roomID)
|
||||
urd.IsInvite = false
|
||||
urd.HasLeft = true
|
||||
@ -588,7 +621,10 @@ func (c *UserCache) OnLeftRoom(ctx context.Context, roomID string) {
|
||||
c.roomToData[roomID] = urd
|
||||
c.roomToDataMu.Unlock()
|
||||
|
||||
up := &LeftRoomUpdate{
|
||||
ev := gjson.ParseBytes(leaveEvent)
|
||||
stateKey := ev.Get("state_key").Str
|
||||
|
||||
up := &RoomEventUpdate{
|
||||
RoomUpdate: &roomUpdateCache{
|
||||
roomID: roomID,
|
||||
// do NOT pull from the global cache as it is a snapshot of the room at the point of
|
||||
@ -596,6 +632,18 @@ func (c *UserCache) OnLeftRoom(ctx context.Context, roomID string) {
|
||||
globalRoomData: internal.NewRoomMetadata(roomID),
|
||||
userRoomData: &urd,
|
||||
},
|
||||
EventData: &EventData{
|
||||
Event: leaveEvent,
|
||||
RoomID: roomID,
|
||||
EventType: ev.Get("type").Str,
|
||||
StateKey: &stateKey,
|
||||
Content: ev.Get("content"),
|
||||
Timestamp: ev.Get("origin_server_ts").Uint(),
|
||||
Sender: ev.Get("sender").Str,
|
||||
// if this is an invite rejection we need to make sure we tell the client, and not
|
||||
// skip it because of the lack of a NID (this event may not be in the events table)
|
||||
AlwaysProcess: true,
|
||||
},
|
||||
}
|
||||
c.emitOnRoomUpdate(ctx, up)
|
||||
}
|
||||
@ -608,7 +656,8 @@ func (c *UserCache) OnAccountData(ctx context.Context, datas []state.AccountData
|
||||
up := roomUpdates[d.RoomID]
|
||||
up = append(up, d)
|
||||
roomUpdates[d.RoomID] = up
|
||||
if d.Type == "m.direct" {
|
||||
switch d.Type {
|
||||
case "m.direct":
|
||||
dmRoomSet := make(map[string]struct{})
|
||||
// pull out rooms and mark them as DMs
|
||||
content := gjson.ParseBytes(d.Data).Get("content")
|
||||
@ -633,7 +682,7 @@ func (c *UserCache) OnAccountData(ctx context.Context, datas []state.AccountData
|
||||
c.roomToData[dmRoomID] = u
|
||||
}
|
||||
c.roomToDataMu.Unlock()
|
||||
} else if d.Type == "m.tag" {
|
||||
case "m.tag":
|
||||
content := gjson.ParseBytes(d.Data).Get("content.tags")
|
||||
if tagUpdates[d.RoomID] == nil {
|
||||
tagUpdates[d.RoomID] = make(map[string]float64)
|
||||
@ -642,6 +691,22 @@ func (c *UserCache) OnAccountData(ctx context.Context, datas []state.AccountData
|
||||
tagUpdates[d.RoomID][k.Str] = v.Get("order").Float()
|
||||
return true
|
||||
})
|
||||
case "m.ignored_user_list":
|
||||
if d.RoomID != state.AccountDataGlobalRoom {
|
||||
continue
|
||||
}
|
||||
content := gjson.ParseBytes(d.Data).Get("content.ignored_users")
|
||||
if !content.IsObject() {
|
||||
continue
|
||||
}
|
||||
ignoredUsers := make(map[string]struct{})
|
||||
content.ForEach(func(k, v gjson.Result) bool {
|
||||
ignoredUsers[k.Str] = struct{}{}
|
||||
return true
|
||||
})
|
||||
c.ignoredUsersMu.Lock()
|
||||
c.ignoredUsers = ignoredUsers
|
||||
c.ignoredUsersMu.Unlock()
|
||||
}
|
||||
}
|
||||
if len(tagUpdates) > 0 {
|
||||
@ -674,3 +739,10 @@ func (c *UserCache) OnAccountData(ctx context.Context, datas []state.AccountData
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (u *UserCache) ShouldIgnore(userID string) bool {
|
||||
u.ignoredUsersMu.RLock()
|
||||
defer u.ignoredUsersMu.RUnlock()
|
||||
_, ignored := u.ignoredUsers[userID]
|
||||
return ignored
|
||||
}
|
||||
|
@ -83,8 +83,8 @@ func TestAnnotateWithTransactionIDs(t *testing.T) {
|
||||
data: tc.eventIDToTxnIDs,
|
||||
}
|
||||
uc := caches.NewUserCache(userID, nil, nil, fetcher)
|
||||
got := uc.AnnotateWithTransactionIDs(context.Background(), userID, "DEVICE", convertIDToEventStub(tc.roomIDToEvents))
|
||||
want := convertIDTxnToEventStub(tc.wantRoomIDToEvents)
|
||||
got := uc.AnnotateWithTransactionIDs(context.Background(), userID, "DEVICE", convertIDToEventStub(userID, tc.roomIDToEvents))
|
||||
want := convertIDTxnToEventStub(userID, tc.wantRoomIDToEvents)
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Errorf("%s : got %v want %v", tc.name, js(got), js(want))
|
||||
}
|
||||
@ -96,27 +96,27 @@ func js(in interface{}) string {
|
||||
return string(b)
|
||||
}
|
||||
|
||||
func convertIDToEventStub(roomToEventIDs map[string][]string) map[string][]json.RawMessage {
|
||||
func convertIDToEventStub(sender string, roomToEventIDs map[string][]string) map[string][]json.RawMessage {
|
||||
result := make(map[string][]json.RawMessage)
|
||||
for roomID, eventIDs := range roomToEventIDs {
|
||||
events := make([]json.RawMessage, len(eventIDs))
|
||||
for i := range eventIDs {
|
||||
events[i] = json.RawMessage(fmt.Sprintf(`{"event_id":"%s","type":"x"}`, eventIDs[i]))
|
||||
events[i] = json.RawMessage(fmt.Sprintf(`{"event_id":"%s","type":"x","sender":"%s"}`, eventIDs[i], sender))
|
||||
}
|
||||
result[roomID] = events
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func convertIDTxnToEventStub(roomToEventIDs map[string][][2]string) map[string][]json.RawMessage {
|
||||
func convertIDTxnToEventStub(sender string, roomToEventIDs map[string][][2]string) map[string][]json.RawMessage {
|
||||
result := make(map[string][]json.RawMessage)
|
||||
for roomID, eventIDs := range roomToEventIDs {
|
||||
events := make([]json.RawMessage, len(eventIDs))
|
||||
for i := range eventIDs {
|
||||
if eventIDs[i][1] == "" {
|
||||
events[i] = json.RawMessage(fmt.Sprintf(`{"event_id":"%s","type":"x"}`, eventIDs[i][0]))
|
||||
events[i] = json.RawMessage(fmt.Sprintf(`{"event_id":"%s","type":"x","sender":"%s"}`, eventIDs[i][0], sender))
|
||||
} else {
|
||||
events[i] = json.RawMessage(fmt.Sprintf(`{"event_id":"%s","type":"x","unsigned":{"transaction_id":"%s"}}`, eventIDs[i][0], eventIDs[i][1]))
|
||||
events[i] = json.RawMessage(fmt.Sprintf(`{"event_id":"%s","type":"x","sender":"%s","unsigned":{"transaction_id":"%s"}}`, eventIDs[i][0], sender, eventIDs[i][1]))
|
||||
}
|
||||
}
|
||||
result[roomID] = events
|
||||
|
@ -14,7 +14,7 @@ import (
|
||||
// The amount of time to artificially wait if the server detects spamming clients. This time will
|
||||
// be added to responses when the server detects the same request being sent over and over e.g
|
||||
// /sync?pos=5 then /sync?pos=5 over and over. Likewise /sync without a ?pos=.
|
||||
var SpamProtectionInterval = time.Second
|
||||
var SpamProtectionInterval = 10 * time.Millisecond
|
||||
|
||||
type ConnID struct {
|
||||
UserID string
|
||||
@ -30,8 +30,9 @@ type ConnHandler interface {
|
||||
// Callback which is allowed to block as long as the context is active. Return the response
|
||||
// to send back or an error. Errors of type *internal.HandlerError are inspected for the correct
|
||||
// status code to send back.
|
||||
OnIncomingRequest(ctx context.Context, cid ConnID, req *Request, isInitial bool) (*Response, error)
|
||||
OnIncomingRequest(ctx context.Context, cid ConnID, req *Request, isInitial bool, start time.Time) (*Response, error)
|
||||
OnUpdate(ctx context.Context, update caches.Update)
|
||||
PublishEventsUpTo(roomID string, nid int64)
|
||||
Destroy()
|
||||
Alive() bool
|
||||
}
|
||||
@ -88,7 +89,7 @@ func (c *Conn) OnUpdate(ctx context.Context, update caches.Update) {
|
||||
// upwards but will NOT be logged to Sentry (neither here nor by the caller). Errors
|
||||
// should be reported to Sentry as close as possible to the point of creating the error,
|
||||
// to provide the best possible Sentry traceback.
|
||||
func (c *Conn) tryRequest(ctx context.Context, req *Request) (res *Response, err error) {
|
||||
func (c *Conn) tryRequest(ctx context.Context, req *Request, start time.Time) (res *Response, err error) {
|
||||
// TODO: include useful information from the request in the sentry hub/context
|
||||
// Might be better done in the caller though?
|
||||
defer func() {
|
||||
@ -116,7 +117,7 @@ func (c *Conn) tryRequest(ctx context.Context, req *Request) (res *Response, err
|
||||
ctx, task := internal.StartTask(ctx, taskType)
|
||||
defer task.End()
|
||||
internal.Logf(ctx, "connstate", "starting user=%v device=%v pos=%v", c.UserID, c.ConnID.DeviceID, req.pos)
|
||||
return c.handler.OnIncomingRequest(ctx, c.ConnID, req, req.pos == 0)
|
||||
return c.handler.OnIncomingRequest(ctx, c.ConnID, req, req.pos == 0, start)
|
||||
}
|
||||
|
||||
func (c *Conn) isOutstanding(pos int64) bool {
|
||||
@ -132,7 +133,7 @@ func (c *Conn) isOutstanding(pos int64) bool {
|
||||
// If an error is returned, it will be logged by the caller and transmitted to the
|
||||
// client. It will NOT be reported to Sentry---this should happen as close as possible
|
||||
// to the creation of the error (or else Sentry cannot provide a meaningful traceback.)
|
||||
func (c *Conn) OnIncomingRequest(ctx context.Context, req *Request) (resp *Response, herr *internal.HandlerError) {
|
||||
func (c *Conn) OnIncomingRequest(ctx context.Context, req *Request, start time.Time) (resp *Response, herr *internal.HandlerError) {
|
||||
c.cancelOutstandingRequestMu.Lock()
|
||||
if c.cancelOutstandingRequest != nil {
|
||||
c.cancelOutstandingRequest()
|
||||
@ -217,7 +218,7 @@ func (c *Conn) OnIncomingRequest(ctx context.Context, req *Request) (resp *Respo
|
||||
req.SetTimeoutMSecs(1)
|
||||
}
|
||||
|
||||
resp, err := c.tryRequest(ctx, req)
|
||||
resp, err := c.tryRequest(ctx, req, start)
|
||||
if err != nil {
|
||||
herr, ok := err.(*internal.HandlerError)
|
||||
if !ok {
|
||||
|
@ -16,7 +16,7 @@ type connHandlerMock struct {
|
||||
fn func(ctx context.Context, cid ConnID, req *Request, isInitial bool) (*Response, error)
|
||||
}
|
||||
|
||||
func (c *connHandlerMock) OnIncomingRequest(ctx context.Context, cid ConnID, req *Request, init bool) (*Response, error) {
|
||||
func (c *connHandlerMock) OnIncomingRequest(ctx context.Context, cid ConnID, req *Request, init bool, start time.Time) (*Response, error) {
|
||||
return c.fn(ctx, cid, req, init)
|
||||
}
|
||||
func (c *connHandlerMock) UserID() string {
|
||||
@ -25,6 +25,7 @@ func (c *connHandlerMock) UserID() string {
|
||||
func (c *connHandlerMock) Destroy() {}
|
||||
func (c *connHandlerMock) Alive() bool { return true }
|
||||
func (c *connHandlerMock) OnUpdate(ctx context.Context, update caches.Update) {}
|
||||
func (c *connHandlerMock) PublishEventsUpTo(roomID string, nid int64) {}
|
||||
|
||||
// Test that Conn can send and receive requests based on positions
|
||||
func TestConn(t *testing.T) {
|
||||
@ -47,7 +48,7 @@ func TestConn(t *testing.T) {
|
||||
// initial request
|
||||
resp, err := c.OnIncomingRequest(ctx, &Request{
|
||||
pos: 0,
|
||||
})
|
||||
}, time.Now())
|
||||
assertNoError(t, err)
|
||||
assertPos(t, resp.Pos, 1)
|
||||
assertInt(t, resp.Lists["a"].Count, 101)
|
||||
@ -55,14 +56,14 @@ func TestConn(t *testing.T) {
|
||||
// happy case, pos=1
|
||||
resp, err = c.OnIncomingRequest(ctx, &Request{
|
||||
pos: 1,
|
||||
})
|
||||
}, time.Now())
|
||||
assertPos(t, resp.Pos, 2)
|
||||
assertInt(t, resp.Lists["a"].Count, 102)
|
||||
assertNoError(t, err)
|
||||
// bogus position returns a 400
|
||||
_, err = c.OnIncomingRequest(ctx, &Request{
|
||||
pos: 31415,
|
||||
})
|
||||
}, time.Now())
|
||||
if err == nil {
|
||||
t.Fatalf("expected error, got none")
|
||||
}
|
||||
@ -106,7 +107,7 @@ func TestConnBlocking(t *testing.T) {
|
||||
Sort: []string{"hi"},
|
||||
},
|
||||
},
|
||||
})
|
||||
}, time.Now())
|
||||
}()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
@ -118,7 +119,7 @@ func TestConnBlocking(t *testing.T) {
|
||||
Sort: []string{"hi2"},
|
||||
},
|
||||
},
|
||||
})
|
||||
}, time.Now())
|
||||
}()
|
||||
go func() {
|
||||
wg.Wait()
|
||||
@ -148,18 +149,18 @@ func TestConnRetries(t *testing.T) {
|
||||
},
|
||||
}}, nil
|
||||
}})
|
||||
resp, err := c.OnIncomingRequest(ctx, &Request{})
|
||||
resp, err := c.OnIncomingRequest(ctx, &Request{}, time.Now())
|
||||
assertPos(t, resp.Pos, 1)
|
||||
assertInt(t, resp.Lists["a"].Count, 20)
|
||||
assertInt(t, callCount, 1)
|
||||
assertNoError(t, err)
|
||||
resp, err = c.OnIncomingRequest(ctx, &Request{pos: 1})
|
||||
resp, err = c.OnIncomingRequest(ctx, &Request{pos: 1}, time.Now())
|
||||
assertPos(t, resp.Pos, 2)
|
||||
assertInt(t, resp.Lists["a"].Count, 20)
|
||||
assertInt(t, callCount, 2)
|
||||
assertNoError(t, err)
|
||||
// retry! Shouldn't invoke handler again
|
||||
resp, err = c.OnIncomingRequest(ctx, &Request{pos: 1})
|
||||
resp, err = c.OnIncomingRequest(ctx, &Request{pos: 1}, time.Now())
|
||||
assertPos(t, resp.Pos, 2)
|
||||
assertInt(t, resp.Lists["a"].Count, 20)
|
||||
assertInt(t, callCount, 2) // this doesn't increment
|
||||
@ -170,7 +171,7 @@ func TestConnRetries(t *testing.T) {
|
||||
"a": {
|
||||
Sort: []string{SortByName},
|
||||
},
|
||||
}})
|
||||
}}, time.Now())
|
||||
assertPos(t, resp.Pos, 2)
|
||||
assertInt(t, resp.Lists["a"].Count, 20)
|
||||
assertInt(t, callCount, 3) // this doesn't increment
|
||||
@ -191,25 +192,25 @@ func TestConnBufferRes(t *testing.T) {
|
||||
},
|
||||
}}, nil
|
||||
}})
|
||||
resp, err := c.OnIncomingRequest(ctx, &Request{})
|
||||
resp, err := c.OnIncomingRequest(ctx, &Request{}, time.Now())
|
||||
assertNoError(t, err)
|
||||
assertPos(t, resp.Pos, 1)
|
||||
assertInt(t, resp.Lists["a"].Count, 1)
|
||||
assertInt(t, callCount, 1)
|
||||
resp, err = c.OnIncomingRequest(ctx, &Request{pos: 1})
|
||||
resp, err = c.OnIncomingRequest(ctx, &Request{pos: 1}, time.Now())
|
||||
assertNoError(t, err)
|
||||
assertPos(t, resp.Pos, 2)
|
||||
assertInt(t, resp.Lists["a"].Count, 2)
|
||||
assertInt(t, callCount, 2)
|
||||
// retry with modified request data that shouldn't prompt data to be returned.
|
||||
// should invoke handler again!
|
||||
resp, err = c.OnIncomingRequest(ctx, &Request{pos: 1, UnsubscribeRooms: []string{"a"}})
|
||||
resp, err = c.OnIncomingRequest(ctx, &Request{pos: 1, UnsubscribeRooms: []string{"a"}}, time.Now())
|
||||
assertNoError(t, err)
|
||||
assertPos(t, resp.Pos, 2)
|
||||
assertInt(t, resp.Lists["a"].Count, 2)
|
||||
assertInt(t, callCount, 3) // this DOES increment, the response is buffered and not returned yet.
|
||||
// retry with same request body, so should NOT invoke handler again and return buffered response
|
||||
resp, err = c.OnIncomingRequest(ctx, &Request{pos: 2, UnsubscribeRooms: []string{"a"}})
|
||||
resp, err = c.OnIncomingRequest(ctx, &Request{pos: 2, UnsubscribeRooms: []string{"a"}}, time.Now())
|
||||
assertNoError(t, err)
|
||||
assertPos(t, resp.Pos, 3)
|
||||
assertInt(t, resp.Lists["a"].Count, 3)
|
||||
@ -228,7 +229,7 @@ func TestConnErrors(t *testing.T) {
|
||||
|
||||
// random errors = 500
|
||||
errCh <- errors.New("oops")
|
||||
_, herr := c.OnIncomingRequest(ctx, &Request{})
|
||||
_, herr := c.OnIncomingRequest(ctx, &Request{}, time.Now())
|
||||
if herr.StatusCode != 500 {
|
||||
t.Fatalf("random errors should be status 500, got %d", herr.StatusCode)
|
||||
}
|
||||
@ -237,7 +238,7 @@ func TestConnErrors(t *testing.T) {
|
||||
StatusCode: 400,
|
||||
Err: errors.New("no way!"),
|
||||
}
|
||||
_, herr = c.OnIncomingRequest(ctx, &Request{})
|
||||
_, herr = c.OnIncomingRequest(ctx, &Request{}, time.Now())
|
||||
if herr.StatusCode != 400 {
|
||||
t.Fatalf("expected status 400, got %d", herr.StatusCode)
|
||||
}
|
||||
@ -258,7 +259,7 @@ func TestConnErrorsNoCache(t *testing.T) {
|
||||
}
|
||||
}})
|
||||
// errors should not be cached
|
||||
resp, herr := c.OnIncomingRequest(ctx, &Request{})
|
||||
resp, herr := c.OnIncomingRequest(ctx, &Request{}, time.Now())
|
||||
if herr != nil {
|
||||
t.Fatalf("expected no error, got %+v", herr)
|
||||
}
|
||||
@ -267,12 +268,12 @@ func TestConnErrorsNoCache(t *testing.T) {
|
||||
StatusCode: 400,
|
||||
Err: errors.New("no way!"),
|
||||
}
|
||||
_, herr = c.OnIncomingRequest(ctx, &Request{pos: resp.PosInt()})
|
||||
_, herr = c.OnIncomingRequest(ctx, &Request{pos: resp.PosInt()}, time.Now())
|
||||
if herr.StatusCode != 400 {
|
||||
t.Fatalf("expected status 400, got %d", herr.StatusCode)
|
||||
}
|
||||
// but doing the exact same request should now work
|
||||
_, herr = c.OnIncomingRequest(ctx, &Request{pos: resp.PosInt()})
|
||||
_, herr = c.OnIncomingRequest(ctx, &Request{pos: resp.PosInt()}, time.Now())
|
||||
if herr != nil {
|
||||
t.Fatalf("expected no error, got %+v", herr)
|
||||
}
|
||||
@ -361,7 +362,7 @@ func TestConnBufferRememberInflight(t *testing.T) {
|
||||
var err *internal.HandlerError
|
||||
for i, step := range steps {
|
||||
t.Logf("Executing step %d", i)
|
||||
resp, err = c.OnIncomingRequest(ctx, step.req)
|
||||
resp, err = c.OnIncomingRequest(ctx, step.req, time.Now())
|
||||
if !step.wantErr {
|
||||
assertNoError(t, err)
|
||||
}
|
||||
|
@ -5,6 +5,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/ReneKroon/ttlcache/v2"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
)
|
||||
|
||||
// ConnMap stores a collection of Conns.
|
||||
@ -15,10 +16,15 @@ type ConnMap struct {
|
||||
userIDToConn map[string][]*Conn
|
||||
connIDToConn map[string]*Conn
|
||||
|
||||
numConns prometheus.Gauge
|
||||
// counters for reasons why connections have expired
|
||||
expiryTimedOutCounter prometheus.Counter
|
||||
expiryBufferFullCounter prometheus.Counter
|
||||
|
||||
mu *sync.Mutex
|
||||
}
|
||||
|
||||
func NewConnMap() *ConnMap {
|
||||
func NewConnMap(enablePrometheus bool) *ConnMap {
|
||||
cm := &ConnMap{
|
||||
userIDToConn: make(map[string][]*Conn),
|
||||
connIDToConn: make(map[string]*Conn),
|
||||
@ -27,17 +33,61 @@ func NewConnMap() *ConnMap {
|
||||
}
|
||||
cm.cache.SetTTL(30 * time.Minute) // TODO: customisable
|
||||
cm.cache.SetExpirationCallback(cm.closeConnExpires)
|
||||
|
||||
if enablePrometheus {
|
||||
cm.expiryTimedOutCounter = prometheus.NewCounter(prometheus.CounterOpts{
|
||||
Namespace: "sliding_sync",
|
||||
Subsystem: "api",
|
||||
Name: "expiry_conn_timed_out",
|
||||
Help: "Counter of expired API connections due to reaching TTL limit",
|
||||
})
|
||||
prometheus.MustRegister(cm.expiryTimedOutCounter)
|
||||
cm.expiryBufferFullCounter = prometheus.NewCounter(prometheus.CounterOpts{
|
||||
Namespace: "sliding_sync",
|
||||
Subsystem: "api",
|
||||
Name: "expiry_conn_buffer_full",
|
||||
Help: "Counter of expired API connections due to reaching buffer update limit",
|
||||
})
|
||||
prometheus.MustRegister(cm.expiryBufferFullCounter)
|
||||
cm.numConns = prometheus.NewGauge(prometheus.GaugeOpts{
|
||||
Namespace: "sliding_sync",
|
||||
Subsystem: "api",
|
||||
Name: "num_active_conns",
|
||||
Help: "Number of active sliding sync connections.",
|
||||
})
|
||||
prometheus.MustRegister(cm.numConns)
|
||||
}
|
||||
return cm
|
||||
}
|
||||
|
||||
func (m *ConnMap) Teardown() {
|
||||
m.cache.Close()
|
||||
|
||||
if m.numConns != nil {
|
||||
prometheus.Unregister(m.numConns)
|
||||
}
|
||||
if m.expiryBufferFullCounter != nil {
|
||||
prometheus.Unregister(m.expiryBufferFullCounter)
|
||||
}
|
||||
if m.expiryTimedOutCounter != nil {
|
||||
prometheus.Unregister(m.expiryTimedOutCounter)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *ConnMap) Len() int {
|
||||
// UpdateMetrics recalculates the number of active connections. Do this when you think there is a change.
|
||||
func (m *ConnMap) UpdateMetrics() {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
return len(m.connIDToConn)
|
||||
m.updateMetrics(len(m.connIDToConn))
|
||||
}
|
||||
|
||||
// updateMetrics is like UpdateMetrics but doesn't touch connIDToConn and hence doesn't need to lock. We use this internally
|
||||
// when we need to update the metric and already have the lock held, as calling UpdateMetrics would deadlock.
|
||||
func (m *ConnMap) updateMetrics(numConns int) {
|
||||
if m.numConns == nil {
|
||||
return
|
||||
}
|
||||
m.numConns.Set(float64(numConns))
|
||||
}
|
||||
|
||||
// Conns return all connections for this user|device
|
||||
@ -55,6 +105,14 @@ func (m *ConnMap) Conns(userID, deviceID string) []*Conn {
|
||||
|
||||
// Conn returns a connection with this ConnID. Returns nil if no connection exists.
|
||||
func (m *ConnMap) Conn(cid ConnID) *Conn {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
return m.getConn(cid)
|
||||
}
|
||||
|
||||
// getConn returns a connection with this ConnID. Returns nil if no connection exists. Expires connections if the buffer is full.
|
||||
// Must hold mu.
|
||||
func (m *ConnMap) getConn(cid ConnID) *Conn {
|
||||
cint, _ := m.cache.Get(cid.String())
|
||||
if cint == nil {
|
||||
return nil
|
||||
@ -64,8 +122,11 @@ func (m *ConnMap) Conn(cid ConnID) *Conn {
|
||||
return conn
|
||||
}
|
||||
// e.g buffer exceeded, close it and remove it from the cache
|
||||
logger.Trace().Str("conn", cid.String()).Msg("closing connection due to dead connection (buffer full)")
|
||||
logger.Info().Str("conn", cid.String()).Msg("closing connection due to dead connection (buffer full)")
|
||||
m.closeConn(conn)
|
||||
if m.expiryBufferFullCounter != nil {
|
||||
m.expiryBufferFullCounter.Inc()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -74,7 +135,7 @@ func (m *ConnMap) CreateConn(cid ConnID, newConnHandler func() ConnHandler) (*Co
|
||||
// atomically check if a conn exists already and nuke it if it exists
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
conn := m.Conn(cid)
|
||||
conn := m.getConn(cid)
|
||||
if conn != nil {
|
||||
// tear down this connection and fallthrough
|
||||
isSpamming := conn.lastPos <= 1
|
||||
@ -92,6 +153,7 @@ func (m *ConnMap) CreateConn(cid ConnID, newConnHandler func() ConnHandler) (*Co
|
||||
m.cache.Set(cid.String(), conn)
|
||||
m.connIDToConn[cid.String()] = conn
|
||||
m.userIDToConn[cid.UserID] = append(m.userIDToConn[cid.UserID], conn)
|
||||
m.updateMetrics(len(m.connIDToConn))
|
||||
return conn, true
|
||||
}
|
||||
|
||||
@ -121,7 +183,10 @@ func (m *ConnMap) closeConnExpires(connID string, value interface{}) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
conn := value.(*Conn)
|
||||
logger.Trace().Str("conn", connID).Msg("closing connection due to expired TTL in cache")
|
||||
logger.Info().Str("conn", connID).Msg("closing connection due to expired TTL in cache")
|
||||
if m.expiryTimedOutCounter != nil {
|
||||
m.expiryTimedOutCounter.Inc()
|
||||
}
|
||||
m.closeConn(conn)
|
||||
}
|
||||
|
||||
@ -147,4 +212,14 @@ func (m *ConnMap) closeConn(conn *Conn) {
|
||||
m.userIDToConn[conn.UserID] = conns
|
||||
// remove user cache listeners etc
|
||||
h.Destroy()
|
||||
m.updateMetrics(len(m.connIDToConn))
|
||||
}
|
||||
|
||||
func (m *ConnMap) ClearUpdateQueues(userID, roomID string, nid int64) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
for _, conn := range m.userIDToConn[userID] {
|
||||
conn.handler.PublishEventsUpTo(roomID, nid)
|
||||
}
|
||||
}
|
||||
|
@ -87,14 +87,15 @@ func (d *Dispatcher) newEventData(event json.RawMessage, roomID string, latestPo
|
||||
eventType := ev.Get("type").Str
|
||||
|
||||
return &caches.EventData{
|
||||
Event: event,
|
||||
RoomID: roomID,
|
||||
EventType: eventType,
|
||||
StateKey: stateKey,
|
||||
Content: ev.Get("content"),
|
||||
NID: latestPos,
|
||||
Timestamp: ev.Get("origin_server_ts").Uint(),
|
||||
Sender: ev.Get("sender").Str,
|
||||
Event: event,
|
||||
RoomID: roomID,
|
||||
EventType: eventType,
|
||||
StateKey: stateKey,
|
||||
Content: ev.Get("content"),
|
||||
NID: latestPos,
|
||||
Timestamp: ev.Get("origin_server_ts").Uint(),
|
||||
Sender: ev.Get("sender").Str,
|
||||
TransactionID: ev.Get("unsigned.transaction_id").Str,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -42,7 +42,8 @@ type ConnState struct {
|
||||
// roomID -> latest load pos
|
||||
loadPositions map[string]int64
|
||||
|
||||
live *connStateLive
|
||||
txnIDWaiter *TxnIDWaiter
|
||||
live *connStateLive
|
||||
|
||||
globalCache *caches.GlobalCache
|
||||
userCache *caches.UserCache
|
||||
@ -52,13 +53,14 @@ type ConnState struct {
|
||||
joinChecker JoinChecker
|
||||
|
||||
extensionsHandler extensions.HandlerInterface
|
||||
setupHistogramVec *prometheus.HistogramVec
|
||||
processHistogramVec *prometheus.HistogramVec
|
||||
}
|
||||
|
||||
func NewConnState(
|
||||
userID, deviceID string, userCache *caches.UserCache, globalCache *caches.GlobalCache,
|
||||
ex extensions.HandlerInterface, joinChecker JoinChecker, histVec *prometheus.HistogramVec,
|
||||
maxPendingEventUpdates int,
|
||||
ex extensions.HandlerInterface, joinChecker JoinChecker, setupHistVec *prometheus.HistogramVec, histVec *prometheus.HistogramVec,
|
||||
maxPendingEventUpdates int, maxTransactionIDDelay time.Duration,
|
||||
) *ConnState {
|
||||
cs := &ConnState{
|
||||
globalCache: globalCache,
|
||||
@ -72,12 +74,20 @@ func NewConnState(
|
||||
extensionsHandler: ex,
|
||||
joinChecker: joinChecker,
|
||||
lazyCache: NewLazyCache(),
|
||||
setupHistogramVec: setupHistVec,
|
||||
processHistogramVec: histVec,
|
||||
}
|
||||
cs.live = &connStateLive{
|
||||
ConnState: cs,
|
||||
updates: make(chan caches.Update, maxPendingEventUpdates),
|
||||
}
|
||||
cs.txnIDWaiter = NewTxnIDWaiter(
|
||||
userID,
|
||||
maxTransactionIDDelay,
|
||||
func(delayed bool, update caches.Update) {
|
||||
cs.live.onUpdate(update)
|
||||
},
|
||||
)
|
||||
// subscribe for updates before loading. We risk seeing dupes but that's fine as load positions
|
||||
// will stop us double-processing.
|
||||
cs.userCacheID = cs.userCache.Subsribe(cs)
|
||||
@ -160,7 +170,7 @@ func (s *ConnState) load(ctx context.Context, req *sync3.Request) error {
|
||||
}
|
||||
|
||||
// OnIncomingRequest is guaranteed to be called sequentially (it's protected by a mutex in conn.go)
|
||||
func (s *ConnState) OnIncomingRequest(ctx context.Context, cid sync3.ConnID, req *sync3.Request, isInitial bool) (*sync3.Response, error) {
|
||||
func (s *ConnState) OnIncomingRequest(ctx context.Context, cid sync3.ConnID, req *sync3.Request, isInitial bool, start time.Time) (*sync3.Response, error) {
|
||||
if s.anchorLoadPosition <= 0 {
|
||||
// load() needs no ctx so drop it
|
||||
_, region := internal.StartSpan(ctx, "load")
|
||||
@ -172,45 +182,50 @@ func (s *ConnState) OnIncomingRequest(ctx context.Context, cid sync3.ConnID, req
|
||||
}
|
||||
region.End()
|
||||
}
|
||||
setupTime := time.Since(start)
|
||||
s.trackSetupDuration(setupTime, isInitial)
|
||||
return s.onIncomingRequest(ctx, req, isInitial)
|
||||
}
|
||||
|
||||
// onIncomingRequest is a callback which fires when the client makes a request to the server. Whilst each request may
|
||||
// be on their own goroutine, the requests are linearised for us by Conn so it is safe to modify ConnState without
|
||||
// additional locking mechanisms.
|
||||
func (s *ConnState) onIncomingRequest(ctx context.Context, req *sync3.Request, isInitial bool) (*sync3.Response, error) {
|
||||
func (s *ConnState) onIncomingRequest(reqCtx context.Context, req *sync3.Request, isInitial bool) (*sync3.Response, error) {
|
||||
start := time.Now()
|
||||
// ApplyDelta works fine if s.muxedReq is nil
|
||||
var delta *sync3.RequestDelta
|
||||
s.muxedReq, delta = s.muxedReq.ApplyDelta(req)
|
||||
internal.Logf(ctx, "connstate", "new subs=%v unsubs=%v num_lists=%v", len(delta.Subs), len(delta.Unsubs), len(delta.Lists))
|
||||
internal.Logf(reqCtx, "connstate", "new subs=%v unsubs=%v num_lists=%v", len(delta.Subs), len(delta.Unsubs), len(delta.Lists))
|
||||
for key, l := range delta.Lists {
|
||||
listData := ""
|
||||
if l.Curr != nil {
|
||||
listDataBytes, _ := json.Marshal(l.Curr)
|
||||
listData = string(listDataBytes)
|
||||
}
|
||||
internal.Logf(ctx, "connstate", "list[%v] prev_empty=%v curr=%v", key, l.Prev == nil, listData)
|
||||
internal.Logf(reqCtx, "connstate", "list[%v] prev_empty=%v curr=%v", key, l.Prev == nil, listData)
|
||||
}
|
||||
for roomID, sub := range s.muxedReq.RoomSubscriptions {
|
||||
internal.Logf(reqCtx, "connstate", "room sub[%v] %v", roomID, sub)
|
||||
}
|
||||
|
||||
// work out which rooms we'll return data for and add their relevant subscriptions to the builder
|
||||
// for it to mix together
|
||||
builder := NewRoomsBuilder()
|
||||
// works out which rooms are subscribed to but doesn't pull room data
|
||||
s.buildRoomSubscriptions(ctx, builder, delta.Subs, delta.Unsubs)
|
||||
s.buildRoomSubscriptions(reqCtx, builder, delta.Subs, delta.Unsubs)
|
||||
// works out how rooms get moved about but doesn't pull room data
|
||||
respLists := s.buildListSubscriptions(ctx, builder, delta.Lists)
|
||||
respLists := s.buildListSubscriptions(reqCtx, builder, delta.Lists)
|
||||
|
||||
// pull room data and set changes on the response
|
||||
response := &sync3.Response{
|
||||
Rooms: s.buildRooms(ctx, builder.BuildSubscriptions()), // pull room data
|
||||
Rooms: s.buildRooms(reqCtx, builder.BuildSubscriptions()), // pull room data
|
||||
Lists: respLists,
|
||||
}
|
||||
|
||||
// Handle extensions AFTER processing lists as extensions may need to know which rooms the client
|
||||
// is being notified about (e.g. for room account data)
|
||||
ctx, region := internal.StartSpan(ctx, "extensions")
|
||||
response.Extensions = s.extensionsHandler.Handle(ctx, s.muxedReq.Extensions, extensions.Context{
|
||||
extCtx, region := internal.StartSpan(reqCtx, "extensions")
|
||||
response.Extensions = s.extensionsHandler.Handle(extCtx, s.muxedReq.Extensions, extensions.Context{
|
||||
UserID: s.userID,
|
||||
DeviceID: s.deviceID,
|
||||
RoomIDToTimeline: response.RoomIDsToTimelineEventIDs(),
|
||||
@ -231,8 +246,8 @@ func (s *ConnState) onIncomingRequest(ctx context.Context, req *sync3.Request, i
|
||||
}
|
||||
|
||||
// do live tracking if we have nothing to tell the client yet
|
||||
ctx, region = internal.StartSpan(ctx, "liveUpdate")
|
||||
s.live.liveUpdate(ctx, req, s.muxedReq.Extensions, isInitial, response)
|
||||
updateCtx, region := internal.StartSpan(reqCtx, "liveUpdate")
|
||||
s.live.liveUpdate(updateCtx, req, s.muxedReq.Extensions, isInitial, response)
|
||||
region.End()
|
||||
|
||||
// counts are AFTER events are applied, hence after liveUpdate
|
||||
@ -245,7 +260,7 @@ func (s *ConnState) onIncomingRequest(ctx context.Context, req *sync3.Request, i
|
||||
// Add membership events for users sending typing notifications. We do this after live update
|
||||
// and initial room loading code so we LL room members in all cases.
|
||||
if response.Extensions.Typing != nil && response.Extensions.Typing.HasData(isInitial) {
|
||||
s.lazyLoadTypingMembers(ctx, response)
|
||||
s.lazyLoadTypingMembers(reqCtx, response)
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
@ -453,6 +468,12 @@ func (s *ConnState) buildRooms(ctx context.Context, builtSubs []BuiltSubscriptio
|
||||
ctx, span := internal.StartSpan(ctx, "buildRooms")
|
||||
defer span.End()
|
||||
result := make(map[string]sync3.Room)
|
||||
|
||||
var bumpEventTypes []string
|
||||
for _, x := range s.muxedReq.Lists {
|
||||
bumpEventTypes = append(bumpEventTypes, x.BumpEventTypes...)
|
||||
}
|
||||
|
||||
for _, bs := range builtSubs {
|
||||
roomIDs := bs.RoomIDs
|
||||
if bs.RoomSubscription.IncludeOldRooms != nil {
|
||||
@ -477,14 +498,23 @@ func (s *ConnState) buildRooms(ctx context.Context, builtSubs []BuiltSubscriptio
|
||||
}
|
||||
}
|
||||
}
|
||||
// old rooms use a different subscription
|
||||
oldRooms := s.getInitialRoomData(ctx, *bs.RoomSubscription.IncludeOldRooms, oldRoomIDs...)
|
||||
for oldRoomID, oldRoom := range oldRooms {
|
||||
result[oldRoomID] = oldRoom
|
||||
|
||||
// If we have old rooms to fetch, do so.
|
||||
if len(oldRoomIDs) > 0 {
|
||||
// old rooms use a different subscription
|
||||
oldRooms := s.getInitialRoomData(ctx, *bs.RoomSubscription.IncludeOldRooms, bumpEventTypes, oldRoomIDs...)
|
||||
for oldRoomID, oldRoom := range oldRooms {
|
||||
result[oldRoomID] = oldRoom
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
rooms := s.getInitialRoomData(ctx, bs.RoomSubscription, roomIDs...)
|
||||
// There won't be anything to fetch, try the next subscription.
|
||||
if len(roomIDs) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
rooms := s.getInitialRoomData(ctx, bs.RoomSubscription, bumpEventTypes, roomIDs...)
|
||||
for roomID, room := range rooms {
|
||||
result[roomID] = room
|
||||
}
|
||||
@ -521,7 +551,7 @@ func (s *ConnState) lazyLoadTypingMembers(ctx context.Context, response *sync3.R
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ConnState) getInitialRoomData(ctx context.Context, roomSub sync3.RoomSubscription, roomIDs ...string) map[string]sync3.Room {
|
||||
func (s *ConnState) getInitialRoomData(ctx context.Context, roomSub sync3.RoomSubscription, bumpEventTypes []string, roomIDs ...string) map[string]sync3.Room {
|
||||
ctx, span := internal.StartSpan(ctx, "getInitialRoomData")
|
||||
defer span.End()
|
||||
rooms := make(map[string]sync3.Room, len(roomIDs))
|
||||
@ -590,8 +620,40 @@ func (s *ConnState) getInitialRoomData(ctx context.Context, roomSub sync3.RoomSu
|
||||
requiredState = make([]json.RawMessage, 0)
|
||||
}
|
||||
}
|
||||
|
||||
// Get the highest timestamp, determined by bumpEventTypes,
|
||||
// for this room
|
||||
roomListsMeta := s.lists.ReadOnlyRoom(roomID)
|
||||
var maxTs uint64
|
||||
for _, t := range bumpEventTypes {
|
||||
if roomListsMeta == nil {
|
||||
break
|
||||
}
|
||||
|
||||
evMeta := roomListsMeta.LatestEventsByType[t]
|
||||
if evMeta.Timestamp > maxTs {
|
||||
maxTs = evMeta.Timestamp
|
||||
}
|
||||
}
|
||||
|
||||
// If we didn't find any events which would update the timestamp
|
||||
// use the join event timestamp instead. Also don't leak
|
||||
// timestamp from before we joined.
|
||||
if maxTs == 0 || maxTs < roomListsMeta.JoinTiming.Timestamp {
|
||||
if roomListsMeta != nil {
|
||||
maxTs = roomListsMeta.JoinTiming.Timestamp
|
||||
// If no bumpEventTypes are specified, use the
|
||||
// LastMessageTimestamp so clients are still able
|
||||
// to correctly sort on it.
|
||||
if len(bumpEventTypes) == 0 {
|
||||
maxTs = roomListsMeta.LastMessageTimestamp
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
rooms[roomID] = sync3.Room{
|
||||
Name: internal.CalculateRoomName(metadata, 5), // TODO: customisable?
|
||||
AvatarChange: sync3.NewAvatarChange(internal.CalculateAvatar(metadata)),
|
||||
NotificationCount: int64(userRoomData.NotificationCount),
|
||||
HighlightCount: int64(userRoomData.HighlightCount),
|
||||
Timeline: roomToTimeline[roomID],
|
||||
@ -600,8 +662,9 @@ func (s *ConnState) getInitialRoomData(ctx context.Context, roomSub sync3.RoomSu
|
||||
Initial: true,
|
||||
IsDM: userRoomData.IsDM,
|
||||
JoinedCount: metadata.JoinCount,
|
||||
InvitedCount: metadata.InviteCount,
|
||||
InvitedCount: &metadata.InviteCount,
|
||||
PrevBatch: userRoomData.RequestedLatestEvents.PrevBatch,
|
||||
Timestamp: maxTs,
|
||||
}
|
||||
}
|
||||
|
||||
@ -613,6 +676,17 @@ func (s *ConnState) getInitialRoomData(ctx context.Context, roomSub sync3.RoomSu
|
||||
return rooms
|
||||
}
|
||||
|
||||
func (s *ConnState) trackSetupDuration(dur time.Duration, isInitial bool) {
|
||||
if s.setupHistogramVec == nil {
|
||||
return
|
||||
}
|
||||
val := "0"
|
||||
if isInitial {
|
||||
val = "1"
|
||||
}
|
||||
s.setupHistogramVec.WithLabelValues(val).Observe(float64(dur.Seconds()))
|
||||
}
|
||||
|
||||
func (s *ConnState) trackProcessDuration(dur time.Duration, isInitial bool) {
|
||||
if s.processHistogramVec == nil {
|
||||
return
|
||||
@ -638,7 +712,8 @@ func (s *ConnState) UserID() string {
|
||||
}
|
||||
|
||||
func (s *ConnState) OnUpdate(ctx context.Context, up caches.Update) {
|
||||
s.live.onUpdate(up)
|
||||
// will eventually call s.live.onUpdate
|
||||
s.txnIDWaiter.Ingest(up)
|
||||
}
|
||||
|
||||
// Called by the user cache when updates arrive
|
||||
@ -654,15 +729,19 @@ func (s *ConnState) OnRoomUpdate(ctx context.Context, up caches.RoomUpdate) {
|
||||
}
|
||||
internal.AssertWithContext(ctx, "missing global room metadata", update.GlobalRoomMetadata() != nil)
|
||||
internal.Logf(ctx, "connstate", "queued update %d", update.EventData.NID)
|
||||
s.live.onUpdate(update)
|
||||
s.OnUpdate(ctx, update)
|
||||
case caches.RoomUpdate:
|
||||
internal.AssertWithContext(ctx, "missing global room metadata", update.GlobalRoomMetadata() != nil)
|
||||
s.live.onUpdate(update)
|
||||
s.OnUpdate(ctx, update)
|
||||
default:
|
||||
logger.Warn().Str("room_id", up.RoomID()).Msg("OnRoomUpdate unknown update type")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ConnState) PublishEventsUpTo(roomID string, nid int64) {
|
||||
s.txnIDWaiter.PublishUpToNID(roomID, nid)
|
||||
}
|
||||
|
||||
// clampSliceRangeToListSize helps us to send client-friendly SYNC and INVALIDATE ranges.
|
||||
//
|
||||
// Suppose the client asks for a window on positions [10, 19]. If the list
|
||||
|
@ -37,7 +37,7 @@ func (s *connStateLive) onUpdate(up caches.Update) {
|
||||
select {
|
||||
case s.updates <- up:
|
||||
case <-time.After(BufferWaitTime):
|
||||
logger.Warn().Interface("update", up).Str("user", s.userID).Msg(
|
||||
logger.Warn().Interface("update", up).Str("user", s.userID).Str("device", s.deviceID).Msg(
|
||||
"cannot send update to connection, buffer exceeded. Destroying connection.",
|
||||
)
|
||||
s.bufferFull = true
|
||||
@ -57,6 +57,7 @@ func (s *connStateLive) liveUpdate(
|
||||
if req.TimeoutMSecs() < 100 {
|
||||
req.SetTimeoutMSecs(100)
|
||||
}
|
||||
startBufferSize := len(s.updates)
|
||||
// block until we get a new event, with appropriate timeout
|
||||
startTime := time.Now()
|
||||
hasLiveStreamed := false
|
||||
@ -81,43 +82,75 @@ func (s *connStateLive) liveUpdate(
|
||||
return
|
||||
case update := <-s.updates:
|
||||
internal.Logf(ctx, "liveUpdate", "process live update")
|
||||
|
||||
s.processLiveUpdate(ctx, update, response)
|
||||
// pass event to extensions AFTER processing
|
||||
extCtx := extensions.Context{
|
||||
IsInitial: false,
|
||||
RoomIDToTimeline: response.RoomIDsToTimelineEventIDs(),
|
||||
UserID: s.userID,
|
||||
DeviceID: s.deviceID,
|
||||
RoomIDsToLists: s.lists.ListsByVisibleRoomIDs(s.muxedReq.Lists),
|
||||
AllSubscribedRooms: s.muxedReq.SubscribedRoomIDs(),
|
||||
AllLists: s.muxedReq.ListKeys(),
|
||||
}
|
||||
s.extensionsHandler.HandleLiveUpdate(ctx, update, ex, &response.Extensions, extCtx)
|
||||
s.processUpdate(ctx, update, response, ex)
|
||||
// if there's more updates and we don't have lots stacked up already, go ahead and process another
|
||||
for len(s.updates) > 0 && response.ListOps() < 50 {
|
||||
update = <-s.updates
|
||||
s.processLiveUpdate(ctx, update, response)
|
||||
s.extensionsHandler.HandleLiveUpdate(ctx, update, ex, &response.Extensions, extCtx)
|
||||
s.processUpdate(ctx, update, response, ex)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If a client constantly changes their request params in every request they make, we will never consume from
|
||||
// the update channel as the response will always have data already. In an effort to prevent starvation of new
|
||||
// data, we will process some updates even though we have data already, but only if A) we didn't live stream
|
||||
// due to natural circumstances, B) it isn't an initial request and C) there is in fact some data there.
|
||||
numQueuedUpdates := len(s.updates)
|
||||
if !hasLiveStreamed && !isInitial && numQueuedUpdates > 0 {
|
||||
for i := 0; i < numQueuedUpdates; i++ {
|
||||
update := <-s.updates
|
||||
s.processUpdate(ctx, update, response, ex)
|
||||
}
|
||||
log.Debug().Int("num_queued", numQueuedUpdates).Msg("liveUpdate: caught up")
|
||||
internal.Logf(ctx, "connstate", "liveUpdate caught up %d updates", numQueuedUpdates)
|
||||
}
|
||||
|
||||
log.Trace().Bool("live_streamed", hasLiveStreamed).Msg("liveUpdate: returning")
|
||||
|
||||
internal.SetConnBufferInfo(ctx, startBufferSize, len(s.updates), cap(s.updates))
|
||||
|
||||
// TODO: op consolidation
|
||||
}
|
||||
|
||||
func (s *connStateLive) processUpdate(ctx context.Context, update caches.Update, response *sync3.Response, ex extensions.Request) {
|
||||
internal.Logf(ctx, "liveUpdate", "process live update %s", update.Type())
|
||||
s.processLiveUpdate(ctx, update, response)
|
||||
// pass event to extensions AFTER processing
|
||||
roomIDsToLists := s.lists.ListsByVisibleRoomIDs(s.muxedReq.Lists)
|
||||
s.extensionsHandler.HandleLiveUpdate(ctx, update, ex, &response.Extensions, extensions.Context{
|
||||
IsInitial: false,
|
||||
RoomIDToTimeline: response.RoomIDsToTimelineEventIDs(),
|
||||
UserID: s.userID,
|
||||
DeviceID: s.deviceID,
|
||||
RoomIDsToLists: roomIDsToLists,
|
||||
AllSubscribedRooms: s.muxedReq.SubscribedRoomIDs(),
|
||||
AllLists: s.muxedReq.ListKeys(),
|
||||
})
|
||||
}
|
||||
|
||||
func (s *connStateLive) processLiveUpdate(ctx context.Context, up caches.Update, response *sync3.Response) bool {
|
||||
internal.AssertWithContext(ctx, "processLiveUpdate: response list length != internal list length", s.lists.Len() == len(response.Lists))
|
||||
internal.AssertWithContext(ctx, "processLiveUpdate: request list length != internal list length", s.lists.Len() == len(s.muxedReq.Lists))
|
||||
roomUpdate, _ := up.(caches.RoomUpdate)
|
||||
roomEventUpdate, _ := up.(*caches.RoomEventUpdate)
|
||||
// if this is a room event update we may not want to process this if the event nid is < loadPos,
|
||||
// as that means we have already taken it into account
|
||||
if roomEventUpdate != nil && !roomEventUpdate.EventData.AlwaysProcess {
|
||||
// check if we should skip this update. Do we know of this room (lp > 0) and if so, is this event
|
||||
// behind what we've processed before?
|
||||
lp := s.loadPositions[roomEventUpdate.RoomID()]
|
||||
if lp > 0 && roomEventUpdate.EventData.NID < lp {
|
||||
if roomEventUpdate != nil {
|
||||
// if this is a room event update we may not want to process this event, for a few reasons.
|
||||
if !roomEventUpdate.EventData.AlwaysProcess {
|
||||
// check if we should skip this update. Do we know of this room (lp > 0) and if so, is this event
|
||||
// behind what we've processed before?
|
||||
lp := s.loadPositions[roomEventUpdate.RoomID()]
|
||||
if lp > 0 && roomEventUpdate.EventData.NID < lp {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Skip message events from ignored users.
|
||||
if roomEventUpdate.EventData.StateKey == nil && s.userCache.ShouldIgnore(roomEventUpdate.EventData.Sender) {
|
||||
logger.Trace().
|
||||
Str("user", s.userID).
|
||||
Str("type", roomEventUpdate.EventData.EventType).
|
||||
Str("sender", roomEventUpdate.EventData.Sender).
|
||||
Msg("ignoring event update")
|
||||
return false
|
||||
}
|
||||
}
|
||||
@ -157,16 +190,42 @@ func (s *connStateLive) processLiveUpdate(ctx context.Context, up caches.Update,
|
||||
// include this update in the rooms response TODO: filters on event type?
|
||||
userRoomData := roomUpdate.UserRoomMetadata()
|
||||
r := response.Rooms[roomUpdate.RoomID()]
|
||||
|
||||
// Get the highest timestamp, determined by bumpEventTypes,
|
||||
// for this room
|
||||
roomListsMeta := s.lists.ReadOnlyRoom(roomUpdate.RoomID())
|
||||
var bumpEventTypes []string
|
||||
for _, list := range s.muxedReq.Lists {
|
||||
bumpEventTypes = append(bumpEventTypes, list.BumpEventTypes...)
|
||||
}
|
||||
for _, t := range bumpEventTypes {
|
||||
evMeta := roomListsMeta.LatestEventsByType[t]
|
||||
if evMeta.Timestamp > r.Timestamp {
|
||||
r.Timestamp = evMeta.Timestamp
|
||||
}
|
||||
}
|
||||
|
||||
// If there are no bumpEventTypes defined, use the last message timestamp
|
||||
if r.Timestamp == 0 && len(bumpEventTypes) == 0 {
|
||||
r.Timestamp = roomUpdate.GlobalRoomMetadata().LastMessageTimestamp
|
||||
}
|
||||
// Make sure we don't leak a timestamp from before we joined
|
||||
if r.Timestamp < roomListsMeta.JoinTiming.Timestamp {
|
||||
r.Timestamp = roomListsMeta.JoinTiming.Timestamp
|
||||
}
|
||||
|
||||
r.HighlightCount = int64(userRoomData.HighlightCount)
|
||||
r.NotificationCount = int64(userRoomData.NotificationCount)
|
||||
if roomEventUpdate != nil && roomEventUpdate.EventData.Event != nil {
|
||||
r.NumLive++
|
||||
advancedPastEvent := false
|
||||
if roomEventUpdate.EventData.NID <= s.loadPositions[roomEventUpdate.RoomID()] {
|
||||
// this update has been accounted for by the initial:true room snapshot
|
||||
advancedPastEvent = true
|
||||
if !roomEventUpdate.EventData.AlwaysProcess {
|
||||
if roomEventUpdate.EventData.NID <= s.loadPositions[roomEventUpdate.RoomID()] {
|
||||
// this update has been accounted for by the initial:true room snapshot
|
||||
advancedPastEvent = true
|
||||
}
|
||||
s.loadPositions[roomEventUpdate.RoomID()] = roomEventUpdate.EventData.NID
|
||||
}
|
||||
s.loadPositions[roomEventUpdate.RoomID()] = roomEventUpdate.EventData.NID
|
||||
// we only append to the timeline if we haven't already got this event. This can happen when:
|
||||
// - 2 live events for a room mid-connection
|
||||
// - next request bumps a room from outside to inside the window
|
||||
@ -203,8 +262,13 @@ func (s *connStateLive) processLiveUpdate(ctx context.Context, up caches.Update,
|
||||
metadata.RemoveHero(s.userID)
|
||||
thisRoom.Name = internal.CalculateRoomName(metadata, 5) // TODO: customisable?
|
||||
}
|
||||
if delta.RoomAvatarChanged {
|
||||
metadata := roomUpdate.GlobalRoomMetadata()
|
||||
metadata.RemoveHero(s.userID)
|
||||
thisRoom.AvatarChange = sync3.NewAvatarChange(internal.CalculateAvatar(metadata))
|
||||
}
|
||||
if delta.InviteCountChanged {
|
||||
thisRoom.InvitedCount = roomUpdate.GlobalRoomMetadata().InviteCount
|
||||
thisRoom.InvitedCount = &roomUpdate.GlobalRoomMetadata().InviteCount
|
||||
}
|
||||
if delta.JoinCountChanged {
|
||||
thisRoom.JoinedCount = roomUpdate.GlobalRoomMetadata().JoinCount
|
||||
@ -275,8 +339,17 @@ func (s *connStateLive) processGlobalUpdates(ctx context.Context, builder *Rooms
|
||||
}
|
||||
}
|
||||
|
||||
metadata := rup.GlobalRoomMetadata().CopyHeroes()
|
||||
metadata.RemoveHero(s.userID)
|
||||
// TODO: if we change a room from being a DM to not being a DM, we should call
|
||||
// SetRoom and recalculate avatars. To do that we'd need to
|
||||
// - listen to m.direct global account data events
|
||||
// - compute the symmetric difference between old and new
|
||||
// - call SetRooms for each room in the difference.
|
||||
// I'm assuming this happens so rarely that we can ignore this for now. PRs
|
||||
// welcome if you a strong opinion to the contrary.
|
||||
delta = s.lists.SetRoom(sync3.RoomConnMetadata{
|
||||
RoomMetadata: *rup.GlobalRoomMetadata(),
|
||||
RoomMetadata: *metadata,
|
||||
UserRoomData: *rup.UserRoomMetadata(),
|
||||
LastInterestedEventTimestamps: bumpTimestampInList,
|
||||
})
|
||||
|
@ -107,7 +107,7 @@ func TestConnStateInitial(t *testing.T) {
|
||||
}
|
||||
return result
|
||||
}
|
||||
cs := NewConnState(userID, deviceID, userCache, globalCache, &NopExtensionHandler{}, &NopJoinTracker{}, nil, 1000)
|
||||
cs := NewConnState(userID, deviceID, userCache, globalCache, &NopExtensionHandler{}, &NopJoinTracker{}, nil, nil, 1000, 0)
|
||||
if userID != cs.UserID() {
|
||||
t.Fatalf("UserID returned wrong value, got %v want %v", cs.UserID(), userID)
|
||||
}
|
||||
@ -118,7 +118,7 @@ func TestConnStateInitial(t *testing.T) {
|
||||
{0, 9},
|
||||
}),
|
||||
}},
|
||||
}, false)
|
||||
}, false, time.Now())
|
||||
if err != nil {
|
||||
t.Fatalf("OnIncomingRequest returned error : %s", err)
|
||||
}
|
||||
@ -168,7 +168,7 @@ func TestConnStateInitial(t *testing.T) {
|
||||
{0, 9},
|
||||
}),
|
||||
}},
|
||||
}, false)
|
||||
}, false, time.Now())
|
||||
if err != nil {
|
||||
t.Fatalf("OnIncomingRequest returned error : %s", err)
|
||||
}
|
||||
@ -206,7 +206,7 @@ func TestConnStateInitial(t *testing.T) {
|
||||
{0, 9},
|
||||
}),
|
||||
}},
|
||||
}, false)
|
||||
}, false, time.Now())
|
||||
if err != nil {
|
||||
t.Fatalf("OnIncomingRequest returned error : %s", err)
|
||||
}
|
||||
@ -272,7 +272,7 @@ func TestConnStateMultipleRanges(t *testing.T) {
|
||||
userCache.LazyRoomDataOverride = mockLazyRoomOverride
|
||||
dispatcher.Register(context.Background(), userCache.UserID, userCache)
|
||||
dispatcher.Register(context.Background(), sync3.DispatcherAllUsers, globalCache)
|
||||
cs := NewConnState(userID, deviceID, userCache, globalCache, &NopExtensionHandler{}, &NopJoinTracker{}, nil, 1000)
|
||||
cs := NewConnState(userID, deviceID, userCache, globalCache, &NopExtensionHandler{}, &NopJoinTracker{}, nil, nil, 1000, 0)
|
||||
|
||||
// request first page
|
||||
res, err := cs.OnIncomingRequest(context.Background(), ConnID, &sync3.Request{
|
||||
@ -282,7 +282,7 @@ func TestConnStateMultipleRanges(t *testing.T) {
|
||||
{0, 2},
|
||||
}),
|
||||
}},
|
||||
}, false)
|
||||
}, false, time.Now())
|
||||
if err != nil {
|
||||
t.Fatalf("OnIncomingRequest returned error : %s", err)
|
||||
}
|
||||
@ -308,7 +308,7 @@ func TestConnStateMultipleRanges(t *testing.T) {
|
||||
{0, 2}, {4, 6},
|
||||
}),
|
||||
}},
|
||||
}, false)
|
||||
}, false, time.Now())
|
||||
if err != nil {
|
||||
t.Fatalf("OnIncomingRequest returned error : %s", err)
|
||||
}
|
||||
@ -343,7 +343,7 @@ func TestConnStateMultipleRanges(t *testing.T) {
|
||||
{0, 2}, {4, 6},
|
||||
}),
|
||||
}},
|
||||
}, false)
|
||||
}, false, time.Now())
|
||||
if err != nil {
|
||||
t.Fatalf("OnIncomingRequest returned error : %s", err)
|
||||
}
|
||||
@ -383,7 +383,7 @@ func TestConnStateMultipleRanges(t *testing.T) {
|
||||
{0, 2}, {4, 6},
|
||||
}),
|
||||
}},
|
||||
}, false)
|
||||
}, false, time.Now())
|
||||
if err != nil {
|
||||
t.Fatalf("OnIncomingRequest returned error : %s", err)
|
||||
}
|
||||
@ -451,7 +451,7 @@ func TestBumpToOutsideRange(t *testing.T) {
|
||||
userCache.LazyRoomDataOverride = mockLazyRoomOverride
|
||||
dispatcher.Register(context.Background(), userCache.UserID, userCache)
|
||||
dispatcher.Register(context.Background(), sync3.DispatcherAllUsers, globalCache)
|
||||
cs := NewConnState(userID, deviceID, userCache, globalCache, &NopExtensionHandler{}, &NopJoinTracker{}, nil, 1000)
|
||||
cs := NewConnState(userID, deviceID, userCache, globalCache, &NopExtensionHandler{}, &NopJoinTracker{}, nil, nil, 1000, 0)
|
||||
// Ask for A,B
|
||||
res, err := cs.OnIncomingRequest(context.Background(), ConnID, &sync3.Request{
|
||||
Lists: map[string]sync3.RequestList{"a": {
|
||||
@ -460,7 +460,7 @@ func TestBumpToOutsideRange(t *testing.T) {
|
||||
{0, 1},
|
||||
}),
|
||||
}},
|
||||
}, false)
|
||||
}, false, time.Now())
|
||||
if err != nil {
|
||||
t.Fatalf("OnIncomingRequest returned error : %s", err)
|
||||
}
|
||||
@ -495,7 +495,7 @@ func TestBumpToOutsideRange(t *testing.T) {
|
||||
{0, 1},
|
||||
}),
|
||||
}},
|
||||
}, false)
|
||||
}, false, time.Now())
|
||||
if err != nil {
|
||||
t.Fatalf("OnIncomingRequest returned error : %s", err)
|
||||
}
|
||||
@ -562,7 +562,7 @@ func TestConnStateRoomSubscriptions(t *testing.T) {
|
||||
}
|
||||
dispatcher.Register(context.Background(), userCache.UserID, userCache)
|
||||
dispatcher.Register(context.Background(), sync3.DispatcherAllUsers, globalCache)
|
||||
cs := NewConnState(userID, deviceID, userCache, globalCache, &NopExtensionHandler{}, &NopJoinTracker{}, nil, 1000)
|
||||
cs := NewConnState(userID, deviceID, userCache, globalCache, &NopExtensionHandler{}, &NopJoinTracker{}, nil, nil, 1000, 0)
|
||||
// subscribe to room D
|
||||
res, err := cs.OnIncomingRequest(context.Background(), ConnID, &sync3.Request{
|
||||
RoomSubscriptions: map[string]sync3.RoomSubscription{
|
||||
@ -576,7 +576,7 @@ func TestConnStateRoomSubscriptions(t *testing.T) {
|
||||
{0, 1},
|
||||
}),
|
||||
}},
|
||||
}, false)
|
||||
}, false, time.Now())
|
||||
if err != nil {
|
||||
t.Fatalf("OnIncomingRequest returned error : %s", err)
|
||||
}
|
||||
@ -630,7 +630,7 @@ func TestConnStateRoomSubscriptions(t *testing.T) {
|
||||
{0, 1},
|
||||
}),
|
||||
}},
|
||||
}, false)
|
||||
}, false, time.Now())
|
||||
if err != nil {
|
||||
t.Fatalf("OnIncomingRequest returned error : %s", err)
|
||||
}
|
||||
@ -664,7 +664,7 @@ func TestConnStateRoomSubscriptions(t *testing.T) {
|
||||
{0, 1},
|
||||
}),
|
||||
}},
|
||||
}, false)
|
||||
}, false, time.Now())
|
||||
if err != nil {
|
||||
t.Fatalf("OnIncomingRequest returned error : %s", err)
|
||||
}
|
||||
|
@ -1,9 +1,13 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"github.com/matrix-org/sliding-sync/sync2"
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/matrix-org/sliding-sync/internal"
|
||||
"github.com/matrix-org/sliding-sync/sync2"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
|
||||
"github.com/matrix-org/sliding-sync/pubsub"
|
||||
)
|
||||
|
||||
@ -28,23 +32,38 @@ type EnsurePoller struct {
|
||||
// pendingPolls tracks the status of pollers that we are waiting to start.
|
||||
pendingPolls map[sync2.PollerID]pendingInfo
|
||||
notifier pubsub.Notifier
|
||||
// the total number of outstanding ensurepolling requests.
|
||||
numPendingEnsurePolling prometheus.Gauge
|
||||
}
|
||||
|
||||
func NewEnsurePoller(notifier pubsub.Notifier) *EnsurePoller {
|
||||
return &EnsurePoller{
|
||||
func NewEnsurePoller(notifier pubsub.Notifier, enablePrometheus bool) *EnsurePoller {
|
||||
p := &EnsurePoller{
|
||||
chanName: pubsub.ChanV3,
|
||||
mu: &sync.Mutex{},
|
||||
pendingPolls: make(map[sync2.PollerID]pendingInfo),
|
||||
notifier: notifier,
|
||||
}
|
||||
if enablePrometheus {
|
||||
p.numPendingEnsurePolling = prometheus.NewGauge(prometheus.GaugeOpts{
|
||||
Namespace: "sliding_sync",
|
||||
Subsystem: "api",
|
||||
Name: "num_devices_pending_ensure_polling",
|
||||
Help: "Number of devices blocked on EnsurePolling returning.",
|
||||
})
|
||||
prometheus.MustRegister(p.numPendingEnsurePolling)
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
// EnsurePolling blocks until the V2InitialSyncComplete response is received for this device. It is
|
||||
// the caller's responsibility to call OnInitialSyncComplete when new events arrive.
|
||||
func (p *EnsurePoller) EnsurePolling(pid sync2.PollerID, tokenHash string) {
|
||||
func (p *EnsurePoller) EnsurePolling(ctx context.Context, pid sync2.PollerID, tokenHash string) {
|
||||
ctx, region := internal.StartSpan(ctx, "EnsurePolling")
|
||||
defer region.End()
|
||||
p.mu.Lock()
|
||||
// do we need to wait?
|
||||
if p.pendingPolls[pid].done {
|
||||
internal.Logf(ctx, "EnsurePolling", "user %s device %s already done", pid.UserID, pid.DeviceID)
|
||||
p.mu.Unlock()
|
||||
return
|
||||
}
|
||||
@ -56,7 +75,10 @@ func (p *EnsurePoller) EnsurePolling(pid sync2.PollerID, tokenHash string) {
|
||||
// TODO: several times there have been problems getting the response back from the poller
|
||||
// we should time out here after 100s and return an error or something to kick conns into
|
||||
// trying again
|
||||
internal.Logf(ctx, "EnsurePolling", "user %s device %s channel exits, listening for channel close", pid.UserID, pid.DeviceID)
|
||||
_, r2 := internal.StartSpan(ctx, "waitForExistingChannelClose")
|
||||
<-ch
|
||||
r2.End()
|
||||
return
|
||||
}
|
||||
// Make a channel to wait until we have done an initial sync
|
||||
@ -65,6 +87,7 @@ func (p *EnsurePoller) EnsurePolling(pid sync2.PollerID, tokenHash string) {
|
||||
done: false,
|
||||
ch: ch,
|
||||
}
|
||||
p.calculateNumOutstanding() // increment total
|
||||
p.mu.Unlock()
|
||||
// ask the pollers to poll for this device
|
||||
p.notifier.Notify(p.chanName, &pubsub.V3EnsurePolling{
|
||||
@ -74,10 +97,15 @@ func (p *EnsurePoller) EnsurePolling(pid sync2.PollerID, tokenHash string) {
|
||||
})
|
||||
// if by some miracle the notify AND sync completes before we receive on ch then this is
|
||||
// still fine as recv on a closed channel will return immediately.
|
||||
internal.Logf(ctx, "EnsurePolling", "user %s device %s just made channel, listening for channel close", pid.UserID, pid.DeviceID)
|
||||
_, r2 := internal.StartSpan(ctx, "waitForNewChannelClose")
|
||||
<-ch
|
||||
r2.End()
|
||||
}
|
||||
|
||||
func (p *EnsurePoller) OnInitialSyncComplete(payload *pubsub.V2InitialSyncComplete) {
|
||||
log := logger.With().Str("user", payload.UserID).Str("device", payload.DeviceID).Logger()
|
||||
log.Trace().Msg("OnInitialSyncComplete: got payload")
|
||||
pid := sync2.PollerID{UserID: payload.UserID, DeviceID: payload.DeviceID}
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
@ -86,12 +114,14 @@ func (p *EnsurePoller) OnInitialSyncComplete(payload *pubsub.V2InitialSyncComple
|
||||
if !ok {
|
||||
// This can happen when the v2 poller spontaneously starts polling even without us asking it to
|
||||
// e.g from the database
|
||||
log.Trace().Msg("OnInitialSyncComplete: we weren't waiting for this")
|
||||
p.pendingPolls[pid] = pendingInfo{
|
||||
done: true,
|
||||
}
|
||||
return
|
||||
}
|
||||
if pending.done {
|
||||
log.Trace().Msg("OnInitialSyncComplete: already done")
|
||||
// nothing to do, we just got OnInitialSyncComplete called twice
|
||||
return
|
||||
}
|
||||
@ -101,6 +131,8 @@ func (p *EnsurePoller) OnInitialSyncComplete(payload *pubsub.V2InitialSyncComple
|
||||
pending.done = true
|
||||
pending.ch = nil
|
||||
p.pendingPolls[pid] = pending
|
||||
p.calculateNumOutstanding() // decrement total
|
||||
log.Trace().Msg("OnInitialSyncComplete: closing channel")
|
||||
close(ch)
|
||||
}
|
||||
|
||||
@ -121,4 +153,21 @@ func (p *EnsurePoller) OnExpiredToken(payload *pubsub.V2ExpiredToken) {
|
||||
|
||||
func (p *EnsurePoller) Teardown() {
|
||||
p.notifier.Close()
|
||||
if p.numPendingEnsurePolling != nil {
|
||||
prometheus.Unregister(p.numPendingEnsurePolling)
|
||||
}
|
||||
}
|
||||
|
||||
// must hold p.mu
|
||||
func (p *EnsurePoller) calculateNumOutstanding() {
|
||||
if p.numPendingEnsurePolling == nil {
|
||||
return
|
||||
}
|
||||
var total int
|
||||
for _, pi := range p.pendingPolls {
|
||||
if !pi.done {
|
||||
total++
|
||||
}
|
||||
}
|
||||
p.numPendingEnsurePolling.Set(float64(total))
|
||||
}
|
||||
|
@ -1,6 +1,5 @@
|
||||
package handler
|
||||
|
||||
import "C"
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
@ -59,25 +58,29 @@ type SyncLiveHandler struct {
|
||||
|
||||
GlobalCache *caches.GlobalCache
|
||||
maxPendingEventUpdates int
|
||||
maxTransactionIDDelay time.Duration
|
||||
|
||||
numConns prometheus.Gauge
|
||||
histVec *prometheus.HistogramVec
|
||||
setupHistVec *prometheus.HistogramVec
|
||||
histVec *prometheus.HistogramVec
|
||||
slowReqs prometheus.Counter
|
||||
}
|
||||
|
||||
func NewSync3Handler(
|
||||
store *state.Storage, storev2 *sync2.Storage, v2Client sync2.Client, postgresDBURI, secret string,
|
||||
store *state.Storage, storev2 *sync2.Storage, v2Client sync2.Client, secret string,
|
||||
pub pubsub.Notifier, sub pubsub.Listener, enablePrometheus bool, maxPendingEventUpdates int,
|
||||
maxTransactionIDDelay time.Duration,
|
||||
) (*SyncLiveHandler, error) {
|
||||
logger.Info().Msg("creating handler")
|
||||
sh := &SyncLiveHandler{
|
||||
V2: v2Client,
|
||||
Storage: store,
|
||||
V2Store: storev2,
|
||||
ConnMap: sync3.NewConnMap(),
|
||||
ConnMap: sync3.NewConnMap(enablePrometheus),
|
||||
userCaches: &sync.Map{},
|
||||
Dispatcher: sync3.NewDispatcher(),
|
||||
GlobalCache: caches.NewGlobalCache(store),
|
||||
maxPendingEventUpdates: maxPendingEventUpdates,
|
||||
maxTransactionIDDelay: maxTransactionIDDelay,
|
||||
}
|
||||
sh.Extensions = &extensions.Handler{
|
||||
Store: store,
|
||||
@ -91,7 +94,7 @@ func NewSync3Handler(
|
||||
}
|
||||
|
||||
// set up pubsub mechanism to start from this point
|
||||
sh.EnsurePoller = NewEnsurePoller(pub)
|
||||
sh.EnsurePoller = NewEnsurePoller(pub, enablePrometheus)
|
||||
sh.V2Sub = pubsub.NewV2Sub(sub, sh)
|
||||
|
||||
return sh, nil
|
||||
@ -127,37 +130,41 @@ func (h *SyncLiveHandler) Teardown() {
|
||||
h.V2Sub.Teardown()
|
||||
h.EnsurePoller.Teardown()
|
||||
h.ConnMap.Teardown()
|
||||
if h.numConns != nil {
|
||||
prometheus.Unregister(h.numConns)
|
||||
if h.setupHistVec != nil {
|
||||
prometheus.Unregister(h.setupHistVec)
|
||||
}
|
||||
if h.histVec != nil {
|
||||
prometheus.Unregister(h.histVec)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *SyncLiveHandler) updateMetrics() {
|
||||
if h.numConns == nil {
|
||||
return
|
||||
if h.slowReqs != nil {
|
||||
prometheus.Unregister(h.slowReqs)
|
||||
}
|
||||
h.numConns.Set(float64(h.ConnMap.Len()))
|
||||
}
|
||||
|
||||
func (h *SyncLiveHandler) addPrometheusMetrics() {
|
||||
h.numConns = prometheus.NewGauge(prometheus.GaugeOpts{
|
||||
h.setupHistVec = prometheus.NewHistogramVec(prometheus.HistogramOpts{
|
||||
Namespace: "sliding_sync",
|
||||
Subsystem: "api",
|
||||
Name: "num_active_conns",
|
||||
Help: "Number of active sliding sync connections.",
|
||||
})
|
||||
Name: "setup_duration_secs",
|
||||
Help: "Time taken in seconds after receiving a request before we start calculating a sliding sync response.",
|
||||
Buckets: []float64{0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5, 10},
|
||||
}, []string{"initial"})
|
||||
h.histVec = prometheus.NewHistogramVec(prometheus.HistogramOpts{
|
||||
Namespace: "sliding_sync",
|
||||
Subsystem: "api",
|
||||
Name: "process_duration_secs",
|
||||
Help: "Time taken in seconds for the sliding sync response to calculated, excludes long polling",
|
||||
Help: "Time taken in seconds for the sliding sync response to be calculated, excludes long polling",
|
||||
Buckets: []float64{0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5, 10},
|
||||
}, []string{"initial"})
|
||||
prometheus.MustRegister(h.numConns)
|
||||
h.slowReqs = prometheus.NewCounter(prometheus.CounterOpts{
|
||||
Namespace: "sliding_sync",
|
||||
Subsystem: "api",
|
||||
Name: "slow_requests",
|
||||
Help: "Counter of slow (>=50s) requests, initial or otherwise.",
|
||||
})
|
||||
prometheus.MustRegister(h.setupHistVec)
|
||||
prometheus.MustRegister(h.histVec)
|
||||
prometheus.MustRegister(h.slowReqs)
|
||||
}
|
||||
|
||||
func (h *SyncLiveHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||
@ -174,9 +181,13 @@ func (h *SyncLiveHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
// artificially wait a bit before sending back the error
|
||||
// this guards against tightlooping when the client hammers the server with invalid requests
|
||||
time.Sleep(time.Second)
|
||||
if herr.ErrCode != "M_UNKNOWN_POS" {
|
||||
// artificially wait a bit before sending back the error
|
||||
// this guards against tightlooping when the client hammers the server with invalid requests,
|
||||
// but not for M_UNKNOWN_POS which we expect to send back after expiring a client's connection.
|
||||
// We want to recover rapidly in that scenario, hence not sleeping.
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
w.WriteHeader(herr.StatusCode)
|
||||
w.Write(herr.JSON())
|
||||
}
|
||||
@ -184,6 +195,16 @@ func (h *SyncLiveHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||
|
||||
// Entry point for sync v3
|
||||
func (h *SyncLiveHandler) serve(w http.ResponseWriter, req *http.Request) error {
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
dur := time.Since(start)
|
||||
if dur > 50*time.Second {
|
||||
if h.slowReqs != nil {
|
||||
h.slowReqs.Add(1.0)
|
||||
}
|
||||
internal.DecorateLogger(req.Context(), log.Warn()).Dur("duration", dur).Msg("slow request")
|
||||
}
|
||||
}()
|
||||
var requestBody sync3.Request
|
||||
if req.ContentLength != 0 {
|
||||
defer req.Body.Close()
|
||||
@ -233,7 +254,6 @@ func (h *SyncLiveHandler) serve(w http.ResponseWriter, req *http.Request) error
|
||||
return herr
|
||||
}
|
||||
requestBody.SetPos(cpos)
|
||||
internal.SetRequestContextUserID(req.Context(), conn.UserID)
|
||||
log := hlog.FromRequest(req).With().Str("user", conn.UserID).Int64("pos", cpos).Logger()
|
||||
|
||||
var timeout int
|
||||
@ -250,7 +270,7 @@ func (h *SyncLiveHandler) serve(w http.ResponseWriter, req *http.Request) error
|
||||
requestBody.SetTimeoutMSecs(timeout)
|
||||
log.Trace().Int("timeout", timeout).Msg("recv")
|
||||
|
||||
resp, herr := conn.OnIncomingRequest(req.Context(), &requestBody)
|
||||
resp, herr := conn.OnIncomingRequest(req.Context(), &requestBody, start)
|
||||
if herr != nil {
|
||||
logErrorOrWarning("failed to OnIncomingRequest", herr)
|
||||
return herr
|
||||
@ -271,7 +291,7 @@ func (h *SyncLiveHandler) serve(w http.ResponseWriter, req *http.Request) error
|
||||
}
|
||||
internal.SetRequestContextResponseInfo(
|
||||
req.Context(), cpos, resp.PosInt(), len(resp.Rooms), requestBody.TxnID, numToDeviceEvents, numGlobalAccountData,
|
||||
numChangedDevices, numLeftDevices,
|
||||
numChangedDevices, numLeftDevices, requestBody.ConnID, len(requestBody.Lists), len(requestBody.RoomSubscriptions), len(requestBody.UnsubscribeRooms),
|
||||
)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
@ -300,8 +320,9 @@ func (h *SyncLiveHandler) serve(w http.ResponseWriter, req *http.Request) error
|
||||
// setupConnection associates this request with an existing connection or makes a new connection.
|
||||
// It also sets a v2 sync poll loop going if one didn't exist already for this user.
|
||||
// When this function returns, the connection is alive and active.
|
||||
|
||||
func (h *SyncLiveHandler) setupConnection(req *http.Request, syncReq *sync3.Request, containsPos bool) (*sync3.Conn, *internal.HandlerError) {
|
||||
taskCtx, task := internal.StartTask(req.Context(), "setupConnection")
|
||||
defer task.End()
|
||||
var conn *sync3.Conn
|
||||
// Extract an access token
|
||||
accessToken, err := internal.ExtractAccessToken(req)
|
||||
@ -333,6 +354,8 @@ func (h *SyncLiveHandler) setupConnection(req *http.Request, syncReq *sync3.Requ
|
||||
}
|
||||
}
|
||||
log := hlog.FromRequest(req).With().Str("user", token.UserID).Str("device", token.DeviceID).Logger()
|
||||
internal.SetRequestContextUserID(req.Context(), token.UserID, token.DeviceID)
|
||||
internal.Logf(taskCtx, "setupConnection", "identified access token as user=%s device=%s", token.UserID, token.DeviceID)
|
||||
|
||||
// Record the fact that we've recieved a request from this token
|
||||
err = h.V2Store.TokensTable.MaybeUpdateLastSeen(token, time.Now())
|
||||
@ -358,9 +381,9 @@ func (h *SyncLiveHandler) setupConnection(req *http.Request, syncReq *sync3.Requ
|
||||
return nil, internal.ExpiredSessionError()
|
||||
}
|
||||
|
||||
log.Trace().Msg("checking poller exists and is running")
|
||||
pid := sync2.PollerID{UserID: token.UserID, DeviceID: token.DeviceID}
|
||||
h.EnsurePoller.EnsurePolling(pid, token.AccessTokenHash)
|
||||
log.Trace().Any("pid", pid).Msg("checking poller exists and is running")
|
||||
h.EnsurePoller.EnsurePolling(req.Context(), pid, token.AccessTokenHash)
|
||||
log.Trace().Msg("poller exists and is running")
|
||||
// this may take a while so if the client has given up (e.g timed out) by this point, just stop.
|
||||
// We'll be quicker next time as the poller will already exist.
|
||||
@ -382,7 +405,7 @@ func (h *SyncLiveHandler) setupConnection(req *http.Request, syncReq *sync3.Requ
|
||||
}
|
||||
|
||||
// once we have the conn, make sure our metrics are correct
|
||||
defer h.updateMetrics()
|
||||
defer h.ConnMap.UpdateMetrics()
|
||||
|
||||
// Now the v2 side of things are running, we can make a v3 live sync conn
|
||||
// NB: this isn't inherently racey (we did the check for an existing conn before EnsurePolling)
|
||||
@ -390,7 +413,7 @@ func (h *SyncLiveHandler) setupConnection(req *http.Request, syncReq *sync3.Requ
|
||||
// to check for an existing connection though, as it's possible for the client to call /sync
|
||||
// twice for a new connection.
|
||||
conn, created := h.ConnMap.CreateConn(connID, func() sync3.ConnHandler {
|
||||
return NewConnState(token.UserID, token.DeviceID, userCache, h.GlobalCache, h.Extensions, h.Dispatcher, h.histVec, h.maxPendingEventUpdates)
|
||||
return NewConnState(token.UserID, token.DeviceID, userCache, h.GlobalCache, h.Extensions, h.Dispatcher, h.setupHistVec, h.histVec, h.maxPendingEventUpdates, h.maxTransactionIDDelay)
|
||||
})
|
||||
if created {
|
||||
log.Info().Msg("created new connection")
|
||||
@ -408,6 +431,7 @@ func (h *SyncLiveHandler) identifyUnknownAccessToken(accessToken string, logger
|
||||
return nil, &internal.HandlerError{
|
||||
StatusCode: 401,
|
||||
Err: fmt.Errorf("/whoami returned HTTP 401"),
|
||||
ErrCode: "M_UNKNOWN_TOKEN",
|
||||
}
|
||||
}
|
||||
log.Warn().Err(err).Msg("failed to get user ID from device ID")
|
||||
@ -420,14 +444,14 @@ func (h *SyncLiveHandler) identifyUnknownAccessToken(accessToken string, logger
|
||||
var token *sync2.Token
|
||||
err = sqlutil.WithTransaction(h.V2Store.DB, func(txn *sqlx.Tx) error {
|
||||
// Create a brand-new row for this token.
|
||||
token, err = h.V2Store.TokensTable.Insert(accessToken, userID, deviceID, time.Now())
|
||||
token, err = h.V2Store.TokensTable.Insert(txn, accessToken, userID, deviceID, time.Now())
|
||||
if err != nil {
|
||||
logger.Warn().Err(err).Str("user", userID).Str("device", deviceID).Msg("failed to insert v2 token")
|
||||
return err
|
||||
}
|
||||
|
||||
// Ensure we have a device row for this token.
|
||||
err = h.V2Store.DevicesTable.InsertDevice(userID, deviceID)
|
||||
err = h.V2Store.DevicesTable.InsertDevice(txn, userID, deviceID)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Str("user", userID).Str("device", deviceID).Msg("failed to insert v2 device")
|
||||
return err
|
||||
@ -484,6 +508,15 @@ func (h *SyncLiveHandler) userCache(userID string) (*caches.UserCache, error) {
|
||||
uc.OnAccountData(context.Background(), []state.AccountData{directEvent[0]})
|
||||
}
|
||||
|
||||
// select the ignored users account data event and set ignored user list
|
||||
ignoreEvent, err := h.Storage.AccountData(userID, sync2.AccountDataGlobalRoom, []string{"m.ignored_user_list"})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load ignored user list for user %s: %w", userID, err)
|
||||
}
|
||||
if len(ignoreEvent) == 1 {
|
||||
uc.OnAccountData(context.Background(), []state.AccountData{ignoreEvent[0]})
|
||||
}
|
||||
|
||||
// select all room tag account data and set it
|
||||
tagEvents, err := h.Storage.RoomAccountDatasWithType(userID, "m.tag")
|
||||
if err != nil {
|
||||
@ -601,7 +634,11 @@ func (h *SyncLiveHandler) Accumulate(p *pubsub.V2Accumulate) {
|
||||
func (h *SyncLiveHandler) OnTransactionID(p *pubsub.V2TransactionID) {
|
||||
_, task := internal.StartTask(context.Background(), "TransactionID")
|
||||
defer task.End()
|
||||
// TODO implement me
|
||||
|
||||
// There is some event E for which we now have a transaction ID, or else now know
|
||||
// that we will never get a transaction ID. In either case, tell the sender's
|
||||
// connections to unblock that event in the transaction ID waiter.
|
||||
h.ConnMap.ClearUpdateQueues(p.UserID, p.RoomID, p.NID)
|
||||
}
|
||||
|
||||
// Called from the v2 poller, implements V2DataReceiver
|
||||
@ -632,9 +669,14 @@ func (h *SyncLiveHandler) OnUnreadCounts(p *pubsub.V2UnreadCounts) {
|
||||
func (h *SyncLiveHandler) OnDeviceData(p *pubsub.V2DeviceData) {
|
||||
ctx, task := internal.StartTask(context.Background(), "OnDeviceData")
|
||||
defer task.End()
|
||||
conns := h.ConnMap.Conns(p.UserID, p.DeviceID)
|
||||
for _, conn := range conns {
|
||||
conn.OnUpdate(ctx, caches.DeviceDataUpdate{})
|
||||
internal.Logf(ctx, "device_data", fmt.Sprintf("%v users to notify", len(p.UserIDToDeviceIDs)))
|
||||
for userID, deviceIDs := range p.UserIDToDeviceIDs {
|
||||
for _, deviceID := range deviceIDs {
|
||||
conns := h.ConnMap.Conns(userID, deviceID)
|
||||
for _, conn := range conns {
|
||||
conn.OnUpdate(ctx, caches.DeviceDataUpdate{})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -671,7 +713,7 @@ func (h *SyncLiveHandler) OnLeftRoom(p *pubsub.V2LeaveRoom) {
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
userCache.(*caches.UserCache).OnLeftRoom(ctx, p.RoomID)
|
||||
userCache.(*caches.UserCache).OnLeftRoom(ctx, p.RoomID, p.LeaveEvent)
|
||||
}
|
||||
|
||||
func (h *SyncLiveHandler) OnReceipt(p *pubsub.V2Receipt) {
|
||||
|
94
sync3/handler/txn_id_waiter.go
Normal file
94
sync3/handler/txn_id_waiter.go
Normal file
@ -0,0 +1,94 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"github.com/matrix-org/sliding-sync/sync3/caches"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type TxnIDWaiter struct {
|
||||
userID string
|
||||
publish func(delayed bool, update caches.Update)
|
||||
// mu guards the queues map.
|
||||
mu sync.Mutex
|
||||
queues map[string][]*caches.RoomEventUpdate
|
||||
maxDelay time.Duration
|
||||
}
|
||||
|
||||
func NewTxnIDWaiter(userID string, maxDelay time.Duration, publish func(bool, caches.Update)) *TxnIDWaiter {
|
||||
return &TxnIDWaiter{
|
||||
userID: userID,
|
||||
publish: publish,
|
||||
mu: sync.Mutex{},
|
||||
queues: make(map[string][]*caches.RoomEventUpdate),
|
||||
maxDelay: maxDelay,
|
||||
// TODO: metric that tracks how long events were queued for.
|
||||
}
|
||||
}
|
||||
|
||||
func (t *TxnIDWaiter) Ingest(up caches.Update) {
|
||||
if t.maxDelay <= 0 {
|
||||
t.publish(false, up)
|
||||
return
|
||||
}
|
||||
|
||||
eventUpdate, isEventUpdate := up.(*caches.RoomEventUpdate)
|
||||
if !isEventUpdate {
|
||||
t.publish(false, up)
|
||||
return
|
||||
}
|
||||
|
||||
ed := eventUpdate.EventData
|
||||
|
||||
// An event should be queued if
|
||||
// - it's a state event that our user sent, lacking a txn_id; OR
|
||||
// - the room already has queued events.
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
_, roomQueued := t.queues[ed.RoomID]
|
||||
missingTxnID := ed.StateKey == nil && ed.Sender == t.userID && ed.TransactionID == ""
|
||||
if !(missingTxnID || roomQueued) {
|
||||
t.publish(false, up)
|
||||
return
|
||||
}
|
||||
|
||||
// We've decided to queue the event.
|
||||
queue, exists := t.queues[ed.RoomID]
|
||||
if !exists {
|
||||
queue = make([]*caches.RoomEventUpdate, 0, 10)
|
||||
}
|
||||
// TODO: bound the queue size?
|
||||
t.queues[ed.RoomID] = append(queue, eventUpdate)
|
||||
|
||||
time.AfterFunc(t.maxDelay, func() { t.PublishUpToNID(ed.RoomID, ed.NID) })
|
||||
}
|
||||
|
||||
func (t *TxnIDWaiter) PublishUpToNID(roomID string, publishNID int64) {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
|
||||
queue, exists := t.queues[roomID]
|
||||
if !exists {
|
||||
return
|
||||
}
|
||||
|
||||
var i int
|
||||
for i = 0; i < len(queue); i++ {
|
||||
// Scan forwards through the queue until we find an event with nid > publishNID.
|
||||
if queue[i].EventData.NID > publishNID {
|
||||
break
|
||||
}
|
||||
}
|
||||
// Now queue[:i] has events with nid <= publishNID, and queue[i:] has nids > publishNID.
|
||||
// strip off the first i events from the slice and publish them.
|
||||
toPublish, queue := queue[:i], queue[i:]
|
||||
if len(queue) == 0 {
|
||||
delete(t.queues, roomID)
|
||||
} else {
|
||||
t.queues[roomID] = queue
|
||||
}
|
||||
|
||||
for _, eventUpdate := range toPublish {
|
||||
t.publish(true, eventUpdate)
|
||||
}
|
||||
}
|
390
sync3/handler/txn_id_waiter_test.go
Normal file
390
sync3/handler/txn_id_waiter_test.go
Normal file
@ -0,0 +1,390 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"github.com/matrix-org/sliding-sync/sync3/caches"
|
||||
"github.com/tidwall/gjson"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
type publishArg struct {
|
||||
delayed bool
|
||||
update caches.Update
|
||||
}
|
||||
|
||||
// Test that
|
||||
// - events are (reported as being) delayed when we expect them to be
|
||||
// - delayed events are automatically published after the maximum delay period
|
||||
func TestTxnIDWaiter_QueuingLogic(t *testing.T) {
|
||||
const alice = "alice"
|
||||
const bob = "bob"
|
||||
const room1 = "!theroom"
|
||||
const room2 = "!daszimmer"
|
||||
|
||||
testCases := []struct {
|
||||
Name string
|
||||
Ingest []caches.Update
|
||||
WaitForUpdate int
|
||||
ExpectDelayed bool
|
||||
}{
|
||||
{
|
||||
Name: "empty queue, non-event update",
|
||||
Ingest: []caches.Update{&caches.AccountDataUpdate{}},
|
||||
WaitForUpdate: 0,
|
||||
ExpectDelayed: false,
|
||||
},
|
||||
{
|
||||
Name: "empty queue, event update, another sender",
|
||||
Ingest: []caches.Update{
|
||||
&caches.RoomEventUpdate{
|
||||
EventData: &caches.EventData{
|
||||
RoomID: room1,
|
||||
Sender: bob,
|
||||
},
|
||||
}},
|
||||
WaitForUpdate: 0,
|
||||
ExpectDelayed: false,
|
||||
},
|
||||
{
|
||||
Name: "empty queue, event update, has txn_id",
|
||||
Ingest: []caches.Update{
|
||||
&caches.RoomEventUpdate{
|
||||
EventData: &caches.EventData{
|
||||
RoomID: room1,
|
||||
Sender: alice,
|
||||
TransactionID: "txntxntxn",
|
||||
},
|
||||
}},
|
||||
WaitForUpdate: 0,
|
||||
ExpectDelayed: false,
|
||||
},
|
||||
{
|
||||
Name: "empty queue, event update, no txn_id",
|
||||
Ingest: []caches.Update{
|
||||
&caches.RoomEventUpdate{
|
||||
EventData: &caches.EventData{
|
||||
RoomID: room1,
|
||||
Sender: alice,
|
||||
TransactionID: "",
|
||||
},
|
||||
}},
|
||||
WaitForUpdate: 0,
|
||||
ExpectDelayed: true,
|
||||
},
|
||||
{
|
||||
Name: "nonempty queue, non-event update",
|
||||
Ingest: []caches.Update{
|
||||
&caches.RoomEventUpdate{
|
||||
EventData: &caches.EventData{
|
||||
RoomID: room1,
|
||||
Sender: alice,
|
||||
TransactionID: "",
|
||||
NID: 1,
|
||||
},
|
||||
},
|
||||
&caches.AccountDataUpdate{},
|
||||
},
|
||||
WaitForUpdate: 1,
|
||||
ExpectDelayed: false, // not a room event, no need to queued behind alice's event
|
||||
},
|
||||
{
|
||||
Name: "empty queue, join event for sender",
|
||||
Ingest: []caches.Update{
|
||||
&caches.RoomEventUpdate{
|
||||
EventData: &caches.EventData{
|
||||
RoomID: room1,
|
||||
Sender: alice,
|
||||
TransactionID: "",
|
||||
NID: 1,
|
||||
EventType: "m.room.member",
|
||||
StateKey: ptr(alice),
|
||||
Content: gjson.Parse(`{"membership": "join"}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
WaitForUpdate: 0,
|
||||
ExpectDelayed: false,
|
||||
},
|
||||
{
|
||||
Name: "nonempty queue, join event for sender",
|
||||
Ingest: []caches.Update{
|
||||
&caches.RoomEventUpdate{
|
||||
EventData: &caches.EventData{
|
||||
RoomID: room1,
|
||||
Sender: alice,
|
||||
TransactionID: "",
|
||||
NID: 1,
|
||||
},
|
||||
},
|
||||
&caches.RoomEventUpdate{
|
||||
EventData: &caches.EventData{
|
||||
RoomID: room1,
|
||||
Sender: alice,
|
||||
TransactionID: "",
|
||||
NID: 2,
|
||||
EventType: "m.room.member",
|
||||
StateKey: ptr(alice),
|
||||
Content: gjson.Parse(`{"membership": "join"}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
WaitForUpdate: 1,
|
||||
ExpectDelayed: true,
|
||||
},
|
||||
|
||||
{
|
||||
Name: "nonempty queue, event update, different sender",
|
||||
Ingest: []caches.Update{
|
||||
&caches.RoomEventUpdate{
|
||||
EventData: &caches.EventData{
|
||||
RoomID: room1,
|
||||
Sender: alice,
|
||||
TransactionID: "",
|
||||
NID: 1,
|
||||
},
|
||||
},
|
||||
&caches.RoomEventUpdate{
|
||||
EventData: &caches.EventData{
|
||||
RoomID: room1,
|
||||
Sender: bob,
|
||||
NID: 2,
|
||||
},
|
||||
},
|
||||
},
|
||||
WaitForUpdate: 1,
|
||||
ExpectDelayed: true, // should be queued behind alice's event
|
||||
},
|
||||
{
|
||||
Name: "nonempty queue, event update, has txn_id",
|
||||
Ingest: []caches.Update{
|
||||
&caches.RoomEventUpdate{
|
||||
EventData: &caches.EventData{
|
||||
RoomID: room1,
|
||||
Sender: alice,
|
||||
TransactionID: "",
|
||||
NID: 1,
|
||||
},
|
||||
},
|
||||
&caches.RoomEventUpdate{
|
||||
EventData: &caches.EventData{
|
||||
RoomID: room1,
|
||||
Sender: alice,
|
||||
NID: 2,
|
||||
TransactionID: "I have a txn",
|
||||
},
|
||||
},
|
||||
},
|
||||
WaitForUpdate: 1,
|
||||
ExpectDelayed: true, // should still be queued behind alice's first event
|
||||
},
|
||||
{
|
||||
Name: "existence of queue only matters per-room",
|
||||
Ingest: []caches.Update{
|
||||
&caches.RoomEventUpdate{
|
||||
EventData: &caches.EventData{
|
||||
RoomID: room1,
|
||||
Sender: alice,
|
||||
TransactionID: "",
|
||||
NID: 1,
|
||||
},
|
||||
},
|
||||
&caches.RoomEventUpdate{
|
||||
EventData: &caches.EventData{
|
||||
RoomID: room2,
|
||||
Sender: alice,
|
||||
NID: 2,
|
||||
TransactionID: "I have a txn",
|
||||
},
|
||||
},
|
||||
},
|
||||
WaitForUpdate: 1,
|
||||
ExpectDelayed: false, // queue only tracks room1
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.Name, func(t *testing.T) {
|
||||
updates := make(chan publishArg, 100)
|
||||
publish := func(delayed bool, update caches.Update) {
|
||||
updates <- publishArg{delayed, update}
|
||||
}
|
||||
|
||||
w := NewTxnIDWaiter(alice, time.Millisecond, publish)
|
||||
|
||||
for _, up := range tc.Ingest {
|
||||
w.Ingest(up)
|
||||
}
|
||||
|
||||
wantedUpdate := tc.Ingest[tc.WaitForUpdate]
|
||||
var got publishArg
|
||||
WaitForSelectedUpdate:
|
||||
for {
|
||||
select {
|
||||
case got = <-updates:
|
||||
t.Logf("Got update %v", got.update)
|
||||
if got.update == wantedUpdate {
|
||||
break WaitForSelectedUpdate
|
||||
}
|
||||
case <-time.After(5 * time.Millisecond):
|
||||
t.Fatalf("Did not see update %v published", wantedUpdate)
|
||||
}
|
||||
}
|
||||
|
||||
if got.delayed != tc.ExpectDelayed {
|
||||
t.Errorf("Got delayed=%t want delayed=%t", got.delayed, tc.ExpectDelayed)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test that PublishUpToNID
|
||||
// - correctly pops off the start of the queue
|
||||
// - is idempotent
|
||||
// - deletes map entry if queue is empty (so that roomQueued is set correctly)
|
||||
func TestTxnIDWaiter_PublishUpToNID(t *testing.T) {
|
||||
const alice = "@alice:example.com"
|
||||
const room = "!unimportant"
|
||||
var published []publishArg
|
||||
publish := func(delayed bool, update caches.Update) {
|
||||
published = append(published, publishArg{delayed, update})
|
||||
}
|
||||
// Use an hour's expiry to effectively disable expiry.
|
||||
w := NewTxnIDWaiter(alice, time.Hour, publish)
|
||||
// Ingest 5 events, each of which would be queued by themselves.
|
||||
for i := int64(2); i <= 6; i++ {
|
||||
w.Ingest(&caches.RoomEventUpdate{
|
||||
EventData: &caches.EventData{
|
||||
RoomID: room,
|
||||
Sender: alice,
|
||||
TransactionID: "",
|
||||
NID: i,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
t.Log("Queue has nids [2,3,4,5,6]")
|
||||
t.Log("Publishing up to 1 should do nothing")
|
||||
w.PublishUpToNID(room, 1)
|
||||
assertNIDs(t, published, nil)
|
||||
|
||||
t.Log("Publishing up to 3 should yield nids [2, 3] in that order")
|
||||
w.PublishUpToNID(room, 3)
|
||||
assertNIDs(t, published, []int64{2, 3})
|
||||
assertDelayed(t, published[:2])
|
||||
|
||||
t.Log("Publishing up to 3 a second time should do nothing")
|
||||
w.PublishUpToNID(room, 3)
|
||||
assertNIDs(t, published, []int64{2, 3})
|
||||
|
||||
t.Log("Publishing up to 2 at this point should do nothing.")
|
||||
w.PublishUpToNID(room, 2)
|
||||
assertNIDs(t, published, []int64{2, 3})
|
||||
|
||||
t.Log("Publishing up to 6 should yield nids [4, 5, 6] in that order")
|
||||
w.PublishUpToNID(room, 6)
|
||||
assertNIDs(t, published, []int64{2, 3, 4, 5, 6})
|
||||
assertDelayed(t, published[2:5])
|
||||
|
||||
t.Log("Publishing up to 6 a second time should do nothing")
|
||||
w.PublishUpToNID(room, 6)
|
||||
assertNIDs(t, published, []int64{2, 3, 4, 5, 6})
|
||||
|
||||
t.Log("Ingesting another event that doesn't need to be queueing should be published immediately")
|
||||
w.Ingest(&caches.RoomEventUpdate{
|
||||
EventData: &caches.EventData{
|
||||
RoomID: room,
|
||||
Sender: "@notalice:example.com",
|
||||
TransactionID: "",
|
||||
NID: 7,
|
||||
},
|
||||
})
|
||||
assertNIDs(t, published, []int64{2, 3, 4, 5, 6, 7})
|
||||
if published[len(published)-1].delayed {
|
||||
t.Errorf("Final event was delayed, but should have been published immediately")
|
||||
}
|
||||
}
|
||||
|
||||
// Test that PublishUpToNID only publishes in the given room
|
||||
func TestTxnIDWaiter_PublishUpToNID_MultipleRooms(t *testing.T) {
|
||||
const alice = "@alice:example.com"
|
||||
var published []publishArg
|
||||
publish := func(delayed bool, update caches.Update) {
|
||||
published = append(published, publishArg{delayed, update})
|
||||
}
|
||||
// Use an hour's expiry to effectively disable expiry.
|
||||
w := NewTxnIDWaiter(alice, time.Hour, publish)
|
||||
// Ingest four queueable events across two rooms.
|
||||
w.Ingest(&caches.RoomEventUpdate{
|
||||
EventData: &caches.EventData{
|
||||
RoomID: "!room1",
|
||||
Sender: alice,
|
||||
TransactionID: "",
|
||||
NID: 1,
|
||||
},
|
||||
})
|
||||
w.Ingest(&caches.RoomEventUpdate{
|
||||
EventData: &caches.EventData{
|
||||
RoomID: "!room2",
|
||||
Sender: alice,
|
||||
TransactionID: "",
|
||||
NID: 2,
|
||||
},
|
||||
})
|
||||
w.Ingest(&caches.RoomEventUpdate{
|
||||
EventData: &caches.EventData{
|
||||
RoomID: "!room2",
|
||||
Sender: alice,
|
||||
TransactionID: "",
|
||||
NID: 3,
|
||||
},
|
||||
})
|
||||
w.Ingest(&caches.RoomEventUpdate{
|
||||
EventData: &caches.EventData{
|
||||
RoomID: "!room1",
|
||||
Sender: alice,
|
||||
TransactionID: "",
|
||||
NID: 4,
|
||||
},
|
||||
})
|
||||
|
||||
t.Log("Queues are [1, 4] and [2, 3]")
|
||||
t.Log("Publish up to NID 4 in room 1 should yield nids [1, 4]")
|
||||
w.PublishUpToNID("!room1", 4)
|
||||
assertNIDs(t, published, []int64{1, 4})
|
||||
assertDelayed(t, published)
|
||||
|
||||
t.Log("Queues are [1, 4] and [2, 3]")
|
||||
t.Log("Publish up to NID 3 in room 2 should yield nids [2, 3]")
|
||||
w.PublishUpToNID("!room2", 3)
|
||||
assertNIDs(t, published, []int64{1, 4, 2, 3})
|
||||
assertDelayed(t, published)
|
||||
}
|
||||
|
||||
func assertDelayed(t *testing.T, published []publishArg) {
|
||||
t.Helper()
|
||||
for _, p := range published {
|
||||
if !p.delayed {
|
||||
t.Errorf("published arg with NID %d was not delayed, but we expected it to be", p.update.(*caches.RoomEventUpdate).EventData.NID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func assertNIDs(t *testing.T, published []publishArg, expectedNIDs []int64) {
|
||||
t.Helper()
|
||||
if len(published) != len(expectedNIDs) {
|
||||
t.Errorf("Got %d nids, but expected %d", len(published), len(expectedNIDs))
|
||||
}
|
||||
for i := range published {
|
||||
rup, ok := published[i].update.(*caches.RoomEventUpdate)
|
||||
if !ok {
|
||||
t.Errorf("Update %d (%v) was not a RoomEventUpdate", i, published[i].update)
|
||||
}
|
||||
if rup.EventData.NID != expectedNIDs[i] {
|
||||
t.Errorf("Update %d (%v) got nid %d, expected %d", i, *rup, rup.EventData.NID, expectedNIDs[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func ptr(s string) *string {
|
||||
return &s
|
||||
}
|
@ -33,6 +33,7 @@ type RoomListDelta struct {
|
||||
|
||||
type RoomDelta struct {
|
||||
RoomNameChanged bool
|
||||
RoomAvatarChanged bool
|
||||
JoinCountChanged bool
|
||||
InviteCountChanged bool
|
||||
NotificationCountChanged bool
|
||||
@ -73,8 +74,20 @@ func (s *InternalRequestLists) SetRoom(r RoomConnMetadata) (delta RoomDelta) {
|
||||
strings.Trim(internal.CalculateRoomName(&r.RoomMetadata, 5), "#!():_@"),
|
||||
)
|
||||
} else {
|
||||
// XXX: during TestConnectionTimeoutNotReset there is some situation where
|
||||
// r.CanonicalisedName is the empty string. Looking at the SetRoom
|
||||
// call in connstate_live.go, this is because the UserRoomMetadata on
|
||||
// the RoomUpdate has an empty CanonicalisedName. Either
|
||||
// a) that is expected, in which case we should _always_ write to
|
||||
// r.CanonicalisedName here; or
|
||||
// b) that is not expected, in which case... erm, I don't know what
|
||||
// to conclude.
|
||||
r.CanonicalisedName = existing.CanonicalisedName
|
||||
}
|
||||
delta.RoomAvatarChanged = !existing.SameRoomAvatar(&r.RoomMetadata)
|
||||
if delta.RoomAvatarChanged {
|
||||
r.ResolvedAvatarURL = internal.CalculateAvatar(&r.RoomMetadata)
|
||||
}
|
||||
|
||||
// Interpret the timestamp map on r as the changes we should apply atop the
|
||||
// existing timestamps.
|
||||
@ -99,6 +112,7 @@ func (s *InternalRequestLists) SetRoom(r RoomConnMetadata) (delta RoomDelta) {
|
||||
r.CanonicalisedName = strings.ToLower(
|
||||
strings.Trim(internal.CalculateRoomName(&r.RoomMetadata, 5), "#!():_@"),
|
||||
)
|
||||
r.ResolvedAvatarURL = internal.CalculateAvatar(&r.RoomMetadata)
|
||||
// We'll automatically use the LastInterestedEventTimestamps provided by the
|
||||
// caller, so that recency sorts work.
|
||||
}
|
||||
|
@ -21,9 +21,8 @@ type Response struct {
|
||||
Rooms map[string]Room `json:"rooms"`
|
||||
Extensions extensions.Response `json:"extensions"`
|
||||
|
||||
Pos string `json:"pos"`
|
||||
TxnID string `json:"txn_id,omitempty"`
|
||||
Session string `json:"session_id,omitempty"`
|
||||
Pos string `json:"pos"`
|
||||
TxnID string `json:"txn_id,omitempty"`
|
||||
}
|
||||
|
||||
type ResponseList struct {
|
||||
@ -68,9 +67,8 @@ func (r *Response) UnmarshalJSON(b []byte) error {
|
||||
} `json:"lists"`
|
||||
Extensions extensions.Response `json:"extensions"`
|
||||
|
||||
Pos string `json:"pos"`
|
||||
TxnID string `json:"txn_id,omitempty"`
|
||||
Session string `json:"session_id,omitempty"`
|
||||
Pos string `json:"pos"`
|
||||
TxnID string `json:"txn_id,omitempty"`
|
||||
}{}
|
||||
if err := json.Unmarshal(b, &temporary); err != nil {
|
||||
return err
|
||||
@ -78,7 +76,6 @@ func (r *Response) UnmarshalJSON(b []byte) error {
|
||||
r.Rooms = temporary.Rooms
|
||||
r.Pos = temporary.Pos
|
||||
r.TxnID = temporary.TxnID
|
||||
r.Session = temporary.Session
|
||||
r.Extensions = temporary.Extensions
|
||||
r.Lists = make(map[string]ResponseList, len(temporary.Lists))
|
||||
|
||||
|
@ -2,6 +2,7 @@ package sync3
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
||||
"github.com/matrix-org/sliding-sync/internal"
|
||||
|
||||
"github.com/matrix-org/sliding-sync/sync3/caches"
|
||||
@ -9,6 +10,7 @@ import (
|
||||
|
||||
type Room struct {
|
||||
Name string `json:"name,omitempty"`
|
||||
AvatarChange AvatarChange `json:"avatar,omitempty"`
|
||||
RequiredState []json.RawMessage `json:"required_state,omitempty"`
|
||||
Timeline []json.RawMessage `json:"timeline,omitempty"`
|
||||
InviteState []json.RawMessage `json:"invite_state,omitempty"`
|
||||
@ -17,9 +19,10 @@ type Room struct {
|
||||
Initial bool `json:"initial,omitempty"`
|
||||
IsDM bool `json:"is_dm,omitempty"`
|
||||
JoinedCount int `json:"joined_count,omitempty"`
|
||||
InvitedCount int `json:"invited_count,omitempty"`
|
||||
InvitedCount *int `json:"invited_count,omitempty"`
|
||||
PrevBatch string `json:"prev_batch,omitempty"`
|
||||
NumLive int `json:"num_live,omitempty"`
|
||||
Timestamp uint64 `json:"timestamp,omitempty"`
|
||||
}
|
||||
|
||||
// RoomConnMetadata represents a room as seen by one specific connection (hence one
|
||||
|
74
sync3/room_test.go
Normal file
74
sync3/room_test.go
Normal file
@ -0,0 +1,74 @@
|
||||
package sync3
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/tidwall/gjson"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestAvatarChangeMarshalling(t *testing.T) {
|
||||
var url = "mxc://..."
|
||||
testCases := []struct {
|
||||
Name string
|
||||
AvatarChange AvatarChange
|
||||
Check func(avatar gjson.Result) error
|
||||
}{
|
||||
{
|
||||
Name: "Avatar exists",
|
||||
AvatarChange: NewAvatarChange(url),
|
||||
Check: func(avatar gjson.Result) error {
|
||||
if !(avatar.Exists() && avatar.Type == gjson.String && avatar.Str == url) {
|
||||
return fmt.Errorf("unexpected marshalled avatar: got %#v want %s", avatar, url)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "Avatar doesn't exist",
|
||||
AvatarChange: DeletedAvatar,
|
||||
Check: func(avatar gjson.Result) error {
|
||||
if !(avatar.Exists() && avatar.Type == gjson.Null) {
|
||||
return fmt.Errorf("unexpected marshalled Avatar: got %#v want null", avatar)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "Avatar unchanged",
|
||||
AvatarChange: UnchangedAvatar,
|
||||
Check: func(avatar gjson.Result) error {
|
||||
if avatar.Exists() {
|
||||
return fmt.Errorf("unexpected marshalled Avatar: got %#v want omitted", avatar)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.Name, func(t *testing.T) {
|
||||
room := Room{AvatarChange: tc.AvatarChange}
|
||||
marshalled, err := json.Marshal(room)
|
||||
t.Logf("Marshalled to %s", string(marshalled))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
avatar := gjson.GetBytes(marshalled, "avatar")
|
||||
if err = tc.Check(avatar); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var unmarshalled Room
|
||||
err = json.Unmarshal(marshalled, &unmarshalled)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Logf("Unmarshalled to %#v", unmarshalled.AvatarChange)
|
||||
if !reflect.DeepEqual(unmarshalled, room) {
|
||||
t.Fatalf("Unmarshalled struct is different from original")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
@ -2,11 +2,13 @@ package syncv3_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/sliding-sync/sync3"
|
||||
"github.com/matrix-org/sliding-sync/sync3/extensions"
|
||||
"github.com/matrix-org/sliding-sync/testutils"
|
||||
"github.com/matrix-org/sliding-sync/testutils/m"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestAccountDataRespectsExtensionScope(t *testing.T) {
|
||||
@ -99,14 +101,14 @@ func TestAccountDataRespectsExtensionScope(t *testing.T) {
|
||||
alice,
|
||||
room1,
|
||||
"com.example.room",
|
||||
map[string]interface{}{"room": 1, "version": 1},
|
||||
map[string]interface{}{"room": 1, "version": 11},
|
||||
)
|
||||
room2AccountDataEvent := putRoomAccountData(
|
||||
t,
|
||||
alice,
|
||||
room2,
|
||||
"com.example.room",
|
||||
map[string]interface{}{"room": 2, "version": 2},
|
||||
map[string]interface{}{"room": 2, "version": 22},
|
||||
)
|
||||
|
||||
t.Log("Alice syncs until she sees the account data for room 2. She shouldn't see account data for room 1")
|
||||
@ -122,7 +124,90 @@ func TestAccountDataRespectsExtensionScope(t *testing.T) {
|
||||
return m.MatchAccountData(nil, map[string][]json.RawMessage{room2: {room2AccountDataEvent}})(response)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
// Regression test for https://github.com/matrix-org/sliding-sync/issues/189
|
||||
func TestAccountDataDoesntDupe(t *testing.T) {
|
||||
alice := registerNewUser(t)
|
||||
alice2 := *alice
|
||||
alice2.Login(t, "password", "device2")
|
||||
|
||||
// send some initial account data
|
||||
putGlobalAccountData(t, alice, "initial", map[string]interface{}{"foo": "bar"})
|
||||
|
||||
// no devices are polling.
|
||||
// syncing with both devices => only shows 1 copy of this event per connection
|
||||
for _, client := range []*CSAPI{alice, &alice2} {
|
||||
res := client.SlidingSync(t, sync3.Request{
|
||||
Extensions: extensions.Request{
|
||||
AccountData: &extensions.AccountDataRequest{
|
||||
Core: extensions.Core{
|
||||
Enabled: &boolTrue,
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
m.MatchResponse(t, res, MatchGlobalAccountData([]Event{
|
||||
{
|
||||
Type: "m.push_rules",
|
||||
},
|
||||
{
|
||||
Type: "initial",
|
||||
Content: map[string]interface{}{"foo": "bar"},
|
||||
},
|
||||
}))
|
||||
}
|
||||
|
||||
// now both devices are polling, we're going to do the same thing to make sure we only see only 1 copy still.
|
||||
putGlobalAccountData(t, alice, "initial2", map[string]interface{}{"foo2": "bar2"})
|
||||
time.Sleep(time.Second) // TODO: we need to make sure the pollers have seen this and explciitly don't want to use SlidingSyncUntil...
|
||||
var responses []*sync3.Response
|
||||
for _, client := range []*CSAPI{alice, &alice2} {
|
||||
res := client.SlidingSync(t, sync3.Request{
|
||||
Extensions: extensions.Request{
|
||||
AccountData: &extensions.AccountDataRequest{
|
||||
Core: extensions.Core{
|
||||
Enabled: &boolTrue,
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
m.MatchResponse(t, res, MatchGlobalAccountData([]Event{
|
||||
{
|
||||
Type: "m.push_rules",
|
||||
},
|
||||
{
|
||||
Type: "initial",
|
||||
Content: map[string]interface{}{"foo": "bar"},
|
||||
},
|
||||
{
|
||||
Type: "initial2",
|
||||
Content: map[string]interface{}{"foo2": "bar2"},
|
||||
},
|
||||
}))
|
||||
responses = append(responses, res) // we need the pos values later
|
||||
}
|
||||
|
||||
// now we're going to do an incremental sync with account data to make sure we don't see dupes either.
|
||||
putGlobalAccountData(t, alice, "incremental", map[string]interface{}{"foo3": "bar3"})
|
||||
time.Sleep(time.Second) // TODO: we need to make sure the pollers have seen this and explciitly don't want to use SlidingSyncUntil...
|
||||
for i, client := range []*CSAPI{alice, &alice2} {
|
||||
res := client.SlidingSync(t, sync3.Request{
|
||||
Extensions: extensions.Request{
|
||||
AccountData: &extensions.AccountDataRequest{
|
||||
Core: extensions.Core{
|
||||
Enabled: &boolTrue,
|
||||
},
|
||||
},
|
||||
},
|
||||
}, WithPos(responses[i].Pos))
|
||||
m.MatchResponse(t, res, MatchGlobalAccountData([]Event{
|
||||
{
|
||||
Type: "incremental",
|
||||
Content: map[string]interface{}{"foo3": "bar3"},
|
||||
},
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
// putAccountData is a wrapper around SetGlobalAccountData. It returns the account data
|
||||
|
@ -134,6 +134,7 @@ type CSAPI struct {
|
||||
Localpart string
|
||||
AccessToken string
|
||||
DeviceID string
|
||||
AvatarURL string
|
||||
BaseURL string
|
||||
Client *http.Client
|
||||
// how long are we willing to wait for MustSyncUntil.... calls
|
||||
@ -159,6 +160,16 @@ func (c *CSAPI) UploadContent(t *testing.T, fileBody []byte, fileName string, co
|
||||
return GetJSONFieldStr(t, body, "content_uri")
|
||||
}
|
||||
|
||||
// Use an empty string to remove your avatar.
|
||||
func (c *CSAPI) SetAvatar(t *testing.T, avatarURL string) {
|
||||
t.Helper()
|
||||
reqBody := map[string]interface{}{
|
||||
"avatar_url": avatarURL,
|
||||
}
|
||||
c.MustDoFunc(t, "PUT", []string{"_matrix", "client", "v3", "profile", c.UserID, "avatar_url"}, WithJSONBody(t, reqBody))
|
||||
c.AvatarURL = avatarURL
|
||||
}
|
||||
|
||||
// DownloadContent downloads media from the server, returning the raw bytes and the Content-Type. Fails the test on error.
|
||||
func (c *CSAPI) DownloadContent(t *testing.T, mxcUri string) ([]byte, string) {
|
||||
t.Helper()
|
||||
@ -678,16 +689,32 @@ func (c *CSAPI) SlidingSyncUntilMembership(t *testing.T, pos string, roomID stri
|
||||
})
|
||||
}
|
||||
|
||||
return c.SlidingSyncUntilEvent(t, pos, sync3.Request{
|
||||
return c.SlidingSyncUntil(t, pos, sync3.Request{
|
||||
RoomSubscriptions: map[string]sync3.RoomSubscription{
|
||||
roomID: {
|
||||
TimelineLimit: 10,
|
||||
},
|
||||
},
|
||||
}, roomID, Event{
|
||||
Type: "m.room.member",
|
||||
StateKey: &target.UserID,
|
||||
Content: content,
|
||||
}, func(r *sync3.Response) error {
|
||||
room, ok := r.Rooms[roomID]
|
||||
if !ok {
|
||||
return fmt.Errorf("missing room %s", roomID)
|
||||
}
|
||||
for _, got := range room.Timeline {
|
||||
wantEvent := Event{
|
||||
Type: "m.room.member",
|
||||
StateKey: &target.UserID,
|
||||
}
|
||||
if err := eventsEqual([]Event{wantEvent}, []json.RawMessage{got}); err == nil {
|
||||
gotMembership := gjson.GetBytes(got, "content.membership")
|
||||
if gotMembership.Exists() && gotMembership.Type == gjson.String && gotMembership.Str == membership {
|
||||
return nil
|
||||
}
|
||||
} else {
|
||||
t.Log(err)
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("found room %s but missing event", roomID)
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -1,13 +1,86 @@
|
||||
package syncv3_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/sliding-sync/sync3"
|
||||
"github.com/matrix-org/sliding-sync/testutils/m"
|
||||
)
|
||||
|
||||
func TestInvalidTokenReturnsMUnknownTokenError(t *testing.T) {
|
||||
alice := registerNewUser(t)
|
||||
roomID := alice.CreateRoom(t, map[string]interface{}{})
|
||||
// normal sliding sync
|
||||
alice.SlidingSync(t, sync3.Request{
|
||||
ConnID: "A",
|
||||
RoomSubscriptions: map[string]sync3.RoomSubscription{
|
||||
roomID: {
|
||||
TimelineLimit: 1,
|
||||
},
|
||||
},
|
||||
})
|
||||
// invalidate the access token
|
||||
alice.MustDoFunc(t, "POST", []string{"_matrix", "client", "v3", "logout"})
|
||||
// let the proxy realise the token is expired and tell downstream
|
||||
time.Sleep(time.Second)
|
||||
|
||||
var invalidResponses []*http.Response
|
||||
// using the same token now returns a 401 with M_UNKNOWN_TOKEN
|
||||
httpRes := alice.DoFunc(t, "POST", []string{"_matrix", "client", "unstable", "org.matrix.msc3575", "sync"}, WithQueries(url.Values{
|
||||
"timeout": []string{"500"},
|
||||
}), WithJSONBody(t, sync3.Request{
|
||||
ConnID: "A",
|
||||
RoomSubscriptions: map[string]sync3.RoomSubscription{
|
||||
roomID: {
|
||||
TimelineLimit: 1,
|
||||
},
|
||||
},
|
||||
}))
|
||||
invalidResponses = append(invalidResponses, httpRes)
|
||||
// using a bogus access token returns a 401 with M_UNKNOWN_TOKEN
|
||||
alice.AccessToken = "flibble_wibble"
|
||||
httpRes = alice.DoFunc(t, "POST", []string{"_matrix", "client", "unstable", "org.matrix.msc3575", "sync"}, WithQueries(url.Values{
|
||||
"timeout": []string{"500"},
|
||||
}), WithJSONBody(t, sync3.Request{
|
||||
ConnID: "A",
|
||||
RoomSubscriptions: map[string]sync3.RoomSubscription{
|
||||
roomID: {
|
||||
TimelineLimit: 1,
|
||||
},
|
||||
},
|
||||
}))
|
||||
invalidResponses = append(invalidResponses, httpRes)
|
||||
|
||||
for i, httpRes := range invalidResponses {
|
||||
body, err := io.ReadAll(httpRes.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("[%d] failed to read response body: %v", i, err)
|
||||
}
|
||||
|
||||
if httpRes.StatusCode != 401 {
|
||||
t.Errorf("[%d] got HTTP %v want 401: %v", i, httpRes.StatusCode, string(body))
|
||||
}
|
||||
var jsonError struct {
|
||||
Err string `json:"error"`
|
||||
ErrCode string `json:"errcode"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(body, &jsonError); err != nil {
|
||||
t.Fatalf("[%d] failed to unmarshal error response into JSON: %v", i, string(body))
|
||||
}
|
||||
wantErrCode := "M_UNKNOWN_TOKEN"
|
||||
if jsonError.ErrCode != wantErrCode {
|
||||
t.Errorf("[%d] errcode: got %v want %v", i, jsonError.ErrCode, wantErrCode)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Test that you can have multiple connections with the same device, and they work independently.
|
||||
func TestMultipleConns(t *testing.T) {
|
||||
alice := registerNewUser(t)
|
||||
|
92
tests-e2e/ignored_user_test.go
Normal file
92
tests-e2e/ignored_user_test.go
Normal file
@ -0,0 +1,92 @@
|
||||
package syncv3_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/matrix-org/sliding-sync/sync3"
|
||||
"github.com/matrix-org/sliding-sync/testutils/m"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestInvitesFromIgnoredUsersOmitted(t *testing.T) {
|
||||
alice := registerNamedUser(t, "alice")
|
||||
bob := registerNamedUser(t, "bob")
|
||||
nigel := registerNamedUser(t, "nigel")
|
||||
|
||||
t.Log("Nigel create two public rooms. Bob joins both.")
|
||||
room1 := nigel.CreateRoom(t, map[string]any{"preset": "public_chat", "name": "room 1"})
|
||||
room2 := nigel.CreateRoom(t, map[string]any{"preset": "public_chat", "name": "room 2"})
|
||||
bob.JoinRoom(t, room1, nil)
|
||||
bob.JoinRoom(t, room2, nil)
|
||||
|
||||
t.Log("Alice makes a room for dumping sentinel messages.")
|
||||
aliceRoom := alice.CreateRoom(t, map[string]any{"preset": "private_chat"})
|
||||
|
||||
t.Log("Alice ignores Nigel.")
|
||||
alice.SetGlobalAccountData(t, "m.ignored_user_list", map[string]any{
|
||||
"ignored_users": map[string]any{
|
||||
nigel.UserID: map[string]any{},
|
||||
},
|
||||
})
|
||||
|
||||
t.Log("Nigel invites Alice to room 1.")
|
||||
nigel.InviteRoom(t, room1, alice.UserID)
|
||||
|
||||
t.Log("Bob sliding syncs until he sees that invite.")
|
||||
bob.SlidingSyncUntilMembership(t, "", room1, alice, "invite")
|
||||
|
||||
t.Log("Alice sends a sentinel message in her private room.")
|
||||
sentinel := alice.SendEventSynced(t, aliceRoom, Event{
|
||||
Type: "m.room.message",
|
||||
Content: map[string]any{
|
||||
"body": "Hello, world!",
|
||||
"msgtype": "m.text",
|
||||
},
|
||||
})
|
||||
|
||||
t.Log("Alice does an initial sliding sync.")
|
||||
res := alice.SlidingSync(t, sync3.Request{
|
||||
Lists: map[string]sync3.RequestList{
|
||||
"a": {
|
||||
RoomSubscription: sync3.RoomSubscription{
|
||||
TimelineLimit: 20,
|
||||
},
|
||||
Ranges: sync3.SliceRanges{{0, 20}},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
t.Log("Alice should see her sentinel, but not Nigel's invite.")
|
||||
m.MatchResponse(t, res, m.MatchRoomSubscriptionsStrict(
|
||||
map[string][]m.RoomMatcher{
|
||||
aliceRoom: {MatchRoomTimelineMostRecent(1, []Event{{ID: sentinel}})},
|
||||
},
|
||||
))
|
||||
|
||||
t.Log("Nigel invites Alice to room 2.")
|
||||
nigel.InviteRoom(t, room2, alice.UserID)
|
||||
|
||||
t.Log("Bob sliding syncs until he sees that invite.")
|
||||
bob.SlidingSyncUntilMembership(t, "", room1, alice, "invite")
|
||||
|
||||
t.Log("Alice sends a sentinel message in her private room.")
|
||||
sentinel = alice.SendEventSynced(t, aliceRoom, Event{
|
||||
Type: "m.room.message",
|
||||
Content: map[string]any{
|
||||
"body": "Hello, world, again",
|
||||
"msgtype": "m.text",
|
||||
},
|
||||
})
|
||||
|
||||
t.Log("Alice does an incremental sliding sync.")
|
||||
res = alice.SlidingSyncUntil(t, res.Pos, sync3.Request{}, func(response *sync3.Response) error {
|
||||
if m.MatchRoomSubscription(room2)(response) == nil {
|
||||
err := fmt.Errorf("unexpectedly got subscription for room 2 (%s)", room2)
|
||||
t.Error(err)
|
||||
return err
|
||||
}
|
||||
|
||||
gotSentinel := m.MatchRoomSubscription(aliceRoom, MatchRoomTimelineMostRecent(1, []Event{{ID: sentinel}}))
|
||||
return gotSentinel(response)
|
||||
})
|
||||
|
||||
}
|
@ -1297,3 +1297,391 @@ func TestRangeOutsideTotalRooms(t *testing.T) {
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
// Nicked from Synapse's tests, see
|
||||
// https://github.com/matrix-org/synapse/blob/2cacd0849a02d43f88b6c15ee862398159ab827c/tests/test_utils/__init__.py#L154-L161
|
||||
// Resolution: 1×1, MIME type: image/png, Extension: png, Size: 67 B
|
||||
var smallPNG = []byte(
|
||||
"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01\x08\x06\x00\x00\x00\x1f\x15\xc4\x89\x00\x00\x00\nIDATx\x9cc\x00\x01\x00\x00\x05\x00\x01\r\n-\xb4\x00\x00\x00\x00IEND\xaeB`\x82",
|
||||
)
|
||||
|
||||
func TestAvatarFieldInRoomResponse(t *testing.T) {
|
||||
alice := registerNamedUser(t, "alice")
|
||||
bob := registerNamedUser(t, "bob")
|
||||
chris := registerNamedUser(t, "chris")
|
||||
|
||||
avatarURLs := map[string]struct{}{}
|
||||
uploadAvatar := func(client *CSAPI, filename string) string {
|
||||
avatar := alice.UploadContent(t, smallPNG, filename, "image/png")
|
||||
if _, exists := avatarURLs[avatar]; exists {
|
||||
t.Fatalf("New avatar %s has already been uploaded", avatar)
|
||||
}
|
||||
t.Logf("%s is uploaded as %s", filename, avatar)
|
||||
avatarURLs[avatar] = struct{}{}
|
||||
return avatar
|
||||
}
|
||||
|
||||
t.Log("Alice, Bob and Chris upload and set an avatar.")
|
||||
aliceAvatar := uploadAvatar(alice, "alice.png")
|
||||
bobAvatar := uploadAvatar(bob, "bob.png")
|
||||
chrisAvatar := uploadAvatar(chris, "chris.png")
|
||||
|
||||
alice.SetAvatar(t, aliceAvatar)
|
||||
bob.SetAvatar(t, bobAvatar)
|
||||
chris.SetAvatar(t, chrisAvatar)
|
||||
|
||||
t.Log("Alice makes a public room, a DM with herself, a DM with Bob, a DM with Chris, and a group-DM with Bob and Chris.")
|
||||
public := alice.CreateRoom(t, map[string]interface{}{"preset": "public_chat"})
|
||||
// TODO: you can create a DM with yourself e.g. as below. It probably ought to have
|
||||
// your own face as an avatar.
|
||||
// dmAlice := alice.CreateRoom(t, map[string]interface{}{
|
||||
// "preset": "trusted_private_chat",
|
||||
// "is_direct": true,
|
||||
// })
|
||||
dmBob := alice.CreateRoom(t, map[string]interface{}{
|
||||
"preset": "trusted_private_chat",
|
||||
"is_direct": true,
|
||||
"invite": []string{bob.UserID},
|
||||
})
|
||||
dmChris := alice.CreateRoom(t, map[string]interface{}{
|
||||
"preset": "trusted_private_chat",
|
||||
"is_direct": true,
|
||||
"invite": []string{chris.UserID},
|
||||
})
|
||||
dmBobChris := alice.CreateRoom(t, map[string]interface{}{
|
||||
"preset": "trusted_private_chat",
|
||||
"is_direct": true,
|
||||
"invite": []string{bob.UserID, chris.UserID},
|
||||
})
|
||||
|
||||
t.Logf("Rooms:\npublic=%s\ndmBob=%s\ndmChris=%s\ndmBobChris=%s", public, dmBob, dmChris, dmBobChris)
|
||||
t.Log("Bob accepts his invites. Chris accepts none.")
|
||||
bob.JoinRoom(t, dmBob, nil)
|
||||
bob.JoinRoom(t, dmBobChris, nil)
|
||||
|
||||
t.Log("Alice makes an initial sliding sync.")
|
||||
res := alice.SlidingSync(t, sync3.Request{
|
||||
Lists: map[string]sync3.RequestList{
|
||||
"rooms": {
|
||||
Ranges: sync3.SliceRanges{{0, 4}},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
t.Log("Alice should see each room in the sync response with an appropriate avatar")
|
||||
m.MatchResponse(
|
||||
t,
|
||||
res,
|
||||
m.MatchRoomSubscription(public, m.MatchRoomUnsetAvatar()),
|
||||
m.MatchRoomSubscription(dmBob, m.MatchRoomAvatar(bob.AvatarURL)),
|
||||
m.MatchRoomSubscription(dmChris, m.MatchRoomAvatar(chris.AvatarURL)),
|
||||
m.MatchRoomSubscription(dmBobChris, m.MatchRoomUnsetAvatar()),
|
||||
)
|
||||
|
||||
t.Run("Avatar not resent on message", func(t *testing.T) {
|
||||
t.Log("Bob sends a sentinel message.")
|
||||
sentinel := bob.SendEventSynced(t, dmBob, Event{
|
||||
Type: "m.room.message",
|
||||
Content: map[string]interface{}{
|
||||
"body": "Hello world",
|
||||
"msgtype": "m.text",
|
||||
},
|
||||
})
|
||||
|
||||
t.Log("Alice syncs until she sees the sentinel. She should not see the DM avatar change.")
|
||||
res = alice.SlidingSyncUntil(t, res.Pos, sync3.Request{}, func(response *sync3.Response) error {
|
||||
matchNoAvatarChange := m.MatchRoomSubscription(dmBob, m.MatchRoomUnchangedAvatar())
|
||||
if err := matchNoAvatarChange(response); err != nil {
|
||||
t.Fatalf("Saw DM avatar change: %s", err)
|
||||
}
|
||||
matchSentinel := m.MatchRoomSubscription(dmBob, MatchRoomTimelineMostRecent(1, []Event{{ID: sentinel}}))
|
||||
return matchSentinel(response)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("DM declined", func(t *testing.T) {
|
||||
t.Log("Chris leaves his DM with Alice.")
|
||||
chris.LeaveRoom(t, dmChris)
|
||||
|
||||
t.Log("Alice syncs until she sees Chris's leave.")
|
||||
res = alice.SlidingSyncUntilMembership(t, res.Pos, dmChris, chris, "leave")
|
||||
|
||||
t.Log("Alice sees Chris's avatar vanish.")
|
||||
m.MatchResponse(t, res, m.MatchRoomSubscription(dmChris, m.MatchRoomUnsetAvatar()))
|
||||
})
|
||||
|
||||
t.Run("Group DM declined", func(t *testing.T) {
|
||||
t.Log("Chris leaves his group DM with Alice and Bob.")
|
||||
chris.LeaveRoom(t, dmBobChris)
|
||||
|
||||
t.Log("Alice syncs until she sees Chris's leave.")
|
||||
res = alice.SlidingSyncUntilMembership(t, res.Pos, dmBobChris, chris, "leave")
|
||||
|
||||
t.Log("Alice sees the room's avatar change to Bob's avatar.")
|
||||
// Because this is now a DM room with exactly one other (joined|invited) member.
|
||||
m.MatchResponse(t, res, m.MatchRoomSubscription(dmBobChris, m.MatchRoomAvatar(bob.AvatarURL)))
|
||||
})
|
||||
|
||||
t.Run("Bob's avatar change propagates", func(t *testing.T) {
|
||||
t.Log("Bob changes his avatar.")
|
||||
bobAvatar2 := uploadAvatar(bob, "bob2.png")
|
||||
bob.SetAvatar(t, bobAvatar2)
|
||||
|
||||
avatarChangeInDM := false
|
||||
avatarChangeInGroupDM := false
|
||||
t.Log("Alice syncs until she sees Bob's new avatar.")
|
||||
res = alice.SlidingSyncUntil(
|
||||
t,
|
||||
res.Pos,
|
||||
sync3.Request{},
|
||||
func(response *sync3.Response) error {
|
||||
if !avatarChangeInDM {
|
||||
err := m.MatchRoomSubscription(dmBob, m.MatchRoomAvatar(bob.AvatarURL))(response)
|
||||
if err == nil {
|
||||
avatarChangeInDM = true
|
||||
}
|
||||
}
|
||||
|
||||
if !avatarChangeInGroupDM {
|
||||
err := m.MatchRoomSubscription(dmBobChris, m.MatchRoomAvatar(bob.AvatarURL))(response)
|
||||
if err == nil {
|
||||
avatarChangeInGroupDM = true
|
||||
}
|
||||
}
|
||||
|
||||
if avatarChangeInDM && avatarChangeInGroupDM {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("still waiting: avatarChangeInDM=%t avatarChangeInGroupDM=%t", avatarChangeInDM, avatarChangeInGroupDM)
|
||||
},
|
||||
)
|
||||
|
||||
t.Log("Bob removes his avatar.")
|
||||
bob.SetAvatar(t, "")
|
||||
|
||||
avatarChangeInDM = false
|
||||
avatarChangeInGroupDM = false
|
||||
t.Log("Alice syncs until she sees Bob's avatars vanish.")
|
||||
res = alice.SlidingSyncUntil(
|
||||
t,
|
||||
res.Pos,
|
||||
sync3.Request{},
|
||||
func(response *sync3.Response) error {
|
||||
if !avatarChangeInDM {
|
||||
err := m.MatchRoomSubscription(dmBob, m.MatchRoomUnsetAvatar())(response)
|
||||
if err == nil {
|
||||
avatarChangeInDM = true
|
||||
} else {
|
||||
t.Log(err)
|
||||
}
|
||||
}
|
||||
|
||||
if !avatarChangeInGroupDM {
|
||||
err := m.MatchRoomSubscription(dmBobChris, m.MatchRoomUnsetAvatar())(response)
|
||||
if err == nil {
|
||||
avatarChangeInGroupDM = true
|
||||
} else {
|
||||
t.Log(err)
|
||||
}
|
||||
}
|
||||
|
||||
if avatarChangeInDM && avatarChangeInGroupDM {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("still waiting: avatarChangeInDM=%t avatarChangeInGroupDM=%t", avatarChangeInDM, avatarChangeInGroupDM)
|
||||
},
|
||||
)
|
||||
|
||||
})
|
||||
|
||||
t.Run("Explicit avatar propagates in non-DM room", func(t *testing.T) {
|
||||
t.Log("Alice sets an avatar for the public room.")
|
||||
publicAvatar := uploadAvatar(alice, "public.png")
|
||||
alice.SetState(t, public, "m.room.avatar", "", map[string]interface{}{
|
||||
"url": publicAvatar,
|
||||
})
|
||||
t.Log("Alice syncs until she sees that avatar.")
|
||||
res = alice.SlidingSyncUntil(
|
||||
t,
|
||||
res.Pos,
|
||||
sync3.Request{},
|
||||
m.MatchRoomSubscriptions(map[string][]m.RoomMatcher{
|
||||
public: {m.MatchRoomAvatar(publicAvatar)},
|
||||
}),
|
||||
)
|
||||
|
||||
t.Log("Alice changes the avatar for the public room.")
|
||||
publicAvatar2 := uploadAvatar(alice, "public2.png")
|
||||
alice.SetState(t, public, "m.room.avatar", "", map[string]interface{}{
|
||||
"url": publicAvatar2,
|
||||
})
|
||||
t.Log("Alice syncs until she sees that avatar.")
|
||||
res = alice.SlidingSyncUntil(
|
||||
t,
|
||||
res.Pos,
|
||||
sync3.Request{},
|
||||
m.MatchRoomSubscriptions(map[string][]m.RoomMatcher{
|
||||
public: {m.MatchRoomAvatar(publicAvatar2)},
|
||||
}),
|
||||
)
|
||||
|
||||
t.Log("Alice removes the avatar for the public room.")
|
||||
alice.SetState(t, public, "m.room.avatar", "", map[string]interface{}{})
|
||||
t.Log("Alice syncs until she sees that avatar vanish.")
|
||||
res = alice.SlidingSyncUntil(
|
||||
t,
|
||||
res.Pos,
|
||||
sync3.Request{},
|
||||
m.MatchRoomSubscriptions(map[string][]m.RoomMatcher{
|
||||
public: {m.MatchRoomUnsetAvatar()},
|
||||
}),
|
||||
)
|
||||
})
|
||||
|
||||
t.Run("Explicit avatar propagates in DM room", func(t *testing.T) {
|
||||
t.Log("Alice re-invites Chris to their DM.")
|
||||
alice.InviteRoom(t, dmChris, chris.UserID)
|
||||
|
||||
t.Log("Alice syncs until she sees her invitation to Chris.")
|
||||
res = alice.SlidingSyncUntilMembership(t, res.Pos, dmChris, chris, "invite")
|
||||
|
||||
t.Log("Alice should see the DM with Chris's avatar.")
|
||||
m.MatchResponse(t, res, m.MatchRoomSubscription(dmChris, m.MatchRoomAvatar(chris.AvatarURL)))
|
||||
|
||||
t.Log("Chris joins the room.")
|
||||
chris.JoinRoom(t, dmChris, nil)
|
||||
|
||||
t.Log("Alice syncs until she sees Chris's join.")
|
||||
res = alice.SlidingSyncUntilMembership(t, res.Pos, dmChris, chris, "join")
|
||||
|
||||
t.Log("Alice shouldn't see the DM's avatar change..")
|
||||
m.MatchResponse(t, res, m.MatchRoomSubscription(dmChris, m.MatchRoomUnchangedAvatar()))
|
||||
|
||||
t.Log("Chris gives their DM a bespoke avatar.")
|
||||
dmAvatar := uploadAvatar(chris, "dm.png")
|
||||
chris.SetState(t, dmChris, "m.room.avatar", "", map[string]interface{}{
|
||||
"url": dmAvatar,
|
||||
})
|
||||
|
||||
t.Log("Alice syncs until she sees that avatar.")
|
||||
alice.SlidingSyncUntil(t, res.Pos, sync3.Request{}, m.MatchRoomSubscription(dmChris, m.MatchRoomAvatar(dmAvatar)))
|
||||
|
||||
t.Log("Chris changes his global avatar, which adds a join event to the room.")
|
||||
chrisAvatar2 := uploadAvatar(chris, "chris2.png")
|
||||
chris.SetAvatar(t, chrisAvatar2)
|
||||
|
||||
t.Log("Alice syncs until she sees that join event.")
|
||||
res = alice.SlidingSyncUntilMembership(t, res.Pos, dmChris, chris, "join")
|
||||
|
||||
t.Log("Her response should have either no avatar change, or the same bespoke avatar.")
|
||||
// No change, ideally, but repeating the same avatar isn't _wrong_
|
||||
m.MatchResponse(t, res, m.MatchRoomSubscription(dmChris, func(r sync3.Room) error {
|
||||
noChangeErr := m.MatchRoomUnchangedAvatar()(r)
|
||||
sameBespokeAvatarErr := m.MatchRoomAvatar(dmAvatar)(r)
|
||||
if noChangeErr == nil || sameBespokeAvatarErr == nil {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("expected no change or the same bespoke avatar (%s), got '%s'", dmAvatar, r.AvatarChange)
|
||||
}))
|
||||
|
||||
t.Log("Chris updates the DM's avatar.")
|
||||
dmAvatar2 := uploadAvatar(chris, "dm2.png")
|
||||
chris.SetState(t, dmChris, "m.room.avatar", "", map[string]interface{}{
|
||||
"url": dmAvatar2,
|
||||
})
|
||||
|
||||
t.Log("Alice syncs until she sees that avatar.")
|
||||
res = alice.SlidingSyncUntil(t, res.Pos, sync3.Request{}, m.MatchRoomSubscription(dmChris, m.MatchRoomAvatar(dmAvatar2)))
|
||||
|
||||
t.Log("Chris removes the DM's avatar.")
|
||||
chris.SetState(t, dmChris, "m.room.avatar", "", map[string]interface{}{})
|
||||
|
||||
t.Log("Alice syncs until the DM avatar returns to Chris's most recent avatar.")
|
||||
res = alice.SlidingSyncUntil(t, res.Pos, sync3.Request{}, m.MatchRoomSubscription(dmChris, m.MatchRoomAvatar(chris.AvatarURL)))
|
||||
})
|
||||
|
||||
t.Run("Changing DM flag", func(t *testing.T) {
|
||||
t.Skip("TODO: unimplemented")
|
||||
t.Log("Alice clears the DM flag on Bob's room.")
|
||||
alice.SetGlobalAccountData(t, "m.direct", map[string]interface{}{
|
||||
"content": map[string][]string{
|
||||
bob.UserID: {}, // no dmBob here
|
||||
chris.UserID: {dmChris, dmBobChris},
|
||||
},
|
||||
})
|
||||
|
||||
t.Log("Alice syncs until she sees a new set of account data.")
|
||||
res = alice.SlidingSyncUntil(t, res.Pos, sync3.Request{
|
||||
Extensions: extensions.Request{
|
||||
AccountData: &extensions.AccountDataRequest{
|
||||
extensions.Core{Enabled: &boolTrue},
|
||||
},
|
||||
},
|
||||
}, func(response *sync3.Response) error {
|
||||
if response.Extensions.AccountData == nil {
|
||||
return fmt.Errorf("no account data yet")
|
||||
}
|
||||
if len(response.Extensions.AccountData.Global) == 0 {
|
||||
return fmt.Errorf("no global account data yet")
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
t.Log("The DM with Bob should no longer be a DM and should no longer have an avatar.")
|
||||
m.MatchResponse(t, res, m.MatchRoomSubscription(dmBob, func(r sync3.Room) error {
|
||||
if r.IsDM {
|
||||
return fmt.Errorf("dmBob is still a DM")
|
||||
}
|
||||
return m.MatchRoomUnsetAvatar()(r)
|
||||
}))
|
||||
|
||||
t.Log("Alice sets the DM flag on Bob's room.")
|
||||
alice.SetGlobalAccountData(t, "m.direct", map[string]interface{}{
|
||||
"content": map[string][]string{
|
||||
bob.UserID: {dmBob}, // dmBob reinstated
|
||||
chris.UserID: {dmChris, dmBobChris},
|
||||
},
|
||||
})
|
||||
|
||||
t.Log("Alice syncs until she sees a new set of account data.")
|
||||
res = alice.SlidingSyncUntil(t, res.Pos, sync3.Request{
|
||||
Extensions: extensions.Request{
|
||||
AccountData: &extensions.AccountDataRequest{
|
||||
extensions.Core{Enabled: &boolTrue},
|
||||
},
|
||||
},
|
||||
}, func(response *sync3.Response) error {
|
||||
if response.Extensions.AccountData == nil {
|
||||
return fmt.Errorf("no account data yet")
|
||||
}
|
||||
if len(response.Extensions.AccountData.Global) == 0 {
|
||||
return fmt.Errorf("no global account data yet")
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
t.Log("The room should have Bob's avatar again.")
|
||||
m.MatchResponse(t, res, m.MatchRoomSubscription(dmBob, func(r sync3.Room) error {
|
||||
if !r.IsDM {
|
||||
return fmt.Errorf("dmBob is still not a DM")
|
||||
}
|
||||
return m.MatchRoomAvatar(bob.AvatarURL)(r)
|
||||
}))
|
||||
|
||||
})
|
||||
|
||||
t.Run("See avatar when invited", func(t *testing.T) {
|
||||
t.Log("Chris invites Alice to a DM.")
|
||||
dmInvited := chris.CreateRoom(t, map[string]interface{}{
|
||||
"preset": "trusted_private_chat",
|
||||
"is_direct": true,
|
||||
"invite": []string{alice.UserID},
|
||||
})
|
||||
|
||||
t.Log("Alice syncs until she sees the invite.")
|
||||
res = alice.SlidingSyncUntilMembership(t, res.Pos, dmInvited, alice, "invite")
|
||||
|
||||
t.Log("The new room should use Chris's avatar.")
|
||||
m.MatchResponse(t, res, m.MatchRoomSubscription(dmInvited, m.MatchRoomAvatar(chris.AvatarURL)))
|
||||
})
|
||||
}
|
||||
|
@ -6,6 +6,7 @@ import (
|
||||
"net/url"
|
||||
"os"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
@ -13,6 +14,7 @@ import (
|
||||
|
||||
"github.com/matrix-org/sliding-sync/sync3"
|
||||
"github.com/matrix-org/sliding-sync/testutils/m"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
var (
|
||||
@ -159,13 +161,41 @@ func MatchRoomInviteState(events []Event, partial bool) m.RoomMatcher {
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
return fmt.Errorf("MatchRoomInviteState: want event %+v but it does not exist", want)
|
||||
return fmt.Errorf("MatchRoomInviteState: want event %+v but it does not exist or failed to pass equality checks", want)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// MatchGlobalAccountData builds a matcher which asserts that the account data in a sync
|
||||
// response matches the given `globals`, with any ordering.
|
||||
// If there is no account data extension in the response, the match fails.
|
||||
func MatchGlobalAccountData(globals []Event) m.RespMatcher {
|
||||
// sort want list by type
|
||||
sort.Slice(globals, func(i, j int) bool {
|
||||
return globals[i].Type < globals[j].Type
|
||||
})
|
||||
|
||||
return func(res *sync3.Response) error {
|
||||
if res.Extensions.AccountData == nil {
|
||||
return fmt.Errorf("MatchGlobalAccountData: no account_data extension")
|
||||
}
|
||||
if len(globals) != len(res.Extensions.AccountData.Global) {
|
||||
return fmt.Errorf("MatchGlobalAccountData: got %v global account data, want %v", len(res.Extensions.AccountData.Global), len(globals))
|
||||
}
|
||||
// sort the got list by type
|
||||
got := res.Extensions.AccountData.Global
|
||||
sort.Slice(got, func(i, j int) bool {
|
||||
return gjson.GetBytes(got[i], "type").Str < gjson.GetBytes(got[j], "type").Str
|
||||
})
|
||||
if err := eventsEqual(globals, got); err != nil {
|
||||
return fmt.Errorf("MatchGlobalAccountData: %s", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func registerNewUser(t *testing.T) *CSAPI {
|
||||
return registerNamedUser(t, "user")
|
||||
}
|
||||
|
@ -64,7 +64,22 @@ func TestRoomStateTransitions(t *testing.T) {
|
||||
m.MatchRoomHighlightCount(1),
|
||||
m.MatchRoomInitial(true),
|
||||
m.MatchRoomRequiredState(nil),
|
||||
// TODO m.MatchRoomInviteState(inviteStrippedState.InviteState.Events),
|
||||
m.MatchInviteCount(1),
|
||||
m.MatchJoinCount(1),
|
||||
MatchRoomInviteState([]Event{
|
||||
{
|
||||
Type: "m.room.create",
|
||||
StateKey: ptr(""),
|
||||
// no content as it includes the room version which we don't want to guess/hardcode
|
||||
},
|
||||
{
|
||||
Type: "m.room.join_rules",
|
||||
StateKey: ptr(""),
|
||||
Content: map[string]interface{}{
|
||||
"join_rule": "public",
|
||||
},
|
||||
},
|
||||
}, true),
|
||||
},
|
||||
joinRoomID: {},
|
||||
}),
|
||||
@ -105,6 +120,8 @@ func TestRoomStateTransitions(t *testing.T) {
|
||||
},
|
||||
}),
|
||||
m.MatchRoomInitial(true),
|
||||
m.MatchJoinCount(2),
|
||||
m.MatchInviteCount(0),
|
||||
m.MatchRoomHighlightCount(0),
|
||||
))
|
||||
}
|
||||
@ -216,17 +233,18 @@ func TestInviteRejection(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestInviteAcceptance(t *testing.T) {
|
||||
alice := registerNewUser(t)
|
||||
bob := registerNewUser(t)
|
||||
alice := registerNamedUser(t, "alice")
|
||||
bob := registerNamedUser(t, "bob")
|
||||
|
||||
// ensure that invite state correctly propagates. One room will already be in 'invite' state
|
||||
// prior to the first proxy sync, whereas the 2nd will transition.
|
||||
t.Logf("Alice creates two rooms and invites Bob to the first.")
|
||||
firstInviteRoomID := alice.CreateRoom(t, map[string]interface{}{"preset": "private_chat", "name": "First"})
|
||||
alice.InviteRoom(t, firstInviteRoomID, bob.UserID)
|
||||
secondInviteRoomID := alice.CreateRoom(t, map[string]interface{}{"preset": "private_chat", "name": "Second"})
|
||||
t.Logf("first %s second %s", firstInviteRoomID, secondInviteRoomID)
|
||||
|
||||
// sync as bob, we should see 1 invite
|
||||
t.Log("Sync as Bob, requesting invites only. He should see 1 invite")
|
||||
res := bob.SlidingSync(t, sync3.Request{
|
||||
Lists: map[string]sync3.RequestList{
|
||||
"a": {
|
||||
@ -256,10 +274,12 @@ func TestInviteAcceptance(t *testing.T) {
|
||||
},
|
||||
}))
|
||||
|
||||
// now invite bob
|
||||
t.Log("Alice invites bob to room 2.")
|
||||
alice.InviteRoom(t, secondInviteRoomID, bob.UserID)
|
||||
t.Log("Alice syncs until she sees Bob's invite.")
|
||||
alice.SlidingSyncUntilMembership(t, "", secondInviteRoomID, bob, "invite")
|
||||
|
||||
t.Log("Bob syncs. He should see the invite to room 2 as well.")
|
||||
res = bob.SlidingSync(t, sync3.Request{
|
||||
Lists: map[string]sync3.RequestList{
|
||||
"a": {
|
||||
@ -287,13 +307,16 @@ func TestInviteAcceptance(t *testing.T) {
|
||||
},
|
||||
}))
|
||||
|
||||
// now accept the invites
|
||||
t.Log("Bob accept the invites.")
|
||||
bob.JoinRoom(t, firstInviteRoomID, nil)
|
||||
bob.JoinRoom(t, secondInviteRoomID, nil)
|
||||
|
||||
t.Log("Alice syncs until she sees Bob join room 1.")
|
||||
alice.SlidingSyncUntilMembership(t, "", firstInviteRoomID, bob, "join")
|
||||
t.Log("Alice syncs until she sees Bob join room 2.")
|
||||
alice.SlidingSyncUntilMembership(t, "", secondInviteRoomID, bob, "join")
|
||||
|
||||
// the list should be purged
|
||||
t.Log("Bob does an incremental sync")
|
||||
res = bob.SlidingSync(t, sync3.Request{
|
||||
Lists: map[string]sync3.RequestList{
|
||||
"a": {
|
||||
@ -301,12 +324,13 @@ func TestInviteAcceptance(t *testing.T) {
|
||||
},
|
||||
},
|
||||
}, WithPos(res.Pos))
|
||||
t.Log("Both of his invites should be purged.")
|
||||
m.MatchResponse(t, res, m.MatchList("a", m.MatchV3Count(0), m.MatchV3Ops(
|
||||
m.MatchV3DeleteOp(1),
|
||||
m.MatchV3DeleteOp(0),
|
||||
)))
|
||||
|
||||
// fresh sync -> no invites
|
||||
t.Log("Bob makes a fresh sliding sync request.")
|
||||
res = bob.SlidingSync(t, sync3.Request{
|
||||
Lists: map[string]sync3.RequestList{
|
||||
"a": {
|
||||
@ -317,6 +341,7 @@ func TestInviteAcceptance(t *testing.T) {
|
||||
},
|
||||
},
|
||||
})
|
||||
t.Log("He should see no invites.")
|
||||
m.MatchResponse(t, res, m.MatchNoV3Ops(), m.MatchRoomSubscriptionsStrict(nil), m.MatchList("a", m.MatchV3Count(0)))
|
||||
}
|
||||
|
||||
@ -467,7 +492,7 @@ func TestMemberCounts(t *testing.T) {
|
||||
m.MatchResponse(t, res, m.MatchRoomSubscriptionsStrict(map[string][]m.RoomMatcher{
|
||||
secondRoomID: {
|
||||
m.MatchRoomInitial(false),
|
||||
m.MatchInviteCount(0),
|
||||
m.MatchNoInviteCount(),
|
||||
m.MatchJoinCount(0), // omitempty
|
||||
},
|
||||
}))
|
||||
@ -486,7 +511,7 @@ func TestMemberCounts(t *testing.T) {
|
||||
m.MatchResponse(t, res, m.MatchRoomSubscriptionsStrict(map[string][]m.RoomMatcher{
|
||||
secondRoomID: {
|
||||
m.MatchRoomInitial(false),
|
||||
m.MatchInviteCount(0),
|
||||
m.MatchNoInviteCount(),
|
||||
m.MatchJoinCount(2),
|
||||
},
|
||||
}))
|
||||
|
@ -1,10 +1,13 @@
|
||||
package syncv3_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/sliding-sync/sync3"
|
||||
"github.com/matrix-org/sliding-sync/testutils/m"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestNumLive(t *testing.T) {
|
||||
@ -126,3 +129,70 @@ func TestNumLive(t *testing.T) {
|
||||
},
|
||||
}))
|
||||
}
|
||||
|
||||
// Test that if you constantly change req params, we still see live traffic. It does this by:
|
||||
// - Creating 11 rooms.
|
||||
// - Hitting /sync with a range [0,1] then [0,2] then [0,3]. Each time this causes a new room to be returned.
|
||||
// - Interleaving each /sync request with genuine events sent into a room.
|
||||
// - ensuring we see the genuine events by the time we finish.
|
||||
func TestReqParamStarvation(t *testing.T) {
|
||||
alice := registerNewUser(t)
|
||||
bob := registerNewUser(t)
|
||||
roomID := alice.CreateRoom(t, map[string]interface{}{
|
||||
"preset": "public_chat",
|
||||
})
|
||||
numOtherRooms := 10
|
||||
for i := 0; i < numOtherRooms; i++ {
|
||||
bob.CreateRoom(t, map[string]interface{}{
|
||||
"preset": "public_chat",
|
||||
})
|
||||
}
|
||||
bob.JoinRoom(t, roomID, nil)
|
||||
res := bob.SlidingSyncUntilMembership(t, "", roomID, bob, "join")
|
||||
|
||||
wantEventIDs := make(map[string]bool)
|
||||
for i := 0; i < numOtherRooms; i++ {
|
||||
res = bob.SlidingSync(t, sync3.Request{
|
||||
Lists: map[string]sync3.RequestList{
|
||||
"a": {
|
||||
Ranges: sync3.SliceRanges{{0, int64(i)}}, // [0,0], [0,1], ... [0,9]
|
||||
},
|
||||
},
|
||||
}, WithPos(res.Pos))
|
||||
|
||||
// mark off any event we see in wantEventIDs
|
||||
for _, r := range res.Rooms {
|
||||
for _, ev := range r.Timeline {
|
||||
gotEventID := gjson.GetBytes(ev, "event_id").Str
|
||||
wantEventIDs[gotEventID] = false
|
||||
}
|
||||
}
|
||||
|
||||
// send an event in the first few syncs to add to wantEventIDs
|
||||
// We do this for the first few /syncs and don't dictate which response they should arrive
|
||||
// in, as we do not know and cannot force the proxy to deliver the event in a particular response.
|
||||
if i < 3 {
|
||||
eventID := alice.SendEventSynced(t, roomID, Event{
|
||||
Type: "m.room.message",
|
||||
Content: map[string]interface{}{
|
||||
"msgtype": "m.text",
|
||||
"body": fmt.Sprintf("msg %d", i),
|
||||
},
|
||||
})
|
||||
wantEventIDs[eventID] = true
|
||||
}
|
||||
|
||||
// it's possible the proxy won't see this event before the next /sync
|
||||
// and that is the reason why we don't send it, as opposed to starvation.
|
||||
// To try to counter this, sleep a bit. This is why we sleep on every cycle and
|
||||
// why we send the events early on.
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
}
|
||||
|
||||
// at this point wantEventIDs should all have false values if we got the events
|
||||
for evID, unseen := range wantEventIDs {
|
||||
if unseen {
|
||||
t.Errorf("failed to see event %v", evID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -100,15 +100,16 @@ func TestSecurityLiveStreamEventLeftLeak(t *testing.T) {
|
||||
})
|
||||
|
||||
// check Alice sees both events
|
||||
assertEventsEqual(t, []Event{
|
||||
{
|
||||
Type: "m.room.member",
|
||||
StateKey: ptr(eve.UserID),
|
||||
Content: map[string]interface{}{
|
||||
"membership": "leave",
|
||||
},
|
||||
Sender: alice.UserID,
|
||||
kickEvent := Event{
|
||||
Type: "m.room.member",
|
||||
StateKey: ptr(eve.UserID),
|
||||
Content: map[string]interface{}{
|
||||
"membership": "leave",
|
||||
},
|
||||
Sender: alice.UserID,
|
||||
}
|
||||
assertEventsEqual(t, []Event{
|
||||
kickEvent,
|
||||
{
|
||||
Type: "m.room.name",
|
||||
StateKey: ptr(""),
|
||||
@ -120,7 +121,6 @@ func TestSecurityLiveStreamEventLeftLeak(t *testing.T) {
|
||||
},
|
||||
}, timeline)
|
||||
|
||||
kickEvent := timeline[0]
|
||||
// Ensure Eve doesn't see this message in the timeline, name calc or required_state
|
||||
eveRes = eve.SlidingSync(t, sync3.Request{
|
||||
Lists: map[string]sync3.RequestList{
|
||||
@ -140,7 +140,7 @@ func TestSecurityLiveStreamEventLeftLeak(t *testing.T) {
|
||||
}, WithPos(eveRes.Pos))
|
||||
// the room is deleted from eve's point of view and she sees up to and including her kick event
|
||||
m.MatchResponse(t, eveRes, m.MatchList("a", m.MatchV3Count(0), m.MatchV3Ops(m.MatchV3DeleteOp(0))), m.MatchRoomSubscription(
|
||||
roomID, m.MatchRoomName(""), m.MatchRoomRequiredState(nil), m.MatchRoomTimelineMostRecent(1, []json.RawMessage{kickEvent}),
|
||||
roomID, m.MatchRoomName(""), m.MatchRoomRequiredState(nil), MatchRoomTimelineMostRecent(1, []Event{kickEvent}),
|
||||
))
|
||||
}
|
||||
|
||||
|
180
tests-e2e/timestamp_test.go
Normal file
180
tests-e2e/timestamp_test.go
Normal file
@ -0,0 +1,180 @@
|
||||
package syncv3_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/sliding-sync/sync3"
|
||||
"github.com/matrix-org/sliding-sync/testutils/m"
|
||||
)
|
||||
|
||||
func TestTimestamp(t *testing.T) {
|
||||
alice := registerNewUser(t)
|
||||
bob := registerNewUser(t)
|
||||
charlie := registerNewUser(t)
|
||||
|
||||
roomID := alice.CreateRoom(t, map[string]interface{}{
|
||||
"preset": "public_chat",
|
||||
})
|
||||
|
||||
var gotTs, expectedTs uint64
|
||||
|
||||
lists := map[string]sync3.RequestList{
|
||||
"myFirstList": {
|
||||
Ranges: [][2]int64{{0, 1}},
|
||||
RoomSubscription: sync3.RoomSubscription{
|
||||
TimelineLimit: 10,
|
||||
},
|
||||
BumpEventTypes: []string{"m.room.message"}, // only messages bump the timestamp
|
||||
},
|
||||
"mySecondList": {
|
||||
Ranges: [][2]int64{{0, 1}},
|
||||
RoomSubscription: sync3.RoomSubscription{
|
||||
TimelineLimit: 10,
|
||||
},
|
||||
BumpEventTypes: []string{"m.reaction"}, // only reactions bump the timestamp
|
||||
},
|
||||
}
|
||||
|
||||
// Init sync to get the latest timestamp
|
||||
resAlice := alice.SlidingSync(t, sync3.Request{
|
||||
RoomSubscriptions: map[string]sync3.RoomSubscription{
|
||||
roomID: {
|
||||
TimelineLimit: 10,
|
||||
},
|
||||
},
|
||||
})
|
||||
m.MatchResponse(t, resAlice, m.MatchRoomSubscription(roomID, m.MatchRoomInitial(true)))
|
||||
timestampBeforeBobJoined := resAlice.Rooms[roomID].Timestamp
|
||||
|
||||
bob.JoinRoom(t, roomID, nil)
|
||||
resAlice = alice.SlidingSyncUntilMembership(t, resAlice.Pos, roomID, bob, "join")
|
||||
resBob := bob.SlidingSync(t, sync3.Request{
|
||||
Lists: lists,
|
||||
})
|
||||
|
||||
// Bob should see a different timestamp than alice, as he just joined
|
||||
gotTs = resBob.Rooms[roomID].Timestamp
|
||||
expectedTs = resAlice.Rooms[roomID].Timestamp
|
||||
if gotTs != expectedTs {
|
||||
t.Fatalf("expected timestamp to be equal, but got: %v vs %v", gotTs, expectedTs)
|
||||
}
|
||||
// ... the timestamp should still differ from what Alice received before the join
|
||||
if gotTs == timestampBeforeBobJoined {
|
||||
t.Fatalf("expected timestamp to differ, but got: %v vs %v", gotTs, timestampBeforeBobJoined)
|
||||
}
|
||||
|
||||
// Send an event which should NOT bump Bobs timestamp, because it is not listed it
|
||||
// any BumpEventTypes
|
||||
emptyStateKey := ""
|
||||
eventID := alice.SendEventSynced(t, roomID, Event{
|
||||
Type: "m.room.topic",
|
||||
StateKey: &emptyStateKey,
|
||||
Content: map[string]interface{}{
|
||||
"topic": "random topic",
|
||||
},
|
||||
})
|
||||
time.Sleep(time.Millisecond)
|
||||
|
||||
resBob = bob.SlidingSyncUntilEventID(t, resBob.Pos, roomID, eventID)
|
||||
gotTs = resBob.Rooms[roomID].Timestamp
|
||||
expectedTs = resAlice.Rooms[roomID].Timestamp
|
||||
if gotTs != expectedTs {
|
||||
t.Fatalf("expected timestamps to be the same, but they aren't: %v vs %v", gotTs, expectedTs)
|
||||
}
|
||||
expectedTs = gotTs
|
||||
|
||||
// Now send a message which bumps the timestamp in myFirstList
|
||||
eventID = alice.SendEventSynced(t, roomID, Event{
|
||||
Type: "m.room.message",
|
||||
Content: map[string]interface{}{
|
||||
"msgtype": "m.text",
|
||||
"body": "Hello, world!",
|
||||
},
|
||||
})
|
||||
time.Sleep(time.Millisecond)
|
||||
|
||||
resBob = bob.SlidingSyncUntilEventID(t, resBob.Pos, roomID, eventID)
|
||||
gotTs = resBob.Rooms[roomID].Timestamp
|
||||
if expectedTs == gotTs {
|
||||
t.Fatalf("expected timestamps to be different, but they aren't: %v vs %v", gotTs, expectedTs)
|
||||
}
|
||||
expectedTs = gotTs
|
||||
|
||||
// Now send a message which bumps the timestamp in mySecondList
|
||||
eventID = alice.SendEventSynced(t, roomID, Event{
|
||||
Type: "m.reaction",
|
||||
Content: map[string]interface{}{
|
||||
"m.relates.to": map[string]interface{}{
|
||||
"event_id": eventID,
|
||||
"key": "✅",
|
||||
"rel_type": "m.annotation",
|
||||
},
|
||||
},
|
||||
})
|
||||
time.Sleep(time.Millisecond)
|
||||
|
||||
resBob = bob.SlidingSyncUntilEventID(t, resBob.Pos, roomID, eventID)
|
||||
bobTimestampReaction := resBob.Rooms[roomID].Timestamp
|
||||
if bobTimestampReaction == expectedTs {
|
||||
t.Fatalf("expected timestamps to be different, but they aren't: %v vs %v", expectedTs, bobTimestampReaction)
|
||||
}
|
||||
expectedTs = bobTimestampReaction
|
||||
|
||||
// Send another event which should NOT bump Bobs timestamp
|
||||
eventID = alice.SendEventSynced(t, roomID, Event{
|
||||
Type: "m.room.name",
|
||||
StateKey: &emptyStateKey,
|
||||
Content: map[string]interface{}{
|
||||
"name": "random name",
|
||||
},
|
||||
})
|
||||
time.Sleep(time.Millisecond)
|
||||
|
||||
resBob = bob.SlidingSyncUntilEventID(t, resBob.Pos, roomID, eventID)
|
||||
gotTs = resBob.Rooms[roomID].Timestamp
|
||||
if gotTs != expectedTs {
|
||||
t.Fatalf("expected timestamps to be the same, but they aren't: %v, expected %v", gotTs, expectedTs)
|
||||
}
|
||||
|
||||
// Bob makes an initial sync again, he should still see the m.reaction timestamp
|
||||
resBob = bob.SlidingSync(t, sync3.Request{
|
||||
Lists: lists,
|
||||
})
|
||||
|
||||
gotTs = resBob.Rooms[roomID].Timestamp
|
||||
expectedTs = bobTimestampReaction
|
||||
if gotTs != expectedTs {
|
||||
t.Fatalf("initial sync contains wrong timestamp: %d, expected %d", gotTs, expectedTs)
|
||||
}
|
||||
|
||||
// Charlie joins the room
|
||||
charlie.JoinRoom(t, roomID, nil)
|
||||
resAlice = alice.SlidingSyncUntilMembership(t, resAlice.Pos, roomID, charlie, "join")
|
||||
|
||||
resCharlie := charlie.SlidingSync(t, sync3.Request{
|
||||
Lists: lists,
|
||||
})
|
||||
|
||||
// Charlie just joined so should see the same timestamp as Alice, even if
|
||||
// Charlie has the same bumpEvents as Bob, we don't leak those timestamps.
|
||||
gotTs = resCharlie.Rooms[roomID].Timestamp
|
||||
expectedTs = resAlice.Rooms[roomID].Timestamp
|
||||
if gotTs != expectedTs {
|
||||
t.Fatalf("Charlie should see the timestamp they joined, but didn't: %d, expected %d", gotTs, expectedTs)
|
||||
}
|
||||
|
||||
// Initial sync without bump types should see the most recent timestamp
|
||||
resAlice = alice.SlidingSync(t, sync3.Request{
|
||||
RoomSubscriptions: map[string]sync3.RoomSubscription{
|
||||
roomID: {
|
||||
TimelineLimit: 10,
|
||||
},
|
||||
},
|
||||
})
|
||||
// expected TS stays the same, so the join of Charlie
|
||||
gotTs = resAlice.Rooms[roomID].Timestamp
|
||||
if gotTs != expectedTs {
|
||||
t.Fatalf("Alice should see the timestamp of Charlie joining, but didn't: %d, expected %d", gotTs, expectedTs)
|
||||
}
|
||||
}
|
@ -32,33 +32,10 @@ func TestTransactionIDsAppear(t *testing.T) {
|
||||
|
||||
// we cannot use MatchTimeline here because the Unsigned section contains 'age' which is not
|
||||
// deterministic and MatchTimeline does not do partial matches.
|
||||
matchTransactionID := func(eventID, txnID string) m.RoomMatcher {
|
||||
return func(r sync3.Room) error {
|
||||
for _, ev := range r.Timeline {
|
||||
var got Event
|
||||
if err := json.Unmarshal(ev, &got); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal event: %s", err)
|
||||
}
|
||||
if got.ID != eventID {
|
||||
continue
|
||||
}
|
||||
tx, ok := got.Unsigned["transaction_id"]
|
||||
if !ok {
|
||||
return fmt.Errorf("unsigned block for %s has no transaction_id", eventID)
|
||||
}
|
||||
gotTxnID := tx.(string)
|
||||
if gotTxnID != txnID {
|
||||
return fmt.Errorf("wrong transaction_id, got %s want %s", gotTxnID, txnID)
|
||||
}
|
||||
t.Logf("%s has txn ID %s", eventID, gotTxnID)
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("not found event %s", eventID)
|
||||
}
|
||||
}
|
||||
|
||||
m.MatchResponse(t, res, m.MatchRoomSubscriptionsStrict(map[string][]m.RoomMatcher{
|
||||
roomID: {
|
||||
matchTransactionID(eventID, "foobar"),
|
||||
matchTransactionID(t, eventID, "foobar"),
|
||||
},
|
||||
}))
|
||||
|
||||
@ -74,8 +51,85 @@ func TestTransactionIDsAppear(t *testing.T) {
|
||||
res = client.SlidingSyncUntilEvent(t, res.Pos, sync3.Request{}, roomID, Event{ID: eventID})
|
||||
m.MatchResponse(t, res, m.MatchRoomSubscriptionsStrict(map[string][]m.RoomMatcher{
|
||||
roomID: {
|
||||
matchTransactionID(eventID, "foobar2"),
|
||||
matchTransactionID(t, eventID, "foobar2"),
|
||||
},
|
||||
}))
|
||||
|
||||
}
|
||||
|
||||
// This test has 1 poller expecting a txn ID and 10 others that won't see one.
|
||||
// We test that sending device sees a txnID. Without the TxnIDWaiter logic in place,
|
||||
// this test is likely (but not guaranteed) to fail.
|
||||
func TestTransactionIDsAppearWithMultiplePollers(t *testing.T) {
|
||||
alice := registerNamedUser(t, "alice")
|
||||
|
||||
t.Log("Alice creates a room and syncs until she sees it.")
|
||||
roomID := alice.CreateRoom(t, map[string]interface{}{})
|
||||
res := alice.SlidingSync(t, sync3.Request{
|
||||
Lists: map[string]sync3.RequestList{
|
||||
"a": {
|
||||
RoomSubscription: sync3.RoomSubscription{
|
||||
TimelineLimit: 10,
|
||||
},
|
||||
Ranges: sync3.SliceRanges{{0, 20}},
|
||||
},
|
||||
},
|
||||
})
|
||||
m.MatchResponse(t, res, m.MatchRoomSubscription(roomID))
|
||||
|
||||
t.Log("Alice makes other devices and starts them syncing.")
|
||||
for i := 0; i < 10; i++ {
|
||||
device := *alice
|
||||
device.Login(t, "password", fmt.Sprintf("device_%d", i))
|
||||
device.SlidingSync(t, sync3.Request{
|
||||
Lists: map[string]sync3.RequestList{
|
||||
"a": {
|
||||
RoomSubscription: sync3.RoomSubscription{
|
||||
TimelineLimit: 10,
|
||||
},
|
||||
Ranges: sync3.SliceRanges{{0, 20}},
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
t.Log("Alice sends a message with a transaction ID.")
|
||||
const txnID = "foobar"
|
||||
sendRes := alice.MustDoFunc(t, "PUT", []string{"_matrix", "client", "v3", "rooms", roomID, "send", "m.room.message", txnID},
|
||||
WithJSONBody(t, map[string]interface{}{
|
||||
"msgtype": "m.text",
|
||||
"body": "Hello, world!",
|
||||
}))
|
||||
body := ParseJSON(t, sendRes)
|
||||
eventID := GetJSONFieldStr(t, body, "event_id")
|
||||
|
||||
t.Log("Alice syncs on her main devices until she sees her message.")
|
||||
res = alice.SlidingSyncUntilEventID(t, res.Pos, roomID, eventID)
|
||||
|
||||
m.MatchResponse(t, res, m.MatchRoomSubscription(roomID, matchTransactionID(t, eventID, txnID)))
|
||||
}
|
||||
|
||||
func matchTransactionID(t *testing.T, eventID, txnID string) m.RoomMatcher {
|
||||
return func(r sync3.Room) error {
|
||||
for _, ev := range r.Timeline {
|
||||
var got Event
|
||||
if err := json.Unmarshal(ev, &got); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal event: %s", err)
|
||||
}
|
||||
if got.ID != eventID {
|
||||
continue
|
||||
}
|
||||
tx, ok := got.Unsigned["transaction_id"]
|
||||
if !ok {
|
||||
return fmt.Errorf("unsigned block for %s has no transaction_id", eventID)
|
||||
}
|
||||
gotTxnID := tx.(string)
|
||||
if gotTxnID != txnID {
|
||||
return fmt.Errorf("wrong transaction_id, got %s want %s", gotTxnID, txnID)
|
||||
}
|
||||
t.Logf("%s has txn ID %s", eventID, gotTxnID)
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("not found event %s", eventID)
|
||||
}
|
||||
}
|
||||
|
@ -622,6 +622,20 @@ func TestSessionExpiryOnBufferFill(t *testing.T) {
|
||||
if gjson.ParseBytes(body).Get("errcode").Str != "M_UNKNOWN_POS" {
|
||||
t.Errorf("got %v want errcode=M_UNKNOWN_POS", string(body))
|
||||
}
|
||||
|
||||
// make sure we can sync from fresh (regression for when we deadlocked after this point)
|
||||
res = v3.mustDoV3Request(t, aliceToken, sync3.Request{
|
||||
RoomSubscriptions: map[string]sync3.RoomSubscription{
|
||||
roomID: {
|
||||
TimelineLimit: 1,
|
||||
},
|
||||
},
|
||||
})
|
||||
m.MatchResponse(t, res, m.MatchRoomSubscriptionsStrict(map[string][]m.RoomMatcher{
|
||||
roomID: {
|
||||
m.MatchJoinCount(1),
|
||||
},
|
||||
}))
|
||||
}
|
||||
|
||||
func TestExpiredAccessToken(t *testing.T) {
|
||||
|
103
tests-integration/db_test.go
Normal file
103
tests-integration/db_test.go
Normal file
@ -0,0 +1,103 @@
|
||||
package syncv3
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
syncv3 "github.com/matrix-org/sliding-sync"
|
||||
"github.com/matrix-org/sliding-sync/sync2"
|
||||
"github.com/matrix-org/sliding-sync/sync3"
|
||||
"github.com/matrix-org/sliding-sync/testutils"
|
||||
"github.com/matrix-org/sliding-sync/testutils/m"
|
||||
)
|
||||
|
||||
// Test that the proxy works fine with low max conns. Low max conns can be a problem
|
||||
// if a request A needs 2 conns to respond and that blocks forward progress on the server,
|
||||
// and the request can only obtain 1 conn.
|
||||
func TestMaxDBConns(t *testing.T) {
|
||||
pqString := testutils.PrepareDBConnectionString()
|
||||
// setup code
|
||||
v2 := runTestV2Server(t)
|
||||
opts := syncv3.Opts{
|
||||
DBMaxConns: 1,
|
||||
}
|
||||
v3 := runTestServer(t, v2, pqString, opts)
|
||||
defer v2.close()
|
||||
defer v3.close()
|
||||
|
||||
testMaxDBConns := func() {
|
||||
// make N users and drip feed some events, make sure they are all seen
|
||||
numUsers := 5
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numUsers)
|
||||
for i := 0; i < numUsers; i++ {
|
||||
go func(n int) {
|
||||
defer wg.Done()
|
||||
userID := fmt.Sprintf("@maxconns_%d:localhost", n)
|
||||
token := fmt.Sprintf("maxconns_%d", n)
|
||||
roomID := fmt.Sprintf("!maxconns_%d", n)
|
||||
v2.addAccount(t, userID, token)
|
||||
v2.queueResponse(userID, sync2.SyncResponse{
|
||||
Rooms: sync2.SyncRoomsResponse{
|
||||
Join: v2JoinTimeline(roomEvents{
|
||||
roomID: roomID,
|
||||
state: createRoomState(t, userID, time.Now()),
|
||||
}),
|
||||
},
|
||||
})
|
||||
// initial sync
|
||||
res := v3.mustDoV3Request(t, token, sync3.Request{
|
||||
Lists: map[string]sync3.RequestList{"a": {
|
||||
Ranges: sync3.SliceRanges{
|
||||
[2]int64{0, 1},
|
||||
},
|
||||
RoomSubscription: sync3.RoomSubscription{
|
||||
TimelineLimit: 1,
|
||||
},
|
||||
}},
|
||||
})
|
||||
t.Logf("user %s has done an initial /sync OK", userID)
|
||||
m.MatchResponse(t, res, m.MatchList("a", m.MatchV3Count(1), m.MatchV3Ops(
|
||||
m.MatchV3SyncOp(0, 0, []string{roomID}),
|
||||
)), m.MatchRoomSubscriptionsStrict(map[string][]m.RoomMatcher{
|
||||
roomID: {
|
||||
m.MatchJoinCount(1),
|
||||
},
|
||||
}))
|
||||
// drip feed and get update
|
||||
dripMsg := testutils.NewEvent(t, "m.room.message", userID, map[string]interface{}{
|
||||
"msgtype": "m.text",
|
||||
"body": "drip drip",
|
||||
})
|
||||
v2.queueResponse(userID, sync2.SyncResponse{
|
||||
Rooms: sync2.SyncRoomsResponse{
|
||||
Join: v2JoinTimeline(roomEvents{
|
||||
roomID: roomID,
|
||||
events: []json.RawMessage{
|
||||
dripMsg,
|
||||
},
|
||||
}),
|
||||
},
|
||||
})
|
||||
t.Logf("user %s has queued the drip", userID)
|
||||
v2.waitUntilEmpty(t, userID)
|
||||
t.Logf("user %s poller has received the drip", userID)
|
||||
res = v3.mustDoV3RequestWithPos(t, token, res.Pos, sync3.Request{})
|
||||
m.MatchResponse(t, res, m.MatchRoomSubscriptionsStrict(map[string][]m.RoomMatcher{
|
||||
roomID: {
|
||||
m.MatchRoomTimelineMostRecent(1, []json.RawMessage{dripMsg}),
|
||||
},
|
||||
}))
|
||||
t.Logf("user %s has received the drip", userID)
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
testMaxDBConns()
|
||||
v3.restart(t, v2, pqString, opts)
|
||||
testMaxDBConns()
|
||||
}
|
@ -604,3 +604,171 @@ func TestExtensionLateEnable(t *testing.T) {
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func TestTypingMultiplePoller(t *testing.T) {
|
||||
pqString := testutils.PrepareDBConnectionString()
|
||||
// setup code
|
||||
v2 := runTestV2Server(t)
|
||||
v3 := runTestServer(t, v2, pqString)
|
||||
defer v2.close()
|
||||
defer v3.close()
|
||||
|
||||
roomA := "!a:localhost"
|
||||
|
||||
v2.addAccountWithDeviceID(alice, "first", aliceToken)
|
||||
v2.addAccountWithDeviceID(bob, "second", bobToken)
|
||||
|
||||
// Create the room state and join with Bob
|
||||
roomState := createRoomState(t, alice, time.Now())
|
||||
joinEv := testutils.NewStateEvent(t, "m.room.member", bob, alice, map[string]interface{}{
|
||||
"membership": "join",
|
||||
})
|
||||
|
||||
// Queue the response with Alice typing
|
||||
v2.queueResponse(aliceToken, sync2.SyncResponse{
|
||||
Rooms: sync2.SyncRoomsResponse{
|
||||
Join: map[string]sync2.SyncV2JoinResponse{
|
||||
roomA: {
|
||||
State: sync2.EventsResponse{
|
||||
Events: roomState,
|
||||
},
|
||||
Timeline: sync2.TimelineResponse{
|
||||
Events: []json.RawMessage{joinEv},
|
||||
},
|
||||
Ephemeral: sync2.EventsResponse{
|
||||
Events: []json.RawMessage{json.RawMessage(`{"type":"m.typing","content":{"user_ids":["@alice:localhost"]}}`)},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
// Queue another response for Bob with Bob typing.
|
||||
v2.queueResponse(bobToken, sync2.SyncResponse{
|
||||
Rooms: sync2.SyncRoomsResponse{
|
||||
Join: map[string]sync2.SyncV2JoinResponse{
|
||||
roomA: {
|
||||
State: sync2.EventsResponse{
|
||||
Events: roomState,
|
||||
},
|
||||
Timeline: sync2.TimelineResponse{
|
||||
Events: []json.RawMessage{joinEv},
|
||||
},
|
||||
Ephemeral: sync2.EventsResponse{
|
||||
Events: []json.RawMessage{json.RawMessage(`{"type":"m.typing","content":{"user_ids":["@bob:localhost"]}}`)}},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
// Start the pollers. Since Alice's poller is started first, the poller is in
|
||||
// charge of handling typing notifications for roomA.
|
||||
aliceRes := v3.mustDoV3Request(t, aliceToken, sync3.Request{})
|
||||
bobRes := v3.mustDoV3Request(t, bobToken, sync3.Request{})
|
||||
|
||||
// Get the response from v3
|
||||
for _, token := range []string{aliceToken, bobToken} {
|
||||
pos := aliceRes.Pos
|
||||
if token == bobToken {
|
||||
pos = bobRes.Pos
|
||||
}
|
||||
|
||||
res := v3.mustDoV3RequestWithPos(t, token, pos, sync3.Request{
|
||||
Extensions: extensions.Request{
|
||||
Typing: &extensions.TypingRequest{
|
||||
Core: extensions.Core{Enabled: &boolTrue},
|
||||
},
|
||||
},
|
||||
Lists: map[string]sync3.RequestList{"a": {
|
||||
Ranges: sync3.SliceRanges{
|
||||
[2]int64{0, 1},
|
||||
},
|
||||
Sort: []string{sync3.SortByRecency},
|
||||
RoomSubscription: sync3.RoomSubscription{
|
||||
TimelineLimit: 0,
|
||||
},
|
||||
}},
|
||||
})
|
||||
// We expect only Alice typing, as only Alice Poller is "allowed"
|
||||
// to update typing notifications.
|
||||
m.MatchResponse(t, res, m.MatchTyping(roomA, []string{alice}))
|
||||
if token == bobToken {
|
||||
bobRes = res
|
||||
}
|
||||
if token == aliceToken {
|
||||
aliceRes = res
|
||||
}
|
||||
}
|
||||
|
||||
// Queue the response with Bob typing
|
||||
v2.queueResponse(aliceToken, sync2.SyncResponse{
|
||||
Rooms: sync2.SyncRoomsResponse{
|
||||
Join: map[string]sync2.SyncV2JoinResponse{
|
||||
roomA: {
|
||||
State: sync2.EventsResponse{
|
||||
Events: roomState,
|
||||
},
|
||||
Timeline: sync2.TimelineResponse{
|
||||
Events: []json.RawMessage{joinEv},
|
||||
},
|
||||
Ephemeral: sync2.EventsResponse{
|
||||
Events: []json.RawMessage{json.RawMessage(`{"type":"m.typing","content":{"user_ids":["@bob:localhost"]}}`)},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
// Queue another response for Bob with Charlie typing.
|
||||
// Since Alice's poller is in charge of handling typing notifications, this shouldn't
|
||||
// show up on future responses.
|
||||
v2.queueResponse(bobToken, sync2.SyncResponse{
|
||||
Rooms: sync2.SyncRoomsResponse{
|
||||
Join: map[string]sync2.SyncV2JoinResponse{
|
||||
roomA: {
|
||||
State: sync2.EventsResponse{
|
||||
Events: roomState,
|
||||
},
|
||||
Timeline: sync2.TimelineResponse{
|
||||
Events: []json.RawMessage{joinEv},
|
||||
},
|
||||
Ephemeral: sync2.EventsResponse{
|
||||
Events: []json.RawMessage{json.RawMessage(`{"type":"m.typing","content":{"user_ids":["@charlie:localhost"]}}`)},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
// Wait for the queued responses to be processed.
|
||||
v2.waitUntilEmpty(t, aliceToken)
|
||||
v2.waitUntilEmpty(t, bobToken)
|
||||
|
||||
// Check that only Bob is typing and not Charlie.
|
||||
for _, token := range []string{aliceToken, bobToken} {
|
||||
pos := aliceRes.Pos
|
||||
if token == bobToken {
|
||||
pos = bobRes.Pos
|
||||
}
|
||||
|
||||
res := v3.mustDoV3RequestWithPos(t, token, pos, sync3.Request{
|
||||
Extensions: extensions.Request{
|
||||
Typing: &extensions.TypingRequest{
|
||||
Core: extensions.Core{Enabled: &boolTrue},
|
||||
},
|
||||
},
|
||||
Lists: map[string]sync3.RequestList{"a": {
|
||||
Ranges: sync3.SliceRanges{
|
||||
[2]int64{0, 1},
|
||||
},
|
||||
Sort: []string{sync3.SortByRecency},
|
||||
RoomSubscription: sync3.RoomSubscription{
|
||||
TimelineLimit: 0,
|
||||
},
|
||||
}},
|
||||
})
|
||||
// We expect only Bob typing, as only Alice Poller is "allowed"
|
||||
// to update typing notifications.
|
||||
m.MatchResponse(t, res, m.MatchTyping(roomA, []string{bob}))
|
||||
}
|
||||
}
|
||||
|
188
tests-integration/ignored_users_test.go
Normal file
188
tests-integration/ignored_users_test.go
Normal file
@ -0,0 +1,188 @@
|
||||
package syncv3
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"github.com/matrix-org/sliding-sync/sync2"
|
||||
"github.com/matrix-org/sliding-sync/sync3"
|
||||
"github.com/matrix-org/sliding-sync/testutils"
|
||||
"github.com/matrix-org/sliding-sync/testutils/m"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Test that messages from ignored users are not sent to clients, even if they appear
|
||||
// on someone else's poller first.
|
||||
func TestIgnoredUsersDuringLiveUpdate(t *testing.T) {
|
||||
pqString := testutils.PrepareDBConnectionString()
|
||||
v2 := runTestV2Server(t)
|
||||
v3 := runTestServer(t, v2, pqString)
|
||||
defer v2.close()
|
||||
defer v3.close()
|
||||
|
||||
const nigel = "@nigel:localhost"
|
||||
roomID := "!unimportant"
|
||||
|
||||
v2.addAccount(t, alice, aliceToken)
|
||||
v2.addAccount(t, bob, bobToken)
|
||||
|
||||
// Bob creates a room. Nigel and Alice join.
|
||||
state := createRoomState(t, bob, time.Now())
|
||||
state = append(state, testutils.NewStateEvent(t, "m.room.member", nigel, nigel, map[string]interface{}{
|
||||
"membership": "join",
|
||||
}))
|
||||
aliceJoin := testutils.NewStateEvent(t, "m.room.member", alice, alice, map[string]interface{}{
|
||||
"membership": "join",
|
||||
})
|
||||
|
||||
t.Log("Alice and Bob's pollers sees Alice's join.")
|
||||
v2.queueResponse(alice, sync2.SyncResponse{
|
||||
Rooms: sync2.SyncRoomsResponse{
|
||||
Join: v2JoinTimeline(roomEvents{
|
||||
roomID: roomID,
|
||||
state: state,
|
||||
events: []json.RawMessage{aliceJoin},
|
||||
}),
|
||||
},
|
||||
NextBatch: "alice_sync_1",
|
||||
})
|
||||
v2.queueResponse(bob, sync2.SyncResponse{
|
||||
Rooms: sync2.SyncRoomsResponse{
|
||||
Join: v2JoinTimeline(roomEvents{
|
||||
roomID: roomID,
|
||||
state: state,
|
||||
events: []json.RawMessage{aliceJoin},
|
||||
}),
|
||||
},
|
||||
NextBatch: "bob_sync_1",
|
||||
})
|
||||
|
||||
t.Log("Alice sliding syncs.")
|
||||
aliceRes := v3.mustDoV3Request(t, aliceToken, sync3.Request{
|
||||
Lists: map[string]sync3.RequestList{
|
||||
"a": {
|
||||
Ranges: sync3.SliceRanges{{0, 10}},
|
||||
RoomSubscription: sync3.RoomSubscription{
|
||||
TimelineLimit: 20,
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
t.Log("She should see her join.")
|
||||
m.MatchResponse(t, aliceRes, m.MatchRoomSubscription(roomID, m.MatchRoomTimeline([]json.RawMessage{aliceJoin})))
|
||||
|
||||
t.Log("Bob sliding syncs.")
|
||||
bobRes := v3.mustDoV3Request(t, bobToken, sync3.Request{
|
||||
Lists: map[string]sync3.RequestList{
|
||||
"a": {
|
||||
Ranges: sync3.SliceRanges{{0, 10}},
|
||||
RoomSubscription: sync3.RoomSubscription{
|
||||
TimelineLimit: 20,
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
t.Log("He should see Alice's join.")
|
||||
m.MatchResponse(t, bobRes, m.MatchRoomSubscription(roomID, m.MatchRoomTimeline([]json.RawMessage{aliceJoin})))
|
||||
|
||||
t.Log("Alice ignores Nigel.")
|
||||
v2.queueResponse(alice, sync2.SyncResponse{
|
||||
AccountData: sync2.EventsResponse{
|
||||
Events: []json.RawMessage{
|
||||
testutils.NewAccountData(t, "m.ignored_user_list", map[string]any{
|
||||
"ignored_users": map[string]any{
|
||||
nigel: map[string]any{},
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
NextBatch: "alice_sync_2",
|
||||
})
|
||||
v2.waitUntilEmpty(t, alice)
|
||||
|
||||
t.Log("Bob's poller sees a message from Nigel, then a message from Alice.")
|
||||
nigelMsg := testutils.NewMessageEvent(t, nigel, "naughty nigel")
|
||||
aliceMsg := testutils.NewMessageEvent(t, alice, "angelic alice")
|
||||
v2.queueResponse(bob, sync2.SyncResponse{
|
||||
Rooms: sync2.SyncRoomsResponse{
|
||||
Join: v2JoinTimeline(roomEvents{
|
||||
roomID: roomID,
|
||||
events: []json.RawMessage{nigelMsg, aliceMsg},
|
||||
}),
|
||||
},
|
||||
NextBatch: "bob_sync_2",
|
||||
})
|
||||
v2.waitUntilEmpty(t, bob)
|
||||
|
||||
t.Log("Bob syncs. He should see both messages.")
|
||||
bobRes = v3.mustDoV3RequestWithPos(t, bobToken, bobRes.Pos, sync3.Request{})
|
||||
m.MatchResponse(t, bobRes, m.MatchRoomSubscription(roomID, m.MatchRoomTimeline([]json.RawMessage{nigelMsg, aliceMsg})))
|
||||
|
||||
t.Log("Alice syncs. She should see her message, but not Nigel's.")
|
||||
aliceRes = v3.mustDoV3RequestWithPos(t, aliceToken, aliceRes.Pos, sync3.Request{})
|
||||
m.MatchResponse(t, aliceRes, m.MatchRoomSubscription(roomID, m.MatchRoomTimeline([]json.RawMessage{aliceMsg})))
|
||||
|
||||
t.Log("Bob's poller sees Nigel set a custom state event")
|
||||
nigelState := testutils.NewStateEvent(t, "com.example.fruit", "banana", nigel, map[string]any{})
|
||||
v2.queueResponse(bob, sync2.SyncResponse{
|
||||
Rooms: sync2.SyncRoomsResponse{
|
||||
Join: v2JoinTimeline(roomEvents{
|
||||
roomID: roomID,
|
||||
events: []json.RawMessage{nigelState},
|
||||
}),
|
||||
},
|
||||
NextBatch: "bob_sync_3",
|
||||
})
|
||||
v2.waitUntilEmpty(t, bob)
|
||||
|
||||
t.Log("Alice syncs. She should see Nigel's state event.")
|
||||
aliceRes = v3.mustDoV3RequestWithPos(t, aliceToken, aliceRes.Pos, sync3.Request{})
|
||||
m.MatchResponse(t, aliceRes, m.MatchRoomSubscription(roomID, m.MatchRoomTimeline([]json.RawMessage{nigelState})))
|
||||
|
||||
t.Log("Bob's poller sees Alice send a message.")
|
||||
aliceMsg2 := testutils.NewMessageEvent(t, alice, "angelic alice 2")
|
||||
|
||||
v2.queueResponse(bob, sync2.SyncResponse{
|
||||
Rooms: sync2.SyncRoomsResponse{
|
||||
Join: v2JoinTimeline(roomEvents{
|
||||
roomID: roomID,
|
||||
events: []json.RawMessage{aliceMsg2},
|
||||
}),
|
||||
},
|
||||
NextBatch: "bob_sync_4",
|
||||
})
|
||||
v2.waitUntilEmpty(t, bob)
|
||||
|
||||
t.Log("Alice syncs, making a new conn with a direct room subscription.")
|
||||
aliceRes = v3.mustDoV3Request(t, aliceToken, sync3.Request{
|
||||
Lists: map[string]sync3.RequestList{},
|
||||
RoomSubscriptions: map[string]sync3.RoomSubscription{
|
||||
roomID: {
|
||||
TimelineLimit: 20,
|
||||
},
|
||||
},
|
||||
})
|
||||
t.Log("Alice sees her join, her messages and nigel's state in the timeline.")
|
||||
m.MatchResponse(t, aliceRes, m.MatchRoomSubscription(roomID, m.MatchRoomTimeline([]json.RawMessage{
|
||||
aliceJoin, aliceMsg, nigelState, aliceMsg2,
|
||||
})))
|
||||
|
||||
t.Log("Bob's poller sees Nigel and Alice send a message.")
|
||||
nigelMsg2 := testutils.NewMessageEvent(t, nigel, "naughty nigel 3")
|
||||
aliceMsg3 := testutils.NewMessageEvent(t, alice, "angelic alice 3")
|
||||
|
||||
v2.queueResponse(bob, sync2.SyncResponse{
|
||||
Rooms: sync2.SyncRoomsResponse{
|
||||
Join: v2JoinTimeline(roomEvents{
|
||||
roomID: roomID,
|
||||
events: []json.RawMessage{nigelMsg2, aliceMsg3},
|
||||
}),
|
||||
},
|
||||
NextBatch: "bob_sync_5",
|
||||
})
|
||||
v2.waitUntilEmpty(t, bob)
|
||||
|
||||
t.Log("Alice syncs. She should only see her message.")
|
||||
aliceRes = v3.mustDoV3RequestWithPos(t, aliceToken, aliceRes.Pos, sync3.Request{})
|
||||
m.MatchResponse(t, aliceRes, m.MatchRoomSubscription(roomID, m.MatchRoomTimeline([]json.RawMessage{aliceMsg3})))
|
||||
|
||||
}
|
230
tests-integration/regressions_test.go
Normal file
230
tests-integration/regressions_test.go
Normal file
@ -0,0 +1,230 @@
|
||||
package syncv3
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/sliding-sync/sync2"
|
||||
"github.com/matrix-org/sliding-sync/sync3"
|
||||
"github.com/matrix-org/sliding-sync/testutils"
|
||||
"github.com/matrix-org/sliding-sync/testutils/m"
|
||||
)
|
||||
|
||||
// catch all file for any kind of regression test which doesn't fall into a unique category
|
||||
|
||||
// Regression test for https://github.com/matrix-org/sliding-sync/issues/192
|
||||
// - Bob on his server invites Alice to a room.
|
||||
// - Alice joins the room first over federation. Proxy does the right thing and sets her membership to join. There is no timeline though due to not having backfilled.
|
||||
// - Alice's client backfills in the room which pulls in the invite event, but the SS proxy doesn't see it as it's backfill, not /sync.
|
||||
// - Charlie joins the same room via SS, which makes the SS proxy see 50 timeline events, which includes the invite.
|
||||
// As the proxy has never seen this invite event before, it assumes it is newer than the join event and inserts it, corrupting state.
|
||||
//
|
||||
// Manually confirmed this can happen with 3x Element clients. We need to make sure we drop those earlier events.
|
||||
// The first join over federation presents itself as a single join event in the timeline, with the create event, etc in state.
|
||||
func TestBackfillInviteDoesntCorruptState(t *testing.T) {
|
||||
pqString := testutils.PrepareDBConnectionString()
|
||||
// setup code
|
||||
v2 := runTestV2Server(t)
|
||||
v3 := runTestServer(t, v2, pqString)
|
||||
defer v2.close()
|
||||
defer v3.close()
|
||||
|
||||
fedBob := "@bob:over_federation"
|
||||
charlie := "@charlie:localhost"
|
||||
charlieToken := "CHARLIE_TOKEN"
|
||||
joinEvent := testutils.NewJoinEvent(t, alice)
|
||||
|
||||
room := roomEvents{
|
||||
roomID: "!TestBackfillInviteDoesntCorruptState:localhost",
|
||||
events: []json.RawMessage{
|
||||
joinEvent,
|
||||
},
|
||||
state: createRoomState(t, fedBob, time.Now()),
|
||||
}
|
||||
v2.addAccount(t, alice, aliceToken)
|
||||
v2.queueResponse(alice, sync2.SyncResponse{
|
||||
Rooms: sync2.SyncRoomsResponse{
|
||||
Join: v2JoinTimeline(room),
|
||||
},
|
||||
})
|
||||
|
||||
// alice syncs and should see the room.
|
||||
aliceRes := v3.mustDoV3Request(t, aliceToken, sync3.Request{
|
||||
Lists: map[string]sync3.RequestList{
|
||||
"a": {
|
||||
Ranges: sync3.SliceRanges{{0, 20}},
|
||||
RoomSubscription: sync3.RoomSubscription{
|
||||
TimelineLimit: 5,
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
m.MatchResponse(t, aliceRes, m.MatchList("a", m.MatchV3Count(1), m.MatchV3Ops(m.MatchV3SyncOp(0, 0, []string{room.roomID}))))
|
||||
|
||||
// Alice's client "backfills" new data in, meaning the next user who joins is going to see a different set of timeline events
|
||||
dummyMsg := testutils.NewMessageEvent(t, fedBob, "you didn't see this before joining")
|
||||
charlieJoinEvent := testutils.NewJoinEvent(t, charlie)
|
||||
backfilledTimelineEvents := append(
|
||||
room.state, []json.RawMessage{
|
||||
dummyMsg,
|
||||
testutils.NewStateEvent(t, "m.room.member", alice, fedBob, map[string]interface{}{
|
||||
"membership": "invite",
|
||||
}),
|
||||
joinEvent,
|
||||
charlieJoinEvent,
|
||||
}...,
|
||||
)
|
||||
|
||||
// now charlie also joins the room, causing a different response from /sync v2
|
||||
v2.addAccount(t, charlie, charlieToken)
|
||||
v2.queueResponse(charlie, sync2.SyncResponse{
|
||||
Rooms: sync2.SyncRoomsResponse{
|
||||
Join: v2JoinTimeline(roomEvents{
|
||||
roomID: room.roomID,
|
||||
events: backfilledTimelineEvents,
|
||||
}),
|
||||
},
|
||||
})
|
||||
|
||||
// and now charlie hits SS, which might corrupt membership state for alice.
|
||||
charlieRes := v3.mustDoV3Request(t, charlieToken, sync3.Request{
|
||||
Lists: map[string]sync3.RequestList{
|
||||
"a": {
|
||||
Ranges: sync3.SliceRanges{{0, 20}},
|
||||
},
|
||||
},
|
||||
})
|
||||
m.MatchResponse(t, charlieRes, m.MatchList("a", m.MatchV3Count(1), m.MatchV3Ops(m.MatchV3SyncOp(0, 0, []string{room.roomID}))))
|
||||
|
||||
// alice should not see dummyMsg or the invite
|
||||
aliceRes = v3.mustDoV3RequestWithPos(t, aliceToken, aliceRes.Pos, sync3.Request{})
|
||||
m.MatchResponse(t, aliceRes, m.MatchNoV3Ops(), m.LogResponse(t), m.MatchRoomSubscriptionsStrict(
|
||||
map[string][]m.RoomMatcher{
|
||||
room.roomID: {
|
||||
m.MatchJoinCount(3), // alice, bob, charlie,
|
||||
m.MatchNoInviteCount(),
|
||||
m.MatchNumLive(1),
|
||||
m.MatchRoomTimeline([]json.RawMessage{charlieJoinEvent}),
|
||||
},
|
||||
},
|
||||
))
|
||||
}
|
||||
|
||||
func TestMalformedEventsTimeline(t *testing.T) {
|
||||
pqString := testutils.PrepareDBConnectionString()
|
||||
// setup code
|
||||
v2 := runTestV2Server(t)
|
||||
v3 := runTestServer(t, v2, pqString)
|
||||
defer v2.close()
|
||||
defer v3.close()
|
||||
|
||||
// unusual events ARE VALID EVENTS and should be sent to the client, but are unusual for some reason.
|
||||
unusualEvents := []json.RawMessage{
|
||||
testutils.NewStateEvent(t, "", "", alice, map[string]interface{}{
|
||||
"empty string": "for event type",
|
||||
}),
|
||||
}
|
||||
// malformed events are INVALID and should be ignored by the proxy.
|
||||
malformedEvents := []json.RawMessage{
|
||||
json.RawMessage(`{}`), // empty object
|
||||
json.RawMessage(`{"type":5}`), // type is an integer
|
||||
json.RawMessage(`{"type":"foo","content":{},"event_id":""}`), // 0-length string as event ID
|
||||
json.RawMessage(`{"type":"foo","content":{}}`), // missing event ID
|
||||
}
|
||||
|
||||
room := roomEvents{
|
||||
roomID: "!TestMalformedEventsTimeline:localhost",
|
||||
// append malformed after unusual. All malformed events should be dropped,
|
||||
// leaving only unusualEvents.
|
||||
events: append(unusualEvents, malformedEvents...),
|
||||
state: createRoomState(t, alice, time.Now()),
|
||||
}
|
||||
v2.addAccount(t, alice, aliceToken)
|
||||
v2.queueResponse(alice, sync2.SyncResponse{
|
||||
Rooms: sync2.SyncRoomsResponse{
|
||||
Join: v2JoinTimeline(room),
|
||||
},
|
||||
})
|
||||
|
||||
// alice syncs and should see the room.
|
||||
aliceRes := v3.mustDoV3Request(t, aliceToken, sync3.Request{
|
||||
Lists: map[string]sync3.RequestList{
|
||||
"a": {
|
||||
Ranges: sync3.SliceRanges{{0, 20}},
|
||||
RoomSubscription: sync3.RoomSubscription{
|
||||
TimelineLimit: int64(len(unusualEvents)),
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
m.MatchResponse(t, aliceRes, m.MatchList("a", m.MatchV3Count(1), m.MatchV3Ops(m.MatchV3SyncOp(0, 0, []string{room.roomID}))),
|
||||
m.MatchRoomSubscriptionsStrict(map[string][]m.RoomMatcher{
|
||||
room.roomID: {
|
||||
m.MatchJoinCount(1),
|
||||
m.MatchRoomTimeline(unusualEvents),
|
||||
},
|
||||
}))
|
||||
}
|
||||
|
||||
func TestMalformedEventsState(t *testing.T) {
|
||||
pqString := testutils.PrepareDBConnectionString()
|
||||
// setup code
|
||||
v2 := runTestV2Server(t)
|
||||
v3 := runTestServer(t, v2, pqString)
|
||||
defer v2.close()
|
||||
defer v3.close()
|
||||
|
||||
// unusual events ARE VALID EVENTS and should be sent to the client, but are unusual for some reason.
|
||||
unusualEvents := []json.RawMessage{
|
||||
testutils.NewStateEvent(t, "", "", alice, map[string]interface{}{
|
||||
"empty string": "for event type",
|
||||
}),
|
||||
}
|
||||
// malformed events are INVALID and should be ignored by the proxy.
|
||||
malformedEvents := []json.RawMessage{
|
||||
json.RawMessage(`{}`), // empty object
|
||||
json.RawMessage(`{"type":5,"content":{},"event_id":"f","state_key":""}`), // type is an integer
|
||||
json.RawMessage(`{"type":"foo","content":{},"event_id":"","state_key":""}`), // 0-length string as event ID
|
||||
json.RawMessage(`{"type":"foo","content":{},"state_key":""}`), // missing event ID
|
||||
}
|
||||
|
||||
latestEvent := testutils.NewEvent(t, "m.room.message", alice, map[string]interface{}{"body": "hi"})
|
||||
|
||||
room := roomEvents{
|
||||
roomID: "!TestMalformedEventsState:localhost",
|
||||
events: []json.RawMessage{latestEvent},
|
||||
// append malformed after unusual. All malformed events should be dropped,
|
||||
// leaving only unusualEvents.
|
||||
state: append(createRoomState(t, alice, time.Now()), append(unusualEvents, malformedEvents...)...),
|
||||
}
|
||||
v2.addAccount(t, alice, aliceToken)
|
||||
v2.queueResponse(alice, sync2.SyncResponse{
|
||||
Rooms: sync2.SyncRoomsResponse{
|
||||
Join: v2JoinTimeline(room),
|
||||
},
|
||||
})
|
||||
|
||||
// alice syncs and should see the room.
|
||||
aliceRes := v3.mustDoV3Request(t, aliceToken, sync3.Request{
|
||||
Lists: map[string]sync3.RequestList{
|
||||
"a": {
|
||||
Ranges: sync3.SliceRanges{{0, 20}},
|
||||
RoomSubscription: sync3.RoomSubscription{
|
||||
TimelineLimit: int64(len(unusualEvents)),
|
||||
RequiredState: [][2]string{{"", ""}},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
m.MatchResponse(t, aliceRes, m.MatchList("a", m.MatchV3Count(1), m.MatchV3Ops(m.MatchV3SyncOp(0, 0, []string{room.roomID}))),
|
||||
m.MatchRoomSubscriptionsStrict(map[string][]m.RoomMatcher{
|
||||
room.roomID: {
|
||||
m.MatchJoinCount(1),
|
||||
m.MatchRoomTimeline([]json.RawMessage{latestEvent}),
|
||||
m.MatchRoomRequiredState([]json.RawMessage{
|
||||
unusualEvents[0],
|
||||
}),
|
||||
},
|
||||
}))
|
||||
}
|
@ -137,12 +137,9 @@ func TestRoomSubscriptionMisorderedTimeline(t *testing.T) {
|
||||
})
|
||||
m.MatchResponse(t, res, m.MatchRoomSubscriptionsStrict(map[string][]m.RoomMatcher{
|
||||
room.roomID: {
|
||||
// TODO: this is the correct result, but due to how timeline loading works currently
|
||||
// it will be returning the last 5 events BEFORE D,E, which isn't ideal but also isn't
|
||||
// incorrect per se due to the fact that clients don't know when D,E have been processed
|
||||
// on the server.
|
||||
// m.MatchRoomTimeline(append(abcInitialEvents, deLiveEvents...)),
|
||||
m.MatchRoomTimeline(append(roomState[len(roomState)-2:], abcInitialEvents...)),
|
||||
// we append live events AFTER processing the new timeline limit, so 7 events not 5.
|
||||
// TODO: ideally we'd just return abcde here.
|
||||
m.MatchRoomTimeline(append(roomState[len(roomState)-2:], append(abcInitialEvents, deLiveEvents...)...)),
|
||||
},
|
||||
}), m.LogResponse(t))
|
||||
|
||||
|
@ -3,6 +3,7 @@ package syncv3
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
slidingsync "github.com/matrix-org/sliding-sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@ -689,6 +690,247 @@ func TestTimelineTxnID(t *testing.T) {
|
||||
))
|
||||
}
|
||||
|
||||
// TestTimelineTxnID checks that Alice sees her transaction_id if
|
||||
// - Bob's poller sees Alice's event,
|
||||
// - Alice's poller sees Alice's event with txn_id, and
|
||||
// - Alice syncs.
|
||||
//
|
||||
// This test is similar but not identical. It checks that Alice sees her transaction_id if
|
||||
// - Bob's poller sees Alice's event,
|
||||
// - Alice does an incremental sync, which should omit her event,
|
||||
// - Alice's poller sees Alice's event with txn_id, and
|
||||
// - Alice syncs, seeing her event with txn_id.
|
||||
func TestTimelineTxnIDBuffersForTxnID(t *testing.T) {
|
||||
pqString := testutils.PrepareDBConnectionString()
|
||||
// setup code
|
||||
v2 := runTestV2Server(t)
|
||||
v3 := runTestServer(t, v2, pqString, slidingsync.Opts{
|
||||
// This needs to be greater than the request timeout, which is hardcoded to a
|
||||
// minimum of 100ms in connStateLive.liveUpdate. This ensures that the
|
||||
// liveUpdate call finishes before the TxnIDWaiter publishes the update,
|
||||
// meaning that Alice doesn't see her event before the txn ID is known.
|
||||
MaxTransactionIDDelay: 200 * time.Millisecond,
|
||||
})
|
||||
defer v2.close()
|
||||
defer v3.close()
|
||||
roomID := "!a:localhost"
|
||||
latestTimestamp := time.Now()
|
||||
t.Log("Alice and Bob are in the same room")
|
||||
room := roomEvents{
|
||||
roomID: roomID,
|
||||
events: append(
|
||||
createRoomState(t, alice, latestTimestamp),
|
||||
testutils.NewJoinEvent(t, bob),
|
||||
),
|
||||
}
|
||||
v2.addAccount(t, alice, aliceToken)
|
||||
v2.addAccount(t, bob, bobToken)
|
||||
v2.queueResponse(alice, sync2.SyncResponse{
|
||||
Rooms: sync2.SyncRoomsResponse{
|
||||
Join: v2JoinTimeline(room),
|
||||
},
|
||||
NextBatch: "alice_after_initial_poll",
|
||||
})
|
||||
v2.queueResponse(bob, sync2.SyncResponse{
|
||||
Rooms: sync2.SyncRoomsResponse{
|
||||
Join: v2JoinTimeline(room),
|
||||
},
|
||||
NextBatch: "bob_after_initial_poll",
|
||||
})
|
||||
|
||||
t.Log("Alice and Bob make initial sliding syncs.")
|
||||
aliceRes := v3.mustDoV3Request(t, aliceToken, sync3.Request{
|
||||
Lists: map[string]sync3.RequestList{"a": {
|
||||
Ranges: sync3.SliceRanges{
|
||||
[2]int64{0, 10},
|
||||
},
|
||||
RoomSubscription: sync3.RoomSubscription{
|
||||
TimelineLimit: 2,
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
bobRes := v3.mustDoV3Request(t, bobToken, sync3.Request{
|
||||
Lists: map[string]sync3.RequestList{"a": {
|
||||
Ranges: sync3.SliceRanges{
|
||||
[2]int64{0, 10},
|
||||
},
|
||||
RoomSubscription: sync3.RoomSubscription{
|
||||
TimelineLimit: 2,
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
t.Log("Alice has sent a message... but it arrives down Bob's poller first, without a transaction_id")
|
||||
txnID := "m1234567890"
|
||||
newEvent := testutils.NewEvent(t, "m.room.message", alice, map[string]interface{}{"body": "hi"}, testutils.WithUnsigned(map[string]interface{}{
|
||||
"transaction_id": txnID,
|
||||
}))
|
||||
newEventNoUnsigned, err := sjson.DeleteBytes(newEvent, "unsigned")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to delete bytes: %s", err)
|
||||
}
|
||||
|
||||
v2.queueResponse(bob, sync2.SyncResponse{
|
||||
Rooms: sync2.SyncRoomsResponse{
|
||||
Join: v2JoinTimeline(roomEvents{
|
||||
roomID: roomID,
|
||||
events: []json.RawMessage{newEventNoUnsigned},
|
||||
}),
|
||||
},
|
||||
})
|
||||
t.Log("Bob's poller sees the message.")
|
||||
v2.waitUntilEmpty(t, bob)
|
||||
|
||||
t.Log("Bob makes an incremental sliding sync")
|
||||
bobRes = v3.mustDoV3RequestWithPos(t, bobToken, bobRes.Pos, sync3.Request{})
|
||||
t.Log("Bob should see the message without a transaction_id")
|
||||
m.MatchResponse(t, bobRes, m.MatchList("a", m.MatchV3Count(1)), m.MatchNoV3Ops(), m.MatchRoomSubscription(
|
||||
roomID, m.MatchRoomTimelineMostRecent(1, []json.RawMessage{newEventNoUnsigned}),
|
||||
))
|
||||
|
||||
t.Log("Alice requests an incremental sliding sync with no request changes.")
|
||||
aliceRes = v3.mustDoV3RequestWithPos(t, aliceToken, aliceRes.Pos, sync3.Request{})
|
||||
t.Log("Alice should see no messages.")
|
||||
m.MatchResponse(t, aliceRes, m.MatchRoomSubscriptionsStrict(nil))
|
||||
|
||||
// Now the message arrives down Alice's poller.
|
||||
v2.queueResponse(alice, sync2.SyncResponse{
|
||||
Rooms: sync2.SyncRoomsResponse{
|
||||
Join: v2JoinTimeline(roomEvents{
|
||||
roomID: roomID,
|
||||
events: []json.RawMessage{newEvent},
|
||||
}),
|
||||
},
|
||||
})
|
||||
t.Log("Alice's poller sees the message with transaction_id.")
|
||||
v2.waitUntilEmpty(t, alice)
|
||||
|
||||
t.Log("Alice makes another incremental sync request.")
|
||||
aliceRes = v3.mustDoV3RequestWithPos(t, aliceToken, aliceRes.Pos, sync3.Request{})
|
||||
t.Log("Alice's sync response includes the message with the txn ID.")
|
||||
m.MatchResponse(t, aliceRes, m.MatchList("a", m.MatchV3Count(1)), m.MatchNoV3Ops(), m.MatchRoomSubscription(
|
||||
roomID, m.MatchRoomTimelineMostRecent(1, []json.RawMessage{newEvent}),
|
||||
))
|
||||
|
||||
}
|
||||
|
||||
// Similar to TestTimelineTxnIDBuffersForTxnID, this test checks:
|
||||
// - Bob's poller sees Alice's event,
|
||||
// - Alice does an incremental sync, which should omit her event,
|
||||
// - Alice's poller sees Alice's event without a txn_id, and
|
||||
// - Alice syncs, seeing her event without txn_id.
|
||||
// I.e. we're checking that the "all clear" empties out the buffer of events.
|
||||
func TestTimelineTxnIDRespectsAllClear(t *testing.T) {
|
||||
pqString := testutils.PrepareDBConnectionString()
|
||||
// setup code
|
||||
v2 := runTestV2Server(t)
|
||||
v3 := runTestServer(t, v2, pqString, slidingsync.Opts{
|
||||
// This needs to be greater than the request timeout, which is hardcoded to a
|
||||
// minimum of 100ms in connStateLive.liveUpdate. This ensures that the
|
||||
// liveUpdate call finishes before the TxnIDWaiter publishes the update,
|
||||
// meaning that Alice doesn't see her event before the txn ID is known.
|
||||
MaxTransactionIDDelay: 200 * time.Millisecond,
|
||||
})
|
||||
defer v2.close()
|
||||
defer v3.close()
|
||||
roomID := "!a:localhost"
|
||||
latestTimestamp := time.Now()
|
||||
t.Log("Alice and Bob are in the same room")
|
||||
room := roomEvents{
|
||||
roomID: roomID,
|
||||
events: append(
|
||||
createRoomState(t, alice, latestTimestamp),
|
||||
testutils.NewJoinEvent(t, bob),
|
||||
),
|
||||
}
|
||||
v2.addAccount(t, alice, aliceToken)
|
||||
v2.addAccount(t, bob, bobToken)
|
||||
v2.queueResponse(alice, sync2.SyncResponse{
|
||||
Rooms: sync2.SyncRoomsResponse{
|
||||
Join: v2JoinTimeline(room),
|
||||
},
|
||||
NextBatch: "alice_after_initial_poll",
|
||||
})
|
||||
v2.queueResponse(bob, sync2.SyncResponse{
|
||||
Rooms: sync2.SyncRoomsResponse{
|
||||
Join: v2JoinTimeline(room),
|
||||
},
|
||||
NextBatch: "bob_after_initial_poll",
|
||||
})
|
||||
|
||||
t.Log("Alice and Bob make initial sliding syncs.")
|
||||
aliceRes := v3.mustDoV3Request(t, aliceToken, sync3.Request{
|
||||
Lists: map[string]sync3.RequestList{"a": {
|
||||
Ranges: sync3.SliceRanges{
|
||||
[2]int64{0, 10},
|
||||
},
|
||||
RoomSubscription: sync3.RoomSubscription{
|
||||
TimelineLimit: 2,
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
bobRes := v3.mustDoV3Request(t, bobToken, sync3.Request{
|
||||
Lists: map[string]sync3.RequestList{"a": {
|
||||
Ranges: sync3.SliceRanges{
|
||||
[2]int64{0, 10},
|
||||
},
|
||||
RoomSubscription: sync3.RoomSubscription{
|
||||
TimelineLimit: 2,
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
t.Log("Alice has sent a message... but it arrives down Bob's poller first, without a transaction_id")
|
||||
newEventNoTxn := testutils.NewEvent(t, "m.room.message", alice, map[string]interface{}{"body": "hi"})
|
||||
|
||||
v2.queueResponse(bob, sync2.SyncResponse{
|
||||
Rooms: sync2.SyncRoomsResponse{
|
||||
Join: v2JoinTimeline(roomEvents{
|
||||
roomID: roomID,
|
||||
events: []json.RawMessage{newEventNoTxn},
|
||||
}),
|
||||
},
|
||||
})
|
||||
t.Log("Bob's poller sees the message.")
|
||||
v2.waitUntilEmpty(t, bob)
|
||||
|
||||
t.Log("Bob makes an incremental sliding sync")
|
||||
bobRes = v3.mustDoV3RequestWithPos(t, bobToken, bobRes.Pos, sync3.Request{})
|
||||
t.Log("Bob should see the message without a transaction_id")
|
||||
m.MatchResponse(t, bobRes, m.MatchList("a", m.MatchV3Count(1)), m.MatchNoV3Ops(), m.MatchRoomSubscription(
|
||||
roomID, m.MatchRoomTimelineMostRecent(1, []json.RawMessage{newEventNoTxn}),
|
||||
))
|
||||
|
||||
t.Log("Alice requests an incremental sliding sync with no request changes.")
|
||||
aliceRes = v3.mustDoV3RequestWithPos(t, aliceToken, aliceRes.Pos, sync3.Request{})
|
||||
t.Log("Alice should see no messages.")
|
||||
m.MatchResponse(t, aliceRes, m.MatchRoomSubscriptionsStrict(nil))
|
||||
|
||||
// Now the message arrives down Alice's poller.
|
||||
v2.queueResponse(alice, sync2.SyncResponse{
|
||||
Rooms: sync2.SyncRoomsResponse{
|
||||
Join: v2JoinTimeline(roomEvents{
|
||||
roomID: roomID,
|
||||
events: []json.RawMessage{newEventNoTxn},
|
||||
}),
|
||||
},
|
||||
})
|
||||
t.Log("Alice's poller sees the message without transaction_id.")
|
||||
v2.waitUntilEmpty(t, alice)
|
||||
|
||||
t.Log("Alice makes another incremental sync request.")
|
||||
aliceRes = v3.mustDoV3RequestWithPos(t, aliceToken, aliceRes.Pos, sync3.Request{})
|
||||
t.Log("Alice's sync response includes the event without a txn ID.")
|
||||
m.MatchResponse(t, aliceRes, m.MatchList("a", m.MatchV3Count(1)), m.MatchNoV3Ops(), m.MatchRoomSubscription(
|
||||
roomID, m.MatchRoomTimelineMostRecent(1, []json.RawMessage{newEventNoTxn}),
|
||||
))
|
||||
|
||||
}
|
||||
|
||||
// Executes a sync v3 request without a ?pos and asserts that the count, rooms and timeline events m.Match the inputs given.
|
||||
func testTimelineLoadInitialEvents(v3 *testV3Server, token string, count int, wantRooms []roomEvents, numTimelineEventsPerRoom int) func(t *testing.T) {
|
||||
return func(t *testing.T) {
|
||||
|
@ -3,11 +3,12 @@ package syncv3
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"github.com/tidwall/gjson"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
|
||||
"github.com/matrix-org/sliding-sync/sync2"
|
||||
"github.com/matrix-org/sliding-sync/sync3"
|
||||
"github.com/matrix-org/sliding-sync/testutils"
|
||||
|
@ -291,11 +291,11 @@ func (s *testV3Server) close() {
|
||||
s.h2.Teardown()
|
||||
}
|
||||
|
||||
func (s *testV3Server) restart(t *testing.T, v2 *testV2Server, pq string) {
|
||||
func (s *testV3Server) restart(t *testing.T, v2 *testV2Server, pq string, opts ...syncv3.Opts) {
|
||||
t.Helper()
|
||||
log.Printf("restarting server")
|
||||
s.close()
|
||||
ss := runTestServer(t, v2, pq)
|
||||
ss := runTestServer(t, v2, pq, opts...)
|
||||
// replace all the fields which will be close()d to ensure we don't leak
|
||||
s.srv = ss.srv
|
||||
s.h2 = ss.h2
|
||||
@ -366,20 +366,24 @@ func runTestServer(t testutils.TestBenchInterface, v2Server *testV2Server, postg
|
||||
//tests often repeat requests. To ensure tests remain fast, reduce the spam protection limits.
|
||||
sync3.SpamProtectionInterval = time.Millisecond
|
||||
|
||||
metricsEnabled := false
|
||||
maxPendingEventUpdates := 200
|
||||
combinedOpts := syncv3.Opts{
|
||||
TestingSynchronousPubsub: true, // critical to avoid flakey tests
|
||||
AddPrometheusMetrics: false,
|
||||
MaxPendingEventUpdates: 200,
|
||||
MaxTransactionIDDelay: 0, // disable the txnID buffering to avoid flakey tests
|
||||
}
|
||||
if len(opts) > 0 {
|
||||
metricsEnabled = opts[0].AddPrometheusMetrics
|
||||
if opts[0].MaxPendingEventUpdates > 0 {
|
||||
maxPendingEventUpdates = opts[0].MaxPendingEventUpdates
|
||||
opt := opts[0]
|
||||
combinedOpts.AddPrometheusMetrics = opt.AddPrometheusMetrics
|
||||
combinedOpts.DBConnMaxIdleTime = opt.DBConnMaxIdleTime
|
||||
combinedOpts.DBMaxConns = opt.DBMaxConns
|
||||
combinedOpts.MaxTransactionIDDelay = opt.MaxTransactionIDDelay
|
||||
if opt.MaxPendingEventUpdates > 0 {
|
||||
combinedOpts.MaxPendingEventUpdates = opt.MaxPendingEventUpdates
|
||||
handler.BufferWaitTime = 5 * time.Millisecond
|
||||
}
|
||||
}
|
||||
h2, h3 := syncv3.Setup(v2Server.url(), postgresConnectionString, os.Getenv("SYNCV3_SECRET"), syncv3.Opts{
|
||||
TestingSynchronousPubsub: true, // critical to avoid flakey tests
|
||||
MaxPendingEventUpdates: maxPendingEventUpdates,
|
||||
AddPrometheusMetrics: metricsEnabled,
|
||||
})
|
||||
h2, h3 := syncv3.Setup(v2Server.url(), postgresConnectionString, os.Getenv("SYNCV3_SECRET"), combinedOpts)
|
||||
// for ease of use we don't start v2 pollers at startup in tests
|
||||
r := mux.NewRouter()
|
||||
r.Use(hlog.NewHandler(logger))
|
||||
|
@ -39,6 +39,39 @@ func MatchRoomName(name string) RoomMatcher {
|
||||
}
|
||||
}
|
||||
|
||||
// MatchRoomAvatar builds a RoomMatcher which checks that the given room response has
|
||||
// set the room's avatar to the given value.
|
||||
func MatchRoomAvatar(wantAvatar string) RoomMatcher {
|
||||
return func(r sync3.Room) error {
|
||||
if string(r.AvatarChange) != wantAvatar {
|
||||
return fmt.Errorf("MatchRoomAvatar: got \"%s\" want \"%s\"", r.AvatarChange, wantAvatar)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// MatchRoomUnsetAvatar builds a RoomMatcher which checks that the given room has no
|
||||
// avatar, or has had its avatar deleted.
|
||||
func MatchRoomUnsetAvatar() RoomMatcher {
|
||||
return func(r sync3.Room) error {
|
||||
if r.AvatarChange != sync3.DeletedAvatar {
|
||||
return fmt.Errorf("MatchRoomAvatar: got \"%s\" want \"%s\"", r.AvatarChange, sync3.DeletedAvatar)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// MatchRoomUnchangedAvatar builds a RoomMatcher which checks that the given room has no
|
||||
// change to its avatar, or has had its avatar deleted.
|
||||
func MatchRoomUnchangedAvatar() RoomMatcher {
|
||||
return func(r sync3.Room) error {
|
||||
if r.AvatarChange != sync3.UnchangedAvatar {
|
||||
return fmt.Errorf("MatchRoomAvatar: got \"%s\" want \"%s\"", r.AvatarChange, sync3.UnchangedAvatar)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func MatchJoinCount(count int) RoomMatcher {
|
||||
return func(r sync3.Room) error {
|
||||
if r.JoinedCount != count {
|
||||
@ -48,10 +81,22 @@ func MatchJoinCount(count int) RoomMatcher {
|
||||
}
|
||||
}
|
||||
|
||||
func MatchNoInviteCount() RoomMatcher {
|
||||
return func(r sync3.Room) error {
|
||||
if r.InvitedCount != nil {
|
||||
return fmt.Errorf("MatchInviteCount: invited_count is present when it should be missing: val=%v", *r.InvitedCount)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func MatchInviteCount(count int) RoomMatcher {
|
||||
return func(r sync3.Room) error {
|
||||
if r.InvitedCount != count {
|
||||
return fmt.Errorf("MatchInviteCount: got %v want %v", r.InvitedCount, count)
|
||||
if r.InvitedCount == nil {
|
||||
return fmt.Errorf("MatchInviteCount: invited_count is missing")
|
||||
}
|
||||
if *r.InvitedCount != count {
|
||||
return fmt.Errorf("MatchInviteCount: got %v want %v", *r.InvitedCount, count)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@ -210,11 +255,11 @@ func MatchRoomSubscription(roomID string, matchers ...RoomMatcher) RespMatcher {
|
||||
return func(res *sync3.Response) error {
|
||||
room, ok := res.Rooms[roomID]
|
||||
if !ok {
|
||||
return fmt.Errorf("MatchRoomSubscription: want sub for %s but it was missing", roomID)
|
||||
return fmt.Errorf("MatchRoomSubscription[%s]: want sub but it was missing", roomID)
|
||||
}
|
||||
for _, m := range matchers {
|
||||
if err := m(room); err != nil {
|
||||
return fmt.Errorf("MatchRoomSubscription: %s", err)
|
||||
return fmt.Errorf("MatchRoomSubscription[%s]: %s", roomID, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
@ -633,6 +678,15 @@ func LogResponse(t *testing.T) RespMatcher {
|
||||
}
|
||||
}
|
||||
|
||||
// LogRooms is like LogResponse, but only logs the rooms section of the response.
|
||||
func LogRooms(t *testing.T) RespMatcher {
|
||||
return func(res *sync3.Response) error {
|
||||
dump, _ := json.MarshalIndent(res.Rooms, "", " ")
|
||||
t.Logf("Response rooms were: %s", dump)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func CheckList(listKey string, res sync3.ResponseList, matchers ...ListMatcher) error {
|
||||
for _, m := range matchers {
|
||||
if err := m(res); err != nil {
|
||||
@ -675,13 +729,16 @@ func MatchLists(matchers map[string][]ListMatcher) RespMatcher {
|
||||
}
|
||||
}
|
||||
|
||||
const AnsiRedForeground = "\x1b[31m"
|
||||
const AnsiResetForeground = "\x1b[39m"
|
||||
|
||||
func MatchResponse(t *testing.T, res *sync3.Response, matchers ...RespMatcher) {
|
||||
t.Helper()
|
||||
for _, m := range matchers {
|
||||
err := m(res)
|
||||
if err != nil {
|
||||
b, _ := json.Marshal(res)
|
||||
t.Errorf("MatchResponse: %s\n%+v", err, string(b))
|
||||
b, _ := json.MarshalIndent(res, "", " ")
|
||||
t.Errorf("%vMatchResponse: %s\n%s%v", AnsiRedForeground, err, string(b), AnsiResetForeground)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
51
v3.go
51
v3.go
@ -1,6 +1,7 @@
|
||||
package slidingsync
|
||||
|
||||
import (
|
||||
"embed"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
@ -9,18 +10,24 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/getsentry/sentry-go"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/matrix-org/sliding-sync/internal"
|
||||
"github.com/matrix-org/sliding-sync/pubsub"
|
||||
"github.com/matrix-org/sliding-sync/state"
|
||||
"github.com/matrix-org/sliding-sync/sync2"
|
||||
"github.com/matrix-org/sliding-sync/sync2/handler2"
|
||||
"github.com/matrix-org/sliding-sync/sync3/handler"
|
||||
"github.com/pressly/goose/v3"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/hlog"
|
||||
|
||||
_ "github.com/matrix-org/sliding-sync/state/migrations"
|
||||
)
|
||||
|
||||
//go:embed state/migrations/*
|
||||
var EmbedMigrations embed.FS
|
||||
|
||||
var logger = zerolog.New(os.Stdout).With().Timestamp().Logger().Output(zerolog.ConsoleWriter{
|
||||
Out: os.Stderr,
|
||||
TimeFormat: "15:04:05",
|
||||
@ -36,6 +43,13 @@ type Opts struct {
|
||||
// if true, publishing messages will block until the consumer has consumed it.
|
||||
// Assumes a single producer and a single consumer.
|
||||
TestingSynchronousPubsub bool
|
||||
// MaxTransactionIDDelay is the longest amount of time that we will wait for
|
||||
// confirmation of an event's transaction_id before sending it to its sender.
|
||||
// Set to 0 to disable this delay mechanism entirely.
|
||||
MaxTransactionIDDelay time.Duration
|
||||
|
||||
DBMaxConns int
|
||||
DBConnMaxIdleTime time.Duration
|
||||
}
|
||||
|
||||
type server struct {
|
||||
@ -73,11 +87,38 @@ func Setup(destHomeserver, postgresURI, secret string, opts Opts) (*handler2.Han
|
||||
},
|
||||
DestinationServer: destHomeserver,
|
||||
}
|
||||
store := state.NewStorage(postgresURI)
|
||||
storev2 := sync2.NewStore(postgresURI, secret)
|
||||
db, err := sqlx.Open("postgres", postgresURI)
|
||||
if err != nil {
|
||||
sentry.CaptureException(err)
|
||||
// TODO: if we panic(), will sentry have a chance to flush the event?
|
||||
logger.Panic().Err(err).Str("uri", postgresURI).Msg("failed to open SQL DB")
|
||||
}
|
||||
|
||||
if opts.DBMaxConns > 0 {
|
||||
// https://github.com/go-sql-driver/mysql#important-settings
|
||||
// "db.SetMaxIdleConns() is recommended to be set same to db.SetMaxOpenConns(). When it is smaller
|
||||
// than SetMaxOpenConns(), connections can be opened and closed much more frequently than you expect."
|
||||
db.SetMaxOpenConns(opts.DBMaxConns)
|
||||
db.SetMaxIdleConns(opts.DBMaxConns)
|
||||
}
|
||||
if opts.DBConnMaxIdleTime > 0 {
|
||||
db.SetConnMaxIdleTime(opts.DBConnMaxIdleTime)
|
||||
}
|
||||
store := state.NewStorageWithDB(db)
|
||||
storev2 := sync2.NewStoreWithDB(db, secret)
|
||||
|
||||
// Automatically execute migrations
|
||||
goose.SetBaseFS(EmbedMigrations)
|
||||
err = goose.Up(db.DB, "state/migrations", goose.WithAllowMissing())
|
||||
if err != nil {
|
||||
logger.Panic().Err(err).Msg("failed to execute migrations")
|
||||
}
|
||||
|
||||
bufferSize := 50
|
||||
deviceDataUpdateFrequency := time.Second
|
||||
if opts.TestingSynchronousPubsub {
|
||||
bufferSize = 0
|
||||
deviceDataUpdateFrequency = 0 // don't batch
|
||||
}
|
||||
if opts.MaxPendingEventUpdates == 0 {
|
||||
opts.MaxPendingEventUpdates = 2000
|
||||
@ -86,14 +127,14 @@ func Setup(destHomeserver, postgresURI, secret string, opts Opts) (*handler2.Han
|
||||
|
||||
pMap := sync2.NewPollerMap(v2Client, opts.AddPrometheusMetrics)
|
||||
// create v2 handler
|
||||
h2, err := handler2.NewHandler(postgresURI, pMap, storev2, store, v2Client, pubSub, pubSub, opts.AddPrometheusMetrics)
|
||||
h2, err := handler2.NewHandler(pMap, storev2, store, pubSub, pubSub, opts.AddPrometheusMetrics, deviceDataUpdateFrequency)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
pMap.SetCallbacks(h2)
|
||||
|
||||
// create v3 handler
|
||||
h3, err := handler.NewSync3Handler(store, storev2, v2Client, postgresURI, secret, pubSub, pubSub, opts.AddPrometheusMetrics, opts.MaxPendingEventUpdates)
|
||||
h3, err := handler.NewSync3Handler(store, storev2, v2Client, secret, pubSub, pubSub, opts.AddPrometheusMetrics, opts.MaxPendingEventUpdates, opts.MaxTransactionIDDelay)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user