Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 25 additions & 5 deletions pkg/roomkeysender/roomkeysender.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,21 +25,41 @@ func NewSender(pub Publisher) *Sender {
return &Sender{pub: pub}
}

// Send publishes evt to the room key update subject for the given user account.
// The event is accepted by value; Send stamps its own Timestamp before publishing.
// The value copy is intentional: Send must not mutate the caller's struct.
// Marshal stamps the event Timestamp and serializes it once into a payload that
// can be fanned out to many accounts via SendData without re-marshaling per
// recipient. The event is accepted by value; Marshal must not mutate the
// caller's struct.
//
//nolint:gocritic // hugeParam: by-value is intentional for immutability; the copy cost is acceptable.
func (s *Sender) Send(account string, evt model.RoomKeyEvent) error {
func (s *Sender) Marshal(evt model.RoomKeyEvent) ([]byte, error) {
evt.Timestamp = time.Now().UTC().UnixMilli()
// #nosec G117 -- RoomKeyEvent.PrivateKey is the intended payload: room-key distribution to the authorized account over its auth-callout-gated per-user subject, not a leak
data, err := json.Marshal(evt)
if err != nil {
return fmt.Errorf("marshal room key event: %w", err)
return nil, fmt.Errorf("marshal room key event: %w", err)
}
return data, nil
}

// SendData publishes an already-marshaled payload (from Marshal) to the room key
// update subject for the given user account.
func (s *Sender) SendData(account string, data []byte) error {
subj := subject.RoomKeyUpdate(account)
if err := s.pub.Publish(subj, data); err != nil {
return fmt.Errorf("publish room key event: %w", err)
}
return nil
}

// Send publishes evt to the room key update subject for the given user account.
// The event is accepted by value; Send stamps its own Timestamp before publishing.
// The value copy is intentional: Send must not mutate the caller's struct.
//
//nolint:gocritic // hugeParam: by-value is intentional for immutability; the copy cost is acceptable.
func (s *Sender) Send(account string, evt model.RoomKeyEvent) error {
data, err := s.Marshal(evt)
if err != nil {
return err
}
return s.SendData(account, data)
}
45 changes: 45 additions & 0 deletions pkg/roomkeysender/roomkeysender_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,3 +110,48 @@ func TestSender_Send(t *testing.T) {
})
}
}

// TestSender_Marshal verifies Marshal stamps a timestamp and serializes the
// event once into reusable bytes, without mutating the caller's struct.
func TestSender_Marshal(t *testing.T) {
evt := model.RoomKeyEvent{RoomID: "room-1", Version: 7, PrivateKey: []byte{0x01, 0x02}}
before := evt
before.PrivateKey = append([]byte(nil), evt.PrivateKey...)

sender := roomkeysender.NewSender(&mockPublisher{})
data, err := sender.Marshal(evt)
require.NoError(t, err)

// Non-mutation contract: Marshal takes the event by value.
assert.Equal(t, before, evt, "Marshal must not mutate the caller's RoomKeyEvent")

var got model.RoomKeyEvent
require.NoError(t, json.Unmarshal(data, &got))
assert.Equal(t, evt.RoomID, got.RoomID)
assert.Equal(t, evt.Version, got.Version)
assert.Equal(t, evt.PrivateKey, got.PrivateKey)
assert.Greater(t, got.Timestamp, int64(0), "Marshal must stamp a timestamp")
}

// TestSender_SendData publishes pre-marshaled bytes verbatim to the account's
// room-key subject — the marshal-once fan-out building block.
func TestSender_SendData(t *testing.T) {
t.Run("publishes bytes to the account subject", func(t *testing.T) {
pub := &mockPublisher{}
sender := roomkeysender.NewSender(pub)
payload := []byte(`{"roomId":"r","version":3}`)

require.NoError(t, sender.SendData("alice", payload))
assert.Equal(t, "chat.user.alice.event.room.key", pub.subject)
assert.Equal(t, payload, pub.data, "SendData must publish the bytes verbatim")
})

t.Run("wraps publish errors", func(t *testing.T) {
sentinel := errors.New("connection lost")
sender := roomkeysender.NewSender(&mockPublisher{err: sentinel})
err := sender.SendData("bob", []byte("{}"))
require.Error(t, err)
assert.Contains(t, err.Error(), "publish room key event")
assert.ErrorIs(t, err, sentinel)
})
}
57 changes: 35 additions & 22 deletions room-worker/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,8 @@ func (h *Handler) processRemoveMember(ctx context.Context, data []byte) error {

// rotateAndFanOut generates v+1, fans it out to survivors, then commits via Valkey Rotate.
// Fan-out before Rotate is intentional so survivors hold v+1 before broadcast-worker switches.
func (h *Handler) rotateAndFanOut(ctx context.Context, roomID string, currentPair *roomkeystore.VersionedKeyPair, survivors []model.Subscription) error {
// survivorAccounts is a pre-computed post-deletion snapshot of the room's member accounts.
func (h *Handler) rotateAndFanOut(ctx context.Context, roomID string, currentPair *roomkeystore.VersionedKeyPair, survivorAccounts []string) error {
newPair, err := roomkeystore.GenerateKeyPair()
if err != nil {
return fmt.Errorf("generate room key: %w", err)
Expand All @@ -335,7 +336,7 @@ func (h *Handler) rotateAndFanOut(ctx context.Context, roomID string, currentPai
predictedVersion = currentPair.Version + 1
}
versioned := &roomkeystore.VersionedKeyPair{Version: predictedVersion, KeyPair: *newPair}
h.fanOutRoomKeyToSurvivors(ctx, roomID, versioned, survivors)
h.fanOutRoomKeyToSurvivors(ctx, roomID, versioned, survivorAccounts)

if currentPair == nil {
if _, err := h.keyStore.Set(ctx, roomID, *newPair); err != nil {
Expand Down Expand Up @@ -400,12 +401,14 @@ func (h *Handler) processRemoveIndividual(ctx context.Context, req *model.Remove
return fmt.Errorf("reconcile member counts: %w", err)
}

// Rotate after delete + reconcile; ListByRoom returns post-deletion survivors.
survivors, listErr := h.store.ListByRoom(ctx, req.RoomID)
// Rotate after delete + reconcile; GetSubscriptionAccounts returns the
// post-deletion survivor accounts (projected — fan-out only needs accounts,
// not full subscription docs).
survivorAccounts, listErr := h.store.GetSubscriptionAccounts(ctx, req.RoomID)
if listErr != nil {
return fmt.Errorf("list survivors for key fan-out (room %s): %w", req.RoomID, listErr)
}
if err := h.rotateAndFanOut(ctx, req.RoomID, currentPair, survivors); err != nil {
if err := h.rotateAndFanOut(ctx, req.RoomID, currentPair, survivorAccounts); err != nil {
return fmt.Errorf("rotate and fan-out room key after remove-individual: %w", err)
}

Expand Down Expand Up @@ -608,13 +611,15 @@ func (h *Handler) processRemoveOrg(ctx context.Context, req *model.RemoveMemberR
return fmt.Errorf("reconcile member counts: %w", err)
}

// Rotate only when something was actually deleted; ListByRoom returns post-deletion survivors.
// Rotate only when something was actually deleted; GetSubscriptionAccounts
// returns the post-deletion survivor accounts (projected — fan-out only
// needs accounts, not full subscription docs).
if len(accounts) > 0 {
survivors, listErr := h.store.ListByRoom(ctx, req.RoomID)
survivorAccounts, listErr := h.store.GetSubscriptionAccounts(ctx, req.RoomID)
if listErr != nil {
return fmt.Errorf("list survivors for key fan-out (room %s): %w", req.RoomID, listErr)
}
if err := h.rotateAndFanOut(ctx, req.RoomID, currentPair, survivors); err != nil {
if err := h.rotateAndFanOut(ctx, req.RoomID, currentPair, survivorAccounts); err != nil {
return fmt.Errorf("rotate and fan-out room key after remove-org: %w", err)
}
}
Expand Down Expand Up @@ -1821,21 +1826,18 @@ func (h *Handler) natsServerCreateDM(m otelnats.Msg) {
natsutil.ReplyJSON(m.Msg, reply)
}

// fanOutRoomKeyToSurvivors sends the already-fetched room key to every room member in survivors
// (local + remote). NATS supercluster routes user-subjects to home sites.
// survivors is a pre-computed post-deletion snapshot supplied by the caller; pair must be non-nil.
func (h *Handler) fanOutRoomKeyToSurvivors(ctx context.Context, roomID string, pair *roomkeystore.VersionedKeyPair, survivors []model.Subscription) {
// fanOutRoomKeyToSurvivors sends the already-fetched room key to every survivor
// account (local + remote). NATS supercluster routes user-subjects to home
// sites. survivorAccounts is a pre-computed post-deletion snapshot supplied by
// the caller; pair must be non-nil.
func (h *Handler) fanOutRoomKeyToSurvivors(ctx context.Context, roomID string, pair *roomkeystore.VersionedKeyPair, survivorAccounts []string) {
// PublicKey omitted: server-side only, read from Valkey by broadcast-worker.
evt := model.RoomKeyEvent{
RoomID: roomID,
Version: pair.Version,
PrivateKey: pair.KeyPair.PrivateKey,
}
accounts := make([]string, len(survivors))
for i := range survivors {
accounts[i] = survivors[i].User.Account
}
h.fanOutKey(ctx, roomID, accounts, &evt)
h.fanOutKey(ctx, roomID, survivorAccounts, &evt)
}

// buildAndFanOutRoomKey publishes pair as a RoomKeyEvent to every account in users.
Expand All @@ -1859,17 +1861,28 @@ func (h *Handler) buildAndFanOutRoomKey(ctx context.Context, roomID string, pair
return nil
}

// fanOutKey distributes evt to every account via h.keySender.Send using up to
// h.keyFanoutWorkers concurrent goroutines. Per-account errors are logged and
// counted via roomkeymetrics; partial fan-out is acceptable because JetStream
// redelivers on permanent failure and recipients are idempotent on key version.
// fanOutKey distributes evt to every account using up to h.keyFanoutWorkers
// concurrent goroutines. The event is marshaled exactly once and the resulting
// bytes are published to each account — on a giant room this avoids one
// json.Marshal per recipient. Per-account errors are logged and counted via
// roomkeymetrics; partial fan-out is acceptable because JetStream redelivers on
// permanent failure and recipients are idempotent on key version.
//
// evt is taken by pointer so the 80-byte struct isn't copied per fan-out call;
// callers must not mutate it after passing it in.
func (h *Handler) fanOutKey(ctx context.Context, roomID string, accounts []string, evt *model.RoomKeyEvent) {
if len(accounts) == 0 {
return
}
data, err := h.keySender.Marshal(*evt)
if err != nil {
// Marshaling a RoomKeyEvent effectively never fails; if it somehow does,
// no recipient can be served, so count the whole batch and bail. The
// caller treats fan-out as best-effort and JetStream redelivers.
slog.Error("marshal room key for fan-out", "error", err, "roomId", roomID, "accounts", len(accounts))
roomkeymetrics.FanoutErrors.Add(ctx, int64(len(accounts)), metric.WithAttributes(attribute.String("roomId", roomID)))
return
}
workers := h.keyFanoutWorkers
if workers <= 0 {
// Defensive default for tests and any future construction path that
Expand All @@ -1890,7 +1903,7 @@ func (h *Handler) fanOutKey(ctx context.Context, roomID string, accounts []strin
<-sem
wg.Done()
}()
if err := h.keySender.Send(acct, *evt); err != nil {
if err := h.keySender.SendData(acct, data); err != nil {
slog.Error("send room key", "error", err, "account", acct, "roomId", roomID)
roomkeymetrics.FanoutErrors.Add(ctx, 1, metric.WithAttributes(attribute.String("roomId", roomID)))
}
Expand Down
28 changes: 13 additions & 15 deletions room-worker/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ func TestHandler_ProcessRemoveMember_SelfLeave_IndividualOnly(t *testing.T) {
store.EXPECT().
ReconcileMemberCounts(gomock.Any(), roomID).Return(nil)
store.EXPECT().
ListByRoom(gomock.Any(), roomID).Return(nil, nil)
GetSubscriptionAccounts(gomock.Any(), roomID).Return(nil, nil)

var published []publishedMsg
h := NewHandler(store, siteID, func(_ context.Context, subj string, data []byte, _ string) error {
Expand Down Expand Up @@ -582,7 +582,7 @@ func TestHandler_ProcessRemoveMember_OwnerRemovesIndividual(t *testing.T) {
store.EXPECT().
ReconcileMemberCounts(gomock.Any(), roomID).Return(nil)
store.EXPECT().
ListByRoom(gomock.Any(), roomID).Return(nil, nil)
GetSubscriptionAccounts(gomock.Any(), roomID).Return(nil, nil)
store.EXPECT().
GetUser(gomock.Any(), requester).
Return(&model.User{ID: "u_alice", Account: requester, SiteID: siteID, EngName: "Alice", ChineseName: "愛"}, nil)
Expand Down Expand Up @@ -1168,7 +1168,7 @@ func TestHandler_ProcessRemoveMember_OwnerRemovesOrg(t *testing.T) {
store.EXPECT().
ReconcileMemberCounts(gomock.Any(), roomID).Return(nil) // recount after removal
store.EXPECT().
ListByRoom(gomock.Any(), roomID).Return(nil, nil)
GetSubscriptionAccounts(gomock.Any(), roomID).Return(nil, nil)
store.EXPECT().
GetUser(gomock.Any(), requester).
Return(&model.User{ID: "u_alice", Account: requester, SiteID: siteID, EngName: "Alice", ChineseName: "愛"}, nil)
Expand Down Expand Up @@ -1242,7 +1242,7 @@ func TestHandler_ProcessRemoveMember_CrossSiteOutbox(t *testing.T) {
store.EXPECT().
ReconcileMemberCounts(gomock.Any(), roomID).Return(nil)
store.EXPECT().
ListByRoom(gomock.Any(), roomID).Return(nil, nil)
GetSubscriptionAccounts(gomock.Any(), roomID).Return(nil, nil)

var published []publishedMsg
h := NewHandler(store, localSite, func(_ context.Context, subj string, data []byte, _ string) error {
Expand Down Expand Up @@ -1534,7 +1534,7 @@ func TestHandler_ProcessRemoveIndividual_OutboxFailurePropagates(t *testing.T) {
store.EXPECT().
ReconcileMemberCounts(gomock.Any(), roomID).Return(nil)
store.EXPECT().
ListByRoom(gomock.Any(), roomID).Return(nil, nil)
GetSubscriptionAccounts(gomock.Any(), roomID).Return(nil, nil)

outboxSubj := subject.Outbox(localSite, userSite, "member_removed")
publish := func(_ context.Context, subj string, _ []byte, _ string) error {
Expand Down Expand Up @@ -1573,7 +1573,7 @@ func TestHandler_ProcessRemoveOrg_OutboxFailurePropagates(t *testing.T) {
store.EXPECT().DeleteSubscriptionsByAccounts(gomock.Any(), roomID, []string{"carol"}).Return(int64(1), nil)
store.EXPECT().DeleteRoomMember(gomock.Any(), roomID, model.RoomMemberOrg, orgID).Return(nil)
store.EXPECT().ReconcileMemberCounts(gomock.Any(), roomID).Return(nil)
store.EXPECT().ListByRoom(gomock.Any(), roomID).Return(nil, nil)
store.EXPECT().GetSubscriptionAccounts(gomock.Any(), roomID).Return(nil, nil)
store.EXPECT().GetUser(gomock.Any(), requester).
Return(&model.User{ID: "u_alice", Account: requester, SiteID: localSite, EngName: "Alice", ChineseName: "愛"}, nil)

Expand Down Expand Up @@ -3635,14 +3635,12 @@ func TestFanOutRoomKeyToSurvivors_SendsToAllSurvivorsIncludingRemoteSite(t *test
pair := &roomkeystore.VersionedKeyPair{Version: 5, KeyPair: roomkeystore.RoomKeyPair{
PrivateKey: bytes.Repeat([]byte{0x03}, 32),
}}
survivors := []model.Subscription{
{User: model.SubscriptionUser{Account: "alice"}, RoomID: "r1", SiteID: "site-a"},
{User: model.SubscriptionUser{Account: "bob"}, RoomID: "r1", SiteID: "site-a"},
{User: model.SubscriptionUser{Account: "remote-carol"}, RoomID: "r1", SiteID: "site-b"},
}
// Survivor accounts span the local site (alice, bob) and a remote site
// (remote-carol); the caller projects these out of the subscriptions.
survivorAccounts := []string{"alice", "bob", "remote-carol"}

h := NewHandler(store, "site-a", func(_ context.Context, _ string, _ []byte, _ string) error { return nil }, nil, keySender)
h.fanOutRoomKeyToSurvivors(context.Background(), "r1", pair, survivors)
h.fanOutRoomKeyToSurvivors(context.Background(), "r1", pair, survivorAccounts)
// alice, bob (site-a) and remote-carol (site-b) all receive the new key.
assert.ElementsMatch(t, []string{
"chat.user.alice.event.room.key",
Expand Down Expand Up @@ -4115,7 +4113,7 @@ func TestHandler_ProcessRemoveIndividual_SelfLeave_Content(t *testing.T) {
store.EXPECT().DeleteRoomMember(gomock.Any(), roomID, model.RoomMemberIndividual, "u_b").Return(nil)
store.EXPECT().DeleteSubscription(gomock.Any(), roomID, "bob").Return(int64(1), nil)
store.EXPECT().ReconcileMemberCounts(gomock.Any(), roomID).Return(nil)
store.EXPECT().ListByRoom(gomock.Any(), roomID).Return([]model.Subscription{}, nil)
store.EXPECT().GetSubscriptionAccounts(gomock.Any(), roomID).Return([]string{}, nil)

var published []publishedMsg
h := &Handler{store: store, siteID: "site-a", publish: func(_ context.Context, subj string, data []byte, _ string) error {
Expand Down Expand Up @@ -4145,7 +4143,7 @@ func TestHandler_ProcessRemoveIndividual_RemovedByOther_Content(t *testing.T) {
store.EXPECT().DeleteRoomMember(gomock.Any(), roomID, model.RoomMemberIndividual, "u_b").Return(nil)
store.EXPECT().DeleteSubscription(gomock.Any(), roomID, "bob").Return(int64(1), nil)
store.EXPECT().ReconcileMemberCounts(gomock.Any(), roomID).Return(nil)
store.EXPECT().ListByRoom(gomock.Any(), roomID).Return([]model.Subscription{}, nil)
store.EXPECT().GetSubscriptionAccounts(gomock.Any(), roomID).Return([]string{}, nil)
store.EXPECT().GetUser(gomock.Any(), "alice").
Return(&model.User{ID: "u_a", Account: "alice", SiteID: "site-a", EngName: "Alice", ChineseName: "愛"}, nil)

Expand Down Expand Up @@ -4257,7 +4255,7 @@ func TestHandler_ProcessRemoveOrg_OtherOrgCovers_PreservesSub(t *testing.T) {
// MUST NOT be called — alice is still covered by the sibling org.
store.EXPECT().DeleteSubscriptionsByAccounts(gomock.Any(), gomock.Any(), gomock.Any()).Times(0)
// MUST NOT rotate — no survivors were displaced.
store.EXPECT().ListByRoom(gomock.Any(), gomock.Any()).Times(0)
store.EXPECT().GetSubscriptionAccounts(gomock.Any(), gomock.Any()).Times(0)
// The X org row still gets deleted; the count gets reconciled.
store.EXPECT().DeleteRoomMember(gomock.Any(), roomID, model.RoomMemberOrg, "X").Return(nil)
store.EXPECT().ReconcileMemberCounts(gomock.Any(), roomID).Return(nil)
Expand Down
50 changes: 50 additions & 0 deletions room-worker/keyfanout_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,56 @@ func TestFanOutKey_PublishesEveryAccount(t *testing.T) {
require.Len(t, got, accounts, "must publish once per account")
}

// dataRecordingPublisher records the raw payload bytes of every Publish call.
type dataRecordingPublisher struct {
mu sync.Mutex
payloads [][]byte
}

func (d *dataRecordingPublisher) Publish(_ string, data []byte) error {
d.mu.Lock()
defer d.mu.Unlock()
// Copy: callers may reuse the backing array.
cp := make([]byte, len(data))
copy(cp, data)
d.payloads = append(d.payloads, cp)
return nil
}

func (d *dataRecordingPublisher) snapshot() [][]byte {
d.mu.Lock()
defer d.mu.Unlock()
out := make([][]byte, len(d.payloads))
copy(out, d.payloads)
return out
}

// TestFanOutKey_MarshalsOnce asserts every recipient receives byte-identical
// payload, which holds only when the event is serialized a single time and the
// same bytes are fanned out (rather than re-marshaled per account with a
// per-recipient timestamp).
func TestFanOutKey_MarshalsOnce(t *testing.T) {
const accounts = 50

dp := &dataRecordingPublisher{}
h := newFanoutTestHandler(t, roomkeysender.NewSender(dp), 8)

accts := make([]string, accounts)
for i := range accts {
accts[i] = fmt.Sprintf("acct-%03d", i)
}
evt := model.RoomKeyEvent{RoomID: "r", Version: 2, PrivateKey: []byte{0xaa, 0xbb}}
h.fanOutKey(context.Background(), "r", accts, &evt)

payloads := dp.snapshot()
require.Len(t, payloads, accounts)
first := payloads[0]
require.NotEmpty(t, first)
for i, p := range payloads {
assert.Equal(t, first, p, "payload %d differs; event was not marshaled exactly once", i)
}
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated

func TestFanOutKey_NoAccountsIsNoOp(t *testing.T) {
rp := &recordingPublisher{}
h := newFanoutTestHandler(t, roomkeysender.NewSender(rp), 16)
Expand Down
Loading
Loading