Refactor token handling to more easily support additional positions

This commit is contained in:
Kegan Dougal 2021-08-03 16:25:57 +01:00
parent 5b4b1a10ed
commit 89a6c2e3d7
4 changed files with 121 additions and 81 deletions

View File

@ -38,7 +38,7 @@ func (t *TypingTable) SelectHighestID() (id int64, err error) {
return
}
func (t *TypingTable) SetTyping(roomID string, userIDs []string) (streamID int64, err error) {
func (t *TypingTable) SetTyping(roomID string, userIDs []string) (position int64, err error) {
if userIDs == nil {
userIDs = []string{}
}
@ -46,8 +46,8 @@ func (t *TypingTable) SetTyping(roomID string, userIDs []string) (streamID int64
INSERT INTO syncv3_typing(room_id, user_ids) VALUES($1, $2)
ON CONFLICT (room_id) DO UPDATE SET user_ids = $2, stream_id = nextval('syncv3_typing_seq') RETURNING stream_id`,
roomID, pq.Array(userIDs),
).Scan(&streamID)
return streamID, err
).Scan(&position)
return position, err
}
func (t *TypingTable) Typing(roomID string, fromStreamIDExcl int64) (userIDs []string, latest int64, err error) {

View File

@ -6,22 +6,51 @@ import (
"strings"
)
// V3_S1_57423_123_F9
// "V3_S" $SESSION "_" $NID "_" $TYPING "_F" $FILTER
// To add a new position, specify the const here then increment `totalStreamPositions` and implement
// getters/setters for the new stream position.
const (
IndexEventPosition = iota
IndexTypingPosition
IndexToDevicePosition
)
const totalStreamPositions = 3
// V3_S1_F9_57423_123_5183
// "V3_S" $SESSION "_F" $FILTER "_" $A "_" $B "_" $C
type Token struct {
SessionID int64
NID int64
TypingPosition int64
ToDevicePosition int64
FilterID int64
// User associated values (this is different to sync v2 which doesn't have this concept)
SessionID int64
FilterID int64
// Server-side stream positions (same as sync v2)
positions []int64
}
func (t *Token) EventPosition() int64 {
return t.positions[IndexEventPosition]
}
func (t *Token) TypingPosition() int64 {
return t.positions[IndexTypingPosition]
}
func (t *Token) ToDevicePosition() int64 {
return t.positions[IndexToDevicePosition]
}
func (t *Token) SetEventPosition(pos int64) {
t.positions[IndexEventPosition] = pos
}
func (t *Token) SetTypingPosition(pos int64) {
t.positions[IndexTypingPosition] = pos
}
func (t *Token) SetToDevicePosition(pos int64) {
t.positions[IndexToDevicePosition] = pos
}
func (t *Token) IsAfter(x Token) bool {
if t.NID > x.NID {
return true
}
if t.TypingPosition > x.TypingPosition {
return true
for i := range t.positions {
if t.positions[i] > x.positions[i] {
return true
}
}
return false
}
@ -35,57 +64,70 @@ func (t *Token) AssociateWithUser(userToken Token) {
// ApplyUpdates increments the counters associated with server-side data from `other`, if and only
// if the counters in `other` are newer/higher.
func (t *Token) ApplyUpdates(other Token) {
if other.NID > t.NID {
t.NID = other.NID
}
if other.TypingPosition > t.TypingPosition {
t.TypingPosition = other.TypingPosition
for i := range t.positions {
if other.positions[i] > t.positions[i] {
t.positions[i] = other.positions[i]
}
}
}
func (t *Token) String() string {
var filterID string
if t.FilterID != 0 {
filterID = fmt.Sprintf("%d", t.FilterID)
posStr := make([]string, len(t.positions))
for i := range t.positions {
posStr[i] = strconv.FormatInt(t.positions[i], 10)
}
positions := strings.Join(posStr, "_")
if t.FilterID == 0 {
return fmt.Sprintf("V3_S%d_%s", t.SessionID, positions)
}
return fmt.Sprintf("V3_S%d_F%d_%s", t.SessionID, t.FilterID, positions)
}
func NewBlankSyncToken(sessionID, filterID int64) *Token {
return &Token{
SessionID: sessionID,
FilterID: filterID,
positions: make([]int64, totalStreamPositions),
}
return fmt.Sprintf("V3_S%d_%d_%d_F%s", t.SessionID, t.NID, t.TypingPosition, filterID)
}
func NewSyncToken(since string) (*Token, error) {
segments := strings.SplitN(since, "_", 5)
if len(segments) != 5 {
return nil, fmt.Errorf("not a sync v3 token")
}
segments := strings.Split(since, "_")
if segments[0] != "V3" {
return nil, fmt.Errorf("not a sync v3 token: %s", since)
}
filterstr := strings.TrimPrefix(segments[4], "F")
var fid int64
segments = segments[1:]
var sessionID int64
var filterID int64
var positions []int64
var err error
if len(filterstr) > 0 {
fid, err = strconv.ParseInt(filterstr, 10, 64)
if err != nil {
return nil, fmt.Errorf("invalid filter id: %s", filterstr)
for _, segment := range segments {
if strings.HasPrefix(segment, "F") {
filterStr := strings.TrimPrefix(segment, "F")
filterID, err = strconv.ParseInt(filterStr, 10, 64)
if err != nil {
return nil, fmt.Errorf("invalid filter ID: %s", segment)
}
} else if strings.HasPrefix(segment, "S") {
sessionStr := strings.TrimPrefix(segment, "S")
sessionID, err = strconv.ParseInt(sessionStr, 10, 64)
if err != nil {
return nil, fmt.Errorf("invalid session: %s", segment)
}
} else {
pos, err := strconv.ParseInt(segment, 10, 64)
if err != nil {
return nil, fmt.Errorf("invalid segment '%s': %s", segment, err)
}
positions = append(positions, pos)
}
}
sidstr := strings.TrimPrefix(segments[1], "S")
sid, err := strconv.ParseInt(sidstr, 10, 64)
if err != nil {
return nil, fmt.Errorf("invalid session: %s", sidstr)
}
nid, err := strconv.ParseInt(segments[2], 10, 64)
if err != nil {
return nil, fmt.Errorf("invalid nid: %s", segments[2])
}
typingid, err := strconv.ParseInt(segments[3], 10, 64)
if err != nil {
return nil, fmt.Errorf("invalid typing: %s", segments[3])
if len(positions) != totalStreamPositions {
return nil, fmt.Errorf("expected %d stream positions, got %d", totalStreamPositions, len(positions))
}
return &Token{
SessionID: sid,
NID: nid,
FilterID: fid,
TypingPosition: typingid,
SessionID: sessionID,
FilterID: filterID,
positions: positions,
}, nil
}

View File

@ -22,21 +22,20 @@ func TestNewSyncToken(t *testing.T) {
},
{
// with filter
in: "V3_S1_12_19_F6",
in: "V3_S1_F6_12_19_11",
outToken: &Token{
SessionID: 1,
NID: 12,
TypingPosition: 19,
FilterID: 6,
SessionID: 1,
FilterID: 6,
positions: []int64{12, 19, 11},
},
},
{
// without filter
in: "V3_S1_33_100_F",
in: "V3_S1_33_100_1313",
outToken: &Token{
SessionID: 1,
NID: 33,
TypingPosition: 100,
SessionID: 1,
FilterID: 0,
positions: []int64{33, 100, 1313},
},
},
}

41
v3.go
View File

@ -111,18 +111,18 @@ func NewSyncV3Handler(v2Client sync2.Client, postgresDBURI string) *SyncV3Handle
pollerMu: &sync.Mutex{},
}
sh.typingStream = streams.NewTyping(sh.Storage)
latestToken := sync3.Token{}
latestToken := sync3.NewBlankSyncToken(0, 0)
nid, err := sh.Storage.LatestEventNID()
if err != nil {
panic(err)
}
latestToken.NID = nid
typingID, err := sh.Storage.LatestTypingID()
latestToken.SetEventPosition(nid)
typingPos, err := sh.Storage.LatestTypingID()
if err != nil {
panic(err)
}
latestToken.TypingPosition = typingID
sh.Notifier = notifier.NewNotifier(latestToken)
latestToken.SetTypingPosition(typingPos)
sh.Notifier = notifier.NewNotifier(*latestToken)
// TODO: load up the membership states so the notifier knows who to wake up
/*
roomIDToUserIDs, err := sh.Storage.AllJoinedMembers()
@ -200,14 +200,14 @@ func (h *SyncV3Handler) serve(w http.ResponseWriter, req *http.Request) *handler
// invoke streams to get responses
if syncReq.Typing != nil {
typingResp, typingTo, err := h.typingStream.Process(session.UserID, fromToken.TypingPosition, syncReq.Typing)
typingResp, typingTo, err := h.typingStream.Process(session.UserID, fromToken.TypingPosition(), syncReq.Typing)
if err != nil {
return &handlerError{
StatusCode: 500,
err: fmt.Errorf("typing stream: %s", err),
}
}
upcoming.TypingPosition = typingTo
upcoming.SetTypingPosition(typingTo)
resp.Typing = typingResp
}
if syncReq.ToDevice != nil {
@ -275,9 +275,7 @@ func (h *SyncV3Handler) getOrCreateSession(req *http.Request) (*sync3.Session, *
h.Sessions.UpdateUserIDForDevice(deviceID, userID)
}
if tokv3 == nil {
tokv3 = &sync3.Token{
SessionID: session.ID,
}
tokv3 = sync3.NewBlankSyncToken(session.ID, 0)
}
return session, tokv3, nil
}
@ -319,18 +317,19 @@ func (h *SyncV3Handler) Accumulate(roomID string, timeline []json.RawMessage) er
if numNew == 0 {
return nil
}
var updateToken sync3.Token
updateToken := sync3.NewBlankSyncToken(0, 0)
// TODO: read from memory, persist in Storage?
updateToken.NID, err = h.Storage.LatestEventNID()
nid, err := h.Storage.LatestEventNID()
if err != nil {
return err
}
updateToken.SetEventPosition(nid)
newEvents := timeline[len(timeline)-numNew:]
for _, eventJSON := range newEvents {
event := gjson.ParseBytes(eventJSON)
h.Notifier.OnNewEvent(
roomID, event.Get("sender").Str, event.Get("type").Str,
event.Get("state_key").Str, event.Get("content.membership").Str, nil, updateToken,
event.Get("state_key").Str, event.Get("content.membership").Str, nil, *updateToken,
)
}
return nil
@ -361,14 +360,14 @@ func (h *SyncV3Handler) Initialise(roomID string, state []json.RawMessage) error
// Called from the v2 poller, implements V2DataReceiver
func (h *SyncV3Handler) SetTyping(roomID string, userIDs []string) (int64, error) {
typingID, err := h.Storage.TypingTable.SetTyping(roomID, userIDs)
pos, err := h.Storage.TypingTable.SetTyping(roomID, userIDs)
if err != nil {
return 0, err
}
var updateToken sync3.Token
updateToken.TypingPosition = typingID
h.Notifier.OnNewTyping(roomID, updateToken)
return typingID, nil
updateToken := sync3.NewBlankSyncToken(0, 0)
updateToken.SetTypingPosition(pos)
h.Notifier.OnNewTyping(roomID, *updateToken)
return pos, nil
}
func (h *SyncV3Handler) AddToDeviceMessages(userID, deviceID string, msgs []gomatrixserverlib.SendToDeviceEvent) error {
@ -376,9 +375,9 @@ func (h *SyncV3Handler) AddToDeviceMessages(userID, deviceID string, msgs []goma
if err != nil {
return err
}
var updateToken sync3.Token
updateToken.ToDevicePosition = pos
h.Notifier.OnNewSendToDevice(userID, []string{deviceID}, updateToken)
updateToken := sync3.NewBlankSyncToken(0, 0)
updateToken.SetToDevicePosition(pos)
h.Notifier.OnNewSendToDevice(userID, []string{deviceID}, *updateToken)
return nil
}