2021-05-14 16:49:33 +01:00
|
|
|
package syncv3
|
|
|
|
|
|
|
|
import (
|
|
|
|
"crypto/sha256"
|
|
|
|
"encoding/hex"
|
|
|
|
"encoding/json"
|
|
|
|
"fmt"
|
|
|
|
"net/http"
|
|
|
|
"os"
|
|
|
|
"strings"
|
2021-06-04 13:02:28 +01:00
|
|
|
"sync"
|
2021-05-14 16:49:33 +01:00
|
|
|
"time"
|
|
|
|
|
|
|
|
"github.com/gorilla/mux"
|
2021-06-03 16:18:01 +01:00
|
|
|
"github.com/matrix-org/sync-v3/state"
|
2021-06-09 17:27:54 +01:00
|
|
|
"github.com/matrix-org/sync-v3/sync2"
|
|
|
|
"github.com/matrix-org/sync-v3/sync3"
|
2021-05-14 16:49:33 +01:00
|
|
|
"github.com/rs/zerolog"
|
|
|
|
"github.com/rs/zerolog/hlog"
|
|
|
|
)
|
|
|
|
|
|
|
|
var log = zerolog.New(os.Stdout).With().Timestamp().Logger().Output(zerolog.ConsoleWriter{
|
|
|
|
Out: os.Stderr,
|
|
|
|
TimeFormat: "15:04:05",
|
|
|
|
})
|
|
|
|
|
2021-07-21 12:12:57 +01:00
|
|
|
type server struct {
|
|
|
|
chain []func(next http.Handler) http.Handler
|
|
|
|
final http.Handler
|
|
|
|
}
|
|
|
|
|
|
|
|
func (s *server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
|
|
|
h := s.final
|
|
|
|
for i := range s.chain {
|
|
|
|
h = s.chain[len(s.chain)-1-i](h)
|
|
|
|
}
|
|
|
|
h.ServeHTTP(w, req)
|
|
|
|
}
|
|
|
|
|
2021-05-14 16:49:33 +01:00
|
|
|
// RunSyncV3Server is the main entry point to the server
|
|
|
|
func RunSyncV3Server(destinationServer, bindAddr, postgresDBURI string) {
|
|
|
|
// dependency inject all components together
|
|
|
|
sh := &SyncV3Handler{
|
2021-07-21 16:35:36 +01:00
|
|
|
V2: &sync2.HTTPClient{
|
2021-05-14 16:49:33 +01:00
|
|
|
Client: &http.Client{
|
2021-06-11 17:20:32 +01:00
|
|
|
Timeout: 5 * time.Minute,
|
2021-05-14 16:49:33 +01:00
|
|
|
},
|
|
|
|
DestinationServer: destinationServer,
|
|
|
|
},
|
2021-06-16 17:18:04 +01:00
|
|
|
Sessions: sync3.NewSessions(postgresDBURI),
|
|
|
|
Storage: state.NewStorage(postgresDBURI),
|
|
|
|
Pollers: make(map[string]*sync2.Poller),
|
|
|
|
pollerMu: &sync.Mutex{},
|
2021-05-14 16:49:33 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
// HTTP path routing
|
|
|
|
r := mux.NewRouter()
|
|
|
|
r.Handle("/_matrix/client/v3/sync", sh)
|
2021-07-21 12:12:57 +01:00
|
|
|
|
|
|
|
srv := &server{
|
|
|
|
chain: []func(next http.Handler) http.Handler{
|
|
|
|
hlog.NewHandler(log),
|
|
|
|
hlog.AccessHandler(func(r *http.Request, status, size int, duration time.Duration) {
|
|
|
|
hlog.FromRequest(r).Info().
|
|
|
|
Str("method", r.Method).
|
|
|
|
Int("status", status).
|
|
|
|
Int("size", size).
|
|
|
|
Dur("duration", duration).
|
|
|
|
Str("since", r.URL.Query().Get("since")).
|
|
|
|
Msg("")
|
|
|
|
}),
|
|
|
|
hlog.RemoteAddrHandler("ip"),
|
|
|
|
},
|
|
|
|
final: r,
|
|
|
|
}
|
2021-05-14 16:49:33 +01:00
|
|
|
|
|
|
|
// Block forever
|
|
|
|
log.Info().Msgf("listening on %s", bindAddr)
|
2021-07-21 12:12:57 +01:00
|
|
|
if err := http.ListenAndServe(bindAddr, srv); err != nil {
|
2021-05-14 16:49:33 +01:00
|
|
|
log.Fatal().Err(err).Msg("failed to listen and serve")
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2021-06-09 18:13:55 +01:00
|
|
|
type handlerError struct {
|
|
|
|
StatusCode int
|
|
|
|
err error
|
|
|
|
}
|
|
|
|
|
|
|
|
func (e *handlerError) Error() string {
|
|
|
|
return fmt.Sprintf("HTTP %d : %s", e.StatusCode, e.err.Error())
|
|
|
|
}
|
|
|
|
|
2021-05-14 16:49:33 +01:00
|
|
|
type SyncV3Handler struct {
|
2021-07-21 16:35:36 +01:00
|
|
|
V2 sync2.Client
|
2021-06-16 17:18:04 +01:00
|
|
|
Sessions *sync3.Sessions
|
|
|
|
Storage *state.Storage
|
2021-06-04 13:02:28 +01:00
|
|
|
|
|
|
|
pollerMu *sync.Mutex
|
2021-06-09 17:27:54 +01:00
|
|
|
Pollers map[string]*sync2.Poller // device_id -> poller
|
2021-05-14 16:49:33 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
func (h *SyncV3Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
2021-06-09 18:13:55 +01:00
|
|
|
err := h.serve(w, req)
|
2021-05-14 16:49:33 +01:00
|
|
|
if err != nil {
|
2021-06-09 18:13:55 +01:00
|
|
|
w.WriteHeader(err.StatusCode)
|
2021-05-14 16:49:33 +01:00
|
|
|
w.Write(asJSONError(err))
|
|
|
|
}
|
2021-06-09 18:13:55 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
func (h *SyncV3Handler) serve(w http.ResponseWriter, req *http.Request) *handlerError {
|
2021-06-16 18:56:31 +01:00
|
|
|
session, tokv3, herr := h.getOrCreateSession(req)
|
|
|
|
if herr != nil {
|
|
|
|
return herr
|
2021-06-09 18:13:55 +01:00
|
|
|
}
|
2021-06-16 18:11:40 +01:00
|
|
|
log.Info().Int64("session", session.ID).Str("device", session.DeviceID).Msg("recv /v3/sync")
|
2021-06-09 18:13:55 +01:00
|
|
|
|
|
|
|
// make sure we have a poller for this device
|
|
|
|
h.ensurePolling(req.Header.Get("Authorization"), session)
|
|
|
|
|
2021-06-16 18:56:31 +01:00
|
|
|
// fetch the latest value which we'll base our response on
|
|
|
|
latestNID, err := h.Storage.LatestEventNID()
|
|
|
|
if err != nil {
|
|
|
|
return &handlerError{
|
|
|
|
err: err,
|
|
|
|
StatusCode: 500,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
upcoming := sync3.Token{
|
|
|
|
SessionID: session.ID,
|
|
|
|
NID: latestNID,
|
|
|
|
}
|
2021-07-21 15:03:09 +01:00
|
|
|
/*
|
|
|
|
var from int64
|
|
|
|
if tokv3 != nil {
|
|
|
|
from = tokv3.NID
|
|
|
|
} */
|
2021-06-16 18:56:31 +01:00
|
|
|
|
2021-07-21 15:03:09 +01:00
|
|
|
// read filters and mux in to form complete request
|
|
|
|
_, filterID, herr := h.parseRequest(req, tokv3, session)
|
|
|
|
if herr != nil {
|
|
|
|
return herr
|
2021-06-16 18:56:31 +01:00
|
|
|
}
|
2021-07-21 15:03:09 +01:00
|
|
|
// if there was a change to the filters, update the filter ID
|
|
|
|
if filterID != 0 {
|
|
|
|
upcoming.FilterID = filterID
|
2021-06-16 18:56:31 +01:00
|
|
|
}
|
|
|
|
|
2021-07-21 15:03:09 +01:00
|
|
|
// TODO: invoke streams to get responses
|
|
|
|
/*
|
|
|
|
f := false
|
|
|
|
filter := &streams.FilterRoomList{
|
|
|
|
EntriesPerBatch: 5,
|
|
|
|
RoomNameSize: 70,
|
|
|
|
IncludeRoomAvatarMXC: &f,
|
|
|
|
SummaryEventTypes: []string{"m.room.message", "m.room.member"},
|
|
|
|
}
|
|
|
|
stream := streams.NewRoomList(h.Storage)
|
|
|
|
_, _, err = stream.Process(session.DeviceID, from, latestNID, "", filter)
|
|
|
|
if err != nil {
|
|
|
|
return &handlerError{
|
|
|
|
err: err,
|
|
|
|
StatusCode: 500,
|
|
|
|
}
|
|
|
|
} */
|
|
|
|
|
2021-06-16 18:56:31 +01:00
|
|
|
// finally update our records: confirm that the client received the token they sent us, and mark this
|
|
|
|
// response as unconfirmed
|
|
|
|
var confirmed string
|
|
|
|
if tokv3 != nil {
|
|
|
|
confirmed = tokv3.String()
|
|
|
|
}
|
|
|
|
if err := h.Sessions.UpdateLastTokens(session.ID, confirmed, upcoming.String()); err != nil {
|
|
|
|
return &handlerError{
|
|
|
|
err: err,
|
|
|
|
StatusCode: 500,
|
|
|
|
}
|
|
|
|
}
|
2021-06-09 18:13:55 +01:00
|
|
|
|
|
|
|
w.WriteHeader(200)
|
2021-07-21 15:03:09 +01:00
|
|
|
resp := sync3.Response{
|
|
|
|
Next: upcoming.String(),
|
|
|
|
}
|
|
|
|
if err := json.NewEncoder(w).Encode(&resp); err != nil {
|
|
|
|
log.Warn().Err(err).Msg("failed to marshal response")
|
|
|
|
}
|
2021-06-09 18:13:55 +01:00
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
// getOrCreateSession retrieves an existing session if ?since= is set, else makes a new session.
|
|
|
|
// Returns a session or an error. Returns a token if and only if there is an existing session.
|
|
|
|
func (h *SyncV3Handler) getOrCreateSession(req *http.Request) (*sync3.Session, *sync3.Token, *handlerError) {
|
|
|
|
var session *sync3.Session
|
2021-06-09 17:27:54 +01:00
|
|
|
var tokv3 *sync3.Token
|
2021-06-09 18:13:55 +01:00
|
|
|
deviceID, err := deviceIDFromRequest(req)
|
|
|
|
if err != nil {
|
|
|
|
log.Warn().Err(err).Msg("failed to get device ID from request")
|
|
|
|
return nil, nil, &handlerError{400, err}
|
|
|
|
}
|
2021-05-14 16:49:33 +01:00
|
|
|
sincev3 := req.URL.Query().Get("since")
|
|
|
|
if sincev3 == "" {
|
|
|
|
session, err = h.Sessions.NewSession(deviceID)
|
|
|
|
} else {
|
2021-06-09 17:27:54 +01:00
|
|
|
tokv3, err = sync3.NewSyncToken(sincev3)
|
2021-05-14 16:49:33 +01:00
|
|
|
if err != nil {
|
|
|
|
log.Warn().Err(err).Msg("failed to parse sync v3 token")
|
2021-06-09 18:13:55 +01:00
|
|
|
return nil, nil, &handlerError{400, err}
|
2021-05-14 16:49:33 +01:00
|
|
|
}
|
2021-06-09 17:27:54 +01:00
|
|
|
session, err = h.Sessions.Session(tokv3.SessionID, deviceID)
|
2021-05-14 16:49:33 +01:00
|
|
|
}
|
|
|
|
if err != nil {
|
|
|
|
log.Warn().Err(err).Str("device", deviceID).Msg("failed to ensure Session existed for device")
|
2021-06-09 18:13:55 +01:00
|
|
|
return nil, nil, &handlerError{500, err}
|
2021-05-14 16:49:33 +01:00
|
|
|
}
|
2021-06-16 18:56:31 +01:00
|
|
|
if session.UserID == "" {
|
|
|
|
// we need to work out the user ID to do membership queries
|
|
|
|
userID, err := h.userIDFromRequest(req)
|
|
|
|
if err != nil {
|
2021-07-21 12:12:57 +01:00
|
|
|
log.Warn().Err(err).Msg("failed to work out user ID from request, is the authorization header valid?")
|
2021-06-16 18:56:31 +01:00
|
|
|
return nil, nil, &handlerError{400, err}
|
|
|
|
}
|
|
|
|
session.UserID = userID
|
|
|
|
h.Sessions.UpdateUserIDForDevice(deviceID, userID)
|
|
|
|
}
|
2021-06-09 18:13:55 +01:00
|
|
|
return session, tokv3, nil
|
2021-05-14 16:49:33 +01:00
|
|
|
}
|
|
|
|
|
2021-06-04 13:02:28 +01:00
|
|
|
// ensurePolling makes sure there is a poller for this device, making one if need be.
|
|
|
|
// Blocks until at least 1 sync is done if and only if the poller was just created.
|
|
|
|
// This ensures that calls to the database will return data.
|
2021-06-09 18:13:55 +01:00
|
|
|
func (h *SyncV3Handler) ensurePolling(authHeader string, session *sync3.Session) {
|
2021-06-04 13:02:28 +01:00
|
|
|
h.pollerMu.Lock()
|
2021-06-09 18:13:55 +01:00
|
|
|
poller, ok := h.Pollers[session.DeviceID]
|
2021-06-04 13:02:28 +01:00
|
|
|
// either no poller exists or it did but it died
|
|
|
|
if ok && !poller.Terminated {
|
|
|
|
h.pollerMu.Unlock()
|
|
|
|
return
|
2021-06-03 16:18:01 +01:00
|
|
|
}
|
2021-06-04 13:02:28 +01:00
|
|
|
// replace the poller
|
2021-06-16 17:18:04 +01:00
|
|
|
poller = sync2.NewPoller(authHeader, session.DeviceID, h.V2, h.Storage, h.Sessions)
|
2021-06-04 13:02:28 +01:00
|
|
|
var wg sync.WaitGroup
|
|
|
|
wg.Add(1)
|
2021-06-11 17:20:32 +01:00
|
|
|
go poller.Poll(session.V2Since, func() {
|
2021-06-04 13:02:28 +01:00
|
|
|
wg.Done()
|
|
|
|
})
|
2021-06-09 18:13:55 +01:00
|
|
|
h.Pollers[session.DeviceID] = poller
|
2021-06-04 13:02:28 +01:00
|
|
|
h.pollerMu.Unlock()
|
|
|
|
wg.Wait()
|
2021-06-03 16:18:01 +01:00
|
|
|
}
|
|
|
|
|
2021-07-21 15:03:09 +01:00
|
|
|
func (h *SyncV3Handler) parseRequest(req *http.Request, tok *sync3.Token, session *sync3.Session) (*sync3.Request, int64, *handlerError) {
|
|
|
|
existing := &sync3.Request{} // first request
|
|
|
|
var err error
|
|
|
|
if tok != nil && tok.FilterID != 0 {
|
2021-07-21 15:15:50 +01:00
|
|
|
// load existing filter
|
|
|
|
existing, err = h.Sessions.Filter(tok.SessionID, tok.FilterID)
|
2021-07-21 15:03:09 +01:00
|
|
|
if err != nil {
|
|
|
|
return nil, 0, &handlerError{
|
|
|
|
StatusCode: 400,
|
|
|
|
err: fmt.Errorf("failed to load filters from sync token: %s", err),
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
// load new delta from request
|
|
|
|
defer req.Body.Close()
|
|
|
|
var delta sync3.Request
|
|
|
|
if err := json.NewDecoder(req.Body).Decode(&delta); err != nil {
|
|
|
|
return nil, 0, &handlerError{
|
|
|
|
StatusCode: 400,
|
|
|
|
err: fmt.Errorf("failed to parse request body as JSON: %s", err),
|
|
|
|
}
|
|
|
|
}
|
|
|
|
var filterID int64
|
|
|
|
if existing.ApplyDeltas(&delta) {
|
|
|
|
// persist new filters if there were deltas
|
2021-07-21 15:15:50 +01:00
|
|
|
filterID, err = h.Sessions.InsertFilter(session.ID, existing)
|
2021-07-21 15:03:09 +01:00
|
|
|
if err != nil {
|
|
|
|
return nil, 0, &handlerError{
|
|
|
|
StatusCode: 500,
|
|
|
|
err: fmt.Errorf("failed to persist filters: %s", err),
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return existing, filterID, nil
|
|
|
|
}
|
|
|
|
|
2021-06-16 18:56:31 +01:00
|
|
|
func (h *SyncV3Handler) userIDFromRequest(req *http.Request) (string, error) {
|
|
|
|
return h.V2.WhoAmI(req.Header.Get("Authorization"))
|
|
|
|
}
|
|
|
|
|
2021-05-14 16:49:33 +01:00
|
|
|
type jsonError struct {
|
|
|
|
Err string `json:"error"`
|
|
|
|
}
|
|
|
|
|
|
|
|
func asJSONError(err error) []byte {
|
|
|
|
je := jsonError{err.Error()}
|
|
|
|
b, _ := json.Marshal(je)
|
|
|
|
return b
|
|
|
|
}
|
|
|
|
|
|
|
|
func deviceIDFromRequest(req *http.Request) (string, error) {
|
|
|
|
// return a hash of the access token
|
|
|
|
ah := req.Header.Get("Authorization")
|
|
|
|
if ah == "" {
|
|
|
|
return "", fmt.Errorf("missing Authorization header")
|
|
|
|
}
|
|
|
|
accessToken := strings.TrimPrefix(ah, "Bearer ")
|
|
|
|
hash := sha256.New()
|
|
|
|
hash.Write([]byte(accessToken))
|
|
|
|
return hex.EncodeToString(hash.Sum(nil)), nil
|
|
|
|
}
|