mirror of
https://github.com/matrix-org/sliding-sync.git
synced 2025-03-10 13:37:11 +00:00
Refactor token handling to more easily support additional positions
This commit is contained in:
parent
5b4b1a10ed
commit
89a6c2e3d7
@ -38,7 +38,7 @@ func (t *TypingTable) SelectHighestID() (id int64, err error) {
|
||||
return
|
||||
}
|
||||
|
||||
func (t *TypingTable) SetTyping(roomID string, userIDs []string) (streamID int64, err error) {
|
||||
func (t *TypingTable) SetTyping(roomID string, userIDs []string) (position int64, err error) {
|
||||
if userIDs == nil {
|
||||
userIDs = []string{}
|
||||
}
|
||||
@ -46,8 +46,8 @@ func (t *TypingTable) SetTyping(roomID string, userIDs []string) (streamID int64
|
||||
INSERT INTO syncv3_typing(room_id, user_ids) VALUES($1, $2)
|
||||
ON CONFLICT (room_id) DO UPDATE SET user_ids = $2, stream_id = nextval('syncv3_typing_seq') RETURNING stream_id`,
|
||||
roomID, pq.Array(userIDs),
|
||||
).Scan(&streamID)
|
||||
return streamID, err
|
||||
).Scan(&position)
|
||||
return position, err
|
||||
}
|
||||
|
||||
func (t *TypingTable) Typing(roomID string, fromStreamIDExcl int64) (userIDs []string, latest int64, err error) {
|
||||
|
138
sync3/token.go
138
sync3/token.go
@ -6,22 +6,51 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
// V3_S1_57423_123_F9
|
||||
// "V3_S" $SESSION "_" $NID "_" $TYPING "_F" $FILTER
|
||||
// To add a new position, specify the const here then increment `totalStreamPositions` and implement
|
||||
// getters/setters for the new stream position.
|
||||
|
||||
const (
|
||||
IndexEventPosition = iota
|
||||
IndexTypingPosition
|
||||
IndexToDevicePosition
|
||||
)
|
||||
const totalStreamPositions = 3
|
||||
|
||||
// V3_S1_F9_57423_123_5183
|
||||
// "V3_S" $SESSION "_F" $FILTER "_" $A "_" $B "_" $C
|
||||
type Token struct {
|
||||
SessionID int64
|
||||
NID int64
|
||||
TypingPosition int64
|
||||
ToDevicePosition int64
|
||||
FilterID int64
|
||||
// User associated values (this is different to sync v2 which doesn't have this concept)
|
||||
SessionID int64
|
||||
FilterID int64
|
||||
|
||||
// Server-side stream positions (same as sync v2)
|
||||
positions []int64
|
||||
}
|
||||
|
||||
func (t *Token) EventPosition() int64 {
|
||||
return t.positions[IndexEventPosition]
|
||||
}
|
||||
func (t *Token) TypingPosition() int64 {
|
||||
return t.positions[IndexTypingPosition]
|
||||
}
|
||||
func (t *Token) ToDevicePosition() int64 {
|
||||
return t.positions[IndexToDevicePosition]
|
||||
}
|
||||
func (t *Token) SetEventPosition(pos int64) {
|
||||
t.positions[IndexEventPosition] = pos
|
||||
}
|
||||
func (t *Token) SetTypingPosition(pos int64) {
|
||||
t.positions[IndexTypingPosition] = pos
|
||||
}
|
||||
func (t *Token) SetToDevicePosition(pos int64) {
|
||||
t.positions[IndexToDevicePosition] = pos
|
||||
}
|
||||
|
||||
func (t *Token) IsAfter(x Token) bool {
|
||||
if t.NID > x.NID {
|
||||
return true
|
||||
}
|
||||
if t.TypingPosition > x.TypingPosition {
|
||||
return true
|
||||
for i := range t.positions {
|
||||
if t.positions[i] > x.positions[i] {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
@ -35,57 +64,70 @@ func (t *Token) AssociateWithUser(userToken Token) {
|
||||
// ApplyUpdates increments the counters associated with server-side data from `other`, if and only
|
||||
// if the counters in `other` are newer/higher.
|
||||
func (t *Token) ApplyUpdates(other Token) {
|
||||
if other.NID > t.NID {
|
||||
t.NID = other.NID
|
||||
}
|
||||
if other.TypingPosition > t.TypingPosition {
|
||||
t.TypingPosition = other.TypingPosition
|
||||
for i := range t.positions {
|
||||
if other.positions[i] > t.positions[i] {
|
||||
t.positions[i] = other.positions[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Token) String() string {
|
||||
var filterID string
|
||||
if t.FilterID != 0 {
|
||||
filterID = fmt.Sprintf("%d", t.FilterID)
|
||||
posStr := make([]string, len(t.positions))
|
||||
for i := range t.positions {
|
||||
posStr[i] = strconv.FormatInt(t.positions[i], 10)
|
||||
}
|
||||
positions := strings.Join(posStr, "_")
|
||||
if t.FilterID == 0 {
|
||||
return fmt.Sprintf("V3_S%d_%s", t.SessionID, positions)
|
||||
}
|
||||
return fmt.Sprintf("V3_S%d_F%d_%s", t.SessionID, t.FilterID, positions)
|
||||
}
|
||||
|
||||
func NewBlankSyncToken(sessionID, filterID int64) *Token {
|
||||
return &Token{
|
||||
SessionID: sessionID,
|
||||
FilterID: filterID,
|
||||
positions: make([]int64, totalStreamPositions),
|
||||
}
|
||||
return fmt.Sprintf("V3_S%d_%d_%d_F%s", t.SessionID, t.NID, t.TypingPosition, filterID)
|
||||
}
|
||||
|
||||
func NewSyncToken(since string) (*Token, error) {
|
||||
segments := strings.SplitN(since, "_", 5)
|
||||
if len(segments) != 5 {
|
||||
return nil, fmt.Errorf("not a sync v3 token")
|
||||
}
|
||||
segments := strings.Split(since, "_")
|
||||
if segments[0] != "V3" {
|
||||
return nil, fmt.Errorf("not a sync v3 token: %s", since)
|
||||
}
|
||||
filterstr := strings.TrimPrefix(segments[4], "F")
|
||||
var fid int64
|
||||
segments = segments[1:]
|
||||
var sessionID int64
|
||||
var filterID int64
|
||||
var positions []int64
|
||||
var err error
|
||||
if len(filterstr) > 0 {
|
||||
fid, err = strconv.ParseInt(filterstr, 10, 64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid filter id: %s", filterstr)
|
||||
for _, segment := range segments {
|
||||
if strings.HasPrefix(segment, "F") {
|
||||
filterStr := strings.TrimPrefix(segment, "F")
|
||||
filterID, err = strconv.ParseInt(filterStr, 10, 64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid filter ID: %s", segment)
|
||||
}
|
||||
} else if strings.HasPrefix(segment, "S") {
|
||||
sessionStr := strings.TrimPrefix(segment, "S")
|
||||
sessionID, err = strconv.ParseInt(sessionStr, 10, 64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid session: %s", segment)
|
||||
}
|
||||
} else {
|
||||
pos, err := strconv.ParseInt(segment, 10, 64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid segment '%s': %s", segment, err)
|
||||
}
|
||||
positions = append(positions, pos)
|
||||
}
|
||||
}
|
||||
|
||||
sidstr := strings.TrimPrefix(segments[1], "S")
|
||||
sid, err := strconv.ParseInt(sidstr, 10, 64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid session: %s", sidstr)
|
||||
}
|
||||
nid, err := strconv.ParseInt(segments[2], 10, 64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid nid: %s", segments[2])
|
||||
}
|
||||
typingid, err := strconv.ParseInt(segments[3], 10, 64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid typing: %s", segments[3])
|
||||
if len(positions) != totalStreamPositions {
|
||||
return nil, fmt.Errorf("expected %d stream positions, got %d", totalStreamPositions, len(positions))
|
||||
}
|
||||
return &Token{
|
||||
SessionID: sid,
|
||||
NID: nid,
|
||||
FilterID: fid,
|
||||
TypingPosition: typingid,
|
||||
SessionID: sessionID,
|
||||
FilterID: filterID,
|
||||
positions: positions,
|
||||
}, nil
|
||||
}
|
||||
|
@ -22,21 +22,20 @@ func TestNewSyncToken(t *testing.T) {
|
||||
},
|
||||
{
|
||||
// with filter
|
||||
in: "V3_S1_12_19_F6",
|
||||
in: "V3_S1_F6_12_19_11",
|
||||
outToken: &Token{
|
||||
SessionID: 1,
|
||||
NID: 12,
|
||||
TypingPosition: 19,
|
||||
FilterID: 6,
|
||||
SessionID: 1,
|
||||
FilterID: 6,
|
||||
positions: []int64{12, 19, 11},
|
||||
},
|
||||
},
|
||||
{
|
||||
// without filter
|
||||
in: "V3_S1_33_100_F",
|
||||
in: "V3_S1_33_100_1313",
|
||||
outToken: &Token{
|
||||
SessionID: 1,
|
||||
NID: 33,
|
||||
TypingPosition: 100,
|
||||
SessionID: 1,
|
||||
FilterID: 0,
|
||||
positions: []int64{33, 100, 1313},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
41
v3.go
41
v3.go
@ -111,18 +111,18 @@ func NewSyncV3Handler(v2Client sync2.Client, postgresDBURI string) *SyncV3Handle
|
||||
pollerMu: &sync.Mutex{},
|
||||
}
|
||||
sh.typingStream = streams.NewTyping(sh.Storage)
|
||||
latestToken := sync3.Token{}
|
||||
latestToken := sync3.NewBlankSyncToken(0, 0)
|
||||
nid, err := sh.Storage.LatestEventNID()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
latestToken.NID = nid
|
||||
typingID, err := sh.Storage.LatestTypingID()
|
||||
latestToken.SetEventPosition(nid)
|
||||
typingPos, err := sh.Storage.LatestTypingID()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
latestToken.TypingPosition = typingID
|
||||
sh.Notifier = notifier.NewNotifier(latestToken)
|
||||
latestToken.SetTypingPosition(typingPos)
|
||||
sh.Notifier = notifier.NewNotifier(*latestToken)
|
||||
// TODO: load up the membership states so the notifier knows who to wake up
|
||||
/*
|
||||
roomIDToUserIDs, err := sh.Storage.AllJoinedMembers()
|
||||
@ -200,14 +200,14 @@ func (h *SyncV3Handler) serve(w http.ResponseWriter, req *http.Request) *handler
|
||||
|
||||
// invoke streams to get responses
|
||||
if syncReq.Typing != nil {
|
||||
typingResp, typingTo, err := h.typingStream.Process(session.UserID, fromToken.TypingPosition, syncReq.Typing)
|
||||
typingResp, typingTo, err := h.typingStream.Process(session.UserID, fromToken.TypingPosition(), syncReq.Typing)
|
||||
if err != nil {
|
||||
return &handlerError{
|
||||
StatusCode: 500,
|
||||
err: fmt.Errorf("typing stream: %s", err),
|
||||
}
|
||||
}
|
||||
upcoming.TypingPosition = typingTo
|
||||
upcoming.SetTypingPosition(typingTo)
|
||||
resp.Typing = typingResp
|
||||
}
|
||||
if syncReq.ToDevice != nil {
|
||||
@ -275,9 +275,7 @@ func (h *SyncV3Handler) getOrCreateSession(req *http.Request) (*sync3.Session, *
|
||||
h.Sessions.UpdateUserIDForDevice(deviceID, userID)
|
||||
}
|
||||
if tokv3 == nil {
|
||||
tokv3 = &sync3.Token{
|
||||
SessionID: session.ID,
|
||||
}
|
||||
tokv3 = sync3.NewBlankSyncToken(session.ID, 0)
|
||||
}
|
||||
return session, tokv3, nil
|
||||
}
|
||||
@ -319,18 +317,19 @@ func (h *SyncV3Handler) Accumulate(roomID string, timeline []json.RawMessage) er
|
||||
if numNew == 0 {
|
||||
return nil
|
||||
}
|
||||
var updateToken sync3.Token
|
||||
updateToken := sync3.NewBlankSyncToken(0, 0)
|
||||
// TODO: read from memory, persist in Storage?
|
||||
updateToken.NID, err = h.Storage.LatestEventNID()
|
||||
nid, err := h.Storage.LatestEventNID()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
updateToken.SetEventPosition(nid)
|
||||
newEvents := timeline[len(timeline)-numNew:]
|
||||
for _, eventJSON := range newEvents {
|
||||
event := gjson.ParseBytes(eventJSON)
|
||||
h.Notifier.OnNewEvent(
|
||||
roomID, event.Get("sender").Str, event.Get("type").Str,
|
||||
event.Get("state_key").Str, event.Get("content.membership").Str, nil, updateToken,
|
||||
event.Get("state_key").Str, event.Get("content.membership").Str, nil, *updateToken,
|
||||
)
|
||||
}
|
||||
return nil
|
||||
@ -361,14 +360,14 @@ func (h *SyncV3Handler) Initialise(roomID string, state []json.RawMessage) error
|
||||
|
||||
// Called from the v2 poller, implements V2DataReceiver
|
||||
func (h *SyncV3Handler) SetTyping(roomID string, userIDs []string) (int64, error) {
|
||||
typingID, err := h.Storage.TypingTable.SetTyping(roomID, userIDs)
|
||||
pos, err := h.Storage.TypingTable.SetTyping(roomID, userIDs)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
var updateToken sync3.Token
|
||||
updateToken.TypingPosition = typingID
|
||||
h.Notifier.OnNewTyping(roomID, updateToken)
|
||||
return typingID, nil
|
||||
updateToken := sync3.NewBlankSyncToken(0, 0)
|
||||
updateToken.SetTypingPosition(pos)
|
||||
h.Notifier.OnNewTyping(roomID, *updateToken)
|
||||
return pos, nil
|
||||
}
|
||||
|
||||
func (h *SyncV3Handler) AddToDeviceMessages(userID, deviceID string, msgs []gomatrixserverlib.SendToDeviceEvent) error {
|
||||
@ -376,9 +375,9 @@ func (h *SyncV3Handler) AddToDeviceMessages(userID, deviceID string, msgs []goma
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var updateToken sync3.Token
|
||||
updateToken.ToDevicePosition = pos
|
||||
h.Notifier.OnNewSendToDevice(userID, []string{deviceID}, updateToken)
|
||||
updateToken := sync3.NewBlankSyncToken(0, 0)
|
||||
updateToken.SetToDevicePosition(pos)
|
||||
h.Notifier.OnNewSendToDevice(userID, []string{deviceID}, *updateToken)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user