From 075a0d19a08f88a4c87463aa9b785113a2ab831a Mon Sep 17 00:00:00 2001 From: Kegan Dougal Date: Tue, 3 Aug 2021 17:43:41 +0100 Subject: [PATCH] Add a Streamer interface and make liberal use of it This will be extended further, because where we're going, we need a lot of streams... --- state/typing_table.go | 5 ++-- state/typing_table_test.go | 2 +- sync3/streams/request.go | 8 +++++++ sync3/streams/response.go | 5 ++-- sync3/streams/stream.go | 20 ++++++++++++++++ sync3/streams/to_device.go | 17 +++++++++---- sync3/streams/typing.go | 20 +++++++++++----- v3.go | 26 +++++++++++--------- v3_test.go | 49 ++++++++++++++++++++++++++++++++++++++ 9 files changed, 126 insertions(+), 26 deletions(-) create mode 100644 sync3/streams/stream.go diff --git a/state/typing_table.go b/state/typing_table.go index d2bdf70..22150c4 100644 --- a/state/typing_table.go +++ b/state/typing_table.go @@ -50,10 +50,11 @@ func (t *TypingTable) SetTyping(roomID string, userIDs []string) (position int64 return position, err } -func (t *TypingTable) Typing(roomID string, fromStreamIDExcl int64) (userIDs []string, latest int64, err error) { +func (t *TypingTable) Typing(roomID string, fromStreamIDExcl, toStreamIDIncl int64) (userIDs []string, latest int64, err error) { var userIDsArray pq.StringArray err = t.db.QueryRow( - `SELECT stream_id, user_ids FROM syncv3_typing WHERE room_id=$1 AND stream_id > $2 `, roomID, fromStreamIDExcl, + `SELECT stream_id, user_ids FROM syncv3_typing WHERE room_id=$1 AND stream_id > $2 AND stream_id <= $3`, + roomID, fromStreamIDExcl, toStreamIDIncl, ).Scan(&latest, &userIDsArray) if err == sql.ErrNoRows { err = nil diff --git a/state/typing_table_test.go b/state/typing_table_test.go index 8e7c84d..0521bdf 100644 --- a/state/typing_table_test.go +++ b/state/typing_table_test.go @@ -32,7 +32,7 @@ func TestTypingTable(t *testing.T) { t.Errorf("SetTyping: streamID returned should always be increasing but it wasn't, got %d, last %d", streamID, lastStreamID) } lastStreamID = streamID - gotUserIDs, _, err := table.Typing(roomID, streamID-1) + gotUserIDs, _, err := table.Typing(roomID, streamID-1, lastStreamID) if err != nil { t.Fatalf("failed to Typing: %s", err) } diff --git a/sync3/streams/request.go b/sync3/streams/request.go index a1f22f9..a6d0672 100644 --- a/sync3/streams/request.go +++ b/sync3/streams/request.go @@ -29,5 +29,13 @@ func (r *Request) ApplyDeltas(req2 *Request) bool { r.Typing = r.Typing.Combine(req2.Typing) } } + if req2.ToDevice != nil { + deltasExist = true + if r.ToDevice == nil { + r.ToDevice = req2.ToDevice + } else { + r.ToDevice = r.ToDevice.Combine(req2.ToDevice) + } + } return deltasExist } diff --git a/sync3/streams/response.go b/sync3/streams/response.go index b15dc10..83addaa 100644 --- a/sync3/streams/response.go +++ b/sync3/streams/response.go @@ -1,6 +1,7 @@ package streams type Response struct { - Next string `json:"next"` - Typing *TypingResponse `json:"typing,omitempty"` + Next string `json:"next"` + Typing *TypingResponse `json:"typing,omitempty"` + ToDevice *ToDeviceResponse `json:"to_device,omitempty"` } diff --git a/sync3/streams/stream.go b/sync3/streams/stream.go new file mode 100644 index 0000000..269394a --- /dev/null +++ b/sync3/streams/stream.go @@ -0,0 +1,20 @@ +package streams + +import ( + "errors" + + "github.com/matrix-org/sync-v3/sync3" +) + +// Streamer specifies an interface which, if satisfied, can be used to make a stream. +type Streamer interface { + // Return the position in the correct stream based on a sync v3 token + Position(tok *sync3.Token) int64 + // Set the stream position for this stream in the sync v3 token + // SetPosition(tok *sync3.Token, pos int64) + // Extract data between the two stream positions and assign to Response. + DataInRange(session *sync3.Session, fromExcl, toIncl int64, req *Request, resp *Response) error +} + +// ErrNotRequested should be returned in DataInRange if the request does not ask for this stream. +var ErrNotRequested = errors.New("stream not requested") diff --git a/sync3/streams/to_device.go b/sync3/streams/to_device.go index 3759871..1ebae9a 100644 --- a/sync3/streams/to_device.go +++ b/sync3/streams/to_device.go @@ -35,12 +35,21 @@ func NewToDevice(s *state.Storage) *ToDevice { return &ToDevice{s} } -func (s *ToDevice) Process(session *sync3.Session, from, to int64, f *FilterToDevice, resp *ToDeviceResponse) error { - msgs, err := s.storage.ToDeviceTable.Messages(session.DeviceID, from, to) +func (s *ToDevice) Position(tok *sync3.Token) int64 { + return tok.ToDevicePosition() +} + +func (s *ToDevice) DataInRange(session *sync3.Session, fromExcl, toIncl int64, request *Request, resp *Response) error { + if request.ToDevice == nil { + return ErrNotRequested + } + msgs, err := s.storage.ToDeviceTable.Messages(session.DeviceID, fromExcl, toIncl) if err != nil { return err } - resp.Limit = f.Limit - resp.Events = msgs + resp.ToDevice = &ToDeviceResponse{ + Limit: request.ToDevice.Limit, + Events: msgs, + } return nil } diff --git a/sync3/streams/typing.go b/sync3/streams/typing.go index d36c715..f5e8aa1 100644 --- a/sync3/streams/typing.go +++ b/sync3/streams/typing.go @@ -4,6 +4,7 @@ import ( "fmt" "github.com/matrix-org/sync-v3/state" + "github.com/matrix-org/sync-v3/sync3" ) type FilterTyping struct { @@ -33,13 +34,20 @@ func NewTyping(s *state.Storage) *Typing { return &Typing{s} } -func (s *Typing) Process(userID string, from int64, f *FilterTyping) (resp *TypingResponse, next int64, err error) { - userIDs, to, err := s.storage.TypingTable.Typing(f.RoomID, from) - if err != nil { - return nil, 0, fmt.Errorf("Typing: %s", err) +func (s *Typing) Position(tok *sync3.Token) int64 { + return tok.TypingPosition() +} + +func (s *Typing) DataInRange(session *sync3.Session, fromExcl, toIncl int64, request *Request, resp *Response) error { + if request.Typing == nil { + return ErrNotRequested } - resp = &TypingResponse{ + userIDs, _, err := s.storage.TypingTable.Typing(request.Typing.RoomID, fromExcl, toIncl) + if err != nil { + return fmt.Errorf("Typing: %s", err) + } + resp.Typing = &TypingResponse{ UserIDs: userIDs, } - return resp, to, nil + return nil } diff --git a/v3.go b/v3.go index f38b6a0..e194c3b 100644 --- a/v3.go +++ b/v3.go @@ -96,7 +96,7 @@ type SyncV3Handler struct { Storage *state.Storage Notifier *notifier.Notifier - typingStream *streams.Typing + streams []streams.Streamer pollerMu *sync.Mutex Pollers map[string]*sync2.Poller // device_id -> poller @@ -110,7 +110,9 @@ func NewSyncV3Handler(v2Client sync2.Client, postgresDBURI string) *SyncV3Handle Pollers: make(map[string]*sync2.Poller), pollerMu: &sync.Mutex{}, } - sh.typingStream = streams.NewTyping(sh.Storage) + sh.streams = append(sh.streams, streams.NewTyping(sh.Storage)) + sh.streams = append(sh.streams, streams.NewToDevice(sh.Storage)) + latestToken := sync3.NewBlankSyncToken(0, 0) nid, err := sh.Storage.LatestEventNID() if err != nil { @@ -199,18 +201,19 @@ func (h *SyncV3Handler) serve(w http.ResponseWriter, req *http.Request) *handler resp := streams.Response{} // invoke streams to get responses - if syncReq.Typing != nil { - typingResp, typingTo, err := h.typingStream.Process(session.UserID, fromToken.TypingPosition(), syncReq.Typing) + for _, stream := range h.streams { + fromExcl := stream.Position(fromToken) + toIncl := stream.Position(&upcoming) + err = stream.DataInRange(session, fromExcl, toIncl, syncReq, &resp) + if err == streams.ErrNotRequested { + continue + } if err != nil { return &handlerError{ StatusCode: 500, - err: fmt.Errorf("typing stream: %s", err), + err: fmt.Errorf("stream error: %s", err), } } - upcoming.SetTypingPosition(typingTo) - resp.Typing = typingResp - } - if syncReq.ToDevice != nil { } resp.Next = upcoming.String() @@ -218,8 +221,8 @@ func (h *SyncV3Handler) serve(w http.ResponseWriter, req *http.Request) *handler // finally update our records: confirm that the client received the token they sent us, and mark this // response as unconfirmed confirmed := fromToken.String() - log.Info().Str("since", confirmed).Str("new_since", upcoming.String()).Bool( - "typing_stream", syncReq.Typing != nil, + log.Info().Str("since", confirmed).Str("new_since", upcoming.String()).Bools( + "request[typing,to_device]", []bool{syncReq.Typing != nil, syncReq.ToDevice != nil}, ).Msg("responding") if err := h.Sessions.UpdateLastTokens(session.ID, confirmed, upcoming.String()); err != nil { return &handlerError{ @@ -377,6 +380,7 @@ func (h *SyncV3Handler) AddToDeviceMessages(userID, deviceID string, msgs []goma } updateToken := sync3.NewBlankSyncToken(0, 0) updateToken.SetToDevicePosition(pos) + fmt.Println("AddToDeviceMessages ", userID, deviceID, len(msgs)) h.Notifier.OnNewSendToDevice(userID, []string{deviceID}, *updateToken) return nil } diff --git a/v3_test.go b/v3_test.go index ac56c1f..a5c9b4d 100644 --- a/v3_test.go +++ b/v3_test.go @@ -1,10 +1,12 @@ package syncv3 import ( + "bytes" "encoding/json" "testing" "time" + "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/sync-v3/sync2" ) @@ -138,3 +140,50 @@ func TestHandler(t *testing.T) { t.Fatalf("typing got %s want %s", v3resp.Typing.UserIDs[0], charlie) } } + +// Test to_device stream: +// - Injecting a to_device event gets received. +// - TODO Repeating the request without having ACKed the position returns the event again. +// - TODO After ACKing the position, going back to the old position returns no event. +// - TODO If 2 sessions exist, both session must ACK the position before the event is deleted. +func TestHandlerToDevice(t *testing.T) { + alice := "@alice:localhost" + aliceBearer := "Bearer alice_access_token" + server, v2Client := newSync3Server(t) + aliceV2Stream := v2Client.v2StreamForUser(alice, aliceBearer) + + // prepare a response from v2 + toDeviceEvent := gomatrixserverlib.SendToDeviceEvent{ + Sender: alice, + Type: "to_device.test", + Content: []byte(`{"foo":"bar"}`), + } + v2Resp := &sync2.SyncResponse{ + NextBatch: "don't care", + ToDevice: struct { + Events []gomatrixserverlib.SendToDeviceEvent `json:"events"` + }{ + Events: []gomatrixserverlib.SendToDeviceEvent{ + toDeviceEvent, + }, + }, + } + + aliceV2Stream <- v2Resp + + v3resp := mustDoSync3Request(t, server, aliceBearer, "", map[string]interface{}{ + "to_device": map[string]interface{}{ + "limit": 5, + }, + }) + if v3resp.ToDevice == nil { + t.Fatalf("expected to_device response, got none: %+v", v3resp) + } + if len(v3resp.ToDevice.Events) != 1 { + t.Fatalf("expected 1 to_device message, got %d", len(v3resp.ToDevice.Events)) + } + want, _ := json.Marshal(toDeviceEvent) + if !bytes.Equal(v3resp.ToDevice.Events[0], want) { + t.Fatalf("wrong event returned, got %s want %s", string(v3resp.ToDevice.Events[0]), string(want)) + } +}