sliding-sync/v3.go

312 lines
8.3 KiB
Go
Raw Normal View History

2021-05-14 16:49:33 +01:00
package syncv3
import (
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"net/http"
"os"
"strings"
"sync"
2021-05-14 16:49:33 +01:00
"time"
"github.com/gorilla/mux"
2021-06-03 16:18:01 +01:00
"github.com/matrix-org/sync-v3/state"
"github.com/matrix-org/sync-v3/sync2"
"github.com/matrix-org/sync-v3/sync3"
2021-05-14 16:49:33 +01:00
"github.com/rs/zerolog"
"github.com/rs/zerolog/hlog"
)
var log = zerolog.New(os.Stdout).With().Timestamp().Logger().Output(zerolog.ConsoleWriter{
Out: os.Stderr,
TimeFormat: "15:04:05",
})
2021-07-21 12:12:57 +01:00
type server struct {
chain []func(next http.Handler) http.Handler
final http.Handler
}
func (s *server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
h := s.final
for i := range s.chain {
h = s.chain[len(s.chain)-1-i](h)
}
h.ServeHTTP(w, req)
}
2021-05-14 16:49:33 +01:00
// RunSyncV3Server is the main entry point to the server
func RunSyncV3Server(destinationServer, bindAddr, postgresDBURI string) {
// dependency inject all components together
sh := &SyncV3Handler{
V2: &sync2.HTTPClient{
2021-05-14 16:49:33 +01:00
Client: &http.Client{
2021-06-11 17:20:32 +01:00
Timeout: 5 * time.Minute,
2021-05-14 16:49:33 +01:00
},
DestinationServer: destinationServer,
},
Sessions: sync3.NewSessions(postgresDBURI),
Storage: state.NewStorage(postgresDBURI),
Pollers: make(map[string]*sync2.Poller),
pollerMu: &sync.Mutex{},
2021-05-14 16:49:33 +01:00
}
// HTTP path routing
r := mux.NewRouter()
r.Handle("/_matrix/client/v3/sync", sh)
2021-07-21 12:12:57 +01:00
srv := &server{
chain: []func(next http.Handler) http.Handler{
hlog.NewHandler(log),
hlog.AccessHandler(func(r *http.Request, status, size int, duration time.Duration) {
hlog.FromRequest(r).Info().
Str("method", r.Method).
Int("status", status).
Int("size", size).
Dur("duration", duration).
Str("since", r.URL.Query().Get("since")).
Msg("")
}),
hlog.RemoteAddrHandler("ip"),
},
final: r,
}
2021-05-14 16:49:33 +01:00
// Block forever
log.Info().Msgf("listening on %s", bindAddr)
2021-07-21 12:12:57 +01:00
if err := http.ListenAndServe(bindAddr, srv); err != nil {
2021-05-14 16:49:33 +01:00
log.Fatal().Err(err).Msg("failed to listen and serve")
}
}
type handlerError struct {
StatusCode int
err error
}
func (e *handlerError) Error() string {
return fmt.Sprintf("HTTP %d : %s", e.StatusCode, e.err.Error())
}
2021-05-14 16:49:33 +01:00
type SyncV3Handler struct {
V2 sync2.Client
Sessions *sync3.Sessions
Storage *state.Storage
pollerMu *sync.Mutex
Pollers map[string]*sync2.Poller // device_id -> poller
2021-05-14 16:49:33 +01:00
}
func (h *SyncV3Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
err := h.serve(w, req)
2021-05-14 16:49:33 +01:00
if err != nil {
w.WriteHeader(err.StatusCode)
2021-05-14 16:49:33 +01:00
w.Write(asJSONError(err))
}
}
func (h *SyncV3Handler) serve(w http.ResponseWriter, req *http.Request) *handlerError {
session, tokv3, herr := h.getOrCreateSession(req)
if herr != nil {
return herr
}
log.Info().Int64("session", session.ID).Str("device", session.DeviceID).Msg("recv /v3/sync")
// make sure we have a poller for this device
h.ensurePolling(req.Header.Get("Authorization"), session)
// fetch the latest value which we'll base our response on
latestNID, err := h.Storage.LatestEventNID()
if err != nil {
return &handlerError{
err: err,
StatusCode: 500,
}
}
upcoming := sync3.Token{
SessionID: session.ID,
NID: latestNID,
}
2021-07-21 15:03:09 +01:00
/*
var from int64
if tokv3 != nil {
from = tokv3.NID
} */
2021-07-21 15:03:09 +01:00
// read filters and mux in to form complete request
_, filterID, herr := h.parseRequest(req, tokv3, session)
if herr != nil {
return herr
}
2021-07-21 15:03:09 +01:00
// if there was a change to the filters, update the filter ID
if filterID != 0 {
upcoming.FilterID = filterID
}
2021-07-21 15:03:09 +01:00
// TODO: invoke streams to get responses
/*
f := false
filter := &streams.FilterRoomList{
EntriesPerBatch: 5,
RoomNameSize: 70,
IncludeRoomAvatarMXC: &f,
SummaryEventTypes: []string{"m.room.message", "m.room.member"},
}
stream := streams.NewRoomList(h.Storage)
_, _, err = stream.Process(session.DeviceID, from, latestNID, "", filter)
if err != nil {
return &handlerError{
err: err,
StatusCode: 500,
}
} */
// finally update our records: confirm that the client received the token they sent us, and mark this
// response as unconfirmed
var confirmed string
if tokv3 != nil {
confirmed = tokv3.String()
}
if err := h.Sessions.UpdateLastTokens(session.ID, confirmed, upcoming.String()); err != nil {
return &handlerError{
err: err,
StatusCode: 500,
}
}
w.WriteHeader(200)
2021-07-21 15:03:09 +01:00
resp := sync3.Response{
Next: upcoming.String(),
}
if err := json.NewEncoder(w).Encode(&resp); err != nil {
log.Warn().Err(err).Msg("failed to marshal response")
}
return nil
}
// getOrCreateSession retrieves an existing session if ?since= is set, else makes a new session.
// Returns a session or an error. Returns a token if and only if there is an existing session.
func (h *SyncV3Handler) getOrCreateSession(req *http.Request) (*sync3.Session, *sync3.Token, *handlerError) {
var session *sync3.Session
var tokv3 *sync3.Token
deviceID, err := deviceIDFromRequest(req)
if err != nil {
log.Warn().Err(err).Msg("failed to get device ID from request")
return nil, nil, &handlerError{400, err}
}
2021-05-14 16:49:33 +01:00
sincev3 := req.URL.Query().Get("since")
if sincev3 == "" {
session, err = h.Sessions.NewSession(deviceID)
} else {
tokv3, err = sync3.NewSyncToken(sincev3)
2021-05-14 16:49:33 +01:00
if err != nil {
log.Warn().Err(err).Msg("failed to parse sync v3 token")
return nil, nil, &handlerError{400, err}
2021-05-14 16:49:33 +01:00
}
session, err = h.Sessions.Session(tokv3.SessionID, deviceID)
2021-05-14 16:49:33 +01:00
}
if err != nil {
log.Warn().Err(err).Str("device", deviceID).Msg("failed to ensure Session existed for device")
return nil, nil, &handlerError{500, err}
2021-05-14 16:49:33 +01:00
}
if session.UserID == "" {
// we need to work out the user ID to do membership queries
userID, err := h.userIDFromRequest(req)
if err != nil {
2021-07-21 12:12:57 +01:00
log.Warn().Err(err).Msg("failed to work out user ID from request, is the authorization header valid?")
return nil, nil, &handlerError{400, err}
}
session.UserID = userID
h.Sessions.UpdateUserIDForDevice(deviceID, userID)
}
return session, tokv3, nil
2021-05-14 16:49:33 +01:00
}
// ensurePolling makes sure there is a poller for this device, making one if need be.
// Blocks until at least 1 sync is done if and only if the poller was just created.
// This ensures that calls to the database will return data.
func (h *SyncV3Handler) ensurePolling(authHeader string, session *sync3.Session) {
h.pollerMu.Lock()
poller, ok := h.Pollers[session.DeviceID]
// either no poller exists or it did but it died
if ok && !poller.Terminated {
h.pollerMu.Unlock()
return
2021-06-03 16:18:01 +01:00
}
// replace the poller
poller = sync2.NewPoller(authHeader, session.DeviceID, h.V2, h.Storage, h.Sessions)
var wg sync.WaitGroup
wg.Add(1)
2021-06-11 17:20:32 +01:00
go poller.Poll(session.V2Since, func() {
wg.Done()
})
h.Pollers[session.DeviceID] = poller
h.pollerMu.Unlock()
wg.Wait()
2021-06-03 16:18:01 +01:00
}
2021-07-21 15:03:09 +01:00
func (h *SyncV3Handler) parseRequest(req *http.Request, tok *sync3.Token, session *sync3.Session) (*sync3.Request, int64, *handlerError) {
existing := &sync3.Request{} // first request
var err error
if tok != nil && tok.FilterID != 0 {
2021-07-21 15:15:50 +01:00
// load existing filter
existing, err = h.Sessions.Filter(tok.SessionID, tok.FilterID)
2021-07-21 15:03:09 +01:00
if err != nil {
return nil, 0, &handlerError{
StatusCode: 400,
err: fmt.Errorf("failed to load filters from sync token: %s", err),
}
}
}
// load new delta from request
defer req.Body.Close()
var delta sync3.Request
if err := json.NewDecoder(req.Body).Decode(&delta); err != nil {
return nil, 0, &handlerError{
StatusCode: 400,
err: fmt.Errorf("failed to parse request body as JSON: %s", err),
}
}
var filterID int64
if existing.ApplyDeltas(&delta) {
// persist new filters if there were deltas
2021-07-21 15:15:50 +01:00
filterID, err = h.Sessions.InsertFilter(session.ID, existing)
2021-07-21 15:03:09 +01:00
if err != nil {
return nil, 0, &handlerError{
StatusCode: 500,
err: fmt.Errorf("failed to persist filters: %s", err),
}
}
}
return existing, filterID, nil
}
func (h *SyncV3Handler) userIDFromRequest(req *http.Request) (string, error) {
return h.V2.WhoAmI(req.Header.Get("Authorization"))
}
2021-05-14 16:49:33 +01:00
type jsonError struct {
Err string `json:"error"`
}
func asJSONError(err error) []byte {
je := jsonError{err.Error()}
b, _ := json.Marshal(je)
return b
}
func deviceIDFromRequest(req *http.Request) (string, error) {
// return a hash of the access token
ah := req.Header.Get("Authorization")
if ah == "" {
return "", fmt.Errorf("missing Authorization header")
}
accessToken := strings.TrimPrefix(ah, "Bearer ")
hash := sha256.New()
hash.Write([]byte(accessToken))
return hex.EncodeToString(hash.Sum(nil)), nil
}