diff --git a/state/accumulator.go b/state/accumulator.go index acab4555..7e29badb 100644 --- a/state/accumulator.go +++ b/state/accumulator.go @@ -483,25 +483,3 @@ func (a *Accumulator) filterAndParseTimelineEvents(txn *sqlx.Tx, roomID string, // A is seen event s[A,B,C] => s[0+1:] => [B,C] return dedupedEvents[seenIndex+1:], nil } - -// Delta returns a list of events of at most `limit` for the room not including `lastEventNID`. -// Returns the latest NID of the last event (most recent) -func (a *Accumulator) Delta(roomID string, lastEventNID int64, limit int) (eventsJSON []json.RawMessage, latest int64, err error) { - txn, err := a.db.Beginx() - if err != nil { - return nil, 0, err - } - defer txn.Commit() - events, err := a.eventsTable.SelectEventsBetween(txn, roomID, lastEventNID, EventsEnd, limit) - if err != nil { - return nil, 0, err - } - if len(events) == 0 { - return nil, lastEventNID, nil - } - eventsJSON = make([]json.RawMessage, len(events)) - for i := range events { - eventsJSON[i] = events[i].JSON - } - return eventsJSON, int64(events[len(events)-1].NID), nil -} diff --git a/state/accumulator_test.go b/state/accumulator_test.go index 64ee6c86..0e546ae5 100644 --- a/state/accumulator_test.go +++ b/state/accumulator_test.go @@ -200,59 +200,6 @@ func TestAccumulatorAccumulate(t *testing.T) { } } -func TestAccumulatorDelta(t *testing.T) { - roomID := "!TestAccumulatorDelta:localhost" - db, close := connectToDB(t) - defer close() - accumulator := NewAccumulator(db) - _, err := accumulator.Initialise(roomID, nil) - if err != nil { - t.Fatalf("failed to Initialise accumulator: %s", err) - } - roomEvents := []json.RawMessage{ - []byte(`{"event_id":"aD", "type":"m.room.create", "state_key":"", "content":{"creator":"@TestAccumulatorDelta:localhost"}}`), - []byte(`{"event_id":"aE", "type":"m.room.member", "state_key":"@TestAccumulatorDelta:localhost", "content":{"membership":"join"}}`), - []byte(`{"event_id":"aF", "type":"m.room.join_rules", "state_key":"", "content":{"join_rule":"public"}}`), - []byte(`{"event_id":"aG", "type":"m.room.message","content":{"body":"Hello World","msgtype":"m.text"}}`), - []byte(`{"event_id":"aH", "type":"m.room.join_rules", "state_key":"", "content":{"join_rule":"public"}}`), - []byte(`{"event_id":"aI", "type":"m.room.history_visibility", "state_key":"", "content":{"visibility":"public"}}`), - } - err = sqlutil.WithTransaction(accumulator.db, func(txn *sqlx.Tx) error { - _, _, err = accumulator.Accumulate(txn, roomID, "", roomEvents) - return err - }) - if err != nil { - t.Fatalf("failed to Accumulate: %s", err) - } - - // Draw the create event, tests limits - events, position, err := accumulator.Delta(roomID, EventsStart, 1) - if err != nil { - t.Fatalf("failed to Delta: %s", err) - } - if len(events) != 1 { - t.Fatalf("failed to get events from Delta, got %d want 1", len(events)) - } - if gjson.GetBytes(events[0], "event_id").Str != gjson.GetBytes(roomEvents[0], "event_id").Str { - t.Fatalf("failed to draw first event, got %s want %s", string(events[0]), string(roomEvents[0])) - } - if position == 0 { - t.Errorf("Delta returned zero position") - } - - // Draw up to the end - events, position, err = accumulator.Delta(roomID, position, 1000) - if err != nil { - t.Fatalf("failed to Delta: %s", err) - } - if len(events) != len(roomEvents)-1 { - t.Fatalf("failed to get events from Delta, got %d want %d", len(events), len(roomEvents)-1) - } - if position == 0 { - t.Errorf("Delta returned zero position") - } -} - func TestAccumulatorMembershipLogs(t *testing.T) { roomID := "!TestAccumulatorMembershipLogs:localhost" db, close := connectToDB(t) diff --git a/state/event_table.go b/state/event_table.go index 578afbf5..46f43df4 100644 --- a/state/event_table.go +++ b/state/event_table.go @@ -336,14 +336,6 @@ func (t *EventTable) LatestEventNIDInRooms(txn *sqlx.Tx, roomIDs []string, highe return } -func (t *EventTable) SelectEventsBetween(txn *sqlx.Tx, roomID string, lowerExclusive, upperInclusive int64, limit int) ([]Event, error) { - var events []Event - err := txn.Select(&events, `SELECT event_nid, event FROM syncv3_events WHERE event_nid > $1 AND event_nid <= $2 AND room_id = $3 ORDER BY event_nid ASC LIMIT $4`, - lowerExclusive, upperInclusive, roomID, limit, - ) - return events, err -} - func (t *EventTable) SelectLatestEventsBetween(txn *sqlx.Tx, roomID string, lowerExclusive, upperInclusive int64, limit int) ([]Event, error) { var events []Event // do not pull in events which were in the v2 state block diff --git a/state/event_table_test.go b/state/event_table_test.go index c015b2b2..db4bab36 100644 --- a/state/event_table_test.go +++ b/state/event_table_test.go @@ -297,125 +297,6 @@ func TestEventTableDupeInsert(t *testing.T) { } } -func TestEventTableSelectEventsBetween(t *testing.T) { - db, close := connectToDB(t) - defer close() - txn, err := db.Beginx() - if err != nil { - t.Fatalf("failed to start txn: %s", err) - } - table := NewEventTable(db) - searchRoomID := "!0TestEventTableSelectEventsBetween:localhost" - eventIDs := []string{ - "100TestEventTableSelectEventsBetween", - "101TestEventTableSelectEventsBetween", - "102TestEventTableSelectEventsBetween", - "103TestEventTableSelectEventsBetween", - "104TestEventTableSelectEventsBetween", - } - events := []Event{ - { - JSON: []byte(`{"event_id":"` + eventIDs[0] + `","type": "T1", "state_key":"S1", "room_id":"` + searchRoomID + `"}`), - }, - { - JSON: []byte(`{"event_id":"` + eventIDs[1] + `","type": "T2", "state_key":"S2", "room_id":"` + searchRoomID + `"}`), - }, - { - JSON: []byte(`{"event_id":"` + eventIDs[2] + `","type": "T3", "state_key":"", "room_id":"` + searchRoomID + `"}`), - }, - { - // different room - JSON: []byte(`{"event_id":"` + eventIDs[3] + `","type": "T4", "state_key":"", "room_id":"!1TestEventTableSelectEventsBetween:localhost"}`), - }, - { - JSON: []byte(`{"event_id":"` + eventIDs[4] + `","type": "T5", "state_key":"", "room_id":"` + searchRoomID + `"}`), - }, - } - idToNID, err := table.Insert(txn, events, true) - if err != nil { - t.Fatalf("Insert failed: %s", err) - } - if len(idToNID) != len(events) { - t.Fatalf("failed to insert events: got %d want %d", len(idToNID), len(events)) - } - txn.Commit() - - t.Run("subgroup", func(t *testing.T) { - t.Run("selecting multiple events known lower bound", func(t *testing.T) { - t.Parallel() - txn2, err := db.Beginx() - if err != nil { - t.Fatalf("failed to start txn: %s", err) - } - defer txn2.Rollback() - events, err := table.SelectByIDs(txn2, true, []string{eventIDs[0]}) - if err != nil || len(events) == 0 { - t.Fatalf("failed to extract event for lower bound: %s", err) - } - events, err = table.SelectEventsBetween(txn2, searchRoomID, int64(events[0].NID), EventsEnd, 1000) - if err != nil { - t.Fatalf("failed to SelectEventsBetween: %s", err) - } - // 3 as 1 is from a different room - if len(events) != 3 { - t.Fatalf("wanted 3 events, got %d", len(events)) - } - }) - t.Run("selecting multiple events known lower and upper bound", func(t *testing.T) { - t.Parallel() - txn3, err := db.Beginx() - if err != nil { - t.Fatalf("failed to start txn: %s", err) - } - defer txn3.Rollback() - events, err := table.SelectByIDs(txn3, true, []string{eventIDs[0], eventIDs[2]}) - if err != nil || len(events) == 0 { - t.Fatalf("failed to extract event for lower/upper bound: %s", err) - } - events, err = table.SelectEventsBetween(txn3, searchRoomID, int64(events[0].NID), int64(events[1].NID), 1000) - if err != nil { - t.Fatalf("failed to SelectEventsBetween: %s", err) - } - // eventIDs[1] and eventIDs[2] - if len(events) != 2 { - t.Fatalf("wanted 2 events, got %d", len(events)) - } - }) - t.Run("selecting multiple events unknown bounds (all events)", func(t *testing.T) { - t.Parallel() - txn4, err := db.Beginx() - if err != nil { - t.Fatalf("failed to start txn: %s", err) - } - defer txn4.Rollback() - gotEvents, err := table.SelectEventsBetween(txn4, searchRoomID, EventsStart, EventsEnd, 1000) - if err != nil { - t.Fatalf("failed to SelectEventsBetween: %s", err) - } - // one less as one event is for a different room - if len(gotEvents) != (len(events) - 1) { - t.Fatalf("wanted %d events, got %d", len(events)-1, len(gotEvents)) - } - }) - t.Run("selecting multiple events hitting the limit", func(t *testing.T) { - t.Parallel() - txn5, err := db.Beginx() - if err != nil { - t.Fatalf("failed to start txn: %s", err) - } - defer txn5.Rollback() - limit := 2 - gotEvents, err := table.SelectEventsBetween(txn5, searchRoomID, EventsStart, EventsEnd, limit) - if err != nil { - t.Fatalf("failed to SelectEventsBetween: %s", err) - } - if len(gotEvents) != limit { - t.Fatalf("wanted %d events, got %d", limit, len(gotEvents)) - } - }) - }) -} - func TestEventTableMembershipDetection(t *testing.T) { db, close := connectToDB(t) defer close() diff --git a/state/storage.go b/state/storage.go index bdff086e..d0cb63bd 100644 --- a/state/storage.go +++ b/state/storage.go @@ -235,45 +235,6 @@ func (s *Storage) MetadataForAllRooms(txn *sqlx.Tx, tempTableName string, result result[roomID] = metadata } - // Select the most recent members for each room to serve as Heroes. The spec is ambiguous here: - // "This should be the first 5 members of the room, ordered by stream ordering, which are joined or invited." - // Unclear if this is the first 5 *most recent* (backwards) or forwards. For now we'll use the most recent - // ones, and select 6 of them so we can always use 5 no matter who is requesting the room name. - rows, err := txn.Query(` - SELECT rf.* FROM ( - SELECT room_id, event, rank() OVER ( - PARTITION BY room_id ORDER BY event_nid DESC - ) FROM syncv3_events INNER JOIN ` + tempTableName + ` ON membership_nid=event_nid WHERE ( - membership='join' OR membership='invite' OR membership='_join' - ) AND event_type='m.room.member' - ) rf WHERE rank <= 6;`) - if err != nil { - return fmt.Errorf("failed to query heroes: %s", err) - } - defer rows.Close() - seen := map[string]bool{} - for rows.Next() { - var roomID string - var event json.RawMessage - var rank int - if err := rows.Scan(&roomID, &event, &rank); err != nil { - return err - } - ev := gjson.ParseBytes(event) - targetUser := ev.Get("state_key").Str - key := roomID + " " + targetUser - if seen[key] { - continue - } - seen[key] = true - metadata := loadMetadata(roomID) - metadata.Heroes = append(metadata.Heroes, internal.Hero{ - ID: targetUser, - Name: ev.Get("content.displayname").Str, - Avatar: ev.Get("content.avatar_url").Str, - }) - result[roomID] = metadata - } roomInfos, err := s.Accumulator.roomsTable.SelectRoomInfos(txn) if err != nil { return fmt.Errorf("failed to select room infos: %s", err) @@ -803,9 +764,14 @@ func (s *Storage) RoomMembershipDelta(roomID string, from, to int64, limit int) } // Extract all rooms with joined members, and include the joined user list. Requires a prepared snapshot in order to be called. +// Populates the join/invite count and heroes for the returned metadata. func (s *Storage) AllJoinedMembers(txn *sqlx.Tx, tempTableName string) (joinedMembers map[string][]string, metadata map[string]internal.RoomMetadata, err error) { + // Select the most recent members for each room to serve as Heroes. The spec is ambiguous here: + // "This should be the first 5 members of the room, ordered by stream ordering, which are joined or invited." + // Unclear if this is the first 5 *most recent* (backwards) or forwards. For now we'll use the most recent + // ones, and select 6 of them so we can always use 5 no matter who is requesting the room name. rows, err := txn.Query( - `SELECT room_id, state_key, membership from ` + tempTableName + ` INNER JOIN syncv3_events + `SELECT membership_nid, room_id, state_key, membership from ` + tempTableName + ` INNER JOIN syncv3_events on membership_nid = event_nid WHERE membership='join' OR membership='_join' OR membership='invite' OR membership='_invite' ORDER BY event_nid ASC`, ) if err != nil { @@ -813,14 +779,21 @@ func (s *Storage) AllJoinedMembers(txn *sqlx.Tx, tempTableName string) (joinedMe } defer rows.Close() joinedMembers = make(map[string][]string) - var roomID string inviteCounts := make(map[string]int) + heroNIDs := make(map[string]*circularSlice) var stateKey string var membership string + var roomID string + var nid int64 for rows.Next() { - if err := rows.Scan(&roomID, &stateKey, &membership); err != nil { + if err := rows.Scan(&nid, &roomID, &stateKey, &membership); err != nil { return nil, nil, err } + heroes := heroNIDs[roomID] + if heroes == nil { + heroes = &circularSlice{max: 6} + heroNIDs[roomID] = heroes + } switch membership { case "join": fallthrough @@ -828,17 +801,44 @@ func (s *Storage) AllJoinedMembers(txn *sqlx.Tx, tempTableName string) (joinedMe users := joinedMembers[roomID] users = append(users, stateKey) joinedMembers[roomID] = users + heroes.append(nid) case "invite": fallthrough case "_invite": inviteCounts[roomID] = inviteCounts[roomID] + 1 + heroes.append(nid) } } + + // now select the membership events for the heroes + var allHeroNIDs []int64 + for _, nids := range heroNIDs { + allHeroNIDs = append(allHeroNIDs, nids.vals...) + } + heroEvents, err := s.EventsTable.SelectByNIDs(txn, true, allHeroNIDs) + if err != nil { + return nil, nil, err + } + heroes := make(map[string][]internal.Hero) + // loop backwards so the most recent hero is first in the hero list + for i := len(heroEvents) - 1; i >= 0; i-- { + ev := heroEvents[i] + evJSON := gjson.ParseBytes(ev.JSON) + roomHeroes := heroes[ev.RoomID] + roomHeroes = append(roomHeroes, internal.Hero{ + ID: ev.StateKey, + Name: evJSON.Get("content.displayname").Str, + Avatar: evJSON.Get("content.avatar_url").Str, + }) + heroes[ev.RoomID] = roomHeroes + } + metadata = make(map[string]internal.RoomMetadata) for roomID, members := range joinedMembers { m := internal.NewRoomMetadata(roomID) m.JoinCount = len(members) m.InviteCount = inviteCounts[roomID] + m.Heroes = heroes[roomID] metadata[roomID] = *m } return joinedMembers, metadata, nil @@ -938,3 +938,27 @@ func (s *Storage) Teardown() { panic("Storage.Teardown: " + err.Error()) } } + +// circularSlice is a slice which can be appended to which will wraparound at `max`. +// Mostly useful for lazily calculating heroes. The values returned aren't sorted. +type circularSlice struct { + i int + vals []int64 + max int +} + +func (s *circularSlice) append(val int64) { + if len(s.vals) < s.max { + // populate up to max + s.vals = append(s.vals, val) + s.i++ + return + } + // wraparound + if s.i == s.max { + s.i = 0 + } + // replace this entry + s.vals[s.i] = val + s.i++ +} diff --git a/state/storage_test.go b/state/storage_test.go index c5eaa192..f5ffeec6 100644 --- a/state/storage_test.go +++ b/state/storage_test.go @@ -822,6 +822,56 @@ func TestAllJoinedMembers(t *testing.T) { } } +func TestCircularSlice(t *testing.T) { + testCases := []struct { + name string + max int + appends []int64 + want []int64 // these get sorted in the test + }{ + { + name: "wraparound", + max: 5, + appends: []int64{9, 8, 7, 6, 5, 4, 3, 2}, + want: []int64{2, 3, 4, 5, 6}, + }, + { + name: "exact", + max: 5, + appends: []int64{9, 8, 7, 6, 5}, + want: []int64{5, 6, 7, 8, 9}, + }, + { + name: "unfilled", + max: 5, + appends: []int64{9, 8, 7}, + want: []int64{7, 8, 9}, + }, + { + name: "wraparound x2", + max: 5, + appends: []int64{9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 10}, + want: []int64{0, 1, 2, 3, 10}, + }, + } + for _, tc := range testCases { + cs := &circularSlice{ + max: tc.max, + } + for _, val := range tc.appends { + cs.append(val) + } + sort.Slice(cs.vals, func(i, j int) bool { + return cs.vals[i] < cs.vals[j] + }) + if !reflect.DeepEqual(cs.vals, tc.want) { + t.Errorf("%s: got %v want %v", tc.name, cs.vals, tc.want) + } + + } + +} + func cleanDB(t *testing.T) error { // make a fresh DB which is unpolluted from other tests db, close := connectToDB(t)