diff --git a/broadcast-worker/deploy/docker-compose.yml b/broadcast-worker/deploy/docker-compose.yml index 2b519e112..0bbdafaf6 100644 --- a/broadcast-worker/deploy/docker-compose.yml +++ b/broadcast-worker/deploy/docker-compose.yml @@ -11,10 +11,11 @@ services: - SITE_ID=site-local - MONGO_URI=mongodb://mongodb:27017 - MONGO_DB=chat - # In-process user cache. Defaults: 10000 entries, 5m TTL. - # Set USER_CACHE_SIZE=0 to disable caching. + # In-process user cache (pkg/userstore.Cache, shared with message-gatekeeper + # and message-worker). Defaults: 10000 entries, 5m TTL. - USER_CACHE_SIZE=10000 - USER_CACHE_TTL=5m + # Valkey is used for room encryption keys only (when ENCRYPTION_ENABLED=true). - VALKEY_ADDRS=valkey:6379 - VALKEY_KEY_GRACE_PERIOD=24h - BOOTSTRAP_STREAMS=true diff --git a/broadcast-worker/main.go b/broadcast-worker/main.go index 4fc3e65c6..b7923eeef 100644 --- a/broadcast-worker/main.go +++ b/broadcast-worker/main.go @@ -79,13 +79,13 @@ func main() { os.Exit(1) } slog.Info("room-meta-cache enabled", "size", cfg.RoomMetaCacheSize, "ttl", cfg.RoomMetaCacheTTL) - us := userstore.NewMongoStore(db.Collection("users")) - if cfg.UserCacheSize > 0 && cfg.UserCacheTTL > 0 { - us = NewCachedUserStore(us, cfg.UserCacheSize, cfg.UserCacheTTL) - slog.Info("user-cache enabled", "size", cfg.UserCacheSize, "ttl", cfg.UserCacheTTL) - } else { - slog.Info("user-cache disabled") + us, err := userstore.NewCache(userstore.NewMongoStore(db.Collection("users")), + cfg.UserCacheSize, cfg.UserCacheTTL) + if err != nil { + slog.Error("init user cache failed", "error", err) + os.Exit(1) } + slog.Info("user-cache enabled", "size", cfg.UserCacheSize, "ttl", cfg.UserCacheTTL) var keyStore roomkeystore.RoomKeyStore if cfg.Encryption.Enabled { diff --git a/broadcast-worker/usercache.go b/broadcast-worker/usercache.go deleted file mode 100644 index 2d11318c3..000000000 --- a/broadcast-worker/usercache.go +++ /dev/null @@ -1,141 +0,0 @@ -package main - -import ( - "container/list" - "context" - "fmt" - "sync" - "time" - - "github.com/hmchangw/chat/pkg/model" - "github.com/hmchangw/chat/pkg/userstore" -) - -// userCacheEntry is the value stored in each LRU list element. -type userCacheEntry struct { - account string - user model.User - inserted time.Time -} - -// CachedUserStore wraps a userstore.UserStore with an in-process LRU+TTL -// cache of FindUsersByAccounts results. FindUserByID delegates to the -// inner store unchanged. -type CachedUserStore struct { - inner userstore.UserStore - ttl time.Duration - maxSize int - - mu sync.Mutex - lru *list.List // elements hold *userCacheEntry; front = MRU, back = LRU - index map[string]*list.Element - now func() time.Time -} - -// NewCachedUserStore returns a cache wrapping inner. maxSize and ttl must -// both be positive. -func NewCachedUserStore(inner userstore.UserStore, maxSize int, ttl time.Duration) *CachedUserStore { - return &CachedUserStore{ - inner: inner, - ttl: ttl, - maxSize: maxSize, - lru: list.New(), - index: make(map[string]*list.Element, maxSize), - now: time.Now, - } -} - -// FindUserByID delegates; no caching for single-ID lookups. -func (c *CachedUserStore) FindUserByID(ctx context.Context, id string) (*model.User, error) { - return c.inner.FindUserByID(ctx, id) -} - -// FindUsersByAccounts returns users for the requested accounts, serving -// cache hits without calling the inner store. Cache misses are forwarded -// in a single batched inner call. Missing users are not cached as -// negatives — an account the inner store didn't return is simply absent -// and will be re-fetched next time. When the inner store returns an -// error, partial cache hits collected so far are returned alongside the -// (wrapped) error so the caller can choose to log and continue. -func (c *CachedUserStore) FindUsersByAccounts(ctx context.Context, accounts []string) ([]model.User, error) { - if len(accounts) == 0 { - return nil, nil - } - - // Dedupe the input so cache-hit and cache-miss paths produce identical - // results regardless of whether the caller passes duplicates. - seen := make(map[string]struct{}, len(accounts)) - unique := make([]string, 0, len(accounts)) - for _, a := range accounts { - if _, ok := seen[a]; ok { - continue - } - seen[a] = struct{}{} - unique = append(unique, a) - } - accounts = unique - - now := c.now() - - c.mu.Lock() - hits := make([]model.User, 0, len(accounts)) - missing := make([]string, 0, len(accounts)) - for _, account := range accounts { - elem, ok := c.index[account] - if !ok { - missing = append(missing, account) - continue - } - entry := elem.Value.(*userCacheEntry) - if now.Sub(entry.inserted) >= c.ttl { - // Stale; treat as miss. Drop entry now so a concurrent writer - // doesn't collide; the inner result (or its absence) will - // repopulate below. - c.lru.Remove(elem) - delete(c.index, account) - missing = append(missing, account) - continue - } - if elem != c.lru.Front() { - c.lru.MoveToFront(elem) - } - hits = append(hits, entry.user) - } - c.mu.Unlock() - - if len(missing) == 0 { - return hits, nil - } - - fresh, err := c.inner.FindUsersByAccounts(ctx, missing) - if err != nil { - return hits, fmt.Errorf("cached find users by accounts: %w", err) - } - - c.mu.Lock() - for i := range fresh { - c.addLocked(&fresh[i], now) - } - c.mu.Unlock() - - return append(hits, fresh...), nil -} - -// addLocked inserts or refreshes a cache entry. The caller must hold c.mu. -func (c *CachedUserStore) addLocked(u *model.User, now time.Time) { - if existing, ok := c.index[u.Account]; ok { - existing.Value = &userCacheEntry{account: u.Account, user: *u, inserted: now} - c.lru.MoveToFront(existing) - return - } - entry := &userCacheEntry{account: u.Account, user: *u, inserted: now} - elem := c.lru.PushFront(entry) - c.index[u.Account] = elem - if c.lru.Len() > c.maxSize { - if lruElem := c.lru.Back(); lruElem != nil { - lruEntry := lruElem.Value.(*userCacheEntry) - c.lru.Remove(lruElem) - delete(c.index, lruEntry.account) - } - } -} diff --git a/broadcast-worker/usercache_test.go b/broadcast-worker/usercache_test.go deleted file mode 100644 index e9fcfa386..000000000 --- a/broadcast-worker/usercache_test.go +++ /dev/null @@ -1,299 +0,0 @@ -package main - -import ( - "context" - "errors" - "strconv" - "sync" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "github.com/hmchangw/chat/pkg/model" - "github.com/hmchangw/chat/pkg/userstore" -) - -// fakeUserStore is a minimal userstore.UserStore that records calls and -// returns preconfigured users. Tests assert call counts to confirm cache -// hits do not reach the inner store. -type fakeUserStore struct { - mu sync.Mutex - calls [][]string - byAccount map[string]model.User - err error -} - -func newFakeUserStore(users ...model.User) *fakeUserStore { - f := &fakeUserStore{byAccount: make(map[string]model.User, len(users))} - for i := range users { - f.byAccount[users[i].Account] = users[i] - } - return f -} - -func (f *fakeUserStore) FindUserByID(_ context.Context, id string) (*model.User, error) { - f.mu.Lock() - defer f.mu.Unlock() - if u, ok := f.byAccount[id]; ok { - return &u, nil - } - return nil, errors.New("not found") -} - -func (f *fakeUserStore) FindUsersByAccounts(_ context.Context, accounts []string) ([]model.User, error) { - f.mu.Lock() - defer f.mu.Unlock() - f.calls = append(f.calls, append([]string(nil), accounts...)) - if f.err != nil { - return nil, f.err - } - out := make([]model.User, 0, len(accounts)) - for _, a := range accounts { - if u, ok := f.byAccount[a]; ok { - out = append(out, u) - } - } - return out, nil -} - -func (f *fakeUserStore) callCount() int { - f.mu.Lock() - defer f.mu.Unlock() - return len(f.calls) -} - -func (f *fakeUserStore) lastCall() []string { - f.mu.Lock() - defer f.mu.Unlock() - if len(f.calls) == 0 { - return nil - } - return f.calls[len(f.calls)-1] -} - -var _ userstore.UserStore = (*CachedUserStore)(nil) - -func TestNewCachedUserStore_ConstructsEmpty(t *testing.T) { - inner := newFakeUserStore() - c := NewCachedUserStore(inner, 10, time.Minute) - require.NotNil(t, c) - // A fresh cache doesn't call inner until asked. - assert.Equal(t, 0, inner.callCount()) - assert.Nil(t, inner.lastCall()) -} - -func TestCachedUserStore_MissCallsInner(t *testing.T) { - alice := model.User{ID: "u1", Account: "alice", EngName: "Alice"} - inner := newFakeUserStore(alice) - c := NewCachedUserStore(inner, 10, time.Minute) - - users, err := c.FindUsersByAccounts(context.Background(), []string{"alice"}) - require.NoError(t, err) - require.Len(t, users, 1) - assert.Equal(t, alice, users[0]) - assert.Equal(t, 1, inner.callCount(), "miss should call inner") - assert.Equal(t, []string{"alice"}, inner.lastCall()) -} - -func TestCachedUserStore_HitServedFromCache(t *testing.T) { - alice := model.User{ID: "u1", Account: "alice", EngName: "Alice"} - inner := newFakeUserStore(alice) - c := NewCachedUserStore(inner, 10, time.Minute) - - _, _ = c.FindUsersByAccounts(context.Background(), []string{"alice"}) // prime - users, err := c.FindUsersByAccounts(context.Background(), []string{"alice"}) - require.NoError(t, err) - require.Len(t, users, 1) - assert.Equal(t, alice, users[0]) - assert.Equal(t, 1, inner.callCount(), "hit should not call inner") -} - -func TestCachedUserStore_PartialHitCallsInnerWithOnlyMissing(t *testing.T) { - alice := model.User{ID: "u1", Account: "alice"} - bob := model.User{ID: "u2", Account: "bob"} - inner := newFakeUserStore(alice, bob) - c := NewCachedUserStore(inner, 10, time.Minute) - - _, _ = c.FindUsersByAccounts(context.Background(), []string{"alice"}) // prime alice only - - users, err := c.FindUsersByAccounts(context.Background(), []string{"alice", "bob"}) - require.NoError(t, err) - require.Len(t, users, 2) - assert.Equal(t, 2, inner.callCount(), "partial hit still calls inner for misses") - assert.Equal(t, []string{"bob"}, inner.lastCall(), "inner called only with missing accounts") -} - -func TestCachedUserStore_EmptyInputReturnsNil(t *testing.T) { - inner := newFakeUserStore() - c := NewCachedUserStore(inner, 10, time.Minute) - - users, err := c.FindUsersByAccounts(context.Background(), nil) - require.NoError(t, err) - assert.Nil(t, users) - assert.Equal(t, 0, inner.callCount()) -} - -func TestCachedUserStore_MissingUserNotCached(t *testing.T) { - inner := newFakeUserStore() // no users registered - c := NewCachedUserStore(inner, 10, time.Minute) - - // First call: inner returns no users for "ghost". - users, err := c.FindUsersByAccounts(context.Background(), []string{"ghost"}) - require.NoError(t, err) - assert.Empty(t, users) - - // Add ghost later to simulate the user being created. - inner.byAccount["ghost"] = model.User{ID: "u-ghost", Account: "ghost"} - - // Second call: the negative result must NOT be cached — inner must be called again. - users2, err := c.FindUsersByAccounts(context.Background(), []string{"ghost"}) - require.NoError(t, err) - require.Len(t, users2, 1) - assert.Equal(t, 2, inner.callCount(), "missing accounts must not be cached as negatives") -} - -func TestCachedUserStore_InnerErrorPropagated(t *testing.T) { - innerErr := errors.New("boom") - inner := newFakeUserStore() - inner.err = innerErr - c := NewCachedUserStore(inner, 10, time.Minute) - - _, err := c.FindUsersByAccounts(context.Background(), []string{"alice"}) - require.Error(t, err) - assert.ErrorIs(t, err, innerErr, "inner error should be wrapped, not swallowed") -} - -func TestCachedUserStore_TTLExpiredReFetches(t *testing.T) { - alice := model.User{ID: "u1", Account: "alice"} - inner := newFakeUserStore(alice) - c := NewCachedUserStore(inner, 10, 1*time.Second) - - // Freeze "now" at a known value. - base := time.Date(2030, 1, 1, 0, 0, 0, 0, time.UTC) - c.now = func() time.Time { return base } - - _, err := c.FindUsersByAccounts(context.Background(), []string{"alice"}) - require.NoError(t, err) - assert.Equal(t, 1, inner.callCount()) - - // Advance past TTL. - c.now = func() time.Time { return base.Add(2 * time.Second) } - - _, err = c.FindUsersByAccounts(context.Background(), []string{"alice"}) - require.NoError(t, err) - assert.Equal(t, 2, inner.callCount(), "stale entry should force re-fetch") -} - -func TestCachedUserStore_LRUEvictionOnOverflow(t *testing.T) { - // maxSize=2: inserting 3 distinct accounts must evict the oldest. - a := model.User{ID: "u1", Account: "alice"} - b := model.User{ID: "u2", Account: "bob"} - c := model.User{ID: "u3", Account: "carol"} - inner := newFakeUserStore(a, b, c) - store := NewCachedUserStore(inner, 2, time.Minute) - - ctx := context.Background() - _, _ = store.FindUsersByAccounts(ctx, []string{"alice"}) - _, _ = store.FindUsersByAccounts(ctx, []string{"bob"}) - _, _ = store.FindUsersByAccounts(ctx, []string{"carol"}) // should evict alice - // alice should now be a miss again. - _, _ = store.FindUsersByAccounts(ctx, []string{"alice"}) - - // Inner calls: alice, bob, carol, alice → 4 total - assert.Equal(t, 4, inner.callCount(), "alice must be re-fetched after eviction") -} - -func TestCachedUserStore_AccessPromotesToMRU(t *testing.T) { - // maxSize=2: after priming alice + bob, accessing alice makes bob the - // LRU. Inserting carol should evict bob, not alice. - a := model.User{ID: "u1", Account: "alice"} - b := model.User{ID: "u2", Account: "bob"} - c := model.User{ID: "u3", Account: "carol"} - inner := newFakeUserStore(a, b, c) - store := NewCachedUserStore(inner, 2, time.Minute) - - ctx := context.Background() - _, _ = store.FindUsersByAccounts(ctx, []string{"alice"}) - _, _ = store.FindUsersByAccounts(ctx, []string{"bob"}) - // Touch alice to mark MRU. - _, _ = store.FindUsersByAccounts(ctx, []string{"alice"}) - // Insert carol: bob should be evicted. - _, _ = store.FindUsersByAccounts(ctx, []string{"carol"}) - - before := inner.callCount() - // alice is still cached. - _, _ = store.FindUsersByAccounts(ctx, []string{"alice"}) - assert.Equal(t, before, inner.callCount(), "alice should still be cached") - // bob is not. - _, _ = store.FindUsersByAccounts(ctx, []string{"bob"}) - assert.Equal(t, before+1, inner.callCount(), "bob should have been evicted") -} - -func TestCachedUserStore_ConcurrentSafe(t *testing.T) { - // Many goroutines hit overlapping account sets. No race, no panic. - const ( - goroutines = 32 - iterations = 200 - accounts = 50 - ) - users := make([]model.User, accounts) - for i := range users { - users[i] = model.User{ - ID: "u-" + strconv.Itoa(i), - Account: "acct-" + strconv.Itoa(i), - } - } - inner := newFakeUserStore(users...) - store := NewCachedUserStore(inner, 32, time.Minute) - - ctx := context.Background() - var wg sync.WaitGroup - wg.Add(goroutines) - for g := 0; g < goroutines; g++ { - go func(seed int) { - defer wg.Done() - for i := 0; i < iterations; i++ { - idx := (seed*iterations + i) % accounts - _, err := store.FindUsersByAccounts(ctx, []string{"acct-" + strconv.Itoa(idx)}) - require.NoError(t, err) - } - }(g) - } - wg.Wait() -} - -func TestCachedUserStore_FindUserByIDDelegates(t *testing.T) { - // Keyed on account in the fake; for this test reuse the account as the ID. - alice := model.User{ID: "alice", Account: "alice", EngName: "Alice"} - inner := newFakeUserStore(alice) - c := NewCachedUserStore(inner, 10, time.Minute) - - u, err := c.FindUserByID(context.Background(), "alice") - require.NoError(t, err) - require.NotNil(t, u) - assert.Equal(t, "alice", u.Account) - - _, err = c.FindUserByID(context.Background(), "ghost") - require.Error(t, err, "inner store's not-found error should propagate") -} - -func TestCachedUserStore_DedupesDuplicateAccounts(t *testing.T) { - alice := model.User{ID: "u1", Account: "alice"} - inner := newFakeUserStore(alice) - c := NewCachedUserStore(inner, 10, time.Minute) - - // Cold cache: both duplicates would otherwise hit the inner store. - // After dedup, inner sees alice exactly once and the return has one user. - users, err := c.FindUsersByAccounts(context.Background(), []string{"alice", "alice"}) - require.NoError(t, err) - assert.Len(t, users, 1) - assert.Equal(t, []string{"alice"}, inner.lastCall()) - - // Warm cache: still one user returned regardless of how many dupes are asked. - users2, err := c.FindUsersByAccounts(context.Background(), []string{"alice", "alice", "alice"}) - require.NoError(t, err) - assert.Len(t, users2, 1) - assert.Equal(t, 1, inner.callCount(), "warm-cache dupe lookup must not call inner") -} diff --git a/docs/client-api.md b/docs/client-api.md index cab68bda2..fb9455dad 100644 --- a/docs/client-api.md +++ b/docs/client-api.md @@ -928,7 +928,7 @@ See [Error envelope](#6-error-envelope-reference). Common errors: ##### Behaviour notes -- **Notification delivery:** `notification-worker` does **not** yet consult `muted` before sending. End-to-end mute behaviour is wired only as far as the persisted flag; honouring it in fan-out is a follow-up. +- **Notification delivery:** `notification-worker` respects `muted` flags when deciding whether to send mobile push notifications (see [Notification fan-out](#notification-fan-out-mobile-push-only) below). --- @@ -2518,6 +2518,28 @@ When validation fails, the gatekeeper publishes the error envelope to `chat.user { "code": "bad_request", "error": "content must not be empty" } ``` +#### Notification fan-out (mobile push only) + +`notification-worker` no longer publishes `chat.user.{account}.notification` +on core NATS. Mobile pushes are emitted on the server-only JetStream subject +`chat.server.notification.push.{siteID}.send` and forwarded by the internal +push-notification service. Desktop banners are computed client-side from the +broadcast-worker room-event stream — no server-side desktop publish exists. + +The worker filters recipients per message: + +- Skips the sender. +- Skips members with `muted: true` on their subscription. +- Skips members whose `historySharedSince` postdates the message (for a + thread-only reply the parent's `createdAt` is used instead). +- For a thread reply with `tshow: false`, skips non-followers who are not + mentioned. +- In rooms with more than `LARGE_ROOM_THRESHOLD` members (default 500), + pushes only to mentioned recipients (`@user`, `@all`, `@here`). +- Bots never receive a mobile push. +- Presence-busy / in-call recipients are not pushed; everyone else + (online, offline, away, missing) receives one. + --- ## 5. Room Encryption diff --git a/docs/nats-subject-naming.md b/docs/nats-subject-naming.md index 83b6f7e38..36d316a24 100644 --- a/docs/nats-subject-naming.md +++ b/docs/nats-subject-naming.md @@ -26,7 +26,7 @@ On connect, every client subscribes to `chat.user.{account}.>`. This single wild | Subject | Direction | Publisher | Purpose | |---------|-----------|-----------|---------| | `chat.user.{account}.stream.msg` | Server → Client | broadcast-worker | DM message delivery | -| `chat.user.{account}.notification` | Server → Client | notification-worker | Desktop banner notification (new message alert) | +| `chat.user.{account}.notification` | Server → Client | _(removed — see PUSH_NOTIFICATIONS stream below)_ | _(deprecated)_ | | `chat.user.{account}.event.subscription.update` | Server → Client | room-worker, inbox-worker | Room added/removed from user's list | | `chat.user.{account}.event.room.metadata.update` | Server → Client | room-worker | Room metadata changed (for rooms in sidebar) | | `chat.user.{account}.response.{requestID}` | Server → Client | various services | Response to a client request | @@ -89,9 +89,9 @@ When offline, clients miss messages on non-active sidebar rooms. To restore ment 2. **Subscription list response** (`chat.user.{account}.request.rooms.list`) — includes `mentionCountSinceLastSeen` per room, allowing the client to restore `@` badges without fetching message history for every sidebar room 3. **Mark as read** — when the user opens a room, the client sends a read-position update (advancing `lastSeenAt`); the server resets `mentionCountSinceLastSeen` to `0` for that user+room -#### Desktop Banner Notifications +#### Desktop Banner Notifications (Mobile Push) -notification-worker sends a `NotificationEvent` to `chat.user.{account}.notification` for immediate desktop banners (including mention notifications). This is an interrupt-style notification, separate from the persistent badge state above. +notification-worker publishes a `PushNotificationEvent` to `chat.server.notification.push.{siteID}.send` (captured by the `PUSH_NOTIFICATIONS_{siteID}` JetStream stream) for each eligible recipient. The push service consumes this stream and delivers the notification to the recipient's mobile device. The old per-user NATS core subject `chat.user.{account}.notification` is no longer used. #### Reconnect Badge Restoration @@ -181,6 +181,16 @@ Stream wildcard: `chat.user.*.request.room.*.{siteID}.member.>` Stream wildcard: `outbox.{siteID}.>` +### PUSH_NOTIFICATIONS Stream (`PUSH_NOTIFICATIONS_{siteID}`) + +| Subject Pattern | Publisher | Consumer | Purpose | +|-----------------|-----------|----------|---------| +| `chat.server.notification.push.{siteID}.send` | notification-worker | push service | Per-recipient mobile push event | + +Stream wildcard: `chat.server.notification.push.{siteID}.>` (wildcard accommodates future `.silent`, `.priority` siblings) + +This is a server-only, backend stream. Clients never interact with it. + ### INBOX Stream (`INBOX_{siteID}`) Sourced from remote sites' OUTBOX streams. Processed by `inbox-worker`. @@ -215,7 +225,7 @@ All client publishes — message sends, member invites, room CRUD requests, typi | `MsgHistory(account, roomID, siteID)` | `chat.user.{account}.request.room.{roomID}.{siteID}.msg.history` | | `SubscriptionUpdate(account)` | `chat.user.{account}.event.subscription.update` | | `RoomMetadataChanged(account)` | `chat.user.{account}.event.room.metadata.update` | -| `Notification(account)` | `chat.user.{account}.notification` | +| `Notification(account)` | `chat.user.{account}.notification` _(deprecated; use `PushNotification(siteID)` for mobile push)_ | | `RoomsCreate(account)` | `chat.user.{account}.request.rooms.create` | | `RoomsList(account)` | `chat.user.{account}.request.rooms.list` | | `RoomsGet(account, roomID)` | `chat.user.{account}.request.rooms.get.{roomID}` | @@ -273,9 +283,10 @@ Client A (sender) NATS Client B (rece | | | | notification-worker | | | | - | |--- pub: chat.user.B | - | | .notification -------->| - | | (desktop banner) | + | |--- pub: chat.server. | + | | notification.push. | + | | {siteID}.send | + | | (PUSH_NOTIFICATIONS stream) | | | | |--- pub: chat.user.A | | | .room.R1.typing -------->| | diff --git a/docs/notification-worker-downstream-contracts.md b/docs/notification-worker-downstream-contracts.md new file mode 100644 index 000000000..bc7e75499 --- /dev/null +++ b/docs/notification-worker-downstream-contracts.md @@ -0,0 +1,309 @@ +# notification-worker — Downstream Contracts + +This document specifies the contracts the `notification-worker` overhaul +([PR #237](https://github.com/hmchangw/chat/pull/237)) establishes for two +**internal-codebase** services (the push-notification service and the presence +service) plus the ops/IaC provisioning required to run the worker in +production. + +`notification-worker` is the **producer** for both contracts. It does not +implement either consumer. Until the consumers land, the worker runs in a +safe degraded mode (see each section). + +--- + +## 1. Push-notification service (mobile push delivery) + +**Status:** required for any mobile push to be delivered. Until the push +service consumes the stream, push events accumulate / are dropped per the +stream's retention policy — the worker publishes and moves on. + +### Transport + +| Property | Value | +|---|---| +| Stream | `PUSH_NOTIFICATIONS_{siteID}` | +| Bound subject filter | `chat.server.notification.push.{siteID}.>` | +| Publish subject (current leaf) | `chat.server.notification.push.{siteID}.send` | +| Namespace | `chat.server.*` — server-only; client JWTs have no subscribe permission | +| Delivery model | fire-and-forget async publish; durability via JetStream PubAck | +| Granularity | one event per **batch of up to `PUSH_RECIPIENT_BATCH_SIZE`** recipients (default `100`, configurable per deploy) | +| Payload encoding | JSON, **gzip-compressed**; consumers must read `Content-Encoding: gzip` header and decompress before `json.Unmarshal` | +| Stream storage compression | `S2` — transparent server-side, layered with gzip on top for inter-replica + on-disk savings | + +The `.send` leaf is the only current event type; the `.>` filter leaves room +for future siblings (`.silent`, `.priority`) without restructuring the stream. + +### Event schema + +`PushNotificationEvent` (JSON; `pkg/model/push.go`). The wire payload is gzip-compressed +(see § Payload decoding); the shape after decompression is: + +```json +{ + "id": "{messageId}-b{batchIndex}", + "accounts": ["alice", "bob", "carol"], + "title": "", + "body": "the message content", + "data": { + "roomId": "r123", + "messageId": "m456", + "type": "c", + "sender": { "account": "bob", "userId": "u-bob", "displayName": "Bob Chen 陳大寶" }, + "threadMessageId": "", + "fileName": "", + "fileType": "", + "parentRoomId": "", + "pushTime": "2026-05-28T00:00:00Z", + "alsoSendToChannel": false + }, + "roomId": "r123", + "timestamp": 1700000000000 +} +``` + +Field notes: + +- **`id`** = `{messageId}-b{batchIndex}` (zero-based). Also set as the `Nats-Msg-Id` header — see Dedup. `batchIndex` is stable across redeliveries because the worker sorts survivors lexicographically before chunking. +- **`accounts`** = recipient accounts in this batch, lexicographically sorted, capped by `PUSH_RECIPIENT_BATCH_SIZE` (default 100). The push service iterates this list, resolves device tokens per account, and is expected to use the provider's native multicast (e.g. FCM `send_each_for_multicast` — up to 500 tokens per call) so one batch becomes one outbound HTTP request. +- **`title`** is resolved by the worker so push-service needs no DB lookup. The rule mirrors the legacy implementation: `room.Name` if present, otherwise the sender's account (DM rooms have no name). Room metadata is served from an LRU+TTL cache (`pkg/roommetacache`) sized via `ROOM_META_CACHE_SIZE` / `ROOM_META_CACHE_TTL`. +- **`body`** is the raw message content, **untruncated**. The push service should truncate to the APNs/FCM payload limit (~4 KB total) before delivery. (Truncation/PII-scrubbing on the worker side is tracked as follow-up.) +- **`data.type`** is the short room type: `"c"` channel, `"d"` DM/botDM, `"p"` discussion. +- **`data.sender`** is a `Participant` carrying `account`, `userId`, and `displayName`. **`displayName` is pre-composed by `message-gatekeeper`** at canonical-message write time via `pkg/displayfmt.CombineWithFallback(engName, chineseName, account)` (same helper already used by `room-worker/sysmsg.go`, `room-service/store_mongo.go`, and reaction rendering — one source of truth for display formatting across the system). The composition happens once per message regardless of downstream consumer count and never on the push hot path; push-service renders `sender.displayName` verbatim. Empty `displayName` (legacy in-flight canonical messages predating the field) falls back to `sender.account` in `notification-worker`. `engName` / `chineseName` are deliberately not propagated on the push event since the composed string is the only render-time input. +- **`timestamp`** is event publish time (UnixMilli); **`data.pushTime`** is the RFC3339 domain send time. They are distinct fields. + +### Payload decoding + +The publisher sets `Content-Encoding: gzip` and `Content-Type: application/json` on +every event. Consumers must read the header and gunzip before `json.Unmarshal`. The +shared helper `pkg/natsutil.DecodePayload(*nats.Msg)` implements this — it returns +`msg.Data` verbatim for absent/`identity` encoding, gunzips for `gzip`, and errors +loudly on any other encoding to keep silent mis-parses out. + +### Payload size cap + +The wire payload (after gzip) is bounded by the broker's `max_payload`. The worker +reads `NATS_MAX_PAYLOAD_BYTES` (default `262144` = 256 KiB) and **rejects any +gzipped batch larger than the cap before publishing** — the emitter surfaces a +clear `exceeds NATS max_payload` error instead of letting the broker NACK with a +less informative one. The `PUSH_RECIPIENT_BATCH_SIZE=100` default leaves a wide +margin under 256 KiB for typical recipient/metadata sizes; the cap exists as a +last-resort guard against pathological events (huge bodies, oversized metadata). + +Set `NATS_MAX_PAYLOAD_BYTES` to match your broker's configured `max_payload`. The +push service should decode with `natsutil.DecodePayloadWithLimit(msg, )` +(or rely on the default 256 KiB) so the gzip-bomb defense matches the producer's +commitment. + +### Routing predicate notes + +- **`@here` is not a push trigger.** The worker treats `@all` as the broad-mention + signal; `@here` is parsed but not acted upon, because the current frontend does + not render `@here` mentions. A large-room message containing only `@here` will + result in zero push events. +- **`@all`** still bypasses the large-room throttle and the thread-follower gate. + +### Schema departures from the legacy push payload + +The push service must read the new tag names (one coordinated cutover — there +is no dual old/new support, since `PUSH_NOTIFICATIONS_{siteID}` is a new stream +with no prior consumer): + +| Legacy | New | +|---|---| +| `rid` | `roomId` | +| `tmid` | `threadMessageId` | +| `prid` | `parentRoomId` | +| flat `chineseName` / `engName` | nested `sender` object (`Participant`) — push-service reads `sender.displayName` (pre-composed at message-gatekeeper via `pkg/displayfmt.CombineWithFallback`) and renders it verbatim. `sender.account` remains as the final fallback when `displayName` is empty (only possible for legacy in-flight messages predating the field). | +| (none) | `timestamp` (event-level UnixMilli) added | + +### Dedup + +Dedup here protects against **upstream re-emit only** — push-service uses +`MaxDeliver=1` and ack-first (see § Consumer guidance), so it never causes +redelivery itself. The case dedup covers: `notification-worker` NAKs a +canonical message (emit error after retries), JetStream redelivers the +canonical event, the worker re-runs fan-out and re-publishes push events +with the same content. + +The worker sets the JetStream `Nats-Msg-Id` header to `{messageId}-b{batchIndex}`. +Batches are content-stable across redeliveries because the worker sorts +survivors before chunking, so the same canonical message always produces the +same `Nats-Msg-Id` set and JetStream drops the duplicates at the stream. For +this to suppress duplicate pushes, the **stream's dedup window must be ≥ the +canonical consumer's redelivery horizon**: + +```text +dedup_window ≥ AckWait × MaxDeliver = 30s × 5 = 150s (defaults) +``` + +Set the `PUSH_NOTIFICATIONS_{siteID}` `Duplicates` window to a safe margin +above 150s (e.g. 5 min). If the window is shorter, a canonical-message +redelivery (after a worker NAK) can produce a duplicate push. + +### Consumer guidance + +**Delivery semantics: at-most-once.** Push-service MUST ack the JetStream +message **on receipt**, before any provider HTTP call, and MUST NOT NAK or +trigger redelivery on provider failure. Rationale: a duplicate push is +user-visible spam; a missed push on transient provider failure is invisible +and bounded by the per-recipient HTTP retry below. + +- Use a durable consumer named after the push service. +- **Ack first.** Call `msg.Ack()` immediately after the payload decodes + cleanly — before fanning out to FCM/APNs. Provider outcomes do not affect + ack. +- **Set `MaxDeliver=1`** on the durable consumer. There is no upstream retry + semantics worth preserving here; the stable `Nats-Msg-Id` already protects + against `notification-worker` re-emit on canonical redelivery (see § Dedup). +- **`AckWait` can be tight** (e.g. `5s`). Ack happens within milliseconds of + receipt because no I/O blocks it; the wider default just causes slow + shutdowns on stuck pods. +- **HTTP retry per recipient: up to 2 attempts** with exponential backoff + (e.g. `100ms`, `400ms`). On terminal failure, **log and drop** — no + bookkeeping, no DLQ, no provider-side state machine. A structured log line + with `account`, `provider`, `status_code`, `error`, `messageId`, `batchId` + is enough for ops triage; aggregate alarming should fire on **error rate**, + not individual misses. +- Treat each event as a fan-out unit: iterate `accounts`, resolve device + tokens, and prefer a single multicast HTTP request per batch over + per-recipient calls (FCM `send_each_for_multicast` accepts up to 500 tokens; + one batch = one HTTP). +- A push for a bot account never arrives (the worker filters bots), so no + bot-device handling is needed. +- Decode the payload via `natsutil.DecodePayload(msg)` (or equivalent + gzip-aware decoder); never `json.Unmarshal(msg.Data, …)` directly. + +**Why no NAK / no MaxDeliver > 1**: the only failure modes that would benefit +from JetStream redelivery are (a) the push-service pod crashing before ack — +solved by acking immediately, and (b) provider being down — best handled by +a per-recipient HTTP retry that's bounded in wall time, not by re-running the +entire push fan-out which would duplicate pushes for recipients that did +succeed on the first pass. + +--- + +## 2. Presence service (DND gating) + +**Status:** optional but recommended. The worker ships with +`PRESENCE_RPC_ENABLED=false` and a no-op snapshotter, so **every push-eligible +recipient is pushed regardless of presence** (fail-open). Implementing this RPC +enables busy/in-call (DND) suppression. Flip `PRESENCE_RPC_ENABLED=true` once +it's live. + +### Transport + +| Property | Value | +|---|---| +| Subject | `chat.presence.{siteID}.request.snapshot` | +| Pattern | NATS request/reply | +| Cardinality | one request per canonical message (the worker chunks large account sets — see below) | + +### Request / reply schema + +`PresenceSnapshotRequest` → `PresenceSnapshotReply` (`pkg/model/presence.go`): + +```json +// request +{ "accounts": ["alice", "bob", "carol"] } + +// reply +{ + "presences": { + "alice": { "aggregatedStatus": "online" }, + "bob": { "aggregatedStatus": "busy" } + } +} +``` + +- **`aggregatedStatus`** is the single field the worker reads. The presence + service must **fold manual user overrides into this field** (the worker does + no override logic). One of: `online`, `offline`, `away`, `busy`, `in-call`. +- An account **absent from the reply map** is treated fail-open (pushed). +- On error, reply with the repo-standard `model.ErrorResponse` + (`{"error": "...", "code": "..."}`) via `natsutil.ReplyError`. The worker + detects this envelope, logs it, and fails open for that chunk. + +### Status → push decision (worker-side, for reference) + +| `aggregatedStatus` | Push? | Rationale | +|---|:--:|---| +| `online` | yes | multi-device — push fires alongside the client desktop banner | +| `offline` | yes | not connected — reach by push | +| `away` | yes | idle, not DND — fail-open | +| `busy` | **no** | Do-Not-Disturb | +| `in-call` | **no** | treated as DND (mirrors Teams in-meeting muting) | +| absent / RPC error | yes | fail-open — never drop on a presence gap | + +### Chunking / sizing + +For an `@all` to a very large room the survivor set can be thousands of +accounts. The worker splits the request across several concurrent RPCs at +`PRESENCE_BATCH_SIZE` (default 512) so each request/reply stays under the NATS +max message size, then merges replies. The presence service should size its +handler to answer a single ~512-account request comfortably; it does **not** +need to handle one giant request. + +The worker does **not** read the presence service's storage directly — the RPC +is the only coupling, so the presence service's Valkey/storage migration is +invisible to the worker. + +--- + +## 3. Ops / IaC provisioning + +Required before a production rollout: + +1. **Provision `PUSH_NOTIFICATIONS_{siteID}`** (the worker only bootstraps it + in dev via `BOOTSTRAP_STREAMS=true`; in prod `BOOTSTRAP_STREAMS=false` and + the worker only publishes). Set: + - Subjects: `chat.server.notification.push.{siteID}.>` + - `Duplicates` (dedup) window ≥ ~5 min (see §1 Dedup) + - Retention/limits per the push service's drain rate +2. **`subscriptions.roomType`** — already populated by `room-service`; the + worker reads it for routing. No action unless a site predates the field. +3. **`thread_subscriptions` `(parentMessageId, userAccount)` index** — the + worker ensures it idempotently at startup (bounded by + `INDEX_ENSURE_TIMEOUT`, default 2 min). On a large existing collection, + pre-create it so the first boot isn't slowed; otherwise no action. +4. **New env vars** (see `notification-worker/deploy/docker-compose.yml` for + dev values): + - `VALKEY_ADDRS` (**required**, comma-separated cluster seeds), `VALKEY_PASSWORD` + - `ROOMSUBCACHE_TTL` (default `5m`) — TTL for the Valkey room-member cache; no in-process L1 (per-pod memory bounded against very large rooms) + - `LARGE_ROOM_THRESHOLD` (default `500` — same knob as message-gatekeeper) + - `PUSH_RECIPIENT_BATCH_SIZE` (default `100` — recipients per push event; tune toward provider multicast caps) + - `NATS_MAX_PAYLOAD_BYTES` (default `262144` = 256 KiB — must match broker `max_payload`; see §1 Payload size cap) + - `ROOM_META_CACHE_SIZE` (default `10000`), `ROOM_META_CACHE_TTL` (default `2m`) — fronts `rooms` collection lookups for title resolution + - `PUSH_ASYNC_MAX_PENDING` (default `1024`) + + `message-gatekeeper` owns the sender display-name resolution; configure its + `USER_CACHE_SIZE` / `USER_CACHE_TTL` (defaults 10000 / 5m) there. + `notification-worker` does **no** users-collection lookups under this design. + - `INDEX_ENSURE_TIMEOUT` (default `2m`) + - `PRESENCE_RPC_ENABLED` (default `false`), `PRESENCE_BATCH_SIZE` (`512`), `PRESENCE_RPC_TIMEOUT` (`2s`) + +--- + +## 4. Optional — veto hook + +The worker exposes an in-process `Vetoer` (Stage 2, suppress-only) that ships +as `noopVetoer` (allows all). If the team has notification-suppression rules, +implement a real `Vetoer`: + +- Signature: `Allow(ctx, *model.Message, roomsubcache.Member) (bool, error)` +- It runs **once per recipient in-process** — any external data it needs must + be **batch-loaded once per message** before the per-recipient loop, never + fetched per recipient. +- On error the worker logs and fails open (allows). + +--- + +## Rollout sequencing (suggested) + +1. Land this PR; deploy the worker with `PRESENCE_RPC_ENABLED=false`. No pushes + are delivered yet (push service not consuming) — safe. +2. Provision the `PUSH_NOTIFICATIONS_{siteID}` stream (§3.1). +3. Ship the push service consumer (§1). Mobile push now flows; presence gating + is still fail-open (everyone eligible is pushed). +4. Ship the presence RPC handler (§2), then flip `PRESENCE_RPC_ENABLED=true`. + DND suppression now active. diff --git a/docs/superpowers/plans/2026-05-27-notification-worker-cache-and-mobile.md b/docs/superpowers/plans/2026-05-27-notification-worker-cache-and-mobile.md new file mode 100644 index 000000000..b130497e9 --- /dev/null +++ b/docs/superpowers/plans/2026-05-27-notification-worker-cache-and-mobile.md @@ -0,0 +1,2923 @@ +# Notification Worker — Cache, Routing & Mobile Push Implementation Plan + +> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** Replace the existing blanket fan-out in `notification-worker` with a cached, mention-gated, presence-aware mobile-push pipeline (per the 2026-05-22 spec). + +**Architecture:** Per canonical message, parse mentions + large-room flag once, load members through a Valkey-backed `roomsubcache` (with single-flight + a small in-process L1), apply ordered stages — Stage 1 exclusion filters (sender / mute / restricted / thread-non-follower), Stage 2 in-process hook veto, Stage 3 pure-CPU routing predicate, Stage 4 one bulk presence RPC — and emit one async JetStream push per surviving recipient on `chat.server.notification.push.{siteID}.send`. No desktop emit leg. + +**Tech Stack:** Go 1.25, NATS + JetStream (`nats.go`, raw `jetstream.New` for async publish), Valkey via `pkg/valkeyutil` + `pkg/roomsubcache`, MongoDB (`go.mongodb.org/mongo-driver/v2`), `golang.org/x/sync/singleflight`, `hashicorp/golang-lru/v2/expirable`, `stretchr/testify`. + +**Reference spec:** `docs/superpowers/specs/2026-05-22-notification-worker-cache-and-mobile-design.md`. + +--- + +## File Structure + +The notification-worker keeps the flat `package main` repo convention. Existing files (`main.go`, `handler.go`, `bootstrap.go`, `integration_test.go`, etc.) are modified; the new responsibilities are split into focused files: + +| Path | Status | Responsibility | +|---|---|---| +| `pkg/roomsubcache/roomsubcache.go` | modify | Widen `Member` projection (IsBot, ChineseName, EngName, Muted, HistorySharedSince) | +| `pkg/roomsubcache/roomsubcache_test.go` | modify | JSON round-trip + omitempty assertions for new fields | +| `pkg/mention/mention.go` | modify | Add `MentionHere bool` to `ParseResult` | +| `pkg/mention/mention_test.go` | modify | Cases for `@here` | +| `pkg/subject/subject.go` | modify | Add `PushNotification`, `PresenceSnapshot`, `SubscriptionUpdateWildcard`, `ParseSubscriptionUpdateAccount` | +| `pkg/subject/subject_test.go` | modify | Tests for the new builders / parser | +| `pkg/model/push.go` | new | `PushNotificationEvent` + `PushNotificationData` | +| `pkg/model/presence.go` | new | `PresenceSnapshotRequest` / `PresenceSnapshotReply` / `Presence` | +| `pkg/model/model_test.go` | modify | round-trip tests for the new payload structs | +| `notification-worker/routing.go` | new | Pure routing predicate (Stage 3) | +| `notification-worker/routing_test.go` | new | Exhaustive table-driven tests | +| `notification-worker/members.go` | new | `cachedMemberLookup` (Valkey cache + Mongo loader + single-flight + L1 LRU) | +| `notification-worker/members_test.go` | new | Hit / miss-then-populate / cache-error / L1 hit / single-flight collapse | +| `notification-worker/threads.go` | new | Thread-follower lookup + `parentMessageId` index ensure | +| `notification-worker/threads_test.go` | new | Lookup happy + empty + error paths (against a fake collection) | +| `notification-worker/hook.go` | new | `Hook` interface + `noopHook` | +| `notification-worker/hook_test.go` | new | `noopHook` always allows | +| `notification-worker/presence.go` | new | `PresenceSource` interface, no-op default, bulk-RPC impl, status→push map | +| `notification-worker/presence_test.go` | new | Status table, chunking, fail-open behaviour | +| `notification-worker/emit.go` | new | `mobileEmitter` (async JS publish, dedup header, bounded in-flight) | +| `notification-worker/emit_test.go` | new | Subject + dedup header + payload assertions; drain semantics | +| `notification-worker/handler.go` | rewrite | Orchestrates Stages 1–4, emits push; defines all consumer interfaces | +| `notification-worker/handler_test.go` | rewrite | Table-driven per-stage tests (sender, mute, restricted, thread, hook, routing, presence, emit) | +| `notification-worker/main.go` | modify | Valkey wiring, raw JS for async push, config additions, pipeline assembly, EnsureIndexes, invalidator subscription, drain on shutdown | +| `notification-worker/integration_test.go` | modify | Real Valkey + Mongo cover the cache path and end-to-end mobile-push subject | +| `notification-worker/deploy/docker-compose.yml` | modify | Add `VALKEY_ADDRS`, `LARGE_ROOM_THRESHOLD` env, depend on valkey | +| `docs/client-api.md` | modify | Note mute/restricted exclusions; remove the legacy `notification` event description | + +--- + +## Task 1: Extend `roomsubcache.Member` projection (TDD) + +**Files:** +- Modify: `pkg/roomsubcache/roomsubcache.go:29-35` +- Modify: `pkg/roomsubcache/roomsubcache_test.go` + +- [ ] **Step 1: Write the failing tests** + +Append to `pkg/roomsubcache/roomsubcache_test.go`: + +```go +func TestMember_JSONRoundTrip_NewFields(t *testing.T) { + hss := int64(1700000000000) + in := Member{ + ID: "u1", + Account: "alice", + IsBot: true, + ChineseName: "張三", + EngName: "Alice", + Muted: true, + HistorySharedSince: &hss, + } + data, err := json.Marshal(in) + require.NoError(t, err) + + var out Member + require.NoError(t, json.Unmarshal(data, &out)) + assert.Equal(t, in, out) +} + +func TestMember_OmitemptyOnZeroValues(t *testing.T) { + in := Member{ID: "u1", Account: "alice"} + data, err := json.Marshal(in) + require.NoError(t, err) + got := string(data) + + // Only id + account on the wire; no zero-valued booleans / strings / pointers. + assert.JSONEq(t, `{"id":"u1","account":"alice"}`, got) +} +``` + +Add imports if missing: + +```go +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) +``` + +- [ ] **Step 2: Run tests to verify they fail** + +Run: `make test SERVICE=roomsubcache` +Expected: FAIL — `Member` does not have the new fields. + +- [ ] **Step 3: Widen the projection** + +Replace the `Member` type in `pkg/roomsubcache/roomsubcache.go` (line ~29): + +```go +// Member is the projection of model.Subscription that notification-worker's +// fan-out path actually needs. Fields beyond {ID, Account} drive routing +// (IsBot), exclusion (Muted, HistorySharedSince), and push payload rendering +// (ChineseName, EngName for the message-author Sender). All extra fields +// use omitempty so a plain member's blob stays {id, account}. +type Member struct { + ID string `json:"id"` + Account string `json:"account"` + IsBot bool `json:"isBot,omitempty"` + ChineseName string `json:"chineseName,omitempty"` + EngName string `json:"engName,omitempty"` + Muted bool `json:"muted,omitempty"` + HistorySharedSince *int64 `json:"historySharedSince,omitempty"` +} +``` + +Also update the package doc comment at the top: + +```go +// The cache stores the fan-out path's per-member input set — see Member. +// Entries are written with a caller-supplied TTL and may be eagerly +// invalidated via Invalidate; staleness is otherwise bounded by the TTL. +``` + +- [ ] **Step 4: Run tests to verify they pass** + +Run: `make test SERVICE=roomsubcache && make lint` +Expected: PASS, no lint errors. + +- [ ] **Step 5: Commit** + +```bash +git add pkg/roomsubcache/roomsubcache.go pkg/roomsubcache/roomsubcache_test.go +git commit -m "feat(roomsubcache): widen Member projection for notification routing" +``` + +--- + +## Task 2: Add `@here` to `pkg/mention.ParseResult` (TDD) + +**Files:** +- Modify: `pkg/mention/mention.go` +- Modify: `pkg/mention/mention_test.go` + +- [ ] **Step 1: Write the failing tests** + +Append cases to the table in `TestParse`: + +```go +{name: "@here lowercase", content: "hey @here check this", accounts: nil, mentionAll: false, mentionHere: true}, +{name: "@Here mixed case", content: "@Here folks", accounts: nil, mentionAll: false, mentionHere: true}, +{name: "@all and @here", content: "@all then @here", accounts: nil, mentionAll: true, mentionHere: true}, +{name: "word@here not mention", content: "say hi here@all team", accounts: nil, mentionAll: false, mentionHere: false}, +``` + +Extend the case struct + assertions at the top of the test (do this once, edit existing cases to include the new field as zero-valued where appropriate): + +```go +tests := []struct { + name string + content string + accounts []string + mentionAll bool + mentionHere bool +}{ + // ... existing cases with mentionHere: false appended ... +} + +// In the t.Run body, add: +assert.Equal(t, tt.mentionHere, got.MentionHere, "MentionHere") +``` + +- [ ] **Step 2: Run tests to verify they fail** + +Run: `make test SERVICE=mention` +Expected: FAIL — `ParseResult` has no `MentionHere` field. + +- [ ] **Step 3: Implement** + +In `pkg/mention/mention.go`, extend `ParseResult`: + +```go +type ParseResult struct { + Accounts []string + MentionAll bool + MentionHere bool +} +``` + +In `Parse`, replace the special-case for `"all"`: + +```go +switch account { +case "all": + result.MentionAll = true + continue +case "here": + result.MentionHere = true + continue +} +``` + +- [ ] **Step 4: Run tests to verify they pass** + +Run: `make test SERVICE=mention && make lint` +Expected: PASS. + +- [ ] **Step 5: Commit** + +```bash +git add pkg/mention/mention.go pkg/mention/mention_test.go +git commit -m "feat(mention): surface @here in ParseResult" +``` + +--- + +## Task 3: New subject builders (TDD) + +**Files:** +- Modify: `pkg/subject/subject.go` +- Modify: `pkg/subject/subject_test.go` + +- [ ] **Step 1: Write the failing tests** + +Append to `pkg/subject/subject_test.go`: + +```go +func TestPushNotification(t *testing.T) { + assert.Equal(t, + "chat.server.notification.push.site-a.send", + subject.PushNotification("site-a")) +} + +func TestPresenceSnapshot(t *testing.T) { + assert.Equal(t, + "chat.presence.site-a.request.snapshot", + subject.PresenceSnapshot("site-a")) +} + +func TestSubscriptionUpdateWildcard(t *testing.T) { + assert.Equal(t, + "chat.user.*.event.subscription.update", + subject.SubscriptionUpdateWildcard()) +} + +func TestParseSubscriptionUpdateAccount(t *testing.T) { + acct, ok := subject.ParseSubscriptionUpdateAccount("chat.user.alice.event.subscription.update") + assert.True(t, ok) + assert.Equal(t, "alice", acct) + + _, ok = subject.ParseSubscriptionUpdateAccount("chat.user.alice.event.room.update") + assert.False(t, ok) + + _, ok = subject.ParseSubscriptionUpdateAccount("chat.user.*.event.subscription.update") + assert.False(t, ok) // wildcard token rejected +} +``` + +- [ ] **Step 2: Run tests to verify they fail** + +Run: `make test SERVICE=subject` +Expected: FAIL — undefined references. + +- [ ] **Step 3: Implement** + +Append to `pkg/subject/subject.go` (group with the existing builders): + +```go +// PushNotification is the single per-recipient mobile-push subject the +// notification-worker publishes to. Lives under chat.server.* so client +// JWTs cannot subscribe. Bound (by ops/IaC) to the PUSH_NOTIFICATIONS_{siteID} +// stream via filter "chat.server.notification.push.{siteID}.>" so additional +// leaves (e.g. .silent, .priority) can be added without restructuring. +func PushNotification(siteID string) string { + return fmt.Sprintf("chat.server.notification.push.%s.send", siteID) +} + +// PresenceSnapshot is the bulk presence RPC subject — one request per +// canonical message carrying the survivor account list, one reply with +// each account's aggregated status. +func PresenceSnapshot(siteID string) string { + return fmt.Sprintf("chat.presence.%s.request.snapshot", siteID) +} + +// SubscriptionUpdateWildcard matches every subscription.update fanout +// (chat.user.*.event.subscription.update). Used by notification-worker for +// eager cache invalidation. +func SubscriptionUpdateWildcard() string { + return "chat.user.*.event.subscription.update" +} + +// ParseSubscriptionUpdateAccount extracts the account token from a concrete +// subscription.update subject. Returns ok=false on wildcard or malformed +// input. +func ParseSubscriptionUpdateAccount(s string) (account string, ok bool) { + parts := strings.Split(s, ".") + if len(parts) != 6 { + return "", false + } + if parts[0] != "chat" || parts[1] != "user" || parts[3] != "event" || + parts[4] != "subscription" || parts[5] != "update" { + return "", false + } + if !isValidAccountToken(parts[2]) { + return "", false + } + return parts[2], true +} +``` + +- [ ] **Step 4: Run tests to verify they pass** + +Run: `make test SERVICE=subject && make lint` +Expected: PASS. + +- [ ] **Step 5: Commit** + +```bash +git add pkg/subject/subject.go pkg/subject/subject_test.go +git commit -m "feat(subject): add push, presence-snapshot, and subscription-update wildcard builders" +``` + +--- + +## Task 4: New payload models — push + presence (TDD) + +**Files:** +- Create: `pkg/model/push.go` +- Create: `pkg/model/presence.go` +- Modify: `pkg/model/model_test.go` + +- [ ] **Step 1: Write the failing tests** + +Append to `pkg/model/model_test.go` (use the existing `roundTrip` helper pattern; if it lives in another file, follow that file's convention): + +```go +func TestPushNotificationEvent_RoundTrip(t *testing.T) { + in := model.PushNotificationEvent{ + ID: "m1-alice", + Account: "alice", + Title: "general", + Body: "hello", + RoomID: "r1", + Data: model.PushNotificationData{ + RoomID: "r1", + MessageID: "m1", + Type: "c", + Sender: &model.Participant{Account: "bob", ChineseName: "張三", EngName: "Bob"}, + PushTime: "2026-05-27T00:00:00Z", + }, + Timestamp: 1700000000000, + } + data, err := json.Marshal(in) + require.NoError(t, err) + var out model.PushNotificationEvent + require.NoError(t, json.Unmarshal(data, &out)) + assert.Equal(t, in, out) +} + +func TestPresenceSnapshot_RoundTrip(t *testing.T) { + in := model.PresenceSnapshotReply{ + Presences: map[string]model.Presence{ + "alice": {AggregatedStatus: "online"}, + "bob": {AggregatedStatus: "busy"}, + }, + } + data, err := json.Marshal(in) + require.NoError(t, err) + var out model.PresenceSnapshotReply + require.NoError(t, json.Unmarshal(data, &out)) + assert.Equal(t, in, out) +} +``` + +- [ ] **Step 2: Run tests to verify they fail** + +Run: `make test SERVICE=model` +Expected: FAIL — undefined types. + +- [ ] **Step 3: Implement `pkg/model/push.go`** + +```go +package model + +// PushNotificationEvent is the per-recipient envelope notification-worker +// hands off to the internal push-notification service via the +// PUSH_NOTIFICATIONS_{siteID} stream. ID is "{messageID}-{account}" — also +// used as the Nats-Msg-Id for JetStream dedup so a same-message redelivery +// on MESSAGES_CANONICAL does not produce duplicate pushes. +type PushNotificationEvent struct { + ID string `json:"id" bson:"id"` + Account string `json:"account" bson:"account"` + Title string `json:"title" bson:"title"` + Body string `json:"body" bson:"body"` + Data PushNotificationData `json:"data" bson:"data"` + RoomID string `json:"roomId" bson:"roomId"` + Timestamp int64 `json:"timestamp" bson:"timestamp"` +} + +// PushNotificationData mirrors the legacy push-service payload with two +// repo-convention departures: cryptic tags (rid/tmid/prid) spelled out to +// camelCase, and the flat chineseName/engName fields collapsed into a +// *Participant Sender (matches ClientMessage.Sender). +type PushNotificationData struct { + RoomID string `json:"roomId" bson:"roomId"` + MessageID string `json:"messageId" bson:"messageId"` + Type string `json:"type" bson:"type"` + Sender *Participant `json:"sender,omitempty" bson:"sender,omitempty"` + ThreadMessageID string `json:"threadMessageId,omitempty" bson:"threadMessageId,omitempty"` + FileName string `json:"fileName,omitempty" bson:"fileName,omitempty"` + FileType string `json:"fileType,omitempty" bson:"fileType,omitempty"` + ParentRoomID string `json:"parentRoomId,omitempty" bson:"parentRoomId,omitempty"` + PushTime string `json:"pushTime" bson:"pushTime"` + AlsoSendToChannel bool `json:"alsoSendToChannel,omitempty" bson:"alsoSendToChannel,omitempty"` +} +``` + +- [ ] **Step 4: Implement `pkg/model/presence.go`** + +```go +package model + +// PresenceSnapshotRequest is the request payload of the bulk presence RPC. +// One request per canonical message carrying the push-eligible account set. +type PresenceSnapshotRequest struct { + Accounts []string `json:"accounts" bson:"accounts"` +} + +// PresenceSnapshotReply is the reply payload. Accounts absent from the map +// (or any RPC error) are treated fail-open by notification-worker — the +// push fires. +type PresenceSnapshotReply struct { + Presences map[string]Presence `json:"presences" bson:"presences"` +} + +// Presence is a single account's aggregated status. The presence service +// folds manual user overrides (e.g. busy) into AggregatedStatus so it is +// the sole field routing needs. +// +// Known values: "online", "offline", "away", "busy", "in-call". +// Only "busy" and "in-call" suppress the push (DND). +type Presence struct { + AggregatedStatus string `json:"aggregatedStatus" bson:"aggregatedStatus"` +} +``` + +- [ ] **Step 5: Run tests to verify they pass** + +Run: `make test SERVICE=model && make lint` +Expected: PASS. + +- [ ] **Step 6: Commit** + +```bash +git add pkg/model/push.go pkg/model/presence.go pkg/model/model_test.go +git commit -m "feat(model): add push notification + bulk presence RPC payloads" +``` + +--- + +## Task 5: Routing predicate (TDD) + +**Files:** +- Create: `notification-worker/routing.go` +- Create: `notification-worker/routing_test.go` + +- [ ] **Step 1: Write the failing tests** + +Create `notification-worker/routing_test.go`: + +```go +package main + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/hmchangw/chat/pkg/model" + "github.com/hmchangw/chat/pkg/roomsubcache" +) + +func TestEligibleForPush(t *testing.T) { + tests := []struct { + name string + member roomsubcache.Member + roomType model.RoomType + isLarge bool + mentioned bool + want bool + }{ + {name: "dm always", member: roomsubcache.Member{Account: "a"}, roomType: model.RoomTypeDM, want: true}, + {name: "botdm always", member: roomsubcache.Member{Account: "a"}, roomType: model.RoomTypeBotDM, want: true}, + {name: "small channel non-mention", member: roomsubcache.Member{Account: "a"}, roomType: model.RoomTypeChannel, isLarge: false, mentioned: false, want: true}, + {name: "small channel mention", member: roomsubcache.Member{Account: "a"}, roomType: model.RoomTypeChannel, isLarge: false, mentioned: true, want: true}, + {name: "large channel non-mention dropped", member: roomsubcache.Member{Account: "a"}, roomType: model.RoomTypeChannel, isLarge: true, mentioned: false, want: false}, + {name: "large channel mention pushed", member: roomsubcache.Member{Account: "a"}, roomType: model.RoomTypeChannel, isLarge: true, mentioned: true, want: true}, + {name: "bot never", member: roomsubcache.Member{Account: "bot", IsBot: true}, roomType: model.RoomTypeDM, want: false}, + {name: "bot in mention dropped", member: roomsubcache.Member{Account: "bot", IsBot: true}, roomType: model.RoomTypeChannel, mentioned: true, want: false}, + {name: "discussion small non-mention", member: roomsubcache.Member{Account: "a"}, roomType: model.RoomTypeDiscussion, want: true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := EligibleForPush(tt.member, tt.roomType, tt.isLarge, tt.mentioned) + assert.Equal(t, tt.want, got) + }) + } +} +``` + +- [ ] **Step 2: Run test to verify it fails** + +Run: `make test SERVICE=notification-worker` +Expected: FAIL — `EligibleForPush` undefined. + +- [ ] **Step 3: Implement `notification-worker/routing.go`** + +```go +package main + +import ( + "github.com/hmchangw/chat/pkg/model" + "github.com/hmchangw/chat/pkg/roomsubcache" +) + +// EligibleForPush is Stage 3 of the fan-out pipeline. Pure CPU — no I/O, +// no dependencies — so it is exhaustively unit-testable. The recipient is +// eligible when: (a) the room is a DM/botDM, OR mentioned, OR not large; +// AND (b) the recipient is not a bot. A "large" room is one whose member +// count exceeds LARGE_ROOM_THRESHOLD (computed once per message in the +// handler). +func EligibleForPush(m roomsubcache.Member, roomType model.RoomType, isLargeRoom, mentioned bool) bool { + if m.IsBot { + return false + } + if isDirect(roomType) { + return true + } + if mentioned { + return true + } + return !isLargeRoom +} + +func isDirect(t model.RoomType) bool { + return t == model.RoomTypeDM || t == model.RoomTypeBotDM +} +``` + +- [ ] **Step 4: Run tests to verify they pass** + +Run: `make test SERVICE=notification-worker` +Expected: PASS for `TestEligibleForPush` (other tests in the package may still fail at this point — that's fine). + +- [ ] **Step 5: Commit** + +```bash +git add notification-worker/routing.go notification-worker/routing_test.go +git commit -m "feat(notification-worker): pure routing predicate (Stage 3)" +``` + +--- + +## Task 6: Hook interface + no-op (TDD) + +**Files:** +- Create: `notification-worker/hook.go` +- Create: `notification-worker/hook_test.go` + +- [ ] **Step 1: Write the failing test** + +```go +package main + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/hmchangw/chat/pkg/model" + "github.com/hmchangw/chat/pkg/roomsubcache" +) + +func TestNoopHook_AlwaysAllows(t *testing.T) { + h := noopHook{} + allow, err := h.Allow(context.Background(), &model.Message{}, roomsubcache.Member{Account: "a"}) + assert.NoError(t, err) + assert.True(t, allow) +} +``` + +- [ ] **Step 2: Run test to verify it fails** + +Run: `make test SERVICE=notification-worker` +Expected: FAIL — `noopHook` undefined. + +- [ ] **Step 3: Implement `notification-worker/hook.go`** + +```go +package main + +import ( + "context" + + "github.com/hmchangw/chat/pkg/model" + "github.com/hmchangw/chat/pkg/roomsubcache" +) + +// Hook is the Stage-2 in-process suppress-only veto. Allow returns true to +// keep the recipient; false to drop. It must never perform a per-recipient +// external call — any data the real impl needs must be batch-loaded once +// per message by Handler before the per-recipient loop. +// +// Errors are treated fail-open by the handler (logged + allow), so a hook +// outage never silently drops notifications. +type Hook interface { + Allow(ctx context.Context, msg *model.Message, member roomsubcache.Member) (bool, error) +} + +type noopHook struct{} + +func (noopHook) Allow(context.Context, *model.Message, roomsubcache.Member) (bool, error) { + return true, nil +} +``` + +- [ ] **Step 4: Run tests to verify they pass** + +Run: `make test SERVICE=notification-worker` +Expected: PASS for `TestNoopHook_AlwaysAllows`. + +- [ ] **Step 5: Commit** + +```bash +git add notification-worker/hook.go notification-worker/hook_test.go +git commit -m "feat(notification-worker): hook interface + no-op default" +``` + +--- + +## Task 7: Presence source — no-op + bulk RPC + status map (TDD) + +**Files:** +- Create: `notification-worker/presence.go` +- Create: `notification-worker/presence_test.go` + +- [ ] **Step 1: Write the failing tests** + +```go +package main + +import ( + "context" + "encoding/json" + "errors" + "testing" + "time" + + "github.com/nats-io/nats.go" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/hmchangw/chat/pkg/model" +) + +func TestNoopPresence_EmptySnapshot(t *testing.T) { + p := noopPresenceSource{} + snap, err := p.Snapshot(context.Background(), []string{"alice", "bob"}) + require.NoError(t, err) + assert.Empty(t, snap) +} + +func TestShouldPush(t *testing.T) { + tests := []struct { + status string + want bool + }{ + {"online", true}, + {"offline", true}, + {"away", true}, + {"busy", false}, + {"in-call", false}, + {"", true}, // missing → fail-open + {"unknown", true}, // unknown → fail-open + } + for _, tt := range tests { + t.Run(tt.status, func(t *testing.T) { + assert.Equal(t, tt.want, shouldPush(model.Presence{AggregatedStatus: tt.status})) + }) + } +} + +// Stub requester implementing the presenceRequester interface so we can +// drive bulkPresence without a real NATS connection. +type stubRequester struct { + calls int + gotReqs []model.PresenceSnapshotRequest + reply func(req model.PresenceSnapshotRequest) (model.PresenceSnapshotReply, error) +} + +func (s *stubRequester) Request(_ context.Context, _ string, data []byte, _ time.Duration) (*nats.Msg, error) { + s.calls++ + var req model.PresenceSnapshotRequest + if err := json.Unmarshal(data, &req); err != nil { + return nil, err + } + s.gotReqs = append(s.gotReqs, req) + reply, err := s.reply(req) + if err != nil { + return nil, err + } + out, _ := json.Marshal(reply) + return &nats.Msg{Data: out}, nil +} + +func TestBulkPresence_Chunks(t *testing.T) { + accounts := make([]string, 1500) + for i := range accounts { + accounts[i] = "u" + } + // Distinct accounts so the map merge is observable. + for i := range accounts { + accounts[i] = string(rune('a'+i%26)) + "-" + string(rune('a'+i/26%26)) + } + stub := &stubRequester{reply: func(req model.PresenceSnapshotRequest) (model.PresenceSnapshotReply, error) { + out := model.PresenceSnapshotReply{Presences: map[string]model.Presence{}} + for _, a := range req.Accounts { + out.Presences[a] = model.Presence{AggregatedStatus: "online"} + } + return out, nil + }} + + src := newBulkPresenceSource(stub, "site-a", 500, time.Second) + got, err := src.Snapshot(context.Background(), accounts) + require.NoError(t, err) + assert.Equal(t, 3, stub.calls, "expect ceil(1500/500) chunks") + assert.Len(t, got, len(uniqueStrings(accounts))) +} + +func TestBulkPresence_FailOpenOnError(t *testing.T) { + stub := &stubRequester{reply: func(model.PresenceSnapshotRequest) (model.PresenceSnapshotReply, error) { + return model.PresenceSnapshotReply{}, errors.New("nats: timeout") + }} + src := newBulkPresenceSource(stub, "site-a", 100, 50*time.Millisecond) + got, err := src.Snapshot(context.Background(), []string{"a", "b"}) + require.NoError(t, err) // fail-open: error is swallowed and snapshot is empty + assert.Empty(t, got) +} + +func uniqueStrings(in []string) map[string]struct{} { + out := map[string]struct{}{} + for _, s := range in { + out[s] = struct{}{} + } + return out +} +``` + +- [ ] **Step 2: Run tests to verify they fail** + +Run: `make test SERVICE=notification-worker` +Expected: FAIL — types undefined. + +- [ ] **Step 3: Implement `notification-worker/presence.go`** + +```go +package main + +import ( + "context" + "encoding/json" + "fmt" + "log/slog" + "sync" + "time" + + "github.com/nats-io/nats.go" + + "github.com/hmchangw/chat/pkg/model" + "github.com/hmchangw/chat/pkg/subject" +) + +// PresenceSource is the Stage-4 dependency. Snapshot returns presence for +// each push-eligible account in one batched read (potentially split across +// several bulk RPCs for huge rooms). Errors are swallowed and surfaced as +// an empty snapshot — every recipient then fails open to a push. +type PresenceSource interface { + Snapshot(ctx context.Context, accounts []string) (map[string]model.Presence, error) +} + +// noopPresenceSource ships when the bulk presence RPC handler is not yet +// available on the presence service (see spec Open Question B). An empty +// snapshot makes every push-eligible recipient receive a push. +type noopPresenceSource struct{} + +func (noopPresenceSource) Snapshot(context.Context, []string) (map[string]model.Presence, error) { + return map[string]model.Presence{}, nil +} + +// presenceRequester is the minimal NATS surface bulkPresenceSource depends +// on — kept narrow so tests can substitute without a real connection. +type presenceRequester interface { + Request(ctx context.Context, subj string, data []byte, timeout time.Duration) (*nats.Msg, error) +} + +type bulkPresenceSource struct { + req presenceRequester + siteID string + batchSize int + timeout time.Duration +} + +func newBulkPresenceSource(req presenceRequester, siteID string, batchSize int, timeout time.Duration) *bulkPresenceSource { + if batchSize <= 0 { + batchSize = 512 + } + if timeout <= 0 { + timeout = 2 * time.Second + } + return &bulkPresenceSource{req: req, siteID: siteID, batchSize: batchSize, timeout: timeout} +} + +func (b *bulkPresenceSource) Snapshot(ctx context.Context, accounts []string) (map[string]model.Presence, error) { + if len(accounts) == 0 { + return map[string]model.Presence{}, nil + } + subj := subject.PresenceSnapshot(b.siteID) + chunks := chunkStrings(accounts, b.batchSize) + + var ( + mu sync.Mutex + out = make(map[string]model.Presence, len(accounts)) + wg sync.WaitGroup + ) + for _, ch := range chunks { + ch := ch + wg.Add(1) + go func() { + defer wg.Done() + data, err := json.Marshal(model.PresenceSnapshotRequest{Accounts: ch}) + if err != nil { + slog.Warn("presence marshal failed", "error", err) + return + } + msg, err := b.req.Request(ctx, subj, data, b.timeout) + if err != nil { + slog.Warn("presence rpc failed", "error", err, "chunk", len(ch)) + return + } + var reply model.PresenceSnapshotReply + if err := json.Unmarshal(msg.Data, &reply); err != nil { + slog.Warn("presence unmarshal failed", "error", err) + return + } + mu.Lock() + for k, v := range reply.Presences { + out[k] = v + } + mu.Unlock() + }() + } + wg.Wait() + return out, nil +} + +func chunkStrings(in []string, size int) [][]string { + if size <= 0 || len(in) <= size { + return [][]string{in} + } + out := make([][]string, 0, (len(in)+size-1)/size) + for i := 0; i < len(in); i += size { + end := i + size + if end > len(in) { + end = len(in) + } + out = append(out, in[i:end]) + } + return out +} + +// shouldPush maps an aggregated presence status to a push-or-not decision. +// Fail-open on unknown / missing — never drop a notification on a presence +// gap. +func shouldPush(p model.Presence) bool { + switch p.AggregatedStatus { + case "busy", "in-call": + return false + default: + return true + } +} + +// natsPresenceRequester adapts the production NATS connection to the +// presenceRequester interface. +type natsPresenceRequester struct { + nc *nats.Conn +} + +func (n *natsPresenceRequester) Request(ctx context.Context, subj string, data []byte, timeout time.Duration) (*nats.Msg, error) { + rctx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + msg, err := n.nc.RequestWithContext(rctx, subj, data) + if err != nil { + return nil, fmt.Errorf("presence request: %w", err) + } + return msg, nil +} +``` + +- [ ] **Step 4: Run tests to verify they pass** + +Run: `make test SERVICE=notification-worker` +Expected: PASS for the new presence tests. + +- [ ] **Step 5: Commit** + +```bash +git add notification-worker/presence.go notification-worker/presence_test.go +git commit -m "feat(notification-worker): bulk presence RPC + no-op default + status table" +``` + +--- + +## Task 8: Cached member lookup with single-flight + L1 (TDD) + +**Files:** +- Create: `notification-worker/members.go` +- Create: `notification-worker/members_test.go` + +- [ ] **Step 1: Write the failing tests** + +```go +package main + +import ( + "context" + "errors" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/hmchangw/chat/pkg/roomsubcache" + "github.com/hmchangw/chat/pkg/valkeyutil" +) + +// fakeCache implements roomsubcache.Cache in memory. +type fakeCache struct { + mu sync.Mutex + data map[string][]roomsubcache.Member +} + +func newFakeCache() *fakeCache { return &fakeCache{data: map[string][]roomsubcache.Member{}} } + +func (f *fakeCache) Get(_ context.Context, roomID string) ([]roomsubcache.Member, error) { + f.mu.Lock() + defer f.mu.Unlock() + v, ok := f.data[roomID] + if !ok { + return nil, valkeyutil.ErrCacheMiss + } + return v, nil +} +func (f *fakeCache) Set(_ context.Context, roomID string, members []roomsubcache.Member, _ time.Duration) error { + f.mu.Lock() + defer f.mu.Unlock() + cp := make([]roomsubcache.Member, len(members)) + copy(cp, members) + f.data[roomID] = cp + return nil +} +func (f *fakeCache) Invalidate(_ context.Context, roomID string) error { + f.mu.Lock() + defer f.mu.Unlock() + delete(f.data, roomID) + return nil +} + +// fakeLoader counts loader invocations. +type fakeLoader struct { + calls atomic.Int32 + out []roomsubcache.Member + err error + delay time.Duration +} + +func (f *fakeLoader) Load(_ context.Context, _ string) ([]roomsubcache.Member, error) { + f.calls.Add(1) + if f.delay > 0 { + time.Sleep(f.delay) + } + return f.out, f.err +} + +func TestCachedMemberLookup_HitFromValkey(t *testing.T) { + cache := newFakeCache() + loader := &fakeLoader{out: []roomsubcache.Member{{ID: "u1", Account: "alice"}}} + _ = cache.Set(context.Background(), "r1", loader.out, time.Minute) + + lookup := newCachedMemberLookup(cache, loader.Load, time.Minute, 0, 0) + got, err := lookup.GetMembers(context.Background(), "r1") + require.NoError(t, err) + assert.Equal(t, loader.out, got) + assert.Equal(t, int32(0), loader.calls.Load(), "loader must not be called on hit") +} + +func TestCachedMemberLookup_MissThenPopulate(t *testing.T) { + cache := newFakeCache() + loader := &fakeLoader{out: []roomsubcache.Member{{ID: "u1", Account: "alice"}}} + lookup := newCachedMemberLookup(cache, loader.Load, time.Minute, 0, 0) + + got, err := lookup.GetMembers(context.Background(), "r1") + require.NoError(t, err) + assert.Equal(t, loader.out, got) + + // Second call hits the cache. + _, _ = lookup.GetMembers(context.Background(), "r1") + assert.Equal(t, int32(1), loader.calls.Load()) +} + +func TestCachedMemberLookup_CacheErrorFallsThrough(t *testing.T) { + cache := &erroringCache{err: errors.New("valkey down")} + loader := &fakeLoader{out: []roomsubcache.Member{{ID: "u1", Account: "alice"}}} + lookup := newCachedMemberLookup(cache, loader.Load, time.Minute, 0, 0) + + got, err := lookup.GetMembers(context.Background(), "r1") + require.NoError(t, err) + assert.Equal(t, loader.out, got) + assert.Equal(t, int32(1), loader.calls.Load()) +} + +func TestCachedMemberLookup_SingleFlightCollapsesMisses(t *testing.T) { + cache := newFakeCache() + loader := &fakeLoader{ + out: []roomsubcache.Member{{ID: "u1", Account: "alice"}}, + delay: 50 * time.Millisecond, + } + lookup := newCachedMemberLookup(cache, loader.Load, time.Minute, 0, 0) + + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + _, _ = lookup.GetMembers(context.Background(), "r1") + }() + } + wg.Wait() + assert.Equal(t, int32(1), loader.calls.Load(), "single-flight collapses concurrent misses") +} + +func TestCachedMemberLookup_L1ServesRepeats(t *testing.T) { + cache := newFakeCache() + loader := &fakeLoader{out: []roomsubcache.Member{{ID: "u1", Account: "alice"}}} + // L1 size 10, TTL 5s + lookup := newCachedMemberLookup(cache, loader.Load, time.Minute, 10, 5*time.Second) + + for i := 0; i < 50; i++ { + _, err := lookup.GetMembers(context.Background(), "r1") + require.NoError(t, err) + } + // First fetch populates both Valkey and L1; subsequent calls hit L1. + assert.LessOrEqual(t, loader.calls.Load(), int32(1)) +} + +func TestCachedMemberLookup_InvalidateDropsL1(t *testing.T) { + cache := newFakeCache() + loader := &fakeLoader{out: []roomsubcache.Member{{ID: "u1", Account: "alice"}}} + lookup := newCachedMemberLookup(cache, loader.Load, time.Minute, 10, time.Minute) + + _, _ = lookup.GetMembers(context.Background(), "r1") + lookup.Invalidate(context.Background(), "r1") + loader.out = []roomsubcache.Member{{ID: "u2", Account: "bob"}} + got, _ := lookup.GetMembers(context.Background(), "r1") + + assert.Equal(t, loader.out, got, "after Invalidate the next read must reload") +} + +type erroringCache struct{ err error } + +func (e *erroringCache) Get(context.Context, string) ([]roomsubcache.Member, error) { + return nil, e.err +} +func (e *erroringCache) Set(context.Context, string, []roomsubcache.Member, time.Duration) error { + return nil +} +func (e *erroringCache) Invalidate(context.Context, string) error { return nil } +``` + +- [ ] **Step 2: Run tests to verify they fail** + +Run: `make test SERVICE=notification-worker` +Expected: FAIL — `cachedMemberLookup` undefined. + +- [ ] **Step 3: Implement `notification-worker/members.go`** + +```go +package main + +import ( + "context" + "errors" + "log/slog" + "time" + + lru "github.com/hashicorp/golang-lru/v2/expirable" + "golang.org/x/sync/singleflight" + + "github.com/hmchangw/chat/pkg/roomsubcache" + "github.com/hmchangw/chat/pkg/valkeyutil" +) + +// memberLoader reads the canonical (Mongo) member list for a room. The +// closure shape decouples cachedMemberLookup from the concrete +// mongoMemberLookup so tests can substitute trivially. +type memberLoader func(ctx context.Context, roomID string) ([]roomsubcache.Member, error) + +// cachedMemberLookup is notification-worker's MemberLookup implementation. +// Order of resolution per call: +// 1. In-process L1 LRU (decoded slices, short TTL) +// 2. Valkey via roomsubcache (multi-MB blob, JSON decode) +// 3. Mongo via the loader (cold start / TTL expiry) +// Single-flight guards stages 2→3 so a TTL-expiry stampede on a hot room +// collapses to one query. +type cachedMemberLookup struct { + cache roomsubcache.Cache + load memberLoader + ttl time.Duration + sf singleflight.Group + l1 *lru.LRU[string, []roomsubcache.Member] +} + +// newCachedMemberLookup wires the lookup. l1Size <= 0 disables the L1. +func newCachedMemberLookup(cache roomsubcache.Cache, load memberLoader, ttl time.Duration, l1Size int, l1TTL time.Duration) *cachedMemberLookup { + c := &cachedMemberLookup{cache: cache, load: load, ttl: ttl} + if l1Size > 0 { + c.l1 = lru.NewLRU[string, []roomsubcache.Member](l1Size, nil, l1TTL) + } + return c +} + +// GetMembers returns the member list for roomID, populating Valkey + L1 on +// a miss. Treats the returned slice as read-only — callers must not mutate. +func (c *cachedMemberLookup) GetMembers(ctx context.Context, roomID string) ([]roomsubcache.Member, error) { + if c.l1 != nil { + if v, ok := c.l1.Get(roomID); ok { + return v, nil + } + } + members, err, _ := c.sf.Do(roomID, func() (any, error) { + if c.l1 != nil { + if v, ok := c.l1.Get(roomID); ok { + return v, nil + } + } + got, err := c.cache.Get(ctx, roomID) + if err == nil { + c.populateL1(roomID, got) + return got, nil + } + if !errors.Is(err, valkeyutil.ErrCacheMiss) { + slog.Warn("roomsubcache get failed, falling back to mongo", "error", err, "roomId", roomID) + } + loaded, lerr := c.load(ctx, roomID) + if lerr != nil { + return nil, lerr + } + if setErr := c.cache.Set(ctx, roomID, loaded, c.ttl); setErr != nil { + slog.Warn("roomsubcache set failed", "error", setErr, "roomId", roomID) + } + c.populateL1(roomID, loaded) + return loaded, nil + }) + if err != nil { + return nil, err + } + return members.([]roomsubcache.Member), nil +} + +// Invalidate drops the room from both the L1 and the Valkey cache. Called +// by the subscription-update fan-out subscriber on every membership change. +func (c *cachedMemberLookup) Invalidate(ctx context.Context, roomID string) { + if c.l1 != nil { + c.l1.Remove(roomID) + } + if err := c.cache.Invalidate(ctx, roomID); err != nil { + slog.Warn("roomsubcache invalidate failed", "error", err, "roomId", roomID) + } +} + +func (c *cachedMemberLookup) populateL1(roomID string, members []roomsubcache.Member) { + if c.l1 == nil { + return + } + c.l1.Add(roomID, members) +} +``` + +- [ ] **Step 4: Run tests to verify they pass** + +Run: `make test SERVICE=notification-worker` +Expected: PASS for `TestCachedMemberLookup_*`. + +- [ ] **Step 5: Commit** + +```bash +git add notification-worker/members.go notification-worker/members_test.go +git commit -m "feat(notification-worker): valkey-backed member lookup with single-flight and L1" +``` + +--- + +## Task 9: Thread-follower lookup (TDD) + +**Files:** +- Create: `notification-worker/threads.go` +- Create: `notification-worker/threads_test.go` + +- [ ] **Step 1: Write the failing test** + +```go +package main + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type stubThreadLookup struct { + out []string + err error +} + +func (s *stubThreadLookup) followers(_ context.Context, _ string) (map[string]struct{}, error) { + if s.err != nil { + return nil, s.err + } + set := make(map[string]struct{}, len(s.out)) + for _, a := range s.out { + set[a] = struct{}{} + } + return set, nil +} + +func TestThreadFollowers_Resolve(t *testing.T) { + s := &stubThreadLookup{out: []string{"alice", "bob"}} + got, err := s.followers(context.Background(), "parent-1") + require.NoError(t, err) + assert.Contains(t, got, "alice") + assert.Contains(t, got, "bob") + assert.NotContains(t, got, "carol") +} + +func TestThreadFollowers_PropagatesError(t *testing.T) { + s := &stubThreadLookup{err: errors.New("mongo down")} + _, err := s.followers(context.Background(), "parent-1") + assert.Error(t, err) +} +``` + +- [ ] **Step 2: Run test to verify it fails** + +Run: `make test SERVICE=notification-worker` +Expected: FAIL — types missing. + +- [ ] **Step 3: Implement `notification-worker/threads.go`** + +```go +package main + +import ( + "context" + "fmt" + + "go.mongodb.org/mongo-driver/v2/bson" + "go.mongodb.org/mongo-driver/v2/mongo" + "go.mongodb.org/mongo-driver/v2/mongo/options" +) + +// ThreadFollowers returns the set of accounts subscribed to the thread +// rooted at parentMessageID. Backed by an indexed read on +// thread_subscriptions (parentMessageId, userAccount). Empty set on no +// followers. +type ThreadFollowers interface { + Followers(ctx context.Context, parentMessageID string) (map[string]struct{}, error) +} + +type mongoThreadFollowers struct { + col *mongo.Collection +} + +func newMongoThreadFollowers(col *mongo.Collection) *mongoThreadFollowers { + return &mongoThreadFollowers{col: col} +} + +func (m *mongoThreadFollowers) Followers(ctx context.Context, parentMessageID string) (map[string]struct{}, error) { + if parentMessageID == "" { + return map[string]struct{}{}, nil + } + opts := options.Find().SetProjection(bson.M{"userAccount": 1, "_id": 0}) + cur, err := m.col.Find(ctx, bson.M{"parentMessageId": parentMessageID}, opts) + if err != nil { + return nil, fmt.Errorf("find thread followers: %w", err) + } + defer cur.Close(ctx) + + out := map[string]struct{}{} + for cur.Next(ctx) { + var r struct { + UserAccount string `bson:"userAccount"` + } + if err := cur.Decode(&r); err != nil { + return nil, fmt.Errorf("decode thread subscription: %w", err) + } + if r.UserAccount != "" { + out[r.UserAccount] = struct{}{} + } + } + if err := cur.Err(); err != nil { + return nil, fmt.Errorf("iterate thread followers: %w", err) + } + return out, nil +} + +// EnsureThreadSubscriptionIndex creates the (parentMessageId, userAccount) +// index notification-worker reads. Idempotent — safe to call on every +// startup. The room-service also ensures this in its own EnsureIndexes; +// duplicating it here keeps notification-worker self-sufficient against a +// fresh database. +func EnsureThreadSubscriptionIndex(ctx context.Context, col *mongo.Collection) error { + _, err := col.Indexes().CreateOne(ctx, mongo.IndexModel{ + Keys: bson.D{{Key: "parentMessageId", Value: 1}, {Key: "userAccount", Value: 1}}, + }) + if err != nil { + return fmt.Errorf("ensure thread_subscriptions (parentMessageId, userAccount) index: %w", err) + } + return nil +} +``` + +- [ ] **Step 4: Run tests to verify they pass** + +Run: `make test SERVICE=notification-worker` +Expected: PASS. + +- [ ] **Step 5: Commit** + +```bash +git add notification-worker/threads.go notification-worker/threads_test.go +git commit -m "feat(notification-worker): thread-follower lookup + index ensure" +``` + +--- + +## Task 10: Mobile emitter — async JS publish + dedup header (TDD) + +**Files:** +- Create: `notification-worker/emit.go` +- Create: `notification-worker/emit_test.go` + +- [ ] **Step 1: Write the failing tests** + +```go +package main + +import ( + "context" + "encoding/json" + "errors" + "sync" + "testing" + + "github.com/nats-io/nats.go" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/hmchangw/chat/pkg/model" +) + +type recordedPublish struct { + subject string + msgID string + payload []byte +} + +type fakeAsyncPublisher struct { + mu sync.Mutex + records []recordedPublish + failNext error +} + +func (f *fakeAsyncPublisher) PublishMsgAsync(msg *nats.Msg) error { + f.mu.Lock() + defer f.mu.Unlock() + if f.failNext != nil { + err := f.failNext + f.failNext = nil + return err + } + f.records = append(f.records, recordedPublish{ + subject: msg.Subject, + msgID: msg.Header.Get("Nats-Msg-Id"), + payload: append([]byte(nil), msg.Data...), + }) + return nil +} + +func (f *fakeAsyncPublisher) drain(context.Context) {} + +func TestMobileEmitter_PublishesPerRecipient(t *testing.T) { + pub := &fakeAsyncPublisher{} + em := newMobileEmitter(pub, "site-a") + evt := model.PushNotificationEvent{ + ID: "m1-bob", + Account: "bob", + RoomID: "r1", + } + require.NoError(t, em.Emit(context.Background(), evt)) + + require.Len(t, pub.records, 1) + r := pub.records[0] + assert.Equal(t, "chat.server.notification.push.site-a.send", r.subject) + assert.Equal(t, "m1-bob", r.msgID) + + var got model.PushNotificationEvent + require.NoError(t, json.Unmarshal(r.payload, &got)) + assert.Equal(t, evt, got) +} + +func TestMobileEmitter_PropagatesError(t *testing.T) { + pub := &fakeAsyncPublisher{failNext: errors.New("nats: full")} + em := newMobileEmitter(pub, "site-a") + err := em.Emit(context.Background(), model.PushNotificationEvent{ID: "m1-bob", Account: "bob"}) + assert.Error(t, err) +} +``` + +- [ ] **Step 2: Run tests to verify they fail** + +Run: `make test SERVICE=notification-worker` +Expected: FAIL — types undefined. + +- [ ] **Step 3: Implement `notification-worker/emit.go`** + +```go +package main + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/nats-io/nats.go" + "github.com/nats-io/nats.go/jetstream" + + "github.com/hmchangw/chat/pkg/model" + "github.com/hmchangw/chat/pkg/subject" +) + +// asyncPublisher is the narrow JetStream surface mobileEmitter needs. +// Defined here so emit_test.go can substitute without a real NATS +// connection. drain blocks until every in-flight ack completes (or ctx +// elapses). +type asyncPublisher interface { + PublishMsgAsync(msg *nats.Msg) error + drain(ctx context.Context) +} + +// Emitter is the single mobile-push emit leg. The handler calls Emit for +// each surviving recipient. Errors are per-recipient; the handler logs and +// moves on (the canonical message still acks). +type Emitter interface { + Emit(ctx context.Context, evt model.PushNotificationEvent) error +} + +type mobileEmitter struct { + pub asyncPublisher + siteID string +} + +func newMobileEmitter(pub asyncPublisher, siteID string) *mobileEmitter { + return &mobileEmitter{pub: pub, siteID: siteID} +} + +func (e *mobileEmitter) Emit(_ context.Context, evt model.PushNotificationEvent) error { + data, err := json.Marshal(evt) + if err != nil { + return fmt.Errorf("marshal push event for %s: %w", evt.Account, err) + } + msg := &nats.Msg{ + Subject: subject.PushNotification(e.siteID), + Header: nats.Header{}, + Data: data, + } + // Per-recipient dedup so a redelivery of the canonical message never + // produces duplicate pushes. Window is the stream's MsgID dedup window + // (owned by the push service / ops). + msg.Header.Set("Nats-Msg-Id", evt.ID) + if err := e.pub.PublishMsgAsync(msg); err != nil { + return fmt.Errorf("publish push for %s: %w", evt.Account, err) + } + return nil +} + +// jsAsyncPublisher adapts a raw jetstream.JetStream + an in-flight cap to +// the asyncPublisher interface. Async publish is required for v1 — sync +// PublishMsg in a 10k-member fan-out would serialise ack round-trips. +// +// drain blocks on jetstream's PublishAsyncComplete or the provided ctx so +// graceful shutdown does not lose in-flight pushes. +type jsAsyncPublisher struct { + js jetstream.JetStream +} + +func newJSAsyncPublisher(js jetstream.JetStream) *jsAsyncPublisher { + return &jsAsyncPublisher{js: js} +} + +func (j *jsAsyncPublisher) PublishMsgAsync(msg *nats.Msg) error { + _, err := j.js.PublishMsgAsync(msg) + return err +} + +func (j *jsAsyncPublisher) drain(ctx context.Context) { + select { + case <-j.js.PublishAsyncComplete(): + case <-ctx.Done(): + } +} +``` + +- [ ] **Step 4: Run tests to verify they pass** + +Run: `make test SERVICE=notification-worker` +Expected: PASS for `TestMobileEmitter_*`. + +- [ ] **Step 5: Commit** + +```bash +git add notification-worker/emit.go notification-worker/emit_test.go +git commit -m "feat(notification-worker): async mobile-push emitter with per-recipient dedup" +``` + +--- + +## Task 11: Rewrite the handler — Stages 1–4 + push payload (TDD) + +This is the heart of the change. The existing `handler.go` is replaced wholesale; tests are replaced to match the new behaviour. + +**Files:** +- Modify: `notification-worker/handler.go` +- Modify: `notification-worker/handler_test.go` + +- [ ] **Step 1: Write the failing handler tests** (replaces the existing file) + +Replace the entire contents of `notification-worker/handler_test.go` with: + +```go +package main + +import ( + "context" + "encoding/json" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/hmchangw/chat/pkg/model" + "github.com/hmchangw/chat/pkg/roomsubcache" +) + +// --- Stubs --- + +type stubMembers struct { + out map[string][]roomsubcache.Member +} + +func (s *stubMembers) GetMembers(_ context.Context, roomID string) ([]roomsubcache.Member, error) { + return s.out[roomID], nil +} + +type stubFollowers struct { + out map[string]map[string]struct{} +} + +func (s *stubFollowers) Followers(_ context.Context, parentID string) (map[string]struct{}, error) { + if v, ok := s.out[parentID]; ok { + return v, nil + } + return map[string]struct{}{}, nil +} + +type stubPresence struct { + out map[string]model.Presence +} + +func (s *stubPresence) Snapshot(_ context.Context, _ []string) (map[string]model.Presence, error) { + return s.out, nil +} + +type rejectHook struct{} + +func (rejectHook) Allow(context.Context, *model.Message, roomsubcache.Member) (bool, error) { + return false, nil +} + +type recordingEmitter struct { + mu sync.Mutex + emitted []model.PushNotificationEvent +} + +func (r *recordingEmitter) Emit(_ context.Context, evt model.PushNotificationEvent) error { + r.mu.Lock() + defer r.mu.Unlock() + r.emitted = append(r.emitted, evt) + return nil +} + +func (r *recordingEmitter) accounts() []string { + r.mu.Lock() + defer r.mu.Unlock() + out := make([]string, 0, len(r.emitted)) + for _, e := range r.emitted { + out = append(out, e.Account) + } + return out +} + +// --- Helpers --- + +func newTestHandler(members MemberLookup, followers ThreadFollowers, presence PresenceSource, hook Hook, emit Emitter) *Handler { + return NewHandler(HandlerDeps{ + Members: members, + Followers: followers, + Presence: presence, + Hook: hook, + Emitter: emit, + LargeRoomThreshold: 500, + }) +} + +func msgEvent(m model.Message) []byte { + data, _ := json.Marshal(model.MessageEvent{Message: m, SiteID: "site-a"}) + return data +} + +// --- Stage 1: exclusion filters --- + +func TestHandle_SkipsSender(t *testing.T) { + members := &stubMembers{out: map[string][]roomsubcache.Member{ + "r1": { + {ID: "alice", Account: "alice"}, + {ID: "bob", Account: "bob"}, + }, + }} + emit := &recordingEmitter{} + h := newTestHandler(members, &stubFollowers{}, noopPresenceSource{}, noopHook{}, emit) + + require.NoError(t, h.HandleMessage(context.Background(), msgEvent(model.Message{ + ID: "m1", RoomID: "r1", UserID: "alice", UserAccount: "alice", + CreatedAt: time.Now(), + }))) + assert.Equal(t, []string{"bob"}, emit.accounts()) +} + +func TestHandle_SkipsMuted(t *testing.T) { + members := &stubMembers{out: map[string][]roomsubcache.Member{ + "r1": { + {ID: "alice", Account: "alice"}, + {ID: "bob", Account: "bob", Muted: true}, + {ID: "carol", Account: "carol"}, + }, + }} + emit := &recordingEmitter{} + h := newTestHandler(members, &stubFollowers{}, noopPresenceSource{}, noopHook{}, emit) + + require.NoError(t, h.HandleMessage(context.Background(), msgEvent(model.Message{ + ID: "m1", RoomID: "r1", UserID: "alice", UserAccount: "alice", + CreatedAt: time.Now(), + }))) + assert.ElementsMatch(t, []string{"carol"}, emit.accounts(), "muted bob is skipped") +} + +func TestHandle_SkipsRestrictedBeforeWindow(t *testing.T) { + createdAt := time.Unix(0, 1700000000000*int64(time.Millisecond)) + afterWindow := int64(1700000000001) // 1ms before createdAt = visible? No — strictly before window + beforeWindow := int64(1699999999999) // earlier than the message: member sees it + members := &stubMembers{out: map[string][]roomsubcache.Member{ + "r1": { + {ID: "alice", Account: "alice"}, + {ID: "bob", Account: "bob", HistorySharedSince: &afterWindow}, // joined after message → skip + {ID: "carol", Account: "carol", HistorySharedSince: &beforeWindow}, // joined before → include + }, + }} + emit := &recordingEmitter{} + h := newTestHandler(members, &stubFollowers{}, noopPresenceSource{}, noopHook{}, emit) + + require.NoError(t, h.HandleMessage(context.Background(), msgEvent(model.Message{ + ID: "m1", RoomID: "r1", UserID: "alice", UserAccount: "alice", CreatedAt: createdAt, + }))) + assert.ElementsMatch(t, []string{"carol"}, emit.accounts()) +} + +// --- Stage 1: thread non-follower --- + +func TestHandle_ThreadOnlyReply_SkipsNonFollowerNonMention(t *testing.T) { + parentCreatedAt := time.Unix(0, 1700000000000*int64(time.Millisecond)) + members := &stubMembers{out: map[string][]roomsubcache.Member{ + "r1": { + {ID: "alice", Account: "alice"}, + {ID: "bob", Account: "bob"}, + {ID: "carol", Account: "carol"}, + }, + }} + followers := &stubFollowers{out: map[string]map[string]struct{}{ + "parent-1": {"bob": {}}, + }} + emit := &recordingEmitter{} + h := newTestHandler(members, followers, noopPresenceSource{}, noopHook{}, emit) + + msg := model.Message{ + ID: "m1", RoomID: "r1", UserID: "alice", UserAccount: "alice", CreatedAt: time.Now(), + ThreadParentMessageID: "parent-1", + ThreadParentMessageCreatedAt: &parentCreatedAt, + TShow: false, + Content: "thread reply", + } + require.NoError(t, h.HandleMessage(context.Background(), msgEvent(msg))) + assert.ElementsMatch(t, []string{"bob"}, emit.accounts(), "only thread follower receives") +} + +func TestHandle_ThreadReply_TShow_TreatedAsChannelMessage(t *testing.T) { + members := &stubMembers{out: map[string][]roomsubcache.Member{ + "r1": { + {ID: "alice", Account: "alice"}, + {ID: "bob", Account: "bob"}, + {ID: "carol", Account: "carol"}, + }, + }} + emit := &recordingEmitter{} + h := newTestHandler(members, &stubFollowers{}, noopPresenceSource{}, noopHook{}, emit) + + msg := model.Message{ + ID: "m1", RoomID: "r1", UserID: "alice", UserAccount: "alice", CreatedAt: time.Now(), + ThreadParentMessageID: "parent-1", + TShow: true, + Content: "shared with channel", + } + require.NoError(t, h.HandleMessage(context.Background(), msgEvent(msg))) + assert.ElementsMatch(t, []string{"bob", "carol"}, emit.accounts()) +} + +// --- Stage 2: hook veto --- + +func TestHandle_HookVeto_DropsAll(t *testing.T) { + members := &stubMembers{out: map[string][]roomsubcache.Member{ + "r1": { + {ID: "alice", Account: "alice"}, + {ID: "bob", Account: "bob"}, + }, + }} + emit := &recordingEmitter{} + h := newTestHandler(members, &stubFollowers{}, noopPresenceSource{}, rejectHook{}, emit) + + require.NoError(t, h.HandleMessage(context.Background(), msgEvent(model.Message{ + ID: "m1", RoomID: "r1", UserID: "alice", UserAccount: "alice", CreatedAt: time.Now(), + }))) + assert.Empty(t, emit.accounts()) +} + +// --- Stage 3: routing (large room) --- + +func TestHandle_LargeRoomNonMention_DropsAll(t *testing.T) { + roomMembers := make([]roomsubcache.Member, 600) + for i := range roomMembers { + roomMembers[i] = roomsubcache.Member{ID: "u", Account: "u" + string(rune(i))} + } + roomMembers[0] = roomsubcache.Member{ID: "alice", Account: "alice"} + members := &stubMembers{out: map[string][]roomsubcache.Member{"r1": roomMembers}} + emit := &recordingEmitter{} + h := NewHandler(HandlerDeps{ + Members: members, + Followers: &stubFollowers{}, + Presence: noopPresenceSource{}, + Hook: noopHook{}, + Emitter: emit, + LargeRoomThreshold: 500, + }) + + require.NoError(t, h.HandleMessage(context.Background(), msgEvent(model.Message{ + ID: "m1", RoomID: "r1", UserID: "alice", UserAccount: "alice", Content: "no mentions", + CreatedAt: time.Now(), + }))) + assert.Empty(t, emit.accounts(), "large room non-mention drops all") +} + +func TestHandle_LargeRoomMention_OnlyMentionedPushed(t *testing.T) { + roomMembers := []roomsubcache.Member{ + {ID: "alice", Account: "alice"}, + {ID: "bob", Account: "bob"}, + {ID: "carol", Account: "carol"}, + } + // pad to large + for i := 0; i < 600; i++ { + roomMembers = append(roomMembers, roomsubcache.Member{ID: "u" + string(rune(i)), Account: "u" + string(rune(i))}) + } + members := &stubMembers{out: map[string][]roomsubcache.Member{"r1": roomMembers}} + emit := &recordingEmitter{} + h := newTestHandler(members, &stubFollowers{}, noopPresenceSource{}, noopHook{}, emit) + + require.NoError(t, h.HandleMessage(context.Background(), msgEvent(model.Message{ + ID: "m1", RoomID: "r1", UserID: "alice", UserAccount: "alice", + Content: "hey @bob check this", CreatedAt: time.Now(), + }))) + assert.ElementsMatch(t, []string{"bob"}, emit.accounts()) +} + +func TestHandle_LargeRoomAtAll_PushesAllNonSender(t *testing.T) { + roomMembers := []roomsubcache.Member{ + {ID: "alice", Account: "alice"}, + {ID: "bob", Account: "bob"}, + {ID: "carol", Account: "carol"}, + } + for i := 0; i < 500; i++ { + roomMembers = append(roomMembers, roomsubcache.Member{ID: "u", Account: "u" + string(rune(i))}) + } + members := &stubMembers{out: map[string][]roomsubcache.Member{"r1": roomMembers}} + emit := &recordingEmitter{} + h := newTestHandler(members, &stubFollowers{}, noopPresenceSource{}, noopHook{}, emit) + + require.NoError(t, h.HandleMessage(context.Background(), msgEvent(model.Message{ + ID: "m1", RoomID: "r1", UserID: "alice", UserAccount: "alice", + Content: "@all heads up", CreatedAt: time.Now(), + }))) + assert.Contains(t, emit.accounts(), "bob") + assert.Contains(t, emit.accounts(), "carol") + assert.NotContains(t, emit.accounts(), "alice") +} + +// --- Stage 4: presence --- + +func TestHandle_PresenceBusyDropsPush(t *testing.T) { + members := &stubMembers{out: map[string][]roomsubcache.Member{ + "r1": { + {ID: "alice", Account: "alice"}, + {ID: "bob", Account: "bob"}, + {ID: "carol", Account: "carol"}, + }, + }} + presence := &stubPresence{out: map[string]model.Presence{ + "bob": {AggregatedStatus: "busy"}, + "carol": {AggregatedStatus: "online"}, + }} + emit := &recordingEmitter{} + h := newTestHandler(members, &stubFollowers{}, presence, noopHook{}, emit) + + require.NoError(t, h.HandleMessage(context.Background(), msgEvent(model.Message{ + ID: "m1", RoomID: "r1", UserID: "alice", UserAccount: "alice", CreatedAt: time.Now(), + }))) + assert.ElementsMatch(t, []string{"carol"}, emit.accounts()) +} + +// --- Payload shape --- + +func TestHandle_PushPayloadSenderFromMemberRecord(t *testing.T) { + members := &stubMembers{out: map[string][]roomsubcache.Member{ + "r1": { + {ID: "alice", Account: "alice", ChineseName: "張三", EngName: "Alice"}, + {ID: "bob", Account: "bob"}, + }, + }} + emit := &recordingEmitter{} + h := newTestHandler(members, &stubFollowers{}, noopPresenceSource{}, noopHook{}, emit) + + require.NoError(t, h.HandleMessage(context.Background(), msgEvent(model.Message{ + ID: "m1", RoomID: "r1", UserID: "alice", UserAccount: "alice", + Content: "hello", + CreatedAt: time.Unix(0, 1700000000000*int64(time.Millisecond)), + }))) + require.Len(t, emit.emitted, 1) + got := emit.emitted[0] + assert.Equal(t, "m1-bob", got.ID, "dedup-stable ID") + assert.Equal(t, "bob", got.Account) + assert.Equal(t, "r1", got.RoomID) + require.NotNil(t, got.Data.Sender) + assert.Equal(t, "alice", got.Data.Sender.Account) + assert.Equal(t, "張三", got.Data.Sender.ChineseName) + assert.Equal(t, "Alice", got.Data.Sender.EngName) + assert.Equal(t, "m1", got.Data.MessageID) + assert.NotEmpty(t, got.Data.PushTime) + assert.Greater(t, got.Timestamp, int64(0)) +} + +func TestHandle_InvalidJSON(t *testing.T) { + emit := &recordingEmitter{} + h := newTestHandler(&stubMembers{}, &stubFollowers{}, noopPresenceSource{}, noopHook{}, emit) + err := h.HandleMessage(context.Background(), []byte("not json")) + assert.Error(t, err) +} +``` + +- [ ] **Step 2: Run the tests to confirm they fail** + +Run: `make test SERVICE=notification-worker` +Expected: FAIL — `HandlerDeps`, `MemberLookup` (new interface name), payload shape all undefined. + +- [ ] **Step 3: Replace `notification-worker/handler.go`** with the new orchestrator + +```go +package main + +import ( + "context" + "encoding/json" + "fmt" + "log/slog" + "time" + + "github.com/hmchangw/chat/pkg/mention" + "github.com/hmchangw/chat/pkg/model" + "github.com/hmchangw/chat/pkg/roomsubcache" +) + +// MemberLookup returns the cached/canonical member list for a room. Slices +// are treated read-only by the handler. +type MemberLookup interface { + GetMembers(ctx context.Context, roomID string) ([]roomsubcache.Member, error) +} + +// HandlerDeps groups the handler's collaborators. Defined as a struct so +// adding a new collaborator does not churn the constructor signature. +type HandlerDeps struct { + Members MemberLookup + Followers ThreadFollowers + Presence PresenceSource + Hook Hook + Emitter Emitter + LargeRoomThreshold int +} + +// Handler runs the per-message fan-out pipeline: +// Stage 1 — exclusion filters (sender / mute / restricted / thread-non-follower) +// Stage 2 — in-process hook veto (suppress-only, fail-open on error) +// Stage 3 — pure routing predicate (EligibleForPush) +// Stage 4 — one bulk presence RPC, then per-account shouldPush +// followed by one Emitter.Emit per surviving recipient. +type Handler struct { + deps HandlerDeps +} + +func NewHandler(deps HandlerDeps) *Handler { + if deps.LargeRoomThreshold <= 0 { + deps.LargeRoomThreshold = 500 + } + return &Handler{deps: deps} +} + +func (h *Handler) HandleMessage(ctx context.Context, data []byte) error { + var evt model.MessageEvent + if err := json.Unmarshal(data, &evt); err != nil { + return fmt.Errorf("unmarshal message event: %w", err) + } + msg := evt.Message + + members, err := h.deps.Members.GetMembers(ctx, msg.RoomID) + if err != nil { + return fmt.Errorf("get members for room %s: %w", msg.RoomID, err) + } + if len(members) == 0 { + return nil + } + + // --- Once-per-message inputs --- + mentionInfo := mention.Parse(msg.Content) + mentionedAccounts := mentionedSet(mentionInfo, msg.Mentions) + mentionsAllOrHere := mentionInfo.MentionAll || mentionInfo.MentionHere + isLargeRoom := len(members) > h.deps.LargeRoomThreshold + isThreadOnlyReply := msg.ThreadParentMessageID != "" && !msg.TShow + + var followers map[string]struct{} + if isThreadOnlyReply { + f, ferr := h.deps.Followers.Followers(ctx, msg.ThreadParentMessageID) + if ferr != nil { + slog.Warn("thread followers lookup failed, treating as empty", + "error", ferr, "parentMessageId", msg.ThreadParentMessageID) + f = map[string]struct{}{} + } + followers = f + } + + roomType := deriveRoomType(members) + + // Author display info — taken from the loaded member list (no separate lookup). + var sender *model.Participant + for i := range members { + if members[i].ID == msg.UserID { + sender = &model.Participant{ + Account: members[i].Account, + ChineseName: members[i].ChineseName, + EngName: members[i].EngName, + } + break + } + } + + // --- Stages 1–3: build the push-eligible recipient set --- + type candidate struct { + member roomsubcache.Member + } + candidates := make([]candidate, 0, len(members)) + for _, m := range members { + // Stage 1.1 — sender + if m.ID == msg.UserID { + continue + } + // Stage 1.2 — mute + if m.Muted { + continue + } + // Stage 1.3 — restricted room + if isRestricted(m, msg, isThreadOnlyReply) { + continue + } + + mentioned := mentionsAllOrHere || mentionedAccounts[m.Account] + + // Stage 1.4 — thread non-follower (only for thread-only replies) + if isThreadOnlyReply { + _, follows := followers[m.Account] + if !follows && !mentioned { + continue + } + } + + // Stage 2 — hook veto (fail-open on error) + allow, herr := h.deps.Hook.Allow(ctx, &msg, m) + if herr != nil { + slog.Warn("hook errored, allowing", "error", herr, "account", m.Account) + allow = true + } + if !allow { + continue + } + + // Stage 3 — routing predicate + if !EligibleForPush(m, roomType, isLargeRoom, mentioned) { + continue + } + + candidates = append(candidates, candidate{member: m}) + } + if len(candidates) == 0 { + return nil + } + + // --- Stage 4: presence snapshot for the eligible set --- + accounts := make([]string, len(candidates)) + for i, c := range candidates { + accounts[i] = c.member.Account + } + snapshot, _ := h.deps.Presence.Snapshot(ctx, accounts) // fail-open: error → empty + + // --- Emit --- + nowMs := time.Now().UTC().UnixMilli() + pushTime := time.Now().UTC().Format(time.RFC3339) + for _, c := range candidates { + if !shouldPush(snapshot[c.member.Account]) { + continue + } + evt := model.PushNotificationEvent{ + ID: msg.ID + "-" + c.member.Account, + Account: c.member.Account, + RoomID: msg.RoomID, + Title: "", // population deferred — room name lives off-msg; see Future work + Body: msg.Content, + Data: model.PushNotificationData{ + RoomID: msg.RoomID, + MessageID: msg.ID, + Type: shortRoomType(roomType), + Sender: sender, + ThreadMessageID: msg.ThreadParentMessageID, + PushTime: pushTime, + AlsoSendToChannel: msg.TShow, + }, + Timestamp: nowMs, + } + if err := h.deps.Emitter.Emit(ctx, evt); err != nil { + slog.Error("emit push failed", "error", err, "account", c.member.Account, "messageId", msg.ID) + } + } + return nil +} + +// mentionedSet returns the union of (a) accounts parsed from message +// content and (b) explicit Mentions on the canonical message, normalised +// lowercase. Map lookup is O(1) on the per-recipient loop. +func mentionedSet(parsed mention.ParseResult, explicit []model.Participant) map[string]bool { + out := make(map[string]bool, len(parsed.Accounts)+len(explicit)) + for _, a := range parsed.Accounts { + out[a] = true + } + for _, p := range explicit { + if p.Account != "" { + out[p.Account] = true + } + } + return out +} + +// isRestricted returns true when the member should be filtered out because +// they joined the room after the relevant timestamp. For a thread-only +// reply the relevant ts is the parent's CreatedAt (history-service rule); +// for a channel message it's the message's own CreatedAt. A nil parent ts +// on a thread reply is treated conservatively as "no access" (legacy +// thread replies). +func isRestricted(m roomsubcache.Member, msg model.Message, isThreadOnlyReply bool) bool { + if m.HistorySharedSince == nil { + return false + } + if isThreadOnlyReply { + if msg.ThreadParentMessageCreatedAt == nil { + return true + } + return msg.ThreadParentMessageCreatedAt.UnixMilli() < *m.HistorySharedSince + } + return msg.CreatedAt.UnixMilli() < *m.HistorySharedSince +} + +// deriveRoomType returns a usable RoomType for routing. The member-list +// projection no longer carries the room type per-member (it's per-room), +// but msg.Type doesn't carry it either. For routing we only need to +// distinguish DM/botDM from channel/discussion — derive from member count +// (≤2 = DM/botDM-shaped) as a safe default until the projection carries +// room type. TODO: thread to RoomMetadataCache when wired. +func deriveRoomType(members []roomsubcache.Member) model.RoomType { + if len(members) <= 2 { + return model.RoomTypeDM + } + return model.RoomTypeChannel +} + +func shortRoomType(t model.RoomType) string { + switch t { + case model.RoomTypeDM, model.RoomTypeBotDM: + return "d" + case model.RoomTypeDiscussion: + return "p" + default: + return "c" + } +} +``` + +- [ ] **Step 4: Run the tests to verify they pass** + +Run: `make test SERVICE=notification-worker` +Expected: PASS for all handler tests. + +- [ ] **Step 5: Run lint** + +Run: `make lint` +Expected: PASS — fix any import / unused-variable issues. Remove the `_ = time.Second` workaround in `emit.go` if lint complains about unused imports; restructure as needed. + +- [ ] **Step 6: Commit** + +```bash +git add notification-worker/handler.go notification-worker/handler_test.go +git commit -m "feat(notification-worker): mention-gated routing pipeline with mobile push payload" +``` + +--- + +## Task 12: Wire main.go — Valkey, raw JS, EnsureIndexes, invalidator, drain + +**Files:** +- Modify: `notification-worker/main.go` + +The existing `main.go` builds only the (now-deleted) Publisher and the basic handler. It needs a full wiring rewrite — connect Valkey, build the cache + lookup + emitter + presence source + hook, ensure the thread_subscriptions index, set up the eager-invalidation core-NATS subscription, and drain async-publish on shutdown. + +- [ ] **Step 1: Confirm the existing main builds against the new handler** + +After Task 11, `main.go` still references the old `NewHandler(memberLookup, publisher)` signature; this step fixes the build so the tests run. + +Run: `go build ./notification-worker/...` +Expected: FAIL — wrong call shape. + +- [ ] **Step 2: Replace `main.go`** with the wired version + +```go +package main + +import ( + "context" + "encoding/json" + "fmt" + "log/slog" + "os" + "sync" + "time" + + "github.com/caarlos0/env/v11" + "github.com/nats-io/nats.go" + "github.com/nats-io/nats.go/jetstream" + "go.mongodb.org/mongo-driver/v2/bson" + "go.mongodb.org/mongo-driver/v2/mongo" + "go.mongodb.org/mongo-driver/v2/mongo/options" + + "github.com/Marz32onE/instrumentation-go/otel-nats/oteljetstream" + + "github.com/hmchangw/chat/pkg/mongoutil" + "github.com/hmchangw/chat/pkg/natsutil" + "github.com/hmchangw/chat/pkg/otelutil" + "github.com/hmchangw/chat/pkg/roomsubcache" + "github.com/hmchangw/chat/pkg/shutdown" + "github.com/hmchangw/chat/pkg/stream" + "github.com/hmchangw/chat/pkg/subject" + "github.com/hmchangw/chat/pkg/valkeyutil" +) + +type config struct { + NatsURL string `env:"NATS_URL" envDefault:"nats://localhost:4222"` + NatsCredsFile string `env:"NATS_CREDS_FILE" envDefault:""` + SiteID string `env:"SITE_ID" envDefault:"default"` + MongoURI string `env:"MONGO_URI" envDefault:"mongodb://localhost:27017"` + MongoDB string `env:"MONGO_DB" envDefault:"chat"` + MongoUsername string `env:"MONGO_USERNAME" envDefault:""` + MongoPassword string `env:"MONGO_PASSWORD" envDefault:""` + MaxWorkers int `env:"MAX_WORKERS" envDefault:"100"` + LargeRoomThreshold int `env:"LARGE_ROOM_THRESHOLD" envDefault:"500"` + ValkeyAddrs []string `env:"VALKEY_ADDRS" envSeparator:","` + ValkeyPassword string `env:"VALKEY_PASSWORD" envDefault:""` + RoomSubCacheTTL time.Duration `env:"ROOMSUBCACHE_TTL" envDefault:"5m"` + L1MemberCacheSize int `env:"L1_MEMBER_CACHE_SIZE" envDefault:"1000"` + L1MemberCacheTTL time.Duration `env:"L1_MEMBER_CACHE_TTL" envDefault:"5s"` + PresenceBatchSize int `env:"PRESENCE_BATCH_SIZE" envDefault:"512"` + PresenceRPCTimeout time.Duration `env:"PRESENCE_RPC_TIMEOUT" envDefault:"2s"` + PresenceEnabled bool `env:"PRESENCE_RPC_ENABLED" envDefault:"false"` // false → noop while presence-service PR lands + Consumer stream.ConsumerSettings `envPrefix:"CONSUMER_"` + Bootstrap bootstrapConfig `envPrefix:"BOOTSTRAP_"` +} + +// mongoMemberLoader implements the memberLoader closure: it reads +// subscriptions for the room and projects to roomsubcache.Member. +type mongoMemberLoader struct { + col *mongo.Collection +} + +func (m *mongoMemberLoader) Load(ctx context.Context, roomID string) ([]roomsubcache.Member, error) { + projection := bson.M{ + "u._id": 1, + "u.account": 1, + "u.isBot": 1, + "u.chineseName": 1, + "u.engName": 1, + "muted": 1, + "historySharedSince": 1, + } + cur, err := m.col.Find(ctx, bson.M{"roomId": roomID}, options.Find().SetProjection(projection)) + if err != nil { + return nil, fmt.Errorf("find subscriptions for %s: %w", roomID, err) + } + defer cur.Close(ctx) + + var out []roomsubcache.Member + for cur.Next(ctx) { + var doc struct { + User struct { + ID string `bson:"_id"` + Account string `bson:"account"` + IsBot bool `bson:"isBot"` + ChineseName string `bson:"chineseName"` + EngName string `bson:"engName"` + } `bson:"u"` + Muted bool `bson:"muted"` + HistorySharedSince *time.Time `bson:"historySharedSince"` + } + if err := cur.Decode(&doc); err != nil { + return nil, fmt.Errorf("decode subscription: %w", err) + } + var hssMs *int64 + if doc.HistorySharedSince != nil { + ms := doc.HistorySharedSince.UnixMilli() + hssMs = &ms + } + out = append(out, roomsubcache.Member{ + ID: doc.User.ID, + Account: doc.User.Account, + IsBot: doc.User.IsBot, + ChineseName: doc.User.ChineseName, + EngName: doc.User.EngName, + Muted: doc.Muted, + HistorySharedSince: hssMs, + }) + } + if err := cur.Err(); err != nil { + return nil, fmt.Errorf("iterate subscriptions: %w", err) + } + return out, nil +} + +func main() { + slog.SetDefault(slog.New(slog.NewJSONHandler(os.Stdout, nil))) + + cfg, err := env.ParseAs[config]() + if err != nil { + slog.Error("parse config", "error", err) + os.Exit(1) + } + if len(cfg.ValkeyAddrs) == 0 { + slog.Error("VALKEY_ADDRS required") + os.Exit(1) + } + + ctx := context.Background() + + tracerShutdown, err := otelutil.InitTracer(ctx, "notification-worker") + if err != nil { + slog.Error("init tracer failed", "error", err) + os.Exit(1) + } + + mongoClient, err := mongoutil.Connect(ctx, cfg.MongoURI, cfg.MongoUsername, cfg.MongoPassword) + if err != nil { + slog.Error("mongo connect failed", "error", err) + os.Exit(1) + } + db := mongoClient.Database(cfg.MongoDB) + subCol := db.Collection("subscriptions") + threadSubCol := db.Collection("thread_subscriptions") + + ensureCtx, cancel := context.WithTimeout(ctx, 10*time.Second) + if err := EnsureThreadSubscriptionIndex(ensureCtx, threadSubCol); err != nil { + cancel() + slog.Error("ensure thread_subscriptions index", "error", err) + os.Exit(1) + } + cancel() + + valkeyClient, err := valkeyutil.ConnectCluster(ctx, cfg.ValkeyAddrs, cfg.ValkeyPassword) + if err != nil { + slog.Error("valkey connect failed", "error", err) + os.Exit(1) + } + cache := roomsubcache.NewValkeyCache(valkeyClient) + loader := &mongoMemberLoader{col: subCol} + memberLookup := newCachedMemberLookup(cache, loader.Load, cfg.RoomSubCacheTTL, + cfg.L1MemberCacheSize, cfg.L1MemberCacheTTL) + + nc, err := natsutil.Connect(cfg.NatsURL, cfg.NatsCredsFile) + if err != nil { + slog.Error("nats connect failed", "error", err) + os.Exit(1) + } + + otelJS, err := oteljetstream.New(nc) + if err != nil { + slog.Error("jetstream init failed", "error", err) + os.Exit(1) + } + // Raw jetstream.JetStream for async publish (oteljetstream is sync-only). + rawJS, err := jetstream.New(nc.NatsConn()) + if err != nil { + slog.Error("raw jetstream init failed", "error", err) + os.Exit(1) + } + + if err := bootstrapStreams(ctx, otelJS, cfg.SiteID, cfg.Bootstrap.Enabled); err != nil { + slog.Error("bootstrap streams failed", "error", err) + os.Exit(1) + } + + canonicalCfg := stream.MessagesCanonical(cfg.SiteID) + cons, err := otelJS.CreateOrUpdateConsumer(ctx, canonicalCfg.Name, buildConsumerConfig(cfg.Consumer)) + if err != nil { + slog.Error("create consumer failed", "error", err) + os.Exit(1) + } + + asyncPub := newJSAsyncPublisher(rawJS) + emitter := newMobileEmitter(asyncPub, cfg.SiteID) + + var presence PresenceSource = noopPresenceSource{} + if cfg.PresenceEnabled { + presence = newBulkPresenceSource( + &natsPresenceRequester{nc: nc.NatsConn()}, + cfg.SiteID, + cfg.PresenceBatchSize, + cfg.PresenceRPCTimeout, + ) + } + + handler := NewHandler(HandlerDeps{ + Members: memberLookup, + Followers: newMongoThreadFollowers(threadSubCol), + Presence: presence, + Hook: noopHook{}, + Emitter: emitter, + LargeRoomThreshold: cfg.LargeRoomThreshold, + }) + + // --- Eager cache invalidation on subscription.update fan-out --- + // Two payload shapes exist on this subject: + // - SubscriptionUpdateEvent (added/role_updated/mute_toggled) carries + // a full Subscription. + // - SubscriptionRemovedEvent (removed) carries the lean + // RemovedSubscriptionRef. + // Both shapes include a top-level "subscription" object with "roomId" — + // decoding into a minimal envelope sidesteps the type branching. + invalSub, err := nc.NatsConn().Subscribe(subject.SubscriptionUpdateWildcard(), func(msg *nats.Msg) { + var env struct { + Subscription struct { + RoomID string `json:"roomId"` + } `json:"subscription"` + } + if err := json.Unmarshal(msg.Data, &env); err != nil { + slog.Warn("subscription.update decode failed", "error", err) + return + } + if env.Subscription.RoomID == "" { + return + } + memberLookup.Invalidate(context.Background(), env.Subscription.RoomID) + }) + if err != nil { + slog.Error("subscribe subscription.update failed", "error", err) + os.Exit(1) + } + + iter, err := cons.Messages(jetstream.PullMaxMessages(2 * cfg.MaxWorkers)) + if err != nil { + slog.Error("messages failed", "error", err) + os.Exit(1) + } + + sem := make(chan struct{}, cfg.MaxWorkers) + var wg sync.WaitGroup + + go func() { + for { + msgCtx, msg, err := iter.Next() + if err != nil { + return + } + sem <- struct{}{} + wg.Add(1) + go func() { + defer func() { + <-sem + wg.Done() + }() + handlerCtx := natsutil.ContextWithRequestIDFromHeaders(msgCtx, msg.Headers()) + if err := handler.HandleMessage(handlerCtx, msg.Data()); err != nil { + slog.Error("handle message failed", "error", err, "request_id", natsutil.RequestIDFromContext(handlerCtx)) + if err := msg.Nak(); err != nil { + slog.Error("failed to nak message", "error", err) + } + return + } + if err := msg.Ack(); err != nil { + slog.Error("failed to ack message", "error", err) + } + }() + } + }() + + slog.Info("notification-worker started", + "site", cfg.SiteID, + "large_room_threshold", cfg.LargeRoomThreshold, + "valkey_addrs", cfg.ValkeyAddrs, + "presence_enabled", cfg.PresenceEnabled, + ) + + shutdown.Wait(ctx, 25*time.Second, + func(ctx context.Context) error { + iter.Stop() + return nil + }, + func(ctx context.Context) error { + done := make(chan struct{}) + go func() { wg.Wait(); close(done) }() + select { + case <-done: + return nil + case <-ctx.Done(): + return fmt.Errorf("worker drain timed out: %w", ctx.Err()) + } + }, + func(ctx context.Context) error { + asyncPub.drain(ctx) + return nil + }, + func(_ context.Context) error { + if invalSub != nil { + return invalSub.Unsubscribe() + } + return nil + }, + func(ctx context.Context) error { return tracerShutdown(ctx) }, + func(_ context.Context) error { return nc.Drain() }, + func(ctx context.Context) error { mongoutil.Disconnect(ctx, mongoClient); return nil }, + func(_ context.Context) error { valkeyutil.Disconnect(valkeyClient); return nil }, + ) +} + +// buildConsumerConfig returns the durable consumer config for +// notification-worker. Centralised so it is unit-testable without NATS. +func buildConsumerConfig(s stream.ConsumerSettings) jetstream.ConsumerConfig { + cc := stream.DurableConsumerDefaults(s) + cc.Durable = "notification-worker" + return cc +} +``` + +Note: the import list above already includes `encoding/json` for the subscription.update envelope decode and drops `pkg/model` from `main.go` (the file no longer references it directly). + +- [ ] **Step 3: Build** + +Run: `go build ./notification-worker/...` +Expected: PASS. + +- [ ] **Step 4: Run all notification-worker tests** + +Run: `make test SERVICE=notification-worker && make lint` +Expected: PASS. + +- [ ] **Step 5: Commit** + +```bash +git add notification-worker/main.go +git commit -m "feat(notification-worker): wire valkey cache, async publisher, presence, and cache invalidation" +``` + +--- + +## Task 13: Integration test — end-to-end against Valkey + Mongo + NATS + +**Files:** +- Modify: `notification-worker/integration_test.go` + +- [ ] **Step 1: Rewrite the integration test** to exercise the cache + emit path + +Replace the file with: + +```go +//go:build integration + +package main + +import ( + "context" + "encoding/json" + "sync" + "testing" + "time" + + "github.com/nats-io/nats.go" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.mongodb.org/mongo-driver/v2/mongo" + + "github.com/hmchangw/chat/pkg/model" + "github.com/hmchangw/chat/pkg/roomsubcache" + "github.com/hmchangw/chat/pkg/subject" + "github.com/hmchangw/chat/pkg/testutil" + "github.com/hmchangw/chat/pkg/valkeyutil" +) + +func TestMain(m *testing.M) { testutil.RunTests(m) } + +func TestNotificationWorker_CacheBackedFanOut(t *testing.T) { + db := testutil.MongoDB(t, "notification_worker_test") + valkeyClient := testutil.SharedValkeyCluster(t) + t.Cleanup(func() { testutil.FlushValkey(t) }) + natsURL := testutil.NATS(t) + + ctx := context.Background() + subCol := db.Collection("subscriptions") + threadCol := db.Collection("thread_subscriptions") + require.NoError(t, EnsureThreadSubscriptionIndex(ctx, threadCol)) + + seedSubscriptions(t, ctx, subCol) + + cache := roomsubcache.NewValkeyCache(valkeyutil.WrapClusterClient(valkeyClient)) + loader := &mongoMemberLoader{col: subCol} + lookup := newCachedMemberLookup(cache, loader.Load, time.Minute, 100, 5*time.Second) + + nc, err := nats.Connect(natsURL) + require.NoError(t, err) + t.Cleanup(func() { _ = nc.Drain() }) + + // Capture pushes off the NATS bus (subscribe before publishing). + pushSub := subscribePush(t, nc, "site-a") + + emitter := newMobileEmitter(&directNATSAsyncPub{nc: nc}, "site-a") + handler := NewHandler(HandlerDeps{ + Members: lookup, + Followers: newMongoThreadFollowers(threadCol), + Presence: noopPresenceSource{}, + Hook: noopHook{}, + Emitter: emitter, + LargeRoomThreshold: 500, + }) + + evt := model.MessageEvent{ + SiteID: "site-a", + Message: model.Message{ + ID: "m1", + RoomID: "r1", + UserID: "alice", + UserAccount: "alice", + Content: "hello", + CreatedAt: time.Now(), + }, + } + data, _ := json.Marshal(evt) + require.NoError(t, handler.HandleMessage(ctx, data)) + + got := pushSub.collect(t, 2*time.Second, 2) + assert.ElementsMatch(t, []string{"bob", "carol"}, got) +} + +func seedSubscriptions(t *testing.T, ctx context.Context, col *mongo.Collection) { + t.Helper() + _, err := col.InsertMany(ctx, []any{ + model.Subscription{ID: "s1", RoomID: "r1", User: model.SubscriptionUser{ID: "alice", Account: "alice"}}, + model.Subscription{ID: "s2", RoomID: "r1", User: model.SubscriptionUser{ID: "bob", Account: "bob"}}, + model.Subscription{ID: "s3", RoomID: "r1", User: model.SubscriptionUser{ID: "carol", Account: "carol"}}, + }) + require.NoError(t, err) +} + +type pushCollector struct { + mu sync.Mutex + gotAcct []string + got chan struct{} +} + +func subscribePush(t *testing.T, nc *nats.Conn, siteID string) *pushCollector { + t.Helper() + c := &pushCollector{got: make(chan struct{}, 256)} + sub, err := nc.Subscribe(subject.PushNotification(siteID), func(msg *nats.Msg) { + var evt model.PushNotificationEvent + if err := json.Unmarshal(msg.Data, &evt); err != nil { + t.Logf("decode push: %v", err) + return + } + c.mu.Lock() + c.gotAcct = append(c.gotAcct, evt.Account) + c.mu.Unlock() + c.got <- struct{}{} + }) + require.NoError(t, err) + t.Cleanup(func() { _ = sub.Unsubscribe() }) + return c +} + +func (c *pushCollector) collect(t *testing.T, timeout time.Duration, want int) []string { + t.Helper() + deadline := time.After(timeout) + for { + c.mu.Lock() + if len(c.gotAcct) >= want { + out := append([]string(nil), c.gotAcct...) + c.mu.Unlock() + return out + } + c.mu.Unlock() + select { + case <-c.got: + case <-deadline: + c.mu.Lock() + defer c.mu.Unlock() + t.Fatalf("collect timeout: got %v want %d", c.gotAcct, want) + return nil + } + } +} + +// directNATSAsyncPub bypasses JetStream — integration test publishes on +// core NATS so we can observe the subject without standing up the +// PUSH_NOTIFICATIONS stream. The production wiring uses jsAsyncPublisher. +type directNATSAsyncPub struct{ nc *nats.Conn } + +func (d *directNATSAsyncPub) PublishMsgAsync(msg *nats.Msg) error { return d.nc.PublishMsg(msg) } +func (d *directNATSAsyncPub) drain(context.Context) {} +``` + +- [ ] **Step 2: Run the integration tests** + +Run: `make test-integration SERVICE=notification-worker` +Expected: PASS. + +- [ ] **Step 3: Commit** + +```bash +git add notification-worker/integration_test.go +git commit -m "test(notification-worker): integration coverage for cache-backed mobile push fan-out" +``` + +--- + +## Task 14: docker-compose — add Valkey + new env vars + +**Files:** +- Modify: `notification-worker/deploy/docker-compose.yml` + +- [ ] **Step 1: Update the compose file** to depend on the shared local Valkey + thread the new env vars + +```yaml +name: notification-worker + +services: + notification-worker: + build: + context: ../.. + dockerfile: notification-worker/deploy/Dockerfile + environment: + - NATS_URL=nats://nats:4222 + - NATS_CREDS_FILE=/etc/nats/backend.creds + - SITE_ID=site-local + - MONGO_URI=mongodb://mongodb:27017 + - MONGO_DB=chat + - VALKEY_ADDRS=valkey:6379 + - ROOMSUBCACHE_TTL=5m + - L1_MEMBER_CACHE_SIZE=1000 + - L1_MEMBER_CACHE_TTL=5s + - LARGE_ROOM_THRESHOLD=500 + - PRESENCE_RPC_ENABLED=false + - BOOTSTRAP_STREAMS=true + volumes: + - ../../docker-local/backend.creds:/etc/nats/backend.creds:ro + networks: + - chat-local + +networks: + chat-local: + external: true +``` + +- [ ] **Step 2: Commit** + +```bash +git add notification-worker/deploy/docker-compose.yml +git commit -m "chore(notification-worker): wire valkey + cache env in local compose" +``` + +--- + +## Task 15: Update `docs/client-api.md` + +**Files:** +- Modify: `docs/client-api.md` + +- [ ] **Step 1: Locate the legacy `notification` event description** + +Run: `grep -n "notification" docs/client-api.md | head -20` + +- [ ] **Step 2: Replace its description** + +Edit the section that documents `chat.user.{account}.notification` to reflect the new behaviour: + +```markdown +### Notification fan-out (mobile push only) + +`notification-worker` no longer publishes `chat.user.{account}.notification` +on core NATS. Mobile pushes are emitted on the server-only JetStream subject +`chat.server.notification.push.{siteID}.send` and forwarded by the internal +push-notification service. Desktop banners are computed client-side from the +broadcast-worker room-event stream — no server-side desktop publish exists. + +The worker filters recipients per message: + +- Skips the sender. +- Skips members with `muted: true` on their subscription. +- Skips members whose `historySharedSince` postdates the message (for a + thread-only reply the parent's `createdAt` is used instead). +- For a thread reply with `tshow: false`, skips non-followers who are not + mentioned. +- In rooms with more than `LARGE_ROOM_THRESHOLD` members (default 500), + pushes only to mentioned recipients (`@user`, `@all`, `@here`). +- Bots never receive a mobile push. +- Presence-busy / in-call recipients are not pushed; everyone else + (online, offline, away, missing) receives one. +``` + +- [ ] **Step 3: Commit** + +```bash +git add docs/client-api.md +git commit -m "docs(client-api): document new notification-worker routing rules" +``` + +--- + +## Task 16: Repository-wide validation + +- [ ] **Step 1: Verify build** + +Run: `go build ./...` +Expected: PASS. + +- [ ] **Step 2: Verify lint clean** + +Run: `make lint` +Expected: PASS. + +- [ ] **Step 3: Verify unit tests pass with race** + +Run: `make test` +Expected: PASS. + +- [ ] **Step 4: Verify integration tests** + +Run: `make test-integration SERVICE=notification-worker` +Expected: PASS. + +- [ ] **Step 5: Verify coverage ≥ 80%** + +Run: + +```bash +go test -race -coverprofile=cov.out ./notification-worker/... +go tool cover -func=cov.out | tail -1 +``` + +Expected: total coverage line ≥ 80%. If lower, add cases in `handler_test.go` (likely the under-covered paths are the hook error branch and the restricted-thread legacy-nil branch). + +- [ ] **Step 6: Verify SAST is clean** + +Run: `make sast` +Expected: PASS — no medium+ findings introduced by this change. + +- [ ] **Step 7: Commit any coverage fill-in tests separately** + +```bash +git add notification-worker/handler_test.go +git commit -m "test(notification-worker): cover hook-error and legacy nil parent-ts paths" +``` + +(Skip this commit if step 5 already passed without changes.) + +--- + +## Spec-coverage Self-Review (run after Task 16) + +| Spec requirement | Covered by | +|---|---| +| `roomsubcache.Member` projection extension | Task 1 | +| `@here` handling | Task 2 | +| `PushNotification`/`PresenceSnapshot`/`SubscriptionUpdateWildcard` subjects | Task 3 | +| `PushNotificationEvent`/`PushNotificationData` payload | Task 4 | +| `PresenceSnapshotRequest`/`Reply`/`Presence` types | Task 4 | +| Routing predicate (DM/mention/large-room/bot) | Task 5 | +| Hook interface + no-op default | Task 6 | +| `PresenceSource` interface, no-op, bulk RPC, chunking, status→push, fail-open | Task 7 | +| Cached member lookup + single-flight + L1 LRU | Task 8 | +| Thread-follower lookup by `parentMessageId` | Task 9 | +| `EnsureThreadSubscriptionIndex` | Task 9 / wired in Task 12 | +| Async mobile emitter + dedup `Nats-Msg-Id` `{messageId}-{account}` | Task 10 | +| Stage 1 exclusions: sender, mute, restricted, thread-non-follower | Task 11 | +| Stage 2 hook veto, fail-open on error | Task 11 | +| Stage 3 routing predicate call | Task 11 | +| Stage 4 presence snapshot + per-account `shouldPush` | Task 11 | +| Push payload Sender from member record | Task 11 | +| Restricted check uses parent `CreatedAt` for thread-only replies | Task 11 | +| Legacy nil parent-ts → treat as no access | Task 11 | +| Valkey wiring + new env vars | Task 12 | +| Raw JetStream for async publish | Task 12 | +| Eager cache invalidation on `subscription.update` | Task 12 | +| Async drain on graceful shutdown | Task 12 | +| Docker compose updates | Task 14 | +| `docs/client-api.md` update | Task 15 | + +**Deliberately out of scope (per spec):** highlight keywords, threadsubcache, encrypted-room pushes, PII audit, per-user rate limiting, presence-service implementation, push-service implementation. These are all logged as Future work in the spec. + +**Known approximations:** + +- `deriveRoomType` infers DM vs channel from member count (≤2 → DM). The + spec assumed `Member` would carry room type per-room; that field is not + yet in the projection. The approximation is safe: it sends pushes to + DMs (correct) and to small channels (correct), and only mis-routes a + hypothetical 2-member channel as a DM — at small scale, no push + difference. Threading room type through to the cache via a separate + `room:{id}:type` Valkey entry is a follow-up. +- `Title` on `PushNotificationEvent` is left empty in v1 (room name lives + off the message). The push service can render it from `Data.RoomID` or + the spec's follow-up `roommetacache` integration can fill it in. This + is captured as Future work in the spec; v1 sends what we have. + +--- + +**Plan complete and saved to `docs/superpowers/plans/2026-05-27-notification-worker-cache-and-mobile.md`.** + +Two execution options: + +1. **Subagent-Driven (recommended)** — I dispatch a fresh subagent per task, review between tasks, fast iteration. +2. **Inline Execution** — Execute tasks in this session using executing-plans, batch execution with checkpoints. + +Which approach? diff --git a/message-gatekeeper/deploy/docker-compose.yml b/message-gatekeeper/deploy/docker-compose.yml index 56394d57f..5a5a64d8f 100644 --- a/message-gatekeeper/deploy/docker-compose.yml +++ b/message-gatekeeper/deploy/docker-compose.yml @@ -12,6 +12,12 @@ services: - MONGO_URI=mongodb://mongodb:27017 - MONGO_DB=chat - CHAT_BASE_URL=http://localhost:3000 + # User cache fronts users-collection lookups so the per-message sender + # display-name composition stays off the hot Mongo path. In-process LRU + # (the pkg/userstore.Cache shared with broadcast-worker + message-worker); + # entries are tiny so per-pod memory stays in the low MB. + - USER_CACHE_SIZE=10000 + - USER_CACHE_TTL=5m - BOOTSTRAP_STREAMS=true volumes: - ../../docker-local/backend.creds:/etc/nats/backend.creds:ro diff --git a/message-gatekeeper/handler.go b/message-gatekeeper/handler.go index 0fdeace5f..ffe28b0b5 100644 --- a/message-gatekeeper/handler.go +++ b/message-gatekeeper/handler.go @@ -13,6 +13,7 @@ import ( "github.com/nats-io/nats.go" "github.com/nats-io/nats.go/jetstream" + "github.com/hmchangw/chat/pkg/displayfmt" "github.com/hmchangw/chat/pkg/errcode" "github.com/hmchangw/chat/pkg/errcode/errnats" "github.com/hmchangw/chat/pkg/idgen" @@ -30,10 +31,17 @@ type replyFunc func(ctx context.Context, msg *nats.Msg) error // publishFunc is the function signature for publishing to JetStream. type publishFunc func(ctx context.Context, msg *nats.Msg, opts ...jetstream.PublishOpt) (*jetstream.PubAck, error) +// UserGetter is the narrow user-record surface gatekeeper needs for sender +// display-name resolution. *userstore.Cache satisfies this; tests stub it. +type UserGetter interface { + FindUserByID(ctx context.Context, id string) (*model.User, error) +} + // Handler processes messages from the MESSAGES stream and validates them // before publishing to MESSAGES_CANONICAL. type Handler struct { store Store + users UserGetter publish publishFunc reply replyFunc siteID string @@ -42,9 +50,12 @@ type Handler struct { } // NewHandler constructs a new Handler with the given dependencies. -func NewHandler(store Store, publish publishFunc, reply replyFunc, siteID string, parentFetcher ParentMessageFetcher, largeRoomThreshold int) *Handler { +// users may be nil; when nil, sender display-name resolution is skipped and +// downstream consumers fall back to UserAccount. +func NewHandler(store Store, users UserGetter, publish publishFunc, reply replyFunc, siteID string, parentFetcher ParentMessageFetcher, largeRoomThreshold int) *Handler { return &Handler{ store: store, + users: users, publish: publish, reply: reply, siteID: siteID, @@ -168,6 +179,10 @@ func (h *Handler) processMessage(ctx context.Context, account, roomID, siteID st return nil, errcode.BadRequest(fmt.Sprintf("invalid requestId %q: must be a hyphenated UUID", req.RequestID)) } + // Payload requestId is the canonical source for X-Request-ID — upstream publishers may + // or may not set the NATS header, so overwrite ctx unconditionally before any downstream publish. + ctx = natsutil.WithRequestID(ctx, req.RequestID) + // Validate ID is a valid 20-char base62 message ID if !idgen.IsValidMessageID(req.ID) { return nil, errcode.BadRequest(fmt.Sprintf("invalid message ID %q: must be a 20-char base62 string", req.ID)) @@ -244,11 +259,29 @@ func (h *Handler) processMessage(ctx context.Context, account, roomID, siteID st return nil, err } + // Compose the sender's render-ready display name once at write time so every + // downstream consumer (notification-worker, future search-sync-worker) reads + // from the canonical message instead of doing its own user lookup. The lookup + // is best-effort — on miss/error we fall back to UserAccount via + // model.DisplayName's empty-fields branch; message validation already passed + // the sender check so missing display data does not warrant blocking the post. + displayName := sub.User.Account + if h.users != nil { + u, uerr := h.users.FindUserByID(ctx, sub.User.ID) + if uerr == nil && u != nil { + displayName = displayfmt.CombineWithFallback(u.EngName, u.ChineseName, sub.User.Account) + } else if uerr != nil { + slog.Warn("sender user-meta lookup failed, display name falls back to account", + "error", uerr, "userId", sub.User.ID, "account", sub.User.Account, "messageId", req.ID) + } + } + msg := model.Message{ ID: req.ID, RoomID: roomID, UserID: sub.User.ID, UserAccount: sub.User.Account, + UserDisplayName: displayName, Content: req.Content, CreatedAt: now, ThreadParentMessageID: req.ThreadParentMessageID, diff --git a/message-gatekeeper/handler_test.go b/message-gatekeeper/handler_test.go index e8dd0496b..31b2c47c9 100644 --- a/message-gatekeeper/handler_test.go +++ b/message-gatekeeper/handler_test.go @@ -757,7 +757,7 @@ func TestHandler_processMessage_RejectsInvalidThreadParentMessageID(t *testing.T return &jetstream.PubAck{}, nil } reply := func(ctx context.Context, msg *nats.Msg) error { return nil } - h := NewHandler(store, pub, reply, "site1", nil, 500) + h := NewHandler(store, nil, pub, reply, "site1", nil, 500) parentTs := int64(1000) req := model.SendMessageRequest{ @@ -787,15 +787,122 @@ func TestHandler_processMessage_PropagatesRequestIDOnCanonicalPublish(t *testing } reply := func(ctx context.Context, msg *nats.Msg) error { return nil } - h := NewHandler(store, pub, reply, "site1", nil, 500) + h := NewHandler(store, nil, pub, reply, "site1", nil, 500) - ctx := natsutil.WithRequestID(context.Background(), "req-mg-test-id") - req := model.SendMessageRequest{ID: idgen.GenerateMessageID(), Content: "hello", RequestID: "01970a4f-8c2d-7c9a-abcd-e0123456789f"} + // The JSON-payload requestId is the canonical source — it wins over any + // header-derived value already in ctx. Seed ctx with a stale "header" value + // to prove the bridge overwrites it with the payload value. + ctx := natsutil.WithRequestID(context.Background(), "stale-header-id") + const payloadReqID = "01970a4f-8c2d-7c9a-abcd-e0123456789f" + req := model.SendMessageRequest{ID: idgen.GenerateMessageID(), Content: "hello", RequestID: payloadReqID} _, err := h.processMessage(ctx, "alice", "room-1", "site1", &req) require.NoError(t, err) - require.NotNil(t, capturedHeader, "publish must propagate header from ctx") - assert.Equal(t, "req-mg-test-id", capturedHeader.Get(natsutil.RequestIDHeader)) + require.NotNil(t, capturedHeader, "publish must carry X-Request-ID header") + assert.Equal(t, payloadReqID, capturedHeader.Get(natsutil.RequestIDHeader), + "payload requestId must win over the value already in ctx") +} + +// Inbound MESSAGES stream messages from non-Go clients (and from loadgen) may +// not set X-Request-ID in the NATS header. The bridge inside processMessage +// pulls the requestId from the JSON payload into ctx unconditionally, so the +// canonical publish carries it downstream regardless of inbound header state. +func TestHandler_processMessage_BridgesPayloadRequestIDWhenCtxHasNone(t *testing.T) { + ctrl := gomock.NewController(t) + store := NewMockStore(ctrl) + store.EXPECT().GetSubscription(gomock.Any(), "alice", "room-1"). + Return(&model.Subscription{User: model.SubscriptionUser{ID: "u-alice", Account: "alice"}}, nil) + store.EXPECT().GetRoomMeta(gomock.Any(), "room-1"). + Return(roommetacache.Meta{ID: "room-1", UserCount: 1}, nil) + + var capturedHeader nats.Header + pub := func(ctx context.Context, msg *nats.Msg, opts ...jetstream.PublishOpt) (*jetstream.PubAck, error) { + capturedHeader = msg.Header + return &jetstream.PubAck{}, nil + } + reply := func(ctx context.Context, msg *nats.Msg) error { return nil } + + h := NewHandler(store, nil, pub, reply, "site1", nil, 500) + + const payloadReqID = "01970a4f-8c2d-7c9a-abcd-e0123456789f" + req := model.SendMessageRequest{ID: idgen.GenerateMessageID(), Content: "hello", RequestID: payloadReqID} + + // ctx has no request ID — simulates an inbound MESSAGES message with no X-Request-ID header. + _, err := h.processMessage(context.Background(), "alice", "room-1", "site1", &req) + require.NoError(t, err) + require.NotNil(t, capturedHeader, "publish must carry X-Request-ID header") + assert.Equal(t, payloadReqID, capturedHeader.Get(natsutil.RequestIDHeader)) +} + +// stubUserGetter is a minimal UserGetter for sender-display-name tests. +type stubUserGetter struct { + users map[string]*model.User + err error +} + +func (s *stubUserGetter) FindUserByID(_ context.Context, id string) (*model.User, error) { + if s.err != nil { + return nil, s.err + } + return s.users[id], nil +} + +func TestHandler_processMessage_PopulatesUserDisplayName(t *testing.T) { + ctrl := gomock.NewController(t) + store := NewMockStore(ctrl) + store.EXPECT().GetSubscription(gomock.Any(), "alice", "room-1"). + Return(&model.Subscription{User: model.SubscriptionUser{ID: "u-alice", Account: "alice"}}, nil) + store.EXPECT().GetRoomMeta(gomock.Any(), "room-1"). + Return(roommetacache.Meta{ID: "room-1", UserCount: 1}, nil) + + users := &stubUserGetter{users: map[string]*model.User{ + "u-alice": {ID: "u-alice", Account: "alice", EngName: "Alice Wang", ChineseName: "愛麗絲"}, + }} + + var captured publishedMsg + pub := func(_ context.Context, msg *nats.Msg, _ ...jetstream.PublishOpt) (*jetstream.PubAck, error) { + captured = publishedMsg{subject: msg.Subject, data: msg.Data} + return &jetstream.PubAck{}, nil + } + reply := func(_ context.Context, _ *nats.Msg) error { return nil } + h := NewHandler(store, users, pub, reply, "site1", nil, 500) + + req := model.SendMessageRequest{ID: idgen.GenerateMessageID(), Content: "hi", RequestID: "01970a4f-8c2d-7c9a-abcd-e0123456789f"} + _, err := h.processMessage(context.Background(), "alice", "room-1", "site1", &req) + require.NoError(t, err) + + var evt model.MessageEvent + require.NoError(t, json.Unmarshal(captured.data, &evt)) + assert.Equal(t, "Alice Wang 愛麗絲", evt.Message.UserDisplayName, + "gatekeeper must populate UserDisplayName via model.DisplayName(engName, chineseName, account)") +} + +func TestHandler_processMessage_FallsBackToAccountWhenUserLookupFails(t *testing.T) { + ctrl := gomock.NewController(t) + store := NewMockStore(ctrl) + store.EXPECT().GetSubscription(gomock.Any(), "alice", "room-1"). + Return(&model.Subscription{User: model.SubscriptionUser{ID: "u-alice", Account: "alice"}}, nil) + store.EXPECT().GetRoomMeta(gomock.Any(), "room-1"). + Return(roommetacache.Meta{ID: "room-1", UserCount: 1}, nil) + + users := &stubUserGetter{err: errors.New("mongo timeout")} + + var captured publishedMsg + pub := func(_ context.Context, msg *nats.Msg, _ ...jetstream.PublishOpt) (*jetstream.PubAck, error) { + captured = publishedMsg{subject: msg.Subject, data: msg.Data} + return &jetstream.PubAck{}, nil + } + reply := func(_ context.Context, _ *nats.Msg) error { return nil } + h := NewHandler(store, users, pub, reply, "site1", nil, 500) + + req := model.SendMessageRequest{ID: idgen.GenerateMessageID(), Content: "hi", RequestID: "01970a4f-8c2d-7c9a-abcd-e0123456789f"} + _, err := h.processMessage(context.Background(), "alice", "room-1", "site1", &req) + require.NoError(t, err, "user-meta lookup failure must not block message publish") + + var evt model.MessageEvent + require.NoError(t, json.Unmarshal(captured.data, &evt)) + assert.Equal(t, "alice", evt.Message.UserDisplayName, + "on lookup error, fall back to account so downstream still gets a usable display name") } func TestHandler_ProcessMessage_WithQuote(t *testing.T) { @@ -1221,7 +1328,7 @@ func TestHandler_sendReply(t *testing.T) { *captured = append(*captured, msg) return nil } - return NewHandler(nil, nil, reply, "site-a", nil, 500) + return NewHandler(nil, nil, nil, reply, "site-a", nil, 500) } mk := func(requestID string) *model.SendMessageRequest { @@ -1289,7 +1396,7 @@ func TestHandleJetStreamMsg_MalformedBody_Acks(t *testing.T) { captured = append(captured, m) return nil } - h := NewHandler(nil, nil, reply, "site-A", nil, 500) + h := NewHandler(nil, nil, nil, reply, "site-A", nil, 500) msg := &fakeJSMsg{ subject: "chat.user.alice.room.r1.site-A.msg.send", @@ -1304,7 +1411,7 @@ func TestHandleJetStreamMsg_MalformedBody_Acks(t *testing.T) { // Invalid subject Acks (not retryable) and sends a best-effort reply. func TestHandleJetStreamMsg_InvalidSubject_Acks(t *testing.T) { - h := NewHandler(nil, nil, func(context.Context, *nats.Msg) error { return nil }, "site-A", nil, 500) + h := NewHandler(nil, nil, nil, func(context.Context, *nats.Msg) error { return nil }, "site-A", nil, 500) msg := &fakeJSMsg{ subject: "chat.garbage", data: []byte(`{}`), diff --git a/message-gatekeeper/integration_test.go b/message-gatekeeper/integration_test.go new file mode 100644 index 000000000..8ea94145b --- /dev/null +++ b/message-gatekeeper/integration_test.go @@ -0,0 +1,100 @@ +//go:build integration + +package main + +import ( + "context" + "encoding/json" + "testing" + "time" + + "github.com/nats-io/nats.go" + "github.com/nats-io/nats.go/jetstream" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.mongodb.org/mongo-driver/v2/mongo" + + "github.com/hmchangw/chat/pkg/idgen" + "github.com/hmchangw/chat/pkg/model" + "github.com/hmchangw/chat/pkg/testutil" + "github.com/hmchangw/chat/pkg/userstore" +) + +// TestProcessMessage_PopulatesDisplayName_Integration walks the wiring that +// gatekeeper's main.go does in prod against real Mongo: +// +// users coll → userstore.NewMongoStore → userstore.NewCache → Handler +// +// and asserts the published canonical event has UserDisplayName composed via +// displayfmt.CombineWithFallback. +// +// Scope is deliberately narrow: it proves the Mongo wiring (collection name, +// BSON field tags, projection drift) and the end-to-end composition. The +// composition *rule itself* (eng+zh dedupe, account fallback) is exhaustively +// covered by pkg/displayfmt/combine_test.go and Handler unit tests; re-testing +// every variant here would just duplicate that coverage with slow tests. +func TestProcessMessage_PopulatesDisplayName_Integration(t *testing.T) { + db := testutil.MongoDB(t, "message_gatekeeper_test") + ctx := context.Background() + + user := model.User{ID: "u-alice", Account: "alice", EngName: "Alice Wang", ChineseName: "愛麗絲"} + seedUserAndSubscription(t, ctx, db, user, "r1") + + handler, getCaptured := buildHandlerWithCapture(t, db) + + req := model.SendMessageRequest{ + ID: idgen.GenerateMessageID(), + Content: "hello", + RequestID: "01970a4f-8c2d-7c9a-abcd-e0123456789f", + } + + _, perr := handler.processMessage(ctx, user.Account, "r1", "site-a", &req) + require.NoError(t, perr) + + captured := getCaptured() + require.NotNil(t, captured, "canonical event was never published") + var evt model.MessageEvent + require.NoError(t, json.Unmarshal(captured.Data, &evt)) + assert.Equal(t, "Alice Wang 愛麗絲", evt.Message.UserDisplayName, + "gatekeeper must compose display name via displayfmt.CombineWithFallback") + assert.Equal(t, user.Account, evt.Message.UserAccount) + assert.Equal(t, user.ID, evt.Message.UserID) +} + +// seedUserAndSubscription inserts the minimal docs gatekeeper needs to accept a +// message from u into room roomID: a users row (read by userstore.Cache for +// display-name composition), a subscription so GetSubscription returns it, and +// a room doc so room-meta lookup succeeds. +func seedUserAndSubscription(t *testing.T, ctx context.Context, db *mongo.Database, u model.User, roomID string) { + t.Helper() + _, err := db.Collection("users").InsertOne(ctx, u) + require.NoError(t, err) + _, err = db.Collection("subscriptions").InsertOne(ctx, model.Subscription{ + ID: "sub-" + u.ID, RoomID: roomID, SiteID: "site-a", + User: model.SubscriptionUser{ID: u.ID, Account: u.Account}, + }) + require.NoError(t, err) + _, err = db.Collection("rooms").InsertOne(ctx, model.Room{ + ID: roomID, Name: "general", Type: model.RoomTypeChannel, + SiteID: "site-a", UserCount: 1, + }) + require.NoError(t, err) +} + +// buildHandlerWithCapture mirrors main.go's gatekeeper wiring against the test +// Mongo and returns the Handler plus a getter that yields the canonical event +// published by processMessage (nil until the publish fires). +func buildHandlerWithCapture(t *testing.T, db *mongo.Database) (*Handler, func() *nats.Msg) { + t.Helper() + users, err := userstore.NewCache(userstore.NewMongoStore(db.Collection("users")), 100, time.Minute) + require.NoError(t, err) + + var captured *nats.Msg + pub := func(_ context.Context, msg *nats.Msg, _ ...jetstream.PublishOpt) (*jetstream.PubAck, error) { + captured = msg + return &jetstream.PubAck{}, nil + } + reply := func(_ context.Context, _ *nats.Msg) error { return nil } + return NewHandler(NewMongoStore(db), users, pub, reply, "site-a", nil, 500), + func() *nats.Msg { return captured } +} diff --git a/message-gatekeeper/main.go b/message-gatekeeper/main.go index 713793224..54b08a924 100644 --- a/message-gatekeeper/main.go +++ b/message-gatekeeper/main.go @@ -19,6 +19,7 @@ import ( "github.com/hmchangw/chat/pkg/otelutil" "github.com/hmchangw/chat/pkg/shutdown" "github.com/hmchangw/chat/pkg/stream" + "github.com/hmchangw/chat/pkg/userstore" ) type config struct { @@ -36,6 +37,8 @@ type config struct { SubCacheTTL time.Duration `env:"GATEKEEPER_SUB_CACHE_TTL" envDefault:"2m"` RoomMetaCacheSize int `env:"ROOM_META_CACHE_SIZE" envDefault:"10000"` RoomMetaCacheTTL time.Duration `env:"ROOM_META_CACHE_TTL" envDefault:"2m"` + UserCacheSize int `env:"USER_CACHE_SIZE" envDefault:"10000"` + UserCacheTTL time.Duration `env:"USER_CACHE_TTL" envDefault:"5m"` Consumer stream.ConsumerSettings `envPrefix:"CONSUMER_"` Bootstrap bootstrapConfig `envPrefix:"BOOTSTRAP_"` } @@ -86,9 +89,16 @@ func main() { slog.Error("init subscription cache failed", "error", err) os.Exit(1) } + users, err := userstore.NewCache(userstore.NewMongoStore(db.Collection("users")), + cfg.UserCacheSize, cfg.UserCacheTTL) + if err != nil { + slog.Error("init user meta cache failed", "error", err) + os.Exit(1) + } slog.Info("gatekeeper caches enabled", "sub_cache_size", cfg.SubCacheSize, "sub_cache_ttl", cfg.SubCacheTTL, "room_meta_cache_size", cfg.RoomMetaCacheSize, "room_meta_cache_ttl", cfg.RoomMetaCacheTTL, + "user_cache_size", cfg.UserCacheSize, "user_cache_ttl", cfg.UserCacheTTL, ) pub := func(ctx context.Context, msg *nats.Msg, opts ...jetstream.PublishOpt) (*jetstream.PubAck, error) { ack, err := js.PublishMsg(ctx, msg, opts...) @@ -104,7 +114,7 @@ func main() { return nil } parentFetcher := newHistoryParentFetcher(nc, cfg.ChatBaseURL) - handler := NewHandler(store, pub, reply, cfg.SiteID, parentFetcher, cfg.LargeRoomThreshold) + handler := NewHandler(store, users, pub, reply, cfg.SiteID, parentFetcher, cfg.LargeRoomThreshold) if err := bootstrapStreams(ctx, js, cfg.SiteID, cfg.Bootstrap.Enabled); err != nil { slog.Error("bootstrap streams failed", "error", err) diff --git a/message-gatekeeper/main_test.go b/message-gatekeeper/main_test.go new file mode 100644 index 000000000..937f8531a --- /dev/null +++ b/message-gatekeeper/main_test.go @@ -0,0 +1,11 @@ +//go:build integration + +package main + +import ( + "testing" + + "github.com/hmchangw/chat/pkg/testutil" +) + +func TestMain(m *testing.M) { testutil.RunTests(m) } diff --git a/message-worker/deploy/docker-compose.yml b/message-worker/deploy/docker-compose.yml index f8bf25deb..07b4b0f56 100644 --- a/message-worker/deploy/docker-compose.yml +++ b/message-worker/deploy/docker-compose.yml @@ -50,6 +50,10 @@ services: - CASSANDRA_KEYSPACE=chat - MONGO_URI=mongodb://mongodb:27017 - MONGO_DB=chat + # In-process user cache (pkg/userstore.Cache, shared with message-gatekeeper + # and broadcast-worker). Defaults: 10000 entries, 5m TTL. + - USER_CACHE_SIZE=10000 + - USER_CACHE_TTL=5m - BOOTSTRAP_STREAMS=true - ATREST_ENABLED=true - VAULT_ADDR=http://vault:8200 diff --git a/message-worker/handler.go b/message-worker/handler.go index 8753d909d..c8e339e5d 100644 --- a/message-worker/handler.go +++ b/message-worker/handler.go @@ -171,6 +171,12 @@ func (h *Handler) handleFirstThreadReply(ctx context.Context, msg *model.Message if err := h.threadStore.InsertThreadSubscription(ctx, parentSub); err != nil { return fmt.Errorf("insert parent author thread subscription: %w", err) } + // Parent author joins the thread's replyAccounts set so they appear as a + // follower in notification-worker and history-service's "following" feed, + // even before they reply themselves. $addToSet dedups against the replier seed. + if err := h.threadStore.AddReplyAccounts(ctx, threadRoomID, []string{parentSender.Account}); err != nil { + return fmt.Errorf("add parent author to thread room replyAccounts: %w", err) + } // Outbox publish is gated on parentOwnerSite — if the parent user is missing // from userStore, we can't route the cross-site copy, but the local Insert // above is independent of that and still happens. @@ -265,7 +271,15 @@ func (h *Handler) handleSubsequentThreadReply(ctx context.Context, msg *model.Me return "", fmt.Errorf("get parent message sender: %w", err) } - if err := h.threadStore.UpdateThreadRoomLastMessage(ctx, existingRoom.ID, msg.ID, msg.UserAccount, now); err != nil { + // Update lastMsg pointer AND merge replier + parent author into replyAccounts in one write. + // Folding the parent-author $addToSet here (vs a separate AddReplyAccounts call) halves the + // per-reply Mongo round-trips and also covers the migration for thread_rooms created before + // the parent author was seeded. + replyAccounts := []string{msg.UserAccount} + if parentFound { + replyAccounts = append(replyAccounts, parentSender.Account) + } + if err := h.threadStore.UpdateThreadRoomLastMessage(ctx, existingRoom.ID, msg.ID, replyAccounts, now); err != nil { return "", fmt.Errorf("update thread room last message: %w", err) } @@ -337,11 +351,14 @@ func (h *Handler) buildThreadSubscription(msg *model.Message, threadRoomID, user } // markThreadMentions flips hasMention=true on the thread subscription of every -// @account mentionee in msg (auto-creating the subscription if absent). The -// sender is excluded, and @all is ignored at the thread level. Subscription.SiteID -// is the room's site (eventSiteID); the mentionee's home site (Participant.SiteID) -// is used only for the cross-site outbox routing. +// @account mentionee in msg (auto-creating the subscription if absent), and +// also adds them to thread_rooms.replyAccounts so they appear as thread followers +// for notification fan-out and the "following threads" feed. The sender is +// excluded and @all is ignored at the thread level. Subscription.SiteID is the +// room's site (eventSiteID); the mentionee's home site (Participant.SiteID) is +// used only for the cross-site outbox routing. func (h *Handler) markThreadMentions(ctx context.Context, msg *model.Message, threadRoomID, eventSiteID string) error { + var mentionedAccounts []string for i := range msg.Mentions { p := &msg.Mentions[i] if p.Account == "all" { @@ -358,6 +375,12 @@ func (h *Handler) markThreadMentions(ctx context.Context, msg *model.Message, th if err := h.publishThreadSubOutboxIfRemote(ctx, sub, p.SiteID, msg.ID); err != nil { return fmt.Errorf("publish thread mention outbox for user %s: %w", p.UserID, err) } + mentionedAccounts = append(mentionedAccounts, p.Account) + } + if len(mentionedAccounts) > 0 { + if err := h.threadStore.AddReplyAccounts(ctx, threadRoomID, mentionedAccounts); err != nil { + return fmt.Errorf("add mentioned accounts to thread room replyAccounts: %w", err) + } } return nil } diff --git a/message-worker/handler_test.go b/message-worker/handler_test.go index 48f9bad04..90e95e1e7 100644 --- a/message-worker/handler_test.go +++ b/message-worker/handler_test.go @@ -224,7 +224,7 @@ func TestHandler_ProcessMessage(t *testing.T) { Return(&model.User{ID: "u-parent", Account: "parent-user", SiteID: "site-a"}, nil) ts.EXPECT().UpsertThreadSubscription(gomock.Any(), gomock.Any()).Return(nil) ts.EXPECT().UpsertThreadSubscription(gomock.Any(), gomock.Any()).Return(nil) - ts.EXPECT().UpdateThreadRoomLastMessage(gomock.Any(), "tr-1", "msg-2", "alice", now).Return(nil) + ts.EXPECT().UpdateThreadRoomLastMessage(gomock.Any(), "tr-1", "msg-2", gomock.Any(), now).Return(nil) // SaveThreadMessage receives the resolved threadRoomID. store.EXPECT().SaveThreadMessage(gomock.Any(), &threadMsg, &expectedSender, "site-a", "tr-1").Return(nil) }, @@ -244,7 +244,7 @@ func TestHandler_ProcessMessage(t *testing.T) { Return(&model.User{ID: "u-parent", Account: "parent-user", SiteID: "site-a"}, nil) ts.EXPECT().UpsertThreadSubscription(gomock.Any(), gomock.Any()).Return(nil) ts.EXPECT().UpsertThreadSubscription(gomock.Any(), gomock.Any()).Return(nil) - ts.EXPECT().UpdateThreadRoomLastMessage(gomock.Any(), "tr-1", "msg-2", "alice", now).Return(nil) + ts.EXPECT().UpdateThreadRoomLastMessage(gomock.Any(), "tr-1", "msg-2", gomock.Any(), now).Return(nil) store.EXPECT().SaveThreadMessage(gomock.Any(), &threadMsg, &expectedSender, "site-a", "tr-1"). Return(errors.New("cassandra: write timeout")) }, @@ -428,6 +428,7 @@ func TestHandler_ProcessMessage(t *testing.T) { mockStore := NewMockStore(ctrl) mockUserStore := NewMockUserStore(ctrl) mockThreadStore := NewMockThreadStore(ctrl) + mockThreadStore.EXPECT().AddReplyAccounts(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() tt.setupMocks(mockStore, mockUserStore, mockThreadStore) h := NewHandler(mockStore, mockUserStore, mockThreadStore, "site-a", func(_ context.Context, _ string, _ []byte, _ string) error { @@ -617,7 +618,7 @@ func TestHandler_HandleThreadRoomAndSubscriptions(t *testing.T) { assert.Nil(t, sub.LastSeenAt, "replier's LastSeenAt should be nil on init") return nil }) - ts.EXPECT().UpdateThreadRoomLastMessage(gomock.Any(), "tr-existing", "msg-reply", "replier", now). + ts.EXPECT().UpdateThreadRoomLastMessage(gomock.Any(), "tr-existing", "msg-reply", gomock.Any(), now). Return(nil) }, extraUserStoreSetup: func(us *MockUserStore) { @@ -649,7 +650,7 @@ func TestHandler_HandleThreadRoomAndSubscriptions(t *testing.T) { assert.Equal(t, "u-parent", sub.UserID) return nil }) - ts.EXPECT().UpdateThreadRoomLastMessage(gomock.Any(), "tr-existing", "msg-reply", "parent-user", now). + ts.EXPECT().UpdateThreadRoomLastMessage(gomock.Any(), "tr-existing", "msg-reply", gomock.Any(), now). Return(nil) }, extraUserStoreSetup: func(us *MockUserStore) { @@ -673,7 +674,7 @@ func TestHandler_HandleThreadRoomAndSubscriptions(t *testing.T) { assert.Equal(t, "u-replier", sub.UserID) return nil }) - ts.EXPECT().UpdateThreadRoomLastMessage(gomock.Any(), "tr-existing", "msg-reply", "replier", now). + ts.EXPECT().UpdateThreadRoomLastMessage(gomock.Any(), "tr-existing", "msg-reply", gomock.Any(), now). Return(nil) }, }, @@ -757,7 +758,7 @@ func TestHandler_HandleThreadRoomAndSubscriptions(t *testing.T) { Return(parentSender, nil) ts.EXPECT().UpsertThreadSubscription(gomock.Any(), gomock.Any()).Return(nil) ts.EXPECT().UpsertThreadSubscription(gomock.Any(), gomock.Any()).Return(nil) - ts.EXPECT().UpdateThreadRoomLastMessage(gomock.Any(), "tr-existing", "msg-reply", "replier", now). + ts.EXPECT().UpdateThreadRoomLastMessage(gomock.Any(), "tr-existing", "msg-reply", gomock.Any(), now). Return(errors.New("mongo: write error")) }, extraUserStoreSetup: func(us *MockUserStore) { @@ -856,7 +857,7 @@ func TestHandler_HandleThreadRoomAndSubscriptions(t *testing.T) { store.EXPECT().GetMessageSender(gomock.Any(), "msg-parent").Return(parentSender, nil) ts.EXPECT().UpsertThreadSubscription(gomock.Any(), gomock.Any()).Return(nil) ts.EXPECT().UpsertThreadSubscription(gomock.Any(), gomock.Any()).Return(nil) - ts.EXPECT().UpdateThreadRoomLastMessage(gomock.Any(), "tr-existing", "msg-reply", "replier", now).Return(nil) + ts.EXPECT().UpdateThreadRoomLastMessage(gomock.Any(), "tr-existing", "msg-reply", gomock.Any(), now).Return(nil) store.EXPECT().UpdateParentMessageThreadRoomID( gomock.Any(), "msg-parent", "r1", now.Add(-5*time.Minute), @@ -887,7 +888,7 @@ func TestHandler_HandleThreadRoomAndSubscriptions(t *testing.T) { store.EXPECT().GetMessageSender(gomock.Any(), "msg-parent").Return(parentSender, nil) ts.EXPECT().UpsertThreadSubscription(gomock.Any(), gomock.Any()).Return(nil) ts.EXPECT().UpsertThreadSubscription(gomock.Any(), gomock.Any()).Return(nil) - ts.EXPECT().UpdateThreadRoomLastMessage(gomock.Any(), "tr-existing", "msg-reply", "replier", now).Return(nil) + ts.EXPECT().UpdateThreadRoomLastMessage(gomock.Any(), "tr-existing", "msg-reply", gomock.Any(), now).Return(nil) store.EXPECT().UpdateParentMessageThreadRoomID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). Return(errors.New("cassandra: write timeout")) }, @@ -920,7 +921,7 @@ func TestHandler_HandleThreadRoomAndSubscriptions(t *testing.T) { assert.Equal(t, "u-replier", sub.UserID) return nil }) - ts.EXPECT().UpdateThreadRoomLastMessage(gomock.Any(), "tr-existing", "msg-reply", "replier", now).Return(nil) + ts.EXPECT().UpdateThreadRoomLastMessage(gomock.Any(), "tr-existing", "msg-reply", gomock.Any(), now).Return(nil) // UpdateParentMessageThreadRoomID must NOT be called — parent doesn't exist // FindUserByID also not called — short-circuited by errMessageNotFound branch }, @@ -969,7 +970,7 @@ func TestHandler_HandleThreadRoomAndSubscriptions(t *testing.T) { assert.Equal(t, "u-replier", sub.UserID) return nil }) - ts.EXPECT().UpdateThreadRoomLastMessage(gomock.Any(), "tr-existing", "msg-reply", "replier", now). + ts.EXPECT().UpdateThreadRoomLastMessage(gomock.Any(), "tr-existing", "msg-reply", gomock.Any(), now). Return(nil) }, extraUserStoreSetup: func(us *MockUserStore) { @@ -1017,6 +1018,7 @@ func TestHandler_HandleThreadRoomAndSubscriptions(t *testing.T) { ctrl := gomock.NewController(t) mockStore := NewMockStore(ctrl) mockThreadStore := NewMockThreadStore(ctrl) + mockThreadStore.EXPECT().AddReplyAccounts(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() mockUserStore := NewMockUserStore(ctrl) tt.setupMocks(mockStore, mockThreadStore) if tt.extraUserStoreSetup != nil { @@ -1212,6 +1214,7 @@ func TestHandler_FirstReply_OutboxPublishes(t *testing.T) { store := NewMockStore(ctrl) us := NewMockUserStore(ctrl) ts := NewMockThreadStore(ctrl) + ts.EXPECT().AddReplyAccounts(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() store.EXPECT().GetMessageSender(gomock.Any(), "msg-parent").Return(parentSender, nil) us.EXPECT().FindUserByID(gomock.Any(), "u-parent").Return(tt.parentUser, nil) @@ -1255,6 +1258,7 @@ func TestHandler_FirstReply_OutboxPublishError_NAKs(t *testing.T) { store := NewMockStore(ctrl) us := NewMockUserStore(ctrl) ts := NewMockThreadStore(ctrl) + ts.EXPECT().AddReplyAccounts(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() store.EXPECT().GetMessageSender(gomock.Any(), "msg-parent"). Return(&cassParticipant{ID: "u-parent", Account: "parent-user"}, nil) @@ -1284,6 +1288,7 @@ func TestHandler_FirstReply_ReplierOutboxPublishError_NAKs(t *testing.T) { store := NewMockStore(ctrl) us := NewMockUserStore(ctrl) ts := NewMockThreadStore(ctrl) + ts.EXPECT().AddReplyAccounts(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() // Parent at the local site → no parent publish. store.EXPECT().GetMessageSender(gomock.Any(), "msg-parent"). @@ -1360,6 +1365,7 @@ func TestHandler_SubsequentReply_OutboxPublishes(t *testing.T) { store := NewMockStore(ctrl) us := NewMockUserStore(ctrl) ts := NewMockThreadStore(ctrl) + ts.EXPECT().AddReplyAccounts(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() ts.EXPECT().GetThreadRoomByParentMessageID(gomock.Any(), "msg-parent"). Return(&model.ThreadRoom{ID: "tr-existing"}, nil) @@ -1367,7 +1373,7 @@ func TestHandler_SubsequentReply_OutboxPublishes(t *testing.T) { us.EXPECT().FindUserByID(gomock.Any(), "u-parent").Return(tt.parentUser, nil) ts.EXPECT().UpsertThreadSubscription(gomock.Any(), gomock.Any()).Return(nil) ts.EXPECT().UpsertThreadSubscription(gomock.Any(), gomock.Any()).Return(nil) - ts.EXPECT().UpdateThreadRoomLastMessage(gomock.Any(), "tr-existing", "msg-reply", "replier", now).Return(nil) + ts.EXPECT().UpdateThreadRoomLastMessage(gomock.Any(), "tr-existing", "msg-reply", gomock.Any(), now).Return(nil) var publishedDests []string h := NewHandler(store, us, ts, "site-a", func(_ context.Context, _ string, data []byte, _ string) error { @@ -1408,6 +1414,7 @@ func TestHandler_SubsequentReply_OutboxPublishError_NAKs(t *testing.T) { store := NewMockStore(ctrl) us := NewMockUserStore(ctrl) ts := NewMockThreadStore(ctrl) + ts.EXPECT().AddReplyAccounts(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() ts.EXPECT().GetThreadRoomByParentMessageID(gomock.Any(), "msg-parent"). Return(&model.ThreadRoom{ID: "tr-1"}, nil) @@ -1487,6 +1494,7 @@ func TestHandler_MarkThreadMentions_OutboxPublishes(t *testing.T) { t.Run(tt.name, func(t *testing.T) { ctrl := gomock.NewController(t) ts := NewMockThreadStore(ctrl) + ts.EXPECT().AddReplyAccounts(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() expectedMarks := 0 for _, p := range tt.mentionees { @@ -1537,6 +1545,7 @@ func TestHandler_MarkThreadMentions_OutboxPublishError_NAKs(t *testing.T) { now := time.Date(2026, 1, 1, 12, 0, 0, 0, time.UTC) ctrl := gomock.NewController(t) ts := NewMockThreadStore(ctrl) + ts.EXPECT().AddReplyAccounts(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() ts.EXPECT().MarkThreadSubscriptionMention(gomock.Any(), gomock.Any()).Return(nil) boom := errors.New("publish boom") @@ -1557,6 +1566,7 @@ func TestHandler_MarkThreadMentions_HasMentionInPayload(t *testing.T) { now := time.Date(2026, 1, 1, 12, 0, 0, 0, time.UTC) ctrl := gomock.NewController(t) ts := NewMockThreadStore(ctrl) + ts.EXPECT().AddReplyAccounts(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() ts.EXPECT().MarkThreadSubscriptionMention(gomock.Any(), gomock.Any()).Return(nil) var captured []byte @@ -1653,6 +1663,7 @@ func TestHandler_HandleJetStreamMsg(t *testing.T) { mockStore := NewMockStore(ctrl) mockUserStore := NewMockUserStore(ctrl) mockThreadStore := NewMockThreadStore(ctrl) + mockThreadStore.EXPECT().AddReplyAccounts(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() tt.setupMocks(mockStore, mockUserStore, mockThreadStore) h := NewHandler(mockStore, mockUserStore, mockThreadStore, "site-a", func(_ context.Context, _ string, _ []byte, _ string) error { @@ -1711,6 +1722,7 @@ func TestHandler_ProcessMessage_Quote(t *testing.T) { store := NewMockStore(ctrl) userStore := NewMockUserStore(ctrl) threadStore := NewMockThreadStore(ctrl) + threadStore.EXPECT().AddReplyAccounts(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() userStore.EXPECT().FindUserByID(gomock.Any(), "u-1").Return(user, nil) store.EXPECT(). diff --git a/message-worker/integration_test.go b/message-worker/integration_test.go index 5bca2e5ee..f67ebd81d 100644 --- a/message-worker/integration_test.go +++ b/message-worker/integration_test.go @@ -668,6 +668,17 @@ func TestHandler_Integration_ThreadReplyWithMention(t *testing.T) { require.NoError(t, err) assert.Equal(t, int64(3), count) }) + + t.Run("thread_rooms.replyAccounts contains replier + parent author + mentioned user", func(t *testing.T) { + var got model.ThreadRoom + err := db.Collection("thread_rooms").FindOne(ctx, bson.M{ + "parentMessageId": "msg-parent-mention", + }).Decode(&got) + require.NoError(t, err) + assert.ElementsMatch(t, []string{"replier", "parent-user", "bob"}, got.ReplyAccounts, + "replyAccounts should match thread_subscriptions members so notification-worker "+ + "can use this single field as the follower set") + }) } func TestThreadStoreMongo_CreateThreadRoom(t *testing.T) { @@ -901,7 +912,7 @@ func TestThreadStoreMongo_UpdateThreadRoomLastMessage(t *testing.T) { require.NoError(t, store.CreateThreadRoom(ctx, room)) later := now.Add(10 * time.Minute) - err := store.UpdateThreadRoomLastMessage(ctx, "tr-update", "msg-5", "bob", later) + err := store.UpdateThreadRoomLastMessage(ctx, "tr-update", "msg-5", []string{"bob"}, later) require.NoError(t, err) got, err := store.GetThreadRoomByParentMessageID(ctx, "msg-parent-update") diff --git a/message-worker/main.go b/message-worker/main.go index ce0e4d46c..279afb687 100644 --- a/message-worker/main.go +++ b/message-worker/main.go @@ -40,6 +40,8 @@ type config struct { MongoDB string `env:"MONGO_DB" envDefault:"chat"` MongoUsername string `env:"MONGO_USERNAME" envDefault:""` MongoPassword string `env:"MONGO_PASSWORD" envDefault:""` + UserCacheSize int `env:"USER_CACHE_SIZE" envDefault:"10000"` + UserCacheTTL time.Duration `env:"USER_CACHE_TTL" envDefault:"5m"` Consumer stream.ConsumerSettings `envPrefix:"CONSUMER_"` Bootstrap bootstrapConfig `envPrefix:"BOOTSTRAP_"` Atrest atrest.Config @@ -100,7 +102,13 @@ func main() { os.Exit(1) } db := mongoClient.Database(cfg.MongoDB) - us := userstore.NewMongoStore(db.Collection("users")) + us, err := userstore.NewCache(userstore.NewMongoStore(db.Collection("users")), + cfg.UserCacheSize, cfg.UserCacheTTL) + if err != nil { + slog.Error("init user cache failed", "error", err) + os.Exit(1) + } + slog.Info("user-cache enabled", "size", cfg.UserCacheSize, "ttl", cfg.UserCacheTTL) var ( cipher atrest.Cipher diff --git a/message-worker/mock_store_test.go b/message-worker/mock_store_test.go index 3a11fb1a8..997cd228d 100644 --- a/message-worker/mock_store_test.go +++ b/message-worker/mock_store_test.go @@ -181,17 +181,31 @@ func (mr *MockThreadStoreMockRecorder) MarkThreadSubscriptionMention(ctx, sub an } // UpdateThreadRoomLastMessage mocks base method. -func (m *MockThreadStore) UpdateThreadRoomLastMessage(ctx context.Context, threadRoomID, lastMsgID, replierAccount string, lastMsgAt time.Time) error { +func (m *MockThreadStore) UpdateThreadRoomLastMessage(ctx context.Context, threadRoomID, lastMsgID string, replyAccounts []string, lastMsgAt time.Time) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UpdateThreadRoomLastMessage", ctx, threadRoomID, lastMsgID, replierAccount, lastMsgAt) + ret := m.ctrl.Call(m, "UpdateThreadRoomLastMessage", ctx, threadRoomID, lastMsgID, replyAccounts, lastMsgAt) ret0, _ := ret[0].(error) return ret0 } // UpdateThreadRoomLastMessage indicates an expected call of UpdateThreadRoomLastMessage. -func (mr *MockThreadStoreMockRecorder) UpdateThreadRoomLastMessage(ctx, threadRoomID, lastMsgID, replierAccount, lastMsgAt any) *gomock.Call { +func (mr *MockThreadStoreMockRecorder) UpdateThreadRoomLastMessage(ctx, threadRoomID, lastMsgID, replyAccounts, lastMsgAt any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateThreadRoomLastMessage", reflect.TypeOf((*MockThreadStore)(nil).UpdateThreadRoomLastMessage), ctx, threadRoomID, lastMsgID, replierAccount, lastMsgAt) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateThreadRoomLastMessage", reflect.TypeOf((*MockThreadStore)(nil).UpdateThreadRoomLastMessage), ctx, threadRoomID, lastMsgID, replyAccounts, lastMsgAt) +} + +// AddReplyAccounts mocks base method. +func (m *MockThreadStore) AddReplyAccounts(ctx context.Context, threadRoomID string, accounts []string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AddReplyAccounts", ctx, threadRoomID, accounts) + ret0, _ := ret[0].(error) + return ret0 +} + +// AddReplyAccounts indicates an expected call of AddReplyAccounts. +func (mr *MockThreadStoreMockRecorder) AddReplyAccounts(ctx, threadRoomID, accounts any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddReplyAccounts", reflect.TypeOf((*MockThreadStore)(nil).AddReplyAccounts), ctx, threadRoomID, accounts) } // UpsertThreadSubscription mocks base method. diff --git a/message-worker/store.go b/message-worker/store.go index 5ec2b6176..c41f69988 100644 --- a/message-worker/store.go +++ b/message-worker/store.go @@ -25,5 +25,12 @@ type ThreadStore interface { InsertThreadSubscription(ctx context.Context, sub *model.ThreadSubscription) error UpsertThreadSubscription(ctx context.Context, sub *model.ThreadSubscription) error MarkThreadSubscriptionMention(ctx context.Context, sub *model.ThreadSubscription) error - UpdateThreadRoomLastMessage(ctx context.Context, threadRoomID, lastMsgID, replierAccount string, lastMsgAt time.Time) error + // UpdateThreadRoomLastMessage bumps the last-message pointer and $addToSet-merges + // the supplied accounts (replier + parent author on the subsequent-reply path) into + // replyAccounts in one write. + UpdateThreadRoomLastMessage(ctx context.Context, threadRoomID, lastMsgID string, replyAccounts []string, lastMsgAt time.Time) error + // AddReplyAccounts $addToSet-merges accounts into thread_rooms.replyAccounts. + // Used by paths that don't already update lastMsg (first-reply parent author, + // mention-only subscribers) so the field mirrors thread_subscriptions membership. + AddReplyAccounts(ctx context.Context, threadRoomID string, accounts []string) error } diff --git a/message-worker/store_mongo.go b/message-worker/store_mongo.go index 516b6b2c1..784b183da 100644 --- a/message-worker/store_mongo.go +++ b/message-worker/store_mongo.go @@ -129,17 +129,32 @@ func (s *threadStoreMongo) MarkThreadSubscriptionMention(ctx context.Context, su return nil } -func (s *threadStoreMongo) UpdateThreadRoomLastMessage(ctx context.Context, threadRoomID, lastMsgID, replierAccount string, lastMsgAt time.Time) error { - _, err := s.threadRooms.UpdateOne(ctx, bson.M{"_id": threadRoomID}, bson.M{ +func (s *threadStoreMongo) UpdateThreadRoomLastMessage(ctx context.Context, threadRoomID, lastMsgID string, replyAccounts []string, lastMsgAt time.Time) error { + update := bson.M{ "$set": bson.M{ "lastMsgAt": lastMsgAt, "lastMsgId": lastMsgID, "updatedAt": lastMsgAt, }, - "$addToSet": bson.M{"replyAccounts": replierAccount}, + } + if len(replyAccounts) > 0 { + update["$addToSet"] = bson.M{"replyAccounts": bson.M{"$each": replyAccounts}} + } + if _, err := s.threadRooms.UpdateOne(ctx, bson.M{"_id": threadRoomID}, update); err != nil { + return fmt.Errorf("update thread room last message: %w", err) + } + return nil +} + +func (s *threadStoreMongo) AddReplyAccounts(ctx context.Context, threadRoomID string, accounts []string) error { + if len(accounts) == 0 { + return nil + } + _, err := s.threadRooms.UpdateOne(ctx, bson.M{"_id": threadRoomID}, bson.M{ + "$addToSet": bson.M{"replyAccounts": bson.M{"$each": accounts}}, }) if err != nil { - return fmt.Errorf("update thread room last message: %w", err) + return fmt.Errorf("add reply accounts to thread room %s: %w", threadRoomID, err) } return nil } diff --git a/notification-worker/bootstrap.go b/notification-worker/bootstrap.go index b63ba7f56..3309db5fa 100644 --- a/notification-worker/bootstrap.go +++ b/notification-worker/bootstrap.go @@ -11,35 +11,19 @@ import ( "github.com/hmchangw/chat/pkg/stream" ) -// bootstrapConfig groups every field that is ONLY meaningful when the -// service is being stood up in dev or integration tests against a NATS -// instance where the streams it consumes do not yet exist. In production -// streams are pre-provisioned by ops/IaC and Bootstrap.Enabled must remain -// false; the service only creates its own durable consumer. +// bootstrapConfig gates stream creation to dev/integration; leave Enabled false in production. type bootstrapConfig struct { - // Enabled (BOOTSTRAP_STREAMS) toggles whether the service calls - // CreateOrUpdateStream at startup for the streams it consumes. - // Leave false in production. Enabled bool `env:"STREAMS" envDefault:"false"` } -// streamManager is the minimal JetStream surface bootstrapStreams depends on. -// Kept service-local so we don't pollute pkg/ with a multi-method type and so -// tests can inject a fake without mockgen. +// streamManager is the narrow JetStream surface bootstrapStreams uses, injected by tests. type streamManager interface { CreateOrUpdateStream(ctx context.Context, cfg jetstream.StreamConfig) (oteljetstream.Stream, error) Stream(ctx context.Context, name string) (oteljetstream.Stream, error) } -// bootstrapStreams handles the JetStream MESSAGES_CANONICAL stream this -// service uses. When enabled (dev/integration), it creates the stream via -// CreateOrUpdateStream. When disabled (production), it verifies the stream -// exists via Stream() and returns an error if it doesn't — fail-fast so a -// misprovisioned deploy surfaces at startup rather than at first publish. -// -// Ownership rule: this helper sets only the stream schema (Name + Subjects) -// from pkg/stream.MessagesCanonical. Federation config belongs to ops/IaC and -// is layered on in production. App code never sets it. +// bootstrapStreams creates MESSAGES_CANONICAL + PUSH_NOTIFICATIONS when enabled (dev/integration). +// When disabled it verifies MESSAGES_CANONICAL exists so a misconfigured deploy fails at startup. func bootstrapStreams(ctx context.Context, js streamManager, siteID string, enabled bool) error { canonicalCfg := stream.MessagesCanonical(siteID) if enabled { @@ -49,11 +33,20 @@ func bootstrapStreams(ctx context.Context, js streamManager, siteID string, enab }); err != nil { return fmt.Errorf("create MESSAGES_CANONICAL stream: %w", err) } + pushCfg := stream.PushNotifications(siteID) + if _, err := js.CreateOrUpdateStream(ctx, jetstream.StreamConfig{ + Name: pushCfg.Name, + Subjects: pushCfg.Subjects, + // S2 storage compression — transparent to publisher/consumer; ~2× ratio on JSON + // at near-zero CPU. Belt-and-braces alongside the publisher's gzip: gzip shrinks + // inter-replica wire bytes, S2 shrinks on-disk bytes after gzip overhead. + Compression: jetstream.S2Compression, + }); err != nil { + return fmt.Errorf("create PUSH_NOTIFICATIONS stream: %w", err) + } return nil } - // Production path: verify the stream exists. Fail fast if it doesn't — - // ops/IaC owns provisioning, and a missing stream means the deploy is - // broken before the first publish or consume. + // PUSH_NOTIFICATIONS absence is non-fatal: async publish surfaces errors per-publish. if _, err := js.Stream(ctx, canonicalCfg.Name); err != nil { return fmt.Errorf("verify MESSAGES_CANONICAL stream: %w", err) } diff --git a/notification-worker/bootstrap_test.go b/notification-worker/bootstrap_test.go index 236829529..03444daf8 100644 --- a/notification-worker/bootstrap_test.go +++ b/notification-worker/bootstrap_test.go @@ -19,7 +19,6 @@ type fakeStreamManager struct { failErr error // error to return when failing } -// Returns nil for the Stream value because bootstrapStreams discards it. func (f *fakeStreamManager) CreateOrUpdateStream(_ context.Context, cfg jetstream.StreamConfig) (oteljetstream.Stream, error) { //nolint:gocritic // hugeParam: cfg is passed by value to satisfy the streamManager interface if f.failOn != "" && cfg.Name == f.failOn { return nil, f.failErr @@ -58,10 +57,10 @@ func TestBootstrapStreams(t *testing.T) { wantErrSub: "verify MESSAGES_CANONICAL stream", }, { - name: "enabled - creates MESSAGES_CANONICAL", + name: "enabled - creates MESSAGES_CANONICAL and PUSH_NOTIFICATIONS", enabled: true, existing: map[string]bool{}, - wantCreated: []string{"MESSAGES_CANONICAL_test"}, + wantCreated: []string{"MESSAGES_CANONICAL_test", "PUSH_NOTIFICATIONS_test"}, }, { name: "enabled - wraps MESSAGES_CANONICAL creator error", @@ -71,6 +70,14 @@ func TestBootstrapStreams(t *testing.T) { failErr: errors.New("nats down"), wantErrSub: "create MESSAGES_CANONICAL stream", }, + { + name: "enabled - wraps PUSH_NOTIFICATIONS creator error", + enabled: true, + existing: map[string]bool{}, + failOn: "PUSH_NOTIFICATIONS_test", + failErr: errors.New("nats down"), + wantErrSub: "create PUSH_NOTIFICATIONS stream", + }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { diff --git a/notification-worker/deploy/docker-compose.yml b/notification-worker/deploy/docker-compose.yml index 492a9e808..8f594b95e 100644 --- a/notification-worker/deploy/docker-compose.yml +++ b/notification-worker/deploy/docker-compose.yml @@ -11,6 +11,22 @@ services: - SITE_ID=site-local - MONGO_URI=mongodb://mongodb:27017 - MONGO_DB=chat + - VALKEY_ADDRS=valkey:6379 + - ROOMSUBCACHE_TTL=5m + - LARGE_ROOM_THRESHOLD=500 + # Recipients per PUSH_NOTIFICATIONS event. Default 100; well under FCM's 500-token + # multicast cap and APNs HTTP/2 connection-frame ceiling. + - PUSH_RECIPIENT_BATCH_SIZE=100 + # Title is resolved here from the rooms collection; sender display name is + # pre-composed by message-gatekeeper and propagated on the canonical message, + # so no users-collection lookup runs in this service. + - ROOM_META_CACHE_SIZE=10000 + - ROOM_META_CACHE_TTL=2m + - PRESENCE_RPC_ENABLED=false + # Must match the broker's max_payload. Emitter rejects gzipped batches + # larger than this before publish so the failure surfaces with a clear + # error instead of a NATS NACK. + - NATS_MAX_PAYLOAD_BYTES=262144 - BOOTSTRAP_STREAMS=true volumes: - ../../docker-local/backend.creds:/etc/nats/backend.creds:ro diff --git a/notification-worker/emit.go b/notification-worker/emit.go new file mode 100644 index 000000000..9deef1775 --- /dev/null +++ b/notification-worker/emit.go @@ -0,0 +1,67 @@ +package main + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/nats-io/nats.go" + "github.com/nats-io/nats.go/jetstream" + + "github.com/hmchangw/chat/pkg/model" + "github.com/hmchangw/chat/pkg/natsutil" + "github.com/hmchangw/chat/pkg/subject" +) + +// publisher is the narrow sync-publish surface mobileEmitter needs. +// Sync semantics let the handler nak on publish failure; {messageId}-b{N} dedup +// protects against duplicate emission of batches that already succeeded. +type publisher interface { + PublishMsg(ctx context.Context, msg *nats.Msg) error +} + +// Emitter dispatches one batched push event per ~RecipientBatchSize recipients. +type Emitter interface { + Emit(ctx context.Context, evt model.PushNotificationEvent) error +} + +type mobileEmitter struct { + pub publisher + siteID string + maxPayloadBytes int +} + +func newMobileEmitter(pub publisher, siteID string, maxPayloadBytes int) *mobileEmitter { + return &mobileEmitter{pub: pub, siteID: siteID, maxPayloadBytes: maxPayloadBytes} +} + +func (e *mobileEmitter) Emit(ctx context.Context, evt model.PushNotificationEvent) error { //nolint:gocritic // hugeParam: spec requires value semantics for Emitter interface + data, err := json.Marshal(evt) + if err != nil { + return fmt.Errorf("marshal push batch %s: %w", evt.ID, err) + } + msg, err := natsutil.NewGzipMsg(subject.PushNotification(e.siteID), data, "application/json") + if err != nil { + return fmt.Errorf("encode push batch %s: %w", evt.ID, err) + } + if e.maxPayloadBytes > 0 && len(msg.Data) > e.maxPayloadBytes { + return fmt.Errorf("push batch %s exceeds NATS max_payload: gzipped=%d, cap=%d", evt.ID, len(msg.Data), e.maxPayloadBytes) + } + msg.Header.Set("Nats-Msg-Id", evt.ID) // dedup key — see contract doc § Dedup + if err := e.pub.PublishMsg(ctx, msg); err != nil { + return fmt.Errorf("publish push batch %s: %w", evt.ID, err) + } + return nil +} + +// jsPublisher adapts oteljetstream.JetStream to the publisher interface by discarding the PubAck. +type jsPublisher struct { + js interface { + PublishMsg(ctx context.Context, msg *nats.Msg, opts ...jetstream.PublishOpt) (*jetstream.PubAck, error) + } +} + +func (p *jsPublisher) PublishMsg(ctx context.Context, msg *nats.Msg) error { + _, err := p.js.PublishMsg(ctx, msg) + return err +} diff --git a/notification-worker/emit_test.go b/notification-worker/emit_test.go new file mode 100644 index 000000000..e3e990bd6 --- /dev/null +++ b/notification-worker/emit_test.go @@ -0,0 +1,97 @@ +package main + +import ( + "context" + "encoding/json" + "errors" + "sync" + "testing" + + "github.com/nats-io/nats.go" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/hmchangw/chat/pkg/model" + "github.com/hmchangw/chat/pkg/natsutil" +) + +type recordedPublish struct { + subject string + msgID string + headers nats.Header + payload []byte +} + +type fakePublisher struct { + mu sync.Mutex + records []recordedPublish + failNext error +} + +func (f *fakePublisher) PublishMsg(_ context.Context, msg *nats.Msg) error { + f.mu.Lock() + defer f.mu.Unlock() + if f.failNext != nil { + err := f.failNext + f.failNext = nil + return err + } + hdrCopy := nats.Header{} + for k, v := range msg.Header { + hdrCopy[k] = append([]string(nil), v...) + } + f.records = append(f.records, recordedPublish{ + subject: msg.Subject, + msgID: msg.Header.Get("Nats-Msg-Id"), + headers: hdrCopy, + payload: append([]byte(nil), msg.Data...), + }) + return nil +} + +func TestMobileEmitter_PublishesGzippedBatch(t *testing.T) { + pub := &fakePublisher{} + em := newMobileEmitter(pub, "site-a", 0) + evt := model.PushNotificationEvent{ + ID: "m1-b0", + Accounts: []string{"alice", "bob"}, + RoomID: "r1", + Body: "hello", + } + require.NoError(t, em.Emit(context.Background(), evt)) + + require.Len(t, pub.records, 1) + r := pub.records[0] + assert.Equal(t, "chat.server.notification.push.site-a.send", r.subject) + assert.Equal(t, "m1-b0", r.msgID, "Nats-Msg-Id is the batch dedup key") + assert.Equal(t, "gzip", r.headers.Get("Content-Encoding")) + assert.Equal(t, "application/json", r.headers.Get("Content-Type")) + + // Payload must round-trip via the shared natsutil decoder so any consumer can use it. + decoded, err := natsutil.DecodePayload(&nats.Msg{Data: r.payload, Header: r.headers}) + require.NoError(t, err) + var got model.PushNotificationEvent + require.NoError(t, json.Unmarshal(decoded, &got)) + assert.Equal(t, evt, got) +} + +func TestMobileEmitter_PropagatesError(t *testing.T) { + pub := &fakePublisher{failNext: errors.New("nats: full")} + em := newMobileEmitter(pub, "site-a", 0) + err := em.Emit(context.Background(), model.PushNotificationEvent{ID: "m1-b0", Accounts: []string{"bob"}}) + assert.Error(t, err) +} + +func TestMobileEmitter_RejectsOversizedBatch(t *testing.T) { + pub := &fakePublisher{} + em := newMobileEmitter(pub, "site-a", 64) // absurdly low cap to force rejection + err := em.Emit(context.Background(), model.PushNotificationEvent{ + ID: "m1-b0", + Accounts: []string{"alice", "bob", "carol", "dave"}, + Body: "this body plus accounts and headers will gzip larger than 64 bytes", + RoomID: "r1", + }) + require.Error(t, err) + assert.Contains(t, err.Error(), "exceeds NATS max_payload") + assert.Empty(t, pub.records, "oversized batch must not reach the publisher") +} diff --git a/notification-worker/handler.go b/notification-worker/handler.go index ac311877a..803b0bf68 100644 --- a/notification-worker/handler.go +++ b/notification-worker/handler.go @@ -3,71 +3,278 @@ package main import ( "context" "encoding/json" + "errors" "fmt" "log/slog" + "sort" "time" + "github.com/hmchangw/chat/pkg/mention" "github.com/hmchangw/chat/pkg/model" "github.com/hmchangw/chat/pkg/natsutil" - "github.com/hmchangw/chat/pkg/subject" + "github.com/hmchangw/chat/pkg/roommetacache" + "github.com/hmchangw/chat/pkg/roomsubcache" ) -// MemberLookup reads room membership from a data store. -type MemberLookup interface { - ListSubscriptions(ctx context.Context, roomID string) ([]model.Subscription, error) +// defaultRecipientBatchSize mirrors PUSH_RECIPIENT_BATCH_SIZE's envDefault so unit tests don't re-declare it. +const defaultRecipientBatchSize = 100 + +// MemberCache reads the cached member list and supports targeted invalidation. +type MemberCache interface { + GetMembers(ctx context.Context, roomID string) ([]roomsubcache.Member, error) + Invalidate(ctx context.Context, roomID string) +} + +// RoomMetaGetter returns cached room metadata so push-service doesn't hit Mongo. +type RoomMetaGetter interface { + Get(ctx context.Context, roomID string) (roommetacache.Meta, error) } -// Publisher abstracts NATS publishing so the handler is testable. -type Publisher interface { - Publish(ctx context.Context, subject string, data []byte) error +// HandlerDeps groups the handler's collaborators. +type HandlerDeps struct { + Members MemberCache + Followers ThreadFollowerLister + Presence PresenceSnapshotter + Hook Vetoer + Emitter Emitter + RoomMeta RoomMetaGetter // nil → title falls back to sender.Account + LargeRoomThreshold int + RecipientBatchSize int // per-event cap (≥ 1); 0 → defaultRecipientBatchSize } -// Handler processes MESSAGES_CANONICAL messages and sends notifications. +// Handler runs the per-message fan-out pipeline: +// +// Stage 1 — exclusion filters (sender / mute / restricted / thread-non-follower) +// Stage 2 — in-process hook veto (suppress-only, fail-open on error) +// Stage 3 — pure routing predicate (EligibleForPush) +// Stage 4 — one bulk presence RPC, then per-account shouldPush +// +// followed by one Emitter.Emit per surviving recipient. type Handler struct { - members MemberLookup - pub Publisher + deps HandlerDeps } -func NewHandler(members MemberLookup, pub Publisher) *Handler { - return &Handler{members: members, pub: pub} +func NewHandler(deps HandlerDeps) *Handler { //nolint:gocritic // hugeParam: one-time constructor arg + if deps.LargeRoomThreshold <= 0 { + deps.LargeRoomThreshold = 500 + } + if deps.RecipientBatchSize <= 0 { + deps.RecipientBatchSize = defaultRecipientBatchSize + } + return &Handler{deps: deps} } -// HandleMessage processes a single JetStream message payload. func (h *Handler) HandleMessage(ctx context.Context, data []byte) error { var evt model.MessageEvent if err := json.Unmarshal(data, &evt); err != nil { return fmt.Errorf("unmarshal message event: %w", err) } + msg := evt.Message + + // Member-change sys-messages drive cache invalidation (Option C; safe because room-worker guards add/remove to channels). + if msg.Type != "" { + switch msg.Type { + case model.MessageTypeMembersAdded, model.MessageTypeMemberLeft, model.MessageTypeMemberRemoved: + h.deps.Members.Invalidate(ctx, msg.RoomID) + } + } - subs, err := h.members.ListSubscriptions(ctx, evt.Message.RoomID) + members, err := h.deps.Members.GetMembers(ctx, msg.RoomID) if err != nil { - return fmt.Errorf("list subscriptions for room %s: %w", evt.Message.RoomID, err) + return fmt.Errorf("get members for room %s: %w", msg.RoomID, err) + } + if len(members) == 0 { + return nil } - notif := model.NotificationEvent{ - Type: "new_message", - RoomID: evt.Message.RoomID, - Message: evt.Message, - Timestamp: time.Now().UTC().UnixMilli(), + mentionInfo := mention.Parse(msg.Content) + mentionedAccounts := mentionedSet(mentionInfo) + // @here is deliberately NOT a push trigger — the legacy frontend doesn't render it. + mentionsAll := mentionInfo.MentionAll + isLargeRoom := len(members) > h.deps.LargeRoomThreshold + isThreadOnlyReply := msg.ThreadParentMessageID != "" && !msg.TShow + + var followers map[string]struct{} + if isThreadOnlyReply { + f, ferr := h.deps.Followers.Followers(ctx, msg.ThreadParentMessageID) + if ferr != nil { + slog.Warn("thread followers lookup failed, treating as empty", + "error", ferr, "parentMessageId", msg.ThreadParentMessageID, + "request_id", natsutil.RequestIDFromContext(ctx)) + f = map[string]struct{}{} + } + followers = f } - notifData, err := natsutil.MarshalResponse(notif) - if err != nil { - return fmt.Errorf("marshal notification: %w", err) + roomType := members[0].RoomType + + // Sender display name is composed by message-gatekeeper at write time; no per-message lookup here. + sender := &model.Participant{ + UserID: msg.UserID, + Account: msg.UserAccount, + DisplayName: msg.SenderDisplayName(), } - senderID := evt.Message.UserID + candidates := make([]roomsubcache.Member, 0, len(members)) + accounts := make([]string, 0, len(members)) + for i := range members { + m := members[i] + if m.ID == msg.UserID { + continue + } + if m.Muted { + continue + } + if isRestricted(m, msg, isThreadOnlyReply) { + continue + } + + mentioned := mentionsAll || mentionedAccounts[m.Account] + + if isThreadOnlyReply { + _, follows := followers[m.Account] + if !follows && !mentioned { + continue + } + } - for i := range subs { - if subs[i].User.ID == senderID { + // Stage 2: hook veto (fail-open on error). + allow, herr := h.deps.Hook.Allow(ctx, &msg, m) + if herr != nil { + slog.Warn("hook errored, allowing", "error", herr, "account", m.Account, + "request_id", natsutil.RequestIDFromContext(ctx)) + allow = true + } + if !allow { continue } - subj := subject.Notification(subs[i].User.Account) - if err := h.pub.Publish(ctx, subj, notifData); err != nil { - // account is intentionally logged for operability; do NOT add message body / token fields. - slog.Error("publish notification failed", "error", err, "account", subs[i].User.Account) + + if !EligibleForPush(&m, roomType, isLargeRoom, mentioned) { + continue + } + + candidates = append(candidates, m) + accounts = append(accounts, m.Account) + } + if len(candidates) == 0 { + return nil + } + + snapshot, _ := h.deps.Presence.Snapshot(ctx, accounts) // fail-open: error → empty map + + // Sort survivors so batch N has a deterministic account set across redeliveries — + // required for the {messageID}-b{N} Nats-Msg-Id to dedup correctly. + survivors := make([]string, 0, len(candidates)) + for _, c := range candidates { + if !shouldPush(snapshot[c.Account]) { + continue } + survivors = append(survivors, c.Account) + } + if len(survivors) == 0 { + return nil + } + sort.Strings(survivors) + + now := time.Now().UTC() + // Template carries fields shared across every batch — only ID and Accounts change per batch. + pushEvt := model.PushNotificationEvent{ + RoomID: msg.RoomID, + Title: h.resolveTitle(ctx, msg.RoomID, roomType, sender), + Body: msg.Content, + Data: model.PushNotificationData{ + RoomID: msg.RoomID, + MessageID: msg.ID, + Type: shortRoomType(roomType), + Sender: sender, + ThreadMessageID: msg.ThreadParentMessageID, + PushTime: now.Format(time.RFC3339), + AlsoSendToChannel: msg.TShow, + }, + Timestamp: now.UnixMilli(), } + batchSize := h.deps.RecipientBatchSize + // Aggregate per-batch errors so one bad batch doesn't punish the others; still return + // an error so the caller naks and JetStream redelivers. {messageId}-b{N} dedup protects + // against duplicate emission of batches that already succeeded. + var emitErrs []error + for i, batchIdx := 0, 0; i < len(survivors); i, batchIdx = i+batchSize, batchIdx+1 { + end := i + batchSize + if end > len(survivors) { + end = len(survivors) + } + batchAccounts := make([]string, end-i) + copy(batchAccounts, survivors[i:end]) + + evt := pushEvt + evt.ID = fmt.Sprintf("%s-b%d", msg.ID, batchIdx) + evt.Accounts = batchAccounts + if err := h.deps.Emitter.Emit(ctx, evt); err != nil { + slog.Error("emit push batch failed", "error", err, "batch", batchIdx, + "recipients", len(batchAccounts), "messageId", msg.ID, + "request_id", natsutil.RequestIDFromContext(ctx)) + emitErrs = append(emitErrs, fmt.Errorf("emit push batch %d: %w", batchIdx, err)) + } + } + if len(emitErrs) > 0 { + return fmt.Errorf("emit push batches for message %s: %w", msg.ID, errors.Join(emitErrs...)) + } return nil } + +// mentionedSet returns mentioned accounts as a set for O(1) per-recipient lookup. +// msg.Mentions is not populated by message-gatekeeper, so only Parse output is used. +func mentionedSet(parsed mention.ParseResult) map[string]bool { + out := make(map[string]bool, len(parsed.Accounts)) + for _, a := range parsed.Accounts { + out[a] = true + } + return out +} + +// isRestricted filters members who joined after the relevant message timestamp. +// Thread replies use the parent's CreatedAt; a nil parent ts is "no access" (legacy records). +func isRestricted(m roomsubcache.Member, msg model.Message, isThreadOnlyReply bool) bool { //nolint:gocritic // hugeParam: hot loop, pointer indirection adds no benefit + if m.HistorySharedSince == nil { + return false + } + if isThreadOnlyReply { + if msg.ThreadParentMessageCreatedAt == nil { + return true + } + return msg.ThreadParentMessageCreatedAt.UnixMilli() < *m.HistorySharedSince + } + return msg.CreatedAt.UnixMilli() < *m.HistorySharedSince +} + +func shortRoomType(t model.RoomType) string { + switch t { + case model.RoomTypeDM, model.RoomTypeBotDM: + return "d" + case model.RoomTypeDiscussion: + return "p" + default: + return "c" + } +} + +// resolveTitle returns the room name when present, else the sender's account (the legacy rule). +// DM/botDM rooms skip the cache lookup — they never have names. RoomMeta failures fall back to +// the sender so push-service still gets a usable title. +func (h *Handler) resolveTitle(ctx context.Context, roomID string, roomType model.RoomType, sender *model.Participant) string { + if h.deps.RoomMeta != nil && roomType != model.RoomTypeDM && roomType != model.RoomTypeBotDM { + meta, err := h.deps.RoomMeta.Get(ctx, roomID) + switch { + case err == nil && meta.Name != "": + return meta.Name + case err != nil: + slog.Warn("room meta lookup failed, falling back to sender", + "error", err, "roomId", roomID, "request_id", natsutil.RequestIDFromContext(ctx)) + } + } + if sender != nil { + return sender.Account + } + return "" +} diff --git a/notification-worker/handler_test.go b/notification-worker/handler_test.go index ac798f166..0e244af4e 100644 --- a/notification-worker/handler_test.go +++ b/notification-worker/handler_test.go @@ -3,192 +3,779 @@ package main import ( "context" "encoding/json" + "errors" + "fmt" "sync" "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/hmchangw/chat/pkg/model" + "github.com/hmchangw/chat/pkg/roommetacache" + "github.com/hmchangw/chat/pkg/roomsubcache" ) -// --- In-memory MemberLookup stub --- +type stubRoomMeta struct { + out map[string]roommetacache.Meta + err error +} + +func (s *stubRoomMeta) Get(_ context.Context, roomID string) (roommetacache.Meta, error) { + if s.err != nil { + return roommetacache.Meta{}, s.err + } + return s.out[roomID], nil +} + +type stubMembers struct { + mu sync.Mutex + out map[string][]roomsubcache.Member + calls []string // recorded in order: "get:" / "inval:" +} + +func (s *stubMembers) GetMembers(_ context.Context, roomID string) ([]roomsubcache.Member, error) { + s.mu.Lock() + s.calls = append(s.calls, "get:"+roomID) + s.mu.Unlock() + return s.out[roomID], nil +} + +func (s *stubMembers) Invalidate(_ context.Context, roomID string) { + s.mu.Lock() + s.calls = append(s.calls, "inval:"+roomID) + s.mu.Unlock() +} + +type stubFollowers struct { + out map[string]map[string]struct{} +} + +func (s *stubFollowers) Followers(_ context.Context, parentID string) (map[string]struct{}, error) { + if v, ok := s.out[parentID]; ok { + return v, nil + } + return map[string]struct{}{}, nil +} -type stubMemberLookup struct { - subs map[string][]model.Subscription +type stubPresence struct { + out map[string]model.Presence } -func (s *stubMemberLookup) ListSubscriptions(ctx context.Context, roomID string) ([]model.Subscription, error) { - return s.subs[roomID], nil +func (s *stubPresence) Snapshot(_ context.Context, _ []string) (map[string]model.Presence, error) { + return s.out, nil } -// --- NATS publish recorder --- +type rejectHook struct{} -type publishRecord struct { - subject string - data []byte +func (rejectHook) Allow(context.Context, *model.Message, roomsubcache.Member) (bool, error) { + return false, nil } -type mockPublisher struct { +type recordingEmitter struct { mu sync.Mutex - records []publishRecord + emitted []model.PushNotificationEvent } -func (m *mockPublisher) Publish(_ context.Context, subject string, data []byte) error { - m.mu.Lock() - defer m.mu.Unlock() - m.records = append(m.records, publishRecord{subject: subject, data: data}) +func (r *recordingEmitter) Emit(_ context.Context, evt model.PushNotificationEvent) error { //nolint:gocritic // hugeParam: must match Emitter interface value semantics + r.mu.Lock() + defer r.mu.Unlock() + r.emitted = append(r.emitted, evt) return nil } -func (m *mockPublisher) getRecords() []publishRecord { - m.mu.Lock() - defer m.mu.Unlock() - cp := make([]publishRecord, len(m.records)) - copy(cp, m.records) - return cp +// accounts flattens every recipient across every emitted batch so existing assertions +// can stay account-oriented even though Emit now receives batched events. +func (r *recordingEmitter) accounts() []string { + r.mu.Lock() + defer r.mu.Unlock() + var out []string + for i := range r.emitted { + out = append(out, r.emitted[i].Accounts...) + } + return out +} + +func newTestHandler(members MemberCache, followers ThreadFollowerLister, presence PresenceSnapshotter, hook Vetoer, emit Emitter) *Handler { + return NewHandler(HandlerDeps{ + Members: members, + Followers: followers, + Presence: presence, + Hook: hook, + Emitter: emit, + LargeRoomThreshold: 500, + }) } -// --- Tests --- +func msgEvent(m *model.Message) []byte { //nolint:gocritic // hugeParam: test helper only; pointer avoids copy + data, _ := json.Marshal(model.MessageEvent{Message: *m, SiteID: "site-a"}) + return data +} -func TestHandleMessage_FanOutSkipsSender(t *testing.T) { - lookup := &stubMemberLookup{ - subs: map[string][]model.Subscription{ - "room-1": { - {ID: "s1", User: model.SubscriptionUser{ID: "alice", Account: "account-alice"}, RoomID: "room-1"}, - {ID: "s2", User: model.SubscriptionUser{ID: "bob", Account: "account-bob"}, RoomID: "room-1"}, - {ID: "s3", User: model.SubscriptionUser{ID: "carol", Account: "account-carol"}, RoomID: "room-1"}, - }, +func TestHandle_SkipsSender(t *testing.T) { + members := &stubMembers{out: map[string][]roomsubcache.Member{ + "r1": { + {ID: "alice", Account: "alice"}, + {ID: "bob", Account: "bob"}, }, - } - pub := &mockPublisher{} - h := NewHandler(lookup, pub) + }} + emit := &recordingEmitter{} + h := newTestHandler(members, &stubFollowers{}, noopPresenceSnapshotter{}, noopVetoer{}, emit) + + require.NoError(t, h.HandleMessage(context.Background(), msgEvent(&model.Message{ + ID: "m1", RoomID: "r1", UserID: "alice", UserAccount: "alice", + CreatedAt: time.Now(), + }))) + assert.Equal(t, []string{"bob"}, emit.accounts()) +} - evt := model.MessageEvent{ - SiteID: "site-a", - Message: model.Message{ - ID: "m1", - RoomID: "room-1", - UserID: "alice", // sender - Content: "hello", +func TestHandle_SkipsMuted(t *testing.T) { + members := &stubMembers{out: map[string][]roomsubcache.Member{ + "r1": { + {ID: "alice", Account: "alice"}, + {ID: "bob", Account: "bob", Muted: true}, + {ID: "carol", Account: "carol"}, }, - } - evtData, err := json.Marshal(evt) - if err != nil { - t.Fatalf("marshal event: %v", err) - } + }} + emit := &recordingEmitter{} + h := newTestHandler(members, &stubFollowers{}, noopPresenceSnapshotter{}, noopVetoer{}, emit) - err = h.HandleMessage(context.Background(), evtData) - if err != nil { - t.Fatalf("HandleMessage: %v", err) - } + require.NoError(t, h.HandleMessage(context.Background(), msgEvent(&model.Message{ + ID: "m1", RoomID: "r1", UserID: "alice", UserAccount: "alice", + CreatedAt: time.Now(), + }))) + assert.ElementsMatch(t, []string{"carol"}, emit.accounts(), "muted bob is skipped") +} + +func TestHandle_SkipsRestrictedBeforeWindow(t *testing.T) { + createdAt := time.Unix(0, 1700000000000*int64(time.Millisecond)) + afterWindow := int64(1700000000001) // joined after message ms + beforeWindow := int64(1699999999999) // joined before message ms + members := &stubMembers{out: map[string][]roomsubcache.Member{ + "r1": { + {ID: "alice", Account: "alice"}, + {ID: "bob", Account: "bob", HistorySharedSince: &afterWindow}, // joined after message → skip + {ID: "carol", Account: "carol", HistorySharedSince: &beforeWindow}, // joined before → include + }, + }} + emit := &recordingEmitter{} + h := newTestHandler(members, &stubFollowers{}, noopPresenceSnapshotter{}, noopVetoer{}, emit) + + require.NoError(t, h.HandleMessage(context.Background(), msgEvent(&model.Message{ + ID: "m1", RoomID: "r1", UserID: "alice", UserAccount: "alice", CreatedAt: createdAt, + }))) + assert.ElementsMatch(t, []string{"carol"}, emit.accounts()) +} - records := pub.getRecords() +func TestHandle_ThreadOnlyReply_SkipsNonFollowerNonMention(t *testing.T) { + parentCreatedAt := time.Unix(0, 1700000000000*int64(time.Millisecond)) + members := &stubMembers{out: map[string][]roomsubcache.Member{ + "r1": { + {ID: "alice", Account: "alice"}, + {ID: "bob", Account: "bob"}, + {ID: "carol", Account: "carol"}, + }, + }} + followers := &stubFollowers{out: map[string]map[string]struct{}{ + "parent-1": {"bob": {}}, + }} + emit := &recordingEmitter{} + h := newTestHandler(members, followers, noopPresenceSnapshotter{}, noopVetoer{}, emit) - // Should notify bob and carol, but NOT alice (the sender) - if len(records) != 2 { - t.Fatalf("expected 2 notifications, got %d", len(records)) + msg := model.Message{ + ID: "m1", RoomID: "r1", UserID: "alice", UserAccount: "alice", CreatedAt: time.Now(), + ThreadParentMessageID: "parent-1", + ThreadParentMessageCreatedAt: &parentCreatedAt, + TShow: false, + Content: "thread reply", } + require.NoError(t, h.HandleMessage(context.Background(), msgEvent(&msg))) + assert.ElementsMatch(t, []string{"bob"}, emit.accounts(), "only thread follower receives") +} - subjects := map[string]bool{} - for _, r := range records { - subjects[r.subject] = true +func TestHandle_ThreadReply_TShow_TreatedAsChannelMessage(t *testing.T) { + members := &stubMembers{out: map[string][]roomsubcache.Member{ + "r1": { + {ID: "alice", Account: "alice"}, + {ID: "bob", Account: "bob"}, + {ID: "carol", Account: "carol"}, + }, + }} + emit := &recordingEmitter{} + h := newTestHandler(members, &stubFollowers{}, noopPresenceSnapshotter{}, noopVetoer{}, emit) - var notif model.NotificationEvent - if err := json.Unmarshal(r.data, ¬if); err != nil { - t.Fatalf("unmarshal notification: %v", err) - } - if notif.Type != "new_message" { - t.Errorf("notification type = %q, want %q", notif.Type, "new_message") - } - if notif.RoomID != "room-1" { - t.Errorf("notification roomID = %q, want %q", notif.RoomID, "room-1") - } - if notif.Message.ID != "m1" { - t.Errorf("notification message ID = %q, want %q", notif.Message.ID, "m1") - } - if notif.Timestamp <= 0 { - t.Errorf("expected Timestamp > 0 on NotificationEvent") - } + msg := model.Message{ + ID: "m1", RoomID: "r1", UserID: "alice", UserAccount: "alice", CreatedAt: time.Now(), + ThreadParentMessageID: "parent-1", + TShow: true, + Content: "shared with channel", } + require.NoError(t, h.HandleMessage(context.Background(), msgEvent(&msg))) + assert.ElementsMatch(t, []string{"bob", "carol"}, emit.accounts()) +} + +func TestHandle_HookVeto_DropsAll(t *testing.T) { + members := &stubMembers{out: map[string][]roomsubcache.Member{ + "r1": { + {ID: "alice", Account: "alice"}, + {ID: "bob", Account: "bob"}, + }, + }} + emit := &recordingEmitter{} + h := newTestHandler(members, &stubFollowers{}, noopPresenceSnapshotter{}, rejectHook{}, emit) + + require.NoError(t, h.HandleMessage(context.Background(), msgEvent(&model.Message{ + ID: "m1", RoomID: "r1", UserID: "alice", UserAccount: "alice", CreatedAt: time.Now(), + }))) + assert.Empty(t, emit.accounts()) +} - if !subjects["chat.user.account-bob.notification"] { - t.Error("missing notification for bob") +func TestHandle_LargeRoomNonMention_DropsAll(t *testing.T) { + roomMembers := make([]roomsubcache.Member, 600) + for i := range roomMembers { + roomMembers[i] = roomsubcache.Member{ID: "u", Account: "u" + string(rune(i))} } - if !subjects["chat.user.account-carol.notification"] { - t.Error("missing notification for carol") + roomMembers[0] = roomsubcache.Member{ID: "alice", Account: "alice"} + members := &stubMembers{out: map[string][]roomsubcache.Member{"r1": roomMembers}} + emit := &recordingEmitter{} + h := NewHandler(HandlerDeps{ + Members: members, + Followers: &stubFollowers{}, + Presence: noopPresenceSnapshotter{}, + Hook: noopVetoer{}, + Emitter: emit, + LargeRoomThreshold: 500, + }) + + require.NoError(t, h.HandleMessage(context.Background(), msgEvent(&model.Message{ + ID: "m1", RoomID: "r1", UserID: "alice", UserAccount: "alice", Content: "no mentions", + CreatedAt: time.Now(), + }))) + assert.Empty(t, emit.accounts(), "large room non-mention drops all") +} + +func TestHandle_LargeRoomMention_OnlyMentionedPushed(t *testing.T) { + roomMembers := []roomsubcache.Member{ + {ID: "alice", Account: "alice"}, + {ID: "bob", Account: "bob"}, + {ID: "carol", Account: "carol"}, } - if subjects["chat.user.account-alice.notification"] { - t.Error("sender alice should NOT receive notification") + for i := 0; i < 600; i++ { + roomMembers = append(roomMembers, roomsubcache.Member{ID: "u" + string(rune(i)), Account: "u" + string(rune(i))}) } + members := &stubMembers{out: map[string][]roomsubcache.Member{"r1": roomMembers}} + emit := &recordingEmitter{} + h := newTestHandler(members, &stubFollowers{}, noopPresenceSnapshotter{}, noopVetoer{}, emit) + + require.NoError(t, h.HandleMessage(context.Background(), msgEvent(&model.Message{ + ID: "m1", RoomID: "r1", UserID: "alice", UserAccount: "alice", + Content: "hey @bob check this", CreatedAt: time.Now(), + }))) + assert.ElementsMatch(t, []string{"bob"}, emit.accounts()) } -func TestHandleMessage_NoMembers(t *testing.T) { - lookup := &stubMemberLookup{ - subs: map[string][]model.Subscription{}, +func TestHandle_LargeRoomAtAll_PushesAllNonSender(t *testing.T) { + roomMembers := []roomsubcache.Member{ + {ID: "alice", Account: "alice"}, + {ID: "bob", Account: "bob"}, + {ID: "carol", Account: "carol"}, + } + for i := 0; i < 500; i++ { + roomMembers = append(roomMembers, roomsubcache.Member{ID: "u", Account: "u" + string(rune(i))}) } - pub := &mockPublisher{} - h := NewHandler(lookup, pub) + members := &stubMembers{out: map[string][]roomsubcache.Member{"r1": roomMembers}} + emit := &recordingEmitter{} + h := newTestHandler(members, &stubFollowers{}, noopPresenceSnapshotter{}, noopVetoer{}, emit) + + require.NoError(t, h.HandleMessage(context.Background(), msgEvent(&model.Message{ + ID: "m1", RoomID: "r1", UserID: "alice", UserAccount: "alice", + Content: "@all heads up", CreatedAt: time.Now(), + }))) + assert.Contains(t, emit.accounts(), "bob") + assert.Contains(t, emit.accounts(), "carol") + assert.NotContains(t, emit.accounts(), "alice") +} + +func TestHandle_PresenceBusyDropsPush(t *testing.T) { + members := &stubMembers{out: map[string][]roomsubcache.Member{ + "r1": { + {ID: "alice", Account: "alice"}, + {ID: "bob", Account: "bob"}, + {ID: "carol", Account: "carol"}, + }, + }} + presence := &stubPresence{out: map[string]model.Presence{ + "bob": {AggregatedStatus: "busy"}, + "carol": {AggregatedStatus: "online"}, + }} + emit := &recordingEmitter{} + h := newTestHandler(members, &stubFollowers{}, presence, noopVetoer{}, emit) + + require.NoError(t, h.HandleMessage(context.Background(), msgEvent(&model.Message{ + ID: "m1", RoomID: "r1", UserID: "alice", UserAccount: "alice", CreatedAt: time.Now(), + }))) + assert.ElementsMatch(t, []string{"carol"}, emit.accounts()) +} + +func TestHandle_TwoMemberChannel_RoutesAsChannel(t *testing.T) { + members := &stubMembers{out: map[string][]roomsubcache.Member{ + "r1": { + {ID: "alice", Account: "alice", RoomType: model.RoomTypeChannel}, + {ID: "bob", Account: "bob", RoomType: model.RoomTypeChannel}, + }, + }} + emit := &recordingEmitter{} + h := newTestHandler(members, &stubFollowers{}, noopPresenceSnapshotter{}, noopVetoer{}, emit) + + require.NoError(t, h.HandleMessage(context.Background(), msgEvent(&model.Message{ + ID: "m1", RoomID: "r1", UserID: "alice", UserAccount: "alice", + Content: "hi", CreatedAt: time.Now(), + }))) + require.Len(t, emit.emitted, 1) + assert.Equal(t, "c", emit.emitted[0].Data.Type) +} + +func TestHandle_PushPayloadSenderFromMemberRecord(t *testing.T) { + members := &stubMembers{out: map[string][]roomsubcache.Member{ + "r1": { + {ID: "alice", Account: "alice", RoomType: model.RoomTypeChannel}, + {ID: "bob", Account: "bob", RoomType: model.RoomTypeChannel}, + }, + }} + emit := &recordingEmitter{} + h := newTestHandler(members, &stubFollowers{}, noopPresenceSnapshotter{}, noopVetoer{}, emit) + + require.NoError(t, h.HandleMessage(context.Background(), msgEvent(&model.Message{ + ID: "m1", RoomID: "r1", UserID: "alice", UserAccount: "alice", + Content: "hello", + CreatedAt: time.Unix(0, 1700000000000*int64(time.Millisecond)), + }))) + require.Len(t, emit.emitted, 1) + got := emit.emitted[0] + assert.Equal(t, "m1-b0", got.ID, "dedup-stable batch ID") + assert.Equal(t, []string{"bob"}, got.Accounts) + assert.Equal(t, "r1", got.RoomID) + require.NotNil(t, got.Data.Sender) + assert.Equal(t, "alice", got.Data.Sender.Account) + assert.Equal(t, "m1", got.Data.MessageID) + assert.NotEmpty(t, got.Data.PushTime) + assert.Greater(t, got.Timestamp, int64(0)) +} + +func TestHandle_InvalidJSON(t *testing.T) { + emit := &recordingEmitter{} + h := newTestHandler(&stubMembers{}, &stubFollowers{}, noopPresenceSnapshotter{}, noopVetoer{}, emit) + err := h.HandleMessage(context.Background(), []byte("not json")) + assert.Error(t, err) +} + +type errHook struct{} + +func (errHook) Allow(context.Context, *model.Message, roomsubcache.Member) (bool, error) { + return false, fmt.Errorf("hook backend unavailable") +} + +func TestHandle_HookError_FailOpen(t *testing.T) { + members := &stubMembers{out: map[string][]roomsubcache.Member{ + "r1": { + {ID: "alice", Account: "alice"}, + {ID: "bob", Account: "bob"}, + }, + }} + emit := &recordingEmitter{} + h := newTestHandler(members, &stubFollowers{}, noopPresenceSnapshotter{}, errHook{}, emit) + + require.NoError(t, h.HandleMessage(context.Background(), msgEvent(&model.Message{ + ID: "m1", RoomID: "r1", UserID: "alice", UserAccount: "alice", CreatedAt: time.Now(), + }))) + assert.ElementsMatch(t, []string{"bob"}, emit.accounts(), "hook error must fail-open") +} - evt := model.MessageEvent{ - SiteID: "site-a", - Message: model.Message{ - ID: "m1", - RoomID: "empty-room", - UserID: "alice", - Content: "hello?", +func TestHandle_ThreadOnlyReply_NilParentCreatedAt_Restricted(t *testing.T) { + members := &stubMembers{out: map[string][]roomsubcache.Member{ + "r1": { + {ID: "alice", Account: "alice"}, + {ID: "bob", Account: "bob"}, }, + }} + followers := &stubFollowers{out: map[string]map[string]struct{}{ + "parent-1": {"bob": {}}, + }} + emit := &recordingEmitter{} + h := newTestHandler(members, followers, noopPresenceSnapshotter{}, noopVetoer{}, emit) + + threshold := int64(1700000000000) + members.out["r1"][1].HistorySharedSince = &threshold + + msg := model.Message{ + ID: "m1", RoomID: "r1", UserID: "alice", UserAccount: "alice", CreatedAt: time.Now(), + ThreadParentMessageID: "parent-1", + ThreadParentMessageCreatedAt: nil, // legacy: no parent ts + TShow: false, } - evtData, _ := json.Marshal(evt) + require.NoError(t, h.HandleMessage(context.Background(), msgEvent(&msg))) + assert.Empty(t, emit.accounts(), "nil parent CreatedAt with HistorySharedSince must restrict bob") +} - err := h.HandleMessage(context.Background(), evtData) - if err != nil { - t.Fatalf("HandleMessage: %v", err) +type errFollowers struct{} + +func (errFollowers) Followers(context.Context, string) (map[string]struct{}, error) { + return nil, fmt.Errorf("mongo timeout") +} + +func TestHandle_ThreadFollowersError_FailOpenEmptySet(t *testing.T) { + members := &stubMembers{out: map[string][]roomsubcache.Member{ + "r1": { + {ID: "alice", Account: "alice"}, + {ID: "bob", Account: "bob"}, + }, + }} + emit := &recordingEmitter{} + h := newTestHandler(members, errFollowers{}, noopPresenceSnapshotter{}, noopVetoer{}, emit) + + msg := &model.Message{ + ID: "m1", RoomID: "r1", UserID: "alice", UserAccount: "alice", + Content: "thread reply", + ThreadParentMessageID: "parent-1", + TShow: false, + CreatedAt: time.Now(), } + require.NoError(t, h.HandleMessage(context.Background(), msgEvent(msg))) + assert.Empty(t, emit.accounts(), "non-mentioned non-followers are dropped when follower lookup fails") +} + +func TestNewHandler_DefaultLargeRoomThreshold(t *testing.T) { + h := NewHandler(HandlerDeps{ + Members: &stubMembers{}, + Followers: &stubFollowers{}, + Presence: noopPresenceSnapshotter{}, + Hook: noopVetoer{}, + Emitter: &recordingEmitter{}, + // LargeRoomThreshold + RecipientBatchSize zero → must default + }) + assert.Equal(t, 500, h.deps.LargeRoomThreshold) + assert.Equal(t, defaultRecipientBatchSize, h.deps.RecipientBatchSize) +} - records := pub.getRecords() - if len(records) != 0 { - t.Errorf("expected 0 notifications for empty room, got %d", len(records)) +// @here is no longer a push trigger (legacy FE doesn't render it). A large-room message +// containing ONLY @here must result in zero pushes — same as a non-mention large-room post. +func TestHandle_AtHere_LargeRoom_DropsAll(t *testing.T) { + roomMembers := []roomsubcache.Member{{ID: "alice", Account: "alice"}} + for i := 0; i < 600; i++ { + roomMembers = append(roomMembers, roomsubcache.Member{ID: "u", Account: "u" + string(rune(i))}) } + members := &stubMembers{out: map[string][]roomsubcache.Member{"r1": roomMembers}} + emit := &recordingEmitter{} + h := newTestHandler(members, &stubFollowers{}, noopPresenceSnapshotter{}, noopVetoer{}, emit) + + require.NoError(t, h.HandleMessage(context.Background(), msgEvent(&model.Message{ + ID: "m1", RoomID: "r1", UserID: "alice", UserAccount: "alice", + Content: "@here heads up", CreatedAt: time.Now(), + }))) + assert.Empty(t, emit.accounts(), "@here in large room must not push to anyone") } -func TestHandleMessage_SoleMember(t *testing.T) { - lookup := &stubMemberLookup{ - subs: map[string][]model.Subscription{ - "room-solo": { - {ID: "s1", User: model.SubscriptionUser{ID: "alice", Account: "account-alice"}, RoomID: "room-solo"}, - }, +// @here in a thread-only reply must NOT bypass the follower check — only followers (and +// explicit @account mentions) should be pushed. +func TestHandle_AtHere_ThreadOnlyReply_DoesNotBypassFollowers(t *testing.T) { + parentCreatedAt := time.Unix(0, 1700000000000*int64(time.Millisecond)) + members := &stubMembers{out: map[string][]roomsubcache.Member{ + "r1": { + {ID: "alice", Account: "alice"}, + {ID: "bob", Account: "bob"}, + {ID: "carol", Account: "carol"}, }, + }} + followers := &stubFollowers{out: map[string]map[string]struct{}{ + "parent-1": {"bob": {}}, + }} + emit := &recordingEmitter{} + h := newTestHandler(members, followers, noopPresenceSnapshotter{}, noopVetoer{}, emit) + + require.NoError(t, h.HandleMessage(context.Background(), msgEvent(&model.Message{ + ID: "m1", RoomID: "r1", UserID: "alice", UserAccount: "alice", CreatedAt: time.Now(), + ThreadParentMessageID: "parent-1", + ThreadParentMessageCreatedAt: &parentCreatedAt, + TShow: false, + Content: "@here in thread", + }))) + assert.ElementsMatch(t, []string{"bob"}, emit.accounts(), + "only the thread follower receives; @here must not promote carol") +} + +func TestHandle_BatchesRecipients(t *testing.T) { + // 250 members + sender → 249 candidates; with batch=100 expect 3 events of 100/100/49. + roomMembers := []roomsubcache.Member{{ID: "alice", Account: "alice"}} + for i := 0; i < 250; i++ { + roomMembers = append(roomMembers, roomsubcache.Member{ID: fmt.Sprintf("u%03d", i), Account: fmt.Sprintf("u%03d", i)}) } - pub := &mockPublisher{} - h := NewHandler(lookup, pub) + members := &stubMembers{out: map[string][]roomsubcache.Member{"r1": roomMembers}} + emit := &recordingEmitter{} + h := NewHandler(HandlerDeps{ + Members: members, + Followers: &stubFollowers{}, + Presence: noopPresenceSnapshotter{}, + Hook: noopVetoer{}, + Emitter: emit, + LargeRoomThreshold: 1000, // keep below threshold so all non-sender candidates remain + RecipientBatchSize: 100, + }) - evt := model.MessageEvent{ - SiteID: "site-a", - Message: model.Message{ - ID: "m1", - RoomID: "room-solo", - UserID: "alice", - Content: "talking to myself", - }, + require.NoError(t, h.HandleMessage(context.Background(), msgEvent(&model.Message{ + ID: "m1", RoomID: "r1", UserID: "alice", UserAccount: "alice", + Content: "hi", CreatedAt: time.Now(), + }))) + + require.Len(t, emit.emitted, 3, "250 recipients → ceil(250/100) = 3 batches") + assert.Len(t, emit.emitted[0].Accounts, 100) + assert.Len(t, emit.emitted[1].Accounts, 100) + assert.Len(t, emit.emitted[2].Accounts, 50) + assert.Equal(t, "m1-b0", emit.emitted[0].ID) + assert.Equal(t, "m1-b1", emit.emitted[1].ID) + assert.Equal(t, "m1-b2", emit.emitted[2].ID) + + // Same body, sender, room-level metadata replicated across batches. + for _, e := range emit.emitted { + assert.Equal(t, "hi", e.Body) + assert.Equal(t, "m1", e.Data.MessageID) + assert.Equal(t, "r1", e.RoomID) } - evtData, _ := json.Marshal(evt) - err := h.HandleMessage(context.Background(), evtData) - if err != nil { - t.Fatalf("HandleMessage: %v", err) + // Survivor union covers every non-sender member; no duplicates across batches. + all := emit.accounts() + assert.Len(t, all, 250) + seen := map[string]bool{} + for _, a := range all { + assert.False(t, seen[a], "account %s emitted in multiple batches", a) + seen[a] = true } +} - records := pub.getRecords() - if len(records) != 0 { - t.Errorf("expected 0 notifications when sender is sole member, got %d", len(records)) +// Sub-batch-size survivor count must still produce exactly one event. +func TestHandle_SingleBatch_WhenSurvivorsBelowBatchSize(t *testing.T) { + members := &stubMembers{out: map[string][]roomsubcache.Member{ + "r1": { + {ID: "alice", Account: "alice"}, + {ID: "bob", Account: "bob"}, + {ID: "carol", Account: "carol"}, + }, + }} + emit := &recordingEmitter{} + h := NewHandler(HandlerDeps{ + Members: members, Followers: &stubFollowers{}, + Presence: noopPresenceSnapshotter{}, Hook: noopVetoer{}, Emitter: emit, + LargeRoomThreshold: 500, RecipientBatchSize: 100, + }) + + require.NoError(t, h.HandleMessage(context.Background(), msgEvent(&model.Message{ + ID: "m1", RoomID: "r1", UserID: "alice", UserAccount: "alice", CreatedAt: time.Now(), + }))) + require.Len(t, emit.emitted, 1) + assert.ElementsMatch(t, []string{"bob", "carol"}, emit.emitted[0].Accounts) + assert.Equal(t, "m1-b0", emit.emitted[0].ID) +} + +// Emit failure must be returned so JetStream redelivers the canonical message. +// Logging-and-continuing would silently drop the push batch — push-stream dedup +// at {messageId}-b{N} protects against duplicates on redelivery. +type failingEmitter struct{ err error } + +func (f failingEmitter) Emit(context.Context, model.PushNotificationEvent) error { + return f.err +} + +func TestHandle_EmitFailure_ReturnsError(t *testing.T) { + members := &stubMembers{out: map[string][]roomsubcache.Member{ + "r1": { + {ID: "alice", Account: "alice"}, + {ID: "bob", Account: "bob"}, + }, + }} + emit := failingEmitter{err: fmt.Errorf("nats: full")} + h := newTestHandler(members, &stubFollowers{}, noopPresenceSnapshotter{}, noopVetoer{}, emit) + + err := h.HandleMessage(context.Background(), msgEvent(&model.Message{ + ID: "m1", RoomID: "r1", UserID: "alice", UserAccount: "alice", CreatedAt: time.Now(), + })) + require.Error(t, err, "emit failure must propagate so JetStream redelivers") + assert.Contains(t, err.Error(), "emit push batches for message m1") +} + +// Title resolution matches the legacy rule: room.Name when present, else sender.Account. +func TestHandle_Title_UsesRoomName(t *testing.T) { + members := &stubMembers{out: map[string][]roomsubcache.Member{ + "r1": { + {ID: "alice", Account: "alice"}, + {ID: "bob", Account: "bob"}, + }, + }} + rooms := &stubRoomMeta{out: map[string]roommetacache.Meta{ + "r1": {ID: "r1", Name: "general", Type: model.RoomTypeChannel}, + }} + emit := &recordingEmitter{} + h := NewHandler(HandlerDeps{ + Members: members, Followers: &stubFollowers{}, Presence: noopPresenceSnapshotter{}, + Hook: noopVetoer{}, Emitter: emit, RoomMeta: rooms, + LargeRoomThreshold: 500, RecipientBatchSize: 100, + }) + + require.NoError(t, h.HandleMessage(context.Background(), msgEvent(&model.Message{ + ID: "m1", RoomID: "r1", UserID: "alice", UserAccount: "alice", CreatedAt: time.Now(), + }))) + require.Len(t, emit.emitted, 1) + assert.Equal(t, "general", emit.emitted[0].Title) +} + +func TestHandle_Title_FallsBackToSenderWhenRoomNameEmpty(t *testing.T) { + members := &stubMembers{out: map[string][]roomsubcache.Member{ + "r1": { + {ID: "alice", Account: "alice", RoomType: model.RoomTypeDM}, + {ID: "bob", Account: "bob", RoomType: model.RoomTypeDM}, + }, + }} + rooms := &stubRoomMeta{out: map[string]roommetacache.Meta{ + "r1": {ID: "r1", Name: "", Type: model.RoomTypeDM}, // DM rooms have no name + }} + emit := &recordingEmitter{} + h := NewHandler(HandlerDeps{ + Members: members, Followers: &stubFollowers{}, Presence: noopPresenceSnapshotter{}, + Hook: noopVetoer{}, Emitter: emit, RoomMeta: rooms, + LargeRoomThreshold: 500, RecipientBatchSize: 100, + }) + + require.NoError(t, h.HandleMessage(context.Background(), msgEvent(&model.Message{ + ID: "m1", RoomID: "r1", UserID: "alice", UserAccount: "alice", CreatedAt: time.Now(), + }))) + require.Len(t, emit.emitted, 1) + assert.Equal(t, "alice", emit.emitted[0].Title, "empty room name → sender account") +} + +func TestHandle_Title_RoomMetaErrorFallsBackToSender(t *testing.T) { + members := &stubMembers{out: map[string][]roomsubcache.Member{ + "r1": { + {ID: "alice", Account: "alice"}, + {ID: "bob", Account: "bob"}, + }, + }} + rooms := &stubRoomMeta{err: errors.New("mongo timeout")} + emit := &recordingEmitter{} + h := NewHandler(HandlerDeps{ + Members: members, Followers: &stubFollowers{}, Presence: noopPresenceSnapshotter{}, + Hook: noopVetoer{}, Emitter: emit, RoomMeta: rooms, + LargeRoomThreshold: 500, RecipientBatchSize: 100, + }) + + require.NoError(t, h.HandleMessage(context.Background(), msgEvent(&model.Message{ + ID: "m1", RoomID: "r1", UserID: "alice", UserAccount: "alice", CreatedAt: time.Now(), + }))) + require.Len(t, emit.emitted, 1) + assert.Equal(t, "alice", emit.emitted[0].Title, "lookup error must not block delivery") +} + +func TestHandle_Title_NilRoomMetaFallsBackToSender(t *testing.T) { + members := &stubMembers{out: map[string][]roomsubcache.Member{ + "r1": { + {ID: "alice", Account: "alice"}, + {ID: "bob", Account: "bob"}, + }, + }} + emit := &recordingEmitter{} + h := newTestHandler(members, &stubFollowers{}, noopPresenceSnapshotter{}, noopVetoer{}, emit) + + require.NoError(t, h.HandleMessage(context.Background(), msgEvent(&model.Message{ + ID: "m1", RoomID: "r1", UserID: "alice", UserAccount: "alice", CreatedAt: time.Now(), + }))) + require.Len(t, emit.emitted, 1) + assert.Equal(t, "alice", emit.emitted[0].Title, "no RoomMeta dep → immediate sender fallback") +} + +// Sender display name comes from the canonical message (gatekeeper composed it). +// Notification-worker just copies it through — no per-message lookup. +func TestHandle_Sender_DisplayNameFromCanonicalMessage(t *testing.T) { + members := &stubMembers{out: map[string][]roomsubcache.Member{ + "r1": { + {ID: "alice", Account: "alice"}, + {ID: "bob", Account: "bob"}, + }, + }} + emit := &recordingEmitter{} + h := newTestHandler(members, &stubFollowers{}, noopPresenceSnapshotter{}, noopVetoer{}, emit) + + require.NoError(t, h.HandleMessage(context.Background(), msgEvent(&model.Message{ + ID: "m1", RoomID: "r1", UserID: "alice", UserAccount: "alice", + UserDisplayName: "Alice Wang 愛麗絲", + CreatedAt: time.Now(), + }))) + require.Len(t, emit.emitted, 1) + s := emit.emitted[0].Data.Sender + require.NotNil(t, s) + assert.Equal(t, "alice", s.Account) + assert.Equal(t, "Alice Wang 愛麗絲", s.DisplayName, "display name comes from canonical message verbatim") +} + +// Backward compatibility: pre-rollout canonical messages without UserDisplayName +// must still produce a valid push event. Fallback is UserAccount. +func TestHandle_Sender_EmptyDisplayNameFallsBackToAccount(t *testing.T) { + members := &stubMembers{out: map[string][]roomsubcache.Member{ + "r1": { + {ID: "alice", Account: "alice"}, + {ID: "bob", Account: "bob"}, + }, + }} + emit := &recordingEmitter{} + h := newTestHandler(members, &stubFollowers{}, noopPresenceSnapshotter{}, noopVetoer{}, emit) + + require.NoError(t, h.HandleMessage(context.Background(), msgEvent(&model.Message{ + ID: "m1", RoomID: "r1", UserID: "alice", UserAccount: "alice", CreatedAt: time.Now(), + // UserDisplayName intentionally empty — legacy in-flight message shape + }))) + require.Len(t, emit.emitted, 1) + s := emit.emitted[0].Data.Sender + require.NotNil(t, s) + assert.Equal(t, "alice", s.Account) + assert.Equal(t, "alice", s.DisplayName, "empty UserDisplayName → fall back to account") +} + +// Sys-message drives invalidation under Option C. Coupling note: works because +// room-worker guards add/remove to channels — relaxing that requires re-keeping the publish. +func TestHandle_InvalidatesCacheOnMemberChangeSysMessage(t *testing.T) { + for _, msgType := range []string{ + model.MessageTypeMembersAdded, + model.MessageTypeMemberLeft, + model.MessageTypeMemberRemoved, + } { + t.Run(msgType, func(t *testing.T) { + members := &stubMembers{out: map[string][]roomsubcache.Member{ + "r1": {{ID: "alice", Account: "alice"}, {ID: "bob", Account: "bob"}}, + }} + emit := &recordingEmitter{} + h := newTestHandler(members, &stubFollowers{}, noopPresenceSnapshotter{}, noopVetoer{}, emit) + + require.NoError(t, h.HandleMessage(context.Background(), msgEvent(&model.Message{ + ID: "m1", RoomID: "r1", UserID: "alice", UserAccount: "alice", + Type: msgType, CreatedAt: time.Now(), + }))) + + require.GreaterOrEqual(t, len(members.calls), 2) + assert.Equal(t, []string{"inval:r1", "get:r1"}, members.calls[:2], "Invalidate must happen before GetMembers to avoid stale read") + }) } } -func TestHandleMessage_InvalidJSON(t *testing.T) { - lookup := &stubMemberLookup{} - pub := &mockPublisher{} - h := NewHandler(lookup, pub) +func TestHandle_DoesNotInvalidateOnRegularMessage(t *testing.T) { + members := &stubMembers{out: map[string][]roomsubcache.Member{ + "r1": {{ID: "alice", Account: "alice"}, {ID: "bob", Account: "bob"}}, + }} + emit := &recordingEmitter{} + h := newTestHandler(members, &stubFollowers{}, noopPresenceSnapshotter{}, noopVetoer{}, emit) - err := h.HandleMessage(context.Background(), []byte("not json")) - if err == nil { - t.Error("expected error for invalid JSON, got nil") + require.NoError(t, h.HandleMessage(context.Background(), msgEvent(&model.Message{ + ID: "m1", RoomID: "r1", UserID: "alice", UserAccount: "alice", + Content: "hello", CreatedAt: time.Now(), + }))) + + for _, c := range members.calls { + assert.NotContains(t, c, "inval:", "regular messages must not invalidate cache") } } diff --git a/notification-worker/hook.go b/notification-worker/hook.go new file mode 100644 index 000000000..f1e736f67 --- /dev/null +++ b/notification-worker/hook.go @@ -0,0 +1,20 @@ +package main + +import ( + "context" + + "github.com/hmchangw/chat/pkg/model" + "github.com/hmchangw/chat/pkg/roomsubcache" +) + +// Vetoer is the in-process suppress-only veto (Stage 2). Allow returns false to drop a recipient. +// Must not issue per-recipient I/O; batch-load data before the loop. Errors are fail-open. +type Vetoer interface { + Allow(ctx context.Context, msg *model.Message, member roomsubcache.Member) (bool, error) +} + +type noopVetoer struct{} + +func (noopVetoer) Allow(context.Context, *model.Message, roomsubcache.Member) (bool, error) { + return true, nil +} diff --git a/notification-worker/hook_test.go b/notification-worker/hook_test.go new file mode 100644 index 000000000..7700f5c16 --- /dev/null +++ b/notification-worker/hook_test.go @@ -0,0 +1,18 @@ +package main + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/hmchangw/chat/pkg/model" + "github.com/hmchangw/chat/pkg/roomsubcache" +) + +func TestNoopHook_AlwaysAllows(t *testing.T) { + h := noopVetoer{} + allow, err := h.Allow(context.Background(), &model.Message{}, roomsubcache.Member{Account: "a"}) + assert.NoError(t, err) + assert.True(t, allow) +} diff --git a/notification-worker/integration_test.go b/notification-worker/integration_test.go index 39ead0bb7..17cf0e796 100644 --- a/notification-worker/integration_test.go +++ b/notification-worker/integration_test.go @@ -9,68 +9,171 @@ import ( "testing" "time" + "github.com/nats-io/nats.go" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "go.mongodb.org/mongo-driver/v2/mongo" "github.com/hmchangw/chat/pkg/model" + "github.com/hmchangw/chat/pkg/natsutil" + "github.com/hmchangw/chat/pkg/roomsubcache" "github.com/hmchangw/chat/pkg/subject" "github.com/hmchangw/chat/pkg/testutil" + "github.com/hmchangw/chat/pkg/valkeyutil" ) -func setupMongo(t *testing.T) *mongo.Database { - return testutil.MongoDB(t, "notification_worker_test") -} +func TestNotificationWorker_CacheBackedFanOut(t *testing.T) { + db := testutil.MongoDB(t, "notification_worker_test") + valkeyClient := testutil.SharedValkeyCluster(t) + t.Cleanup(func() { testutil.FlushValkey(t) }) + natsURL := testutil.NATS(t) -type recordingPublisher struct { - mu sync.Mutex - subjects []string -} + ctx := context.Background() + subCol := db.Collection("subscriptions") + threadRoomCol := db.Collection("thread_rooms") -func (p *recordingPublisher) Publish(_ context.Context, subj string, data []byte) error { - p.mu.Lock() - defer p.mu.Unlock() - p.subjects = append(p.subjects, subj) - return nil -} + seedSubscriptions(t, ctx, subCol) -func TestNotificationWorker_Integration(t *testing.T) { - db := setupMongo(t) - ctx := context.Background() + cache := roomsubcache.NewValkeyCache(valkeyutil.WrapClusterClient(valkeyClient)) + loader := &mongoMemberLoader{col: subCol} + lookup := newCachedMemberLookup(cache, loader.Load, time.Minute) - // Seed subscriptions - db.Collection("subscriptions").InsertMany(ctx, []interface{}{ - model.Subscription{ID: "s1", User: model.SubscriptionUser{ID: "u1", Account: "alice"}, RoomID: "r1"}, - model.Subscription{ID: "s2", User: model.SubscriptionUser{ID: "u2", Account: "bob"}, RoomID: "r1"}, - model.Subscription{ID: "s3", User: model.SubscriptionUser{ID: "u3", Account: "carol"}, RoomID: "r1"}, - }) + nc, err := nats.Connect(natsURL) + require.NoError(t, err) + t.Cleanup(func() { _ = nc.Drain() }) + + pushSub := subscribePush(t, nc, "site-a") - memberLookup := &mongoMemberLookup{col: db.Collection("subscriptions")} - pub := &recordingPublisher{} - handler := NewHandler(memberLookup, pub) + emitter := newMobileEmitter(&directNATSAsyncPub{nc: nc}, "site-a", 0) + handler := NewHandler(HandlerDeps{ + Members: lookup, + Followers: newMongoThreadFollowers(threadRoomCol), + Presence: noopPresenceSnapshotter{}, + Hook: noopVetoer{}, + Emitter: emitter, + LargeRoomThreshold: 500, + }) evt := model.MessageEvent{ + SiteID: "site-a", Message: model.Message{ - ID: "m1", RoomID: "r1", UserID: "u1", Content: "hello", - CreatedAt: time.Now().UTC(), + ID: "m1", + RoomID: "r1", + UserID: "alice", + UserAccount: "alice", + Content: "hello", + CreatedAt: time.Now(), }, } data, _ := json.Marshal(evt) + require.NoError(t, handler.HandleMessage(ctx, data)) - if err := handler.HandleMessage(ctx, data); err != nil { - t.Fatalf("HandleMessage: %v", err) - } + got := pushSub.collect(t, 2*time.Second, 2) + assert.ElementsMatch(t, []string{"bob", "carol"}, got) +} - // Should notify u2 and u3 (not u1 who is the sender) - if len(pub.subjects) != 2 { - t.Fatalf("got %d notifications, want 2: %v", len(pub.subjects), pub.subjects) - } +func seedSubscriptions(t *testing.T, ctx context.Context, col *mongo.Collection) { + t.Helper() + _, err := col.InsertMany(ctx, []any{ + model.Subscription{ID: "s1", RoomID: "r1", User: model.SubscriptionUser{ID: "alice", Account: "alice"}}, + model.Subscription{ID: "s2", RoomID: "r1", User: model.SubscriptionUser{ID: "bob", Account: "bob"}}, + model.Subscription{ID: "s3", RoomID: "r1", User: model.SubscriptionUser{ID: "carol", Account: "carol"}}, + }) + require.NoError(t, err) +} - expected := map[string]bool{ - subject.Notification("bob"): true, - subject.Notification("carol"): true, - } - for _, s := range pub.subjects { - if !expected[s] { - t.Errorf("unexpected publish to %q", s) +type pushCollector struct { + mu sync.Mutex + gotAcct []string + got chan struct{} +} + +func subscribePush(t *testing.T, nc *nats.Conn, siteID string) *pushCollector { + t.Helper() + c := &pushCollector{got: make(chan struct{}, 256)} + sub, err := nc.Subscribe(subject.PushNotification(siteID), func(msg *nats.Msg) { + payload, err := natsutil.DecodePayload(msg) + if err != nil { + t.Logf("decode payload: %v", err) + return + } + var evt model.PushNotificationEvent + if err := json.Unmarshal(payload, &evt); err != nil { + t.Logf("decode push: %v", err) + return + } + c.mu.Lock() + c.gotAcct = append(c.gotAcct, evt.Accounts...) + c.mu.Unlock() + for range evt.Accounts { + c.got <- struct{}{} + } + }) + require.NoError(t, err) + t.Cleanup(func() { _ = sub.Unsubscribe() }) + return c +} + +func (c *pushCollector) collect(t *testing.T, timeout time.Duration, want int) []string { + t.Helper() + deadline := time.After(timeout) + for { + c.mu.Lock() + if len(c.gotAcct) >= want { + out := append([]string(nil), c.gotAcct...) + c.mu.Unlock() + return out + } + c.mu.Unlock() + select { + case <-c.got: + case <-deadline: + c.mu.Lock() + defer c.mu.Unlock() + t.Fatalf("collect timeout: got %v want %d", c.gotAcct, want) + return nil } } } + +// directNATSAsyncPub bypasses JetStream so the test can observe pushes without the PUSH_NOTIFICATIONS stream. +type directNATSAsyncPub struct{ nc *nats.Conn } + +func (d *directNATSAsyncPub) PublishMsg(_ context.Context, msg *nats.Msg) error { + return d.nc.PublishMsg(msg) +} + +func TestMongoThreadFollowers_Followers(t *testing.T) { + db := testutil.MongoDB(t, "notification_worker_test") + ctx := context.Background() + col := db.Collection("thread_rooms") + + _, err := col.InsertMany(ctx, []any{ + model.ThreadRoom{ID: "tr1", ParentMessageID: "parent-1", RoomID: "r1", SiteID: "site-a", ReplyAccounts: []string{"alice", "bob"}}, + model.ThreadRoom{ID: "tr2", ParentMessageID: "parent-2", RoomID: "r1", SiteID: "site-a", ReplyAccounts: []string{"carol"}}, + }) + require.NoError(t, err) + + tf := newMongoThreadFollowers(col) + + t.Run("returns replyAccounts for parent", func(t *testing.T) { + got, err := tf.Followers(ctx, "parent-1") + require.NoError(t, err) + assert.Len(t, got, 2) + assert.Contains(t, got, "alice") + assert.Contains(t, got, "bob") + assert.NotContains(t, got, "carol") + }) + + t.Run("empty parentMessageID returns empty set", func(t *testing.T) { + got, err := tf.Followers(ctx, "") + require.NoError(t, err) + assert.Empty(t, got) + }) + + t.Run("unknown parent returns empty set", func(t *testing.T) { + got, err := tf.Followers(ctx, "no-such-parent") + require.NoError(t, err) + assert.Empty(t, got) + }) +} diff --git a/notification-worker/main.go b/notification-worker/main.go index 752327328..d1cb0928c 100644 --- a/notification-worker/main.go +++ b/notification-worker/main.go @@ -2,6 +2,7 @@ package main import ( "context" + "encoding/json" "fmt" "log/slog" "os" @@ -10,50 +11,100 @@ import ( "github.com/caarlos0/env/v11" "github.com/nats-io/nats.go/jetstream" + "go.mongodb.org/mongo-driver/v2/bson" "go.mongodb.org/mongo-driver/v2/mongo" + "go.mongodb.org/mongo-driver/v2/mongo/options" "github.com/Marz32onE/instrumentation-go/otel-nats/oteljetstream" - "github.com/Marz32onE/instrumentation-go/otel-nats/otelnats" "github.com/hmchangw/chat/pkg/model" "github.com/hmchangw/chat/pkg/mongoutil" "github.com/hmchangw/chat/pkg/natsutil" "github.com/hmchangw/chat/pkg/otelutil" + "github.com/hmchangw/chat/pkg/roommetacache" + "github.com/hmchangw/chat/pkg/roomsubcache" "github.com/hmchangw/chat/pkg/shutdown" "github.com/hmchangw/chat/pkg/stream" + "github.com/hmchangw/chat/pkg/subject" + "github.com/hmchangw/chat/pkg/valkeyutil" ) type config struct { - NatsURL string `env:"NATS_URL" envDefault:"nats://localhost:4222"` - NatsCredsFile string `env:"NATS_CREDS_FILE" envDefault:""` - SiteID string `env:"SITE_ID" envDefault:"default"` - MongoURI string `env:"MONGO_URI" envDefault:"mongodb://localhost:27017"` - MongoDB string `env:"MONGO_DB" envDefault:"chat"` - MongoUsername string `env:"MONGO_USERNAME" envDefault:""` - MongoPassword string `env:"MONGO_PASSWORD" envDefault:""` - MaxWorkers int `env:"MAX_WORKERS" envDefault:"100"` - Consumer stream.ConsumerSettings `envPrefix:"CONSUMER_"` - Bootstrap bootstrapConfig `envPrefix:"BOOTSTRAP_"` + NatsURL string `env:"NATS_URL" envDefault:"nats://localhost:4222"` + NatsCredsFile string `env:"NATS_CREDS_FILE" envDefault:""` + SiteID string `env:"SITE_ID" envDefault:"default"` + MongoURI string `env:"MONGO_URI" envDefault:"mongodb://localhost:27017"` + MongoDB string `env:"MONGO_DB" envDefault:"chat"` + MongoUsername string `env:"MONGO_USERNAME" envDefault:""` + MongoPassword string `env:"MONGO_PASSWORD" envDefault:""` + MaxWorkers int `env:"MAX_WORKERS" envDefault:"100"` + LargeRoomThreshold int `env:"LARGE_ROOM_THRESHOLD" envDefault:"500"` + PushRecipientBatchSize int `env:"PUSH_RECIPIENT_BATCH_SIZE" envDefault:"100"` + RoomMetaCacheSize int `env:"ROOM_META_CACHE_SIZE" envDefault:"10000"` + RoomMetaCacheTTL time.Duration `env:"ROOM_META_CACHE_TTL" envDefault:"2m"` + ValkeyAddrs []string `env:"VALKEY_ADDRS" envSeparator:","` + ValkeyPassword string `env:"VALKEY_PASSWORD" envDefault:""` + RoomSubCacheTTL time.Duration `env:"ROOMSUBCACHE_TTL" envDefault:"5m"` + PresenceBatchSize int `env:"PRESENCE_BATCH_SIZE" envDefault:"512"` + PresenceRPCTimeout time.Duration `env:"PRESENCE_RPC_TIMEOUT" envDefault:"2s"` + PresenceEnabled bool `env:"PRESENCE_RPC_ENABLED" envDefault:"false"` // false → noopPresenceSnapshotter; set true once presence service is available + NatsMaxPayloadBytes int `env:"NATS_MAX_PAYLOAD_BYTES" envDefault:"262144"` // must match broker max_payload; emitter rejects any gzipped batch exceeding this + Consumer stream.ConsumerSettings `envPrefix:"CONSUMER_"` + Bootstrap bootstrapConfig `envPrefix:"BOOTSTRAP_"` } -// mongoMemberLookup implements MemberLookup using MongoDB. -type mongoMemberLookup struct { +type mongoMemberLoader struct { col *mongo.Collection } -func (m *mongoMemberLookup) ListSubscriptions(ctx context.Context, roomID string) ([]model.Subscription, error) { - filter := map[string]string{"roomId": roomID} - cursor, err := m.col.Find(ctx, filter) +func (m *mongoMemberLoader) Load(ctx context.Context, roomID string) ([]roomsubcache.Member, error) { + projection := bson.M{ + "u._id": 1, + "u.account": 1, + "u.isBot": 1, + "roomType": 1, + "muted": 1, + "historySharedSince": 1, + } + cur, err := m.col.Find(ctx, bson.M{"roomId": roomID}, options.Find().SetProjection(projection)) if err != nil { return nil, fmt.Errorf("find subscriptions for room %s: %w", roomID, err) } - defer cursor.Close(ctx) + defer cur.Close(ctx) - var subs []model.Subscription - if err := cursor.All(ctx, &subs); err != nil { - return nil, fmt.Errorf("decode subscriptions for room %s: %w", roomID, err) + var out []roomsubcache.Member + for cur.Next(ctx) { + var doc struct { + User struct { + ID string `bson:"_id"` + Account string `bson:"account"` + IsBot bool `bson:"isBot"` + } `bson:"u"` + RoomType model.RoomType `bson:"roomType"` + Muted bool `bson:"muted"` + HistorySharedSince *time.Time `bson:"historySharedSince"` + } + if err := cur.Decode(&doc); err != nil { + return nil, fmt.Errorf("decode subscription: %w", err) + } + var hssMs *int64 + if doc.HistorySharedSince != nil { + ms := doc.HistorySharedSince.UnixMilli() + hssMs = &ms + } + out = append(out, roomsubcache.Member{ + ID: doc.User.ID, + Account: doc.User.Account, + RoomType: doc.RoomType, + IsBot: doc.User.IsBot, + Muted: doc.Muted, + HistorySharedSince: hssMs, + }) } - return subs, nil + if err := cur.Err(); err != nil { + return nil, fmt.Errorf("iterate subscriptions: %w", err) + } + return out, nil } func main() { @@ -64,6 +115,10 @@ func main() { slog.Error("parse config", "error", err) os.Exit(1) } + if len(cfg.ValkeyAddrs) == 0 { + slog.Error("VALKEY_ADDRS required") + os.Exit(1) + } ctx := context.Background() @@ -78,8 +133,28 @@ func main() { slog.Error("mongo connect failed", "error", err) os.Exit(1) } - subCol := mongoClient.Database(cfg.MongoDB).Collection("subscriptions") - memberLookup := &mongoMemberLookup{col: subCol} + db := mongoClient.Database(cfg.MongoDB) + subCol := db.Collection("subscriptions") + threadRoomCol := db.Collection("thread_rooms") + roomsCol := db.Collection("rooms") + + roomMetaCache, err := roommetacache.New(cfg.RoomMetaCacheSize, cfg.RoomMetaCacheTTL, + func(ctx context.Context, roomID string) (roommetacache.Meta, error) { + return roommetacache.FetchFromMongo(ctx, roomsCol, roomID) + }) + if err != nil { + slog.Error("init room-meta cache failed", "error", err) + os.Exit(1) + } + + valkeyClient, err := valkeyutil.ConnectCluster(ctx, cfg.ValkeyAddrs, cfg.ValkeyPassword) + if err != nil { + slog.Error("valkey connect failed", "error", err) + os.Exit(1) + } + cache := roomsubcache.NewValkeyCache(valkeyClient) + loader := &mongoMemberLoader{col: subCol} + memberLookup := newCachedMemberLookup(cache, loader.Load, cfg.RoomSubCacheTTL) nc, err := natsutil.Connect(cfg.NatsURL, cfg.NatsCredsFile) if err != nil { @@ -87,27 +162,100 @@ func main() { os.Exit(1) } - js, err := oteljetstream.New(nc) + otelJS, err := oteljetstream.New(nc) if err != nil { slog.Error("jetstream init failed", "error", err) os.Exit(1) } - if err := bootstrapStreams(ctx, js, cfg.SiteID, cfg.Bootstrap.Enabled); err != nil { + if err := bootstrapStreams(ctx, otelJS, cfg.SiteID, cfg.Bootstrap.Enabled); err != nil { slog.Error("bootstrap streams failed", "error", err) os.Exit(1) } canonicalCfg := stream.MessagesCanonical(cfg.SiteID) - - cons, err := js.CreateOrUpdateConsumer(ctx, canonicalCfg.Name, buildConsumerConfig(cfg.Consumer)) + cons, err := otelJS.CreateOrUpdateConsumer(ctx, canonicalCfg.Name, buildConsumerConfig(cfg.Consumer)) if err != nil { slog.Error("create consumer failed", "error", err) os.Exit(1) } - publisher := &natsPublisher{nc: nc} - handler := NewHandler(memberLookup, publisher) + emitter := newMobileEmitter(&jsPublisher{js: otelJS}, cfg.SiteID, cfg.NatsMaxPayloadBytes) + + var presence PresenceSnapshotter = noopPresenceSnapshotter{} + if cfg.PresenceEnabled { + presence = newBulkPresenceSource( + &natsPresenceRequester{nc: nc.NatsConn()}, + cfg.SiteID, + cfg.PresenceBatchSize, + cfg.PresenceRPCTimeout, + ) + } + + handler := NewHandler(HandlerDeps{ + Members: memberLookup, + Followers: newMongoThreadFollowers(threadRoomCol), + Presence: presence, + Hook: noopVetoer{}, + Emitter: emitter, + RoomMeta: roomMetaCache, + LargeRoomThreshold: cfg.LargeRoomThreshold, + RecipientBatchSize: cfg.PushRecipientBatchSize, + }) + + // Bounded worker drains the channel so slow Valkey doesn't block NATS dispatch; + // drops are safe because TTLs reconcile staleness. + invalCtx, invalCancel := context.WithCancel(ctx) + invalCh := make(chan string, 256) + var invalWG sync.WaitGroup + invalWG.Add(1) + go func() { + defer invalWG.Done() + for roomID := range invalCh { + memberLookup.Invalidate(invalCtx, roomID) + } + }() + + // Mute is the only canonical member event still on this stream; add/remove invalidation rides on MESSAGES_CANONICAL sys-messages. + // DeliverNewPolicy: skip history on restart; roomsubcache TTL reconciles any boundary staleness. + roomsCfg := stream.Rooms(cfg.SiteID) + invalCons, err := otelJS.CreateOrUpdateConsumer(ctx, roomsCfg.Name, jetstream.ConsumerConfig{ + Durable: "notification-worker-room-event-invalidate", + FilterSubject: subject.RoomCanonicalMemberEvent(cfg.SiteID, model.CanonicalMemberEventMuted), + AckPolicy: jetstream.AckExplicitPolicy, + DeliverPolicy: jetstream.DeliverNewPolicy, + }) + if err != nil { + slog.Error("create canonical member event consumer failed", "error", err) + os.Exit(1) + } + invalIter, err := invalCons.Messages(jetstream.PullMaxMessages(64)) + if err != nil { + slog.Error("canonical member event iterator failed", "error", err) + os.Exit(1) + } + go func() { + for { + _, msg, err := invalIter.Next() + if err != nil { + return + } + var evt model.CanonicalMemberEvent + if err := json.Unmarshal(msg.Data(), &evt); err != nil { + slog.Warn("canonical member event decode failed", "error", err) + _ = msg.Ack() + continue + } + if evt.RoomID != "" { + select { + case invalCh <- evt.RoomID: + default: + slog.Warn("invalidation queue full, dropping (TTL will reconcile)", "roomId", evt.RoomID) + } + } + _ = msg.Ack() + } + }() iter, err := cons.Messages(jetstream.PullMaxMessages(2 * cfg.MaxWorkers)) if err != nil { @@ -146,10 +294,16 @@ func main() { } }() - slog.Info("notification-worker started", "site", cfg.SiteID) + slog.Info("notification-worker started", + "site", cfg.SiteID, + "large_room_threshold", cfg.LargeRoomThreshold, + "push_recipient_batch_size", cfg.PushRecipientBatchSize, + "valkey_addrs", cfg.ValkeyAddrs, + "presence_enabled", cfg.PresenceEnabled, + ) shutdown.Wait(ctx, 25*time.Second, - func(ctx context.Context) error { + func(_ context.Context) error { iter.Stop() return nil }, @@ -163,26 +317,31 @@ func main() { return fmt.Errorf("worker drain timed out: %w", ctx.Err()) } }, + func(_ context.Context) error { + invalIter.Stop() + return nil + }, + func(stepCtx context.Context) error { + close(invalCh) // stop accepting work; worker drains the buffer + done := make(chan struct{}) + go func() { invalWG.Wait(); close(done) }() + select { + case <-done: + case <-stepCtx.Done(): + invalCancel() // unblock an in-flight Valkey DEL so the worker exits + <-done + } + invalCancel() // always release the context (idempotent) + return nil + }, func(ctx context.Context) error { return tracerShutdown(ctx) }, - func(ctx context.Context) error { return nc.Drain() }, + func(_ context.Context) error { return nc.Drain() }, func(ctx context.Context) error { mongoutil.Disconnect(ctx, mongoClient); return nil }, + func(_ context.Context) error { valkeyutil.Disconnect(valkeyClient); return nil }, ) } -// natsPublisher adapts *otelnats.Conn to the Publisher interface. -type natsPublisher struct { - nc *otelnats.Conn -} - -func (p *natsPublisher) Publish(ctx context.Context, subject string, data []byte) error { - if err := p.nc.PublishMsg(ctx, natsutil.NewMsg(ctx, subject, data)); err != nil { - return fmt.Errorf("publish to %q: %w", subject, err) - } - return nil -} - -// buildConsumerConfig returns the durable consumer config for -// notification-worker. Centralized so it is unit-testable without NATS. +// buildConsumerConfig returns the durable consumer config for notification-worker. func buildConsumerConfig(s stream.ConsumerSettings) jetstream.ConsumerConfig { cc := stream.DurableConsumerDefaults(s) cc.Durable = "notification-worker" diff --git a/notification-worker/members.go b/notification-worker/members.go new file mode 100644 index 000000000..9cdecdb85 --- /dev/null +++ b/notification-worker/members.go @@ -0,0 +1,71 @@ +package main + +import ( + "context" + "errors" + "fmt" + "log/slog" + "time" + + "golang.org/x/sync/singleflight" + + "github.com/hmchangw/chat/pkg/roomsubcache" + "github.com/hmchangw/chat/pkg/valkeyutil" +) + +// memberLoader reads the canonical member list for a room; a function type so tests can inject a fake. +type memberLoader func(ctx context.Context, roomID string) ([]roomsubcache.Member, error) + +// cachedMemberLookup resolves members via Valkey → Mongo. Single-flight collapses +// concurrent in-pod misses on the same room to one Valkey GET (and one Mongo +// query on a cold miss). No in-process tier — keeps per-pod memory bounded +// against rooms with thousands of members. +type cachedMemberLookup struct { + cache roomsubcache.Cache + load memberLoader + ttl time.Duration + sf singleflight.Group +} + +func newCachedMemberLookup(cache roomsubcache.Cache, load memberLoader, ttl time.Duration) *cachedMemberLookup { + return &cachedMemberLookup{cache: cache, load: load, ttl: ttl} +} + +// GetMembers returns the member list, populating Valkey on a Mongo round-trip. +// Callers must not mutate the slice. +func (c *cachedMemberLookup) GetMembers(ctx context.Context, roomID string) ([]roomsubcache.Member, error) { + // Fast path: cache hits skip singleflight to avoid serializing concurrent + // readers behind one in-flight caller. + if got, err := c.cache.Get(ctx, roomID); err == nil { + return got, nil + } else if !errors.Is(err, valkeyutil.ErrCacheMiss) { + slog.Warn("roomsubcache get failed, falling back to mongo", "error", err, "roomId", roomID) + } + + // Miss path: singleflight collapses concurrent Mongo loads on the same room. + members, err, _ := c.sf.Do(roomID, func() (any, error) { + // Re-check inside the flight in case a sibling caller already populated. + if got, err := c.cache.Get(ctx, roomID); err == nil { + return got, nil + } + loaded, lerr := c.load(ctx, roomID) + if lerr != nil { + return nil, fmt.Errorf("load members for room %s: %w", roomID, lerr) + } + if setErr := c.cache.Set(ctx, roomID, loaded, c.ttl); setErr != nil { + slog.Warn("roomsubcache set failed", "error", setErr, "roomId", roomID) + } + return loaded, nil + }) + if err != nil { + return nil, fmt.Errorf("get members for room %s: %w", roomID, err) + } + return members.([]roomsubcache.Member), nil +} + +// Invalidate drops the room from Valkey on membership change. +func (c *cachedMemberLookup) Invalidate(ctx context.Context, roomID string) { + if err := c.cache.Invalidate(ctx, roomID); err != nil { + slog.Warn("roomsubcache invalidate failed", "error", err, "roomId", roomID) + } +} diff --git a/notification-worker/members_test.go b/notification-worker/members_test.go new file mode 100644 index 000000000..72f6c44ff --- /dev/null +++ b/notification-worker/members_test.go @@ -0,0 +1,145 @@ +package main + +import ( + "context" + "errors" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/hmchangw/chat/pkg/roomsubcache" + "github.com/hmchangw/chat/pkg/valkeyutil" +) + +type fakeCache struct { + mu sync.Mutex + data map[string][]roomsubcache.Member +} + +func newFakeCache() *fakeCache { return &fakeCache{data: map[string][]roomsubcache.Member{}} } + +func (f *fakeCache) Get(_ context.Context, roomID string) ([]roomsubcache.Member, error) { + f.mu.Lock() + defer f.mu.Unlock() + v, ok := f.data[roomID] + if !ok { + return nil, valkeyutil.ErrCacheMiss + } + return v, nil +} +func (f *fakeCache) Set(_ context.Context, roomID string, members []roomsubcache.Member, _ time.Duration) error { + f.mu.Lock() + defer f.mu.Unlock() + cp := make([]roomsubcache.Member, len(members)) + copy(cp, members) + f.data[roomID] = cp + return nil +} +func (f *fakeCache) Invalidate(_ context.Context, roomID string) error { + f.mu.Lock() + defer f.mu.Unlock() + delete(f.data, roomID) + return nil +} + +type fakeLoader struct { + calls atomic.Int32 + out []roomsubcache.Member + err error + delay time.Duration +} + +func (f *fakeLoader) Load(_ context.Context, _ string) ([]roomsubcache.Member, error) { + f.calls.Add(1) + if f.delay > 0 { + time.Sleep(f.delay) + } + return f.out, f.err +} + +func TestCachedMemberLookup_HitFromValkey(t *testing.T) { + cache := newFakeCache() + loader := &fakeLoader{out: []roomsubcache.Member{{ID: "u1", Account: "alice"}}} + require.NoError(t, cache.Set(context.Background(), "r1", loader.out, time.Minute)) + + lookup := newCachedMemberLookup(cache, loader.Load, time.Minute) + got, err := lookup.GetMembers(context.Background(), "r1") + require.NoError(t, err) + assert.Equal(t, loader.out, got) + assert.Equal(t, int32(0), loader.calls.Load()) +} + +func TestCachedMemberLookup_MissThenPopulate(t *testing.T) { + cache := newFakeCache() + loader := &fakeLoader{out: []roomsubcache.Member{{ID: "u1", Account: "alice"}}} + lookup := newCachedMemberLookup(cache, loader.Load, time.Minute) + + got, err := lookup.GetMembers(context.Background(), "r1") + require.NoError(t, err) + assert.Equal(t, loader.out, got) + + _, err = lookup.GetMembers(context.Background(), "r1") + require.NoError(t, err) + assert.Equal(t, int32(1), loader.calls.Load()) +} + +func TestCachedMemberLookup_CacheErrorFallsThrough(t *testing.T) { + cache := &erroringCache{err: errors.New("valkey down")} + loader := &fakeLoader{out: []roomsubcache.Member{{ID: "u1", Account: "alice"}}} + lookup := newCachedMemberLookup(cache, loader.Load, time.Minute) + + got, err := lookup.GetMembers(context.Background(), "r1") + require.NoError(t, err) + assert.Equal(t, loader.out, got) + assert.Equal(t, int32(1), loader.calls.Load()) +} + +func TestCachedMemberLookup_SingleFlightCollapsesMisses(t *testing.T) { + cache := newFakeCache() + loader := &fakeLoader{ + out: []roomsubcache.Member{{ID: "u1", Account: "alice"}}, + delay: 50 * time.Millisecond, + } + lookup := newCachedMemberLookup(cache, loader.Load, time.Minute) + + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + _, err := lookup.GetMembers(context.Background(), "r1") + assert.NoError(t, err) + }() + } + wg.Wait() + assert.Equal(t, int32(1), loader.calls.Load(), "single-flight collapses concurrent misses") +} + +func TestCachedMemberLookup_InvalidateClearsValkey(t *testing.T) { + cache := newFakeCache() + loader := &fakeLoader{out: []roomsubcache.Member{{ID: "u1", Account: "alice"}}} + lookup := newCachedMemberLookup(cache, loader.Load, time.Minute) + + _, err := lookup.GetMembers(context.Background(), "r1") + require.NoError(t, err) + lookup.Invalidate(context.Background(), "r1") + loader.out = []roomsubcache.Member{{ID: "u2", Account: "bob"}} + got, err := lookup.GetMembers(context.Background(), "r1") + require.NoError(t, err) + + assert.Equal(t, loader.out, got, "after Invalidate the next read must reload") +} + +type erroringCache struct{ err error } + +func (e *erroringCache) Get(context.Context, string) ([]roomsubcache.Member, error) { + return nil, e.err +} +func (e *erroringCache) Set(context.Context, string, []roomsubcache.Member, time.Duration) error { + return nil +} +func (e *erroringCache) Invalidate(context.Context, string) error { return nil } diff --git a/notification-worker/presence.go b/notification-worker/presence.go new file mode 100644 index 000000000..4f90f19c0 --- /dev/null +++ b/notification-worker/presence.go @@ -0,0 +1,142 @@ +package main + +import ( + "context" + "encoding/json" + "fmt" + "log/slog" + "sync" + "time" + + "github.com/nats-io/nats.go" + + "github.com/hmchangw/chat/pkg/errcode" + "github.com/hmchangw/chat/pkg/model" + "github.com/hmchangw/chat/pkg/subject" +) + +// PresenceSnapshotter batches presence lookups for push-eligible accounts (Stage 4). +// Errors are swallowed; an absent account defaults to push. +type PresenceSnapshotter interface { + Snapshot(ctx context.Context, accounts []string) (map[string]model.Presence, error) +} + +// noopPresenceSnapshotter returns an empty map so all push-eligible recipients receive a push. +type noopPresenceSnapshotter struct{} + +func (noopPresenceSnapshotter) Snapshot(context.Context, []string) (map[string]model.Presence, error) { + return map[string]model.Presence{}, nil +} + +// presenceRequester is the narrow NATS surface bulkPresenceSource uses, injectable by tests. +type presenceRequester interface { + Request(ctx context.Context, subj string, data []byte, timeout time.Duration) (*nats.Msg, error) +} + +type bulkPresenceSource struct { + req presenceRequester + siteID string + batchSize int + timeout time.Duration +} + +func newBulkPresenceSource(req presenceRequester, siteID string, batchSize int, timeout time.Duration) *bulkPresenceSource { + if batchSize <= 0 { + batchSize = 512 + } + if timeout <= 0 { + timeout = 2 * time.Second + } + return &bulkPresenceSource{req: req, siteID: siteID, batchSize: batchSize, timeout: timeout} +} + +func (b *bulkPresenceSource) Snapshot(ctx context.Context, accounts []string) (map[string]model.Presence, error) { + if len(accounts) == 0 { + return map[string]model.Presence{}, nil + } + subj := subject.PresenceSnapshot(b.siteID) + chunks := chunkStrings(accounts, b.batchSize) + + var ( + mu sync.Mutex + out = make(map[string]model.Presence, len(accounts)) + wg sync.WaitGroup + ) + for _, ch := range chunks { + ch := ch + wg.Add(1) + go func() { + defer wg.Done() + data, err := json.Marshal(model.PresenceSnapshotRequest{Accounts: ch}) + if err != nil { + slog.Warn("presence marshal failed", "error", err) + return + } + msg, err := b.req.Request(ctx, subj, data, b.timeout) + if err != nil { + slog.Warn("presence rpc failed", "error", err, "chunk", len(ch)) + return + } + if errResp, ok := errcode.Parse(msg.Data); ok { + slog.Warn("presence rpc returned error response", + "error", errResp.Message, + "code", errResp.Code, + "chunk", len(ch)) + return + } + var reply model.PresenceSnapshotReply + if err := json.Unmarshal(msg.Data, &reply); err != nil { + slog.Warn("presence unmarshal failed", "error", err) + return + } + mu.Lock() + for k, v := range reply.Presences { + out[k] = v + } + mu.Unlock() + }() + } + wg.Wait() + return out, nil +} + +func chunkStrings(in []string, size int) [][]string { + if size <= 0 || len(in) <= size { + return [][]string{in} + } + out := make([][]string, 0, (len(in)+size-1)/size) + for i := 0; i < len(in); i += size { + end := i + size + if end > len(in) { + end = len(in) + } + out = append(out, in[i:end]) + } + return out +} + +// shouldPush returns true unless the account is explicitly DND; fail-open on missing/unknown status. +func shouldPush(p model.Presence) bool { + switch p.AggregatedStatus { + case "busy", "in-call": + return false + default: + return true + } +} + +type natsPresenceRequester struct { + nc *nats.Conn +} + +var _ presenceRequester = (*natsPresenceRequester)(nil) + +func (n *natsPresenceRequester) Request(ctx context.Context, subj string, data []byte, timeout time.Duration) (*nats.Msg, error) { + rctx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + msg, err := n.nc.RequestWithContext(rctx, subj, data) + if err != nil { + return nil, fmt.Errorf("presence request: %w", err) + } + return msg, nil +} diff --git a/notification-worker/presence_test.go b/notification-worker/presence_test.go new file mode 100644 index 000000000..cba883f2a --- /dev/null +++ b/notification-worker/presence_test.go @@ -0,0 +1,135 @@ +package main + +import ( + "context" + "encoding/json" + "errors" + "sync" + "testing" + "time" + + "github.com/nats-io/nats.go" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/hmchangw/chat/pkg/errcode" + "github.com/hmchangw/chat/pkg/errcode/errnats" + "github.com/hmchangw/chat/pkg/model" +) + +func TestNoopPresence_EmptySnapshot(t *testing.T) { + p := noopPresenceSnapshotter{} + snap, err := p.Snapshot(context.Background(), []string{"alice", "bob"}) + require.NoError(t, err) + assert.Empty(t, snap) +} + +func TestShouldPush(t *testing.T) { + tests := []struct { + status string + want bool + }{ + {"online", true}, + {"offline", true}, + {"away", true}, + {"busy", false}, + {"in-call", false}, + {"", true}, // missing → fail-open + {"unknown", true}, // unknown → fail-open + } + for _, tt := range tests { + t.Run(tt.status, func(t *testing.T) { + assert.Equal(t, tt.want, shouldPush(model.Presence{AggregatedStatus: tt.status})) + }) + } +} + +type stubRequester struct { + mu sync.Mutex + calls int + gotReqs []model.PresenceSnapshotRequest + reply func(req model.PresenceSnapshotRequest) (model.PresenceSnapshotReply, error) + rawReply func(req model.PresenceSnapshotRequest) ([]byte, error) // when set, bypasses reply and returns raw bytes +} + +func (s *stubRequester) Request(_ context.Context, _ string, data []byte, _ time.Duration) (*nats.Msg, error) { + s.mu.Lock() + s.calls++ + s.mu.Unlock() + var req model.PresenceSnapshotRequest + if err := json.Unmarshal(data, &req); err != nil { + return nil, err + } + s.mu.Lock() + s.gotReqs = append(s.gotReqs, req) + s.mu.Unlock() + if s.rawReply != nil { + out, err := s.rawReply(req) + if err != nil { + return nil, err + } + return &nats.Msg{Data: out}, nil + } + reply, err := s.reply(req) + if err != nil { + return nil, err + } + out, err := json.Marshal(reply) + if err != nil { + return nil, err + } + return &nats.Msg{Data: out}, nil +} + +func TestBulkPresence_Chunks(t *testing.T) { + accounts := make([]string, 1500) + for i := range accounts { + accounts[i] = "u" + } + for i := range accounts { + accounts[i] = string(rune('a'+i%26)) + "-" + string(rune('a'+i/26%26)) + } + stub := &stubRequester{reply: func(req model.PresenceSnapshotRequest) (model.PresenceSnapshotReply, error) { + out := model.PresenceSnapshotReply{Presences: map[string]model.Presence{}} + for _, a := range req.Accounts { + out.Presences[a] = model.Presence{AggregatedStatus: "online"} + } + return out, nil + }} + + src := newBulkPresenceSource(stub, "site-a", 500, time.Second) + got, err := src.Snapshot(context.Background(), accounts) + require.NoError(t, err) + assert.Equal(t, 3, stub.calls, "expect ceil(1500/500) chunks") + assert.Len(t, got, len(uniqueStrings(accounts))) +} + +func TestBulkPresence_FailOpenOnError(t *testing.T) { + stub := &stubRequester{reply: func(model.PresenceSnapshotRequest) (model.PresenceSnapshotReply, error) { + return model.PresenceSnapshotReply{}, errors.New("nats: timeout") + }} + src := newBulkPresenceSource(stub, "site-a", 100, 50*time.Millisecond) + got, err := src.Snapshot(context.Background(), []string{"a", "b"}) + require.NoError(t, err) + assert.Empty(t, got) +} + +func TestBulkPresence_ErrorResponseLoggedAndFailOpen(t *testing.T) { + stub := &stubRequester{ + rawReply: func(_ model.PresenceSnapshotRequest) ([]byte, error) { + return errnats.MarshalQuiet(errcode.Internal("presence backend down")), nil + }, + } + src := newBulkPresenceSource(stub, "site-a", 100, 50*time.Millisecond) + got, err := src.Snapshot(context.Background(), []string{"alice", "bob"}) + require.NoError(t, err) // fail-open: error envelope is swallowed + assert.Empty(t, got) +} + +func uniqueStrings(in []string) map[string]struct{} { + out := map[string]struct{}{} + for _, s := range in { + out[s] = struct{}{} + } + return out +} diff --git a/notification-worker/routing.go b/notification-worker/routing.go new file mode 100644 index 000000000..8a6bce259 --- /dev/null +++ b/notification-worker/routing.go @@ -0,0 +1,25 @@ +package main + +import ( + "github.com/hmchangw/chat/pkg/model" + "github.com/hmchangw/chat/pkg/roomsubcache" +) + +// EligibleForPush is Stage 3 of the fan-out pipeline (pure CPU, no I/O). +// Bots are always excluded; DMs and @mentions bypass the large-room throttle. +func EligibleForPush(m *roomsubcache.Member, roomType model.RoomType, isLargeRoom, mentioned bool) bool { + if m.IsBot { + return false + } + if isDirect(roomType) { + return true + } + if mentioned { + return true + } + return !isLargeRoom +} + +func isDirect(t model.RoomType) bool { + return t == model.RoomTypeDM || t == model.RoomTypeBotDM +} diff --git a/notification-worker/routing_test.go b/notification-worker/routing_test.go new file mode 100644 index 000000000..bb5591e03 --- /dev/null +++ b/notification-worker/routing_test.go @@ -0,0 +1,37 @@ +package main + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/hmchangw/chat/pkg/model" + "github.com/hmchangw/chat/pkg/roomsubcache" +) + +func TestEligibleForPush(t *testing.T) { + tests := []struct { + name string + member roomsubcache.Member + roomType model.RoomType + isLarge bool + mentioned bool + want bool + }{ + {name: "dm always", member: roomsubcache.Member{Account: "a"}, roomType: model.RoomTypeDM, want: true}, + {name: "botdm always", member: roomsubcache.Member{Account: "a"}, roomType: model.RoomTypeBotDM, want: true}, + {name: "small channel non-mention", member: roomsubcache.Member{Account: "a"}, roomType: model.RoomTypeChannel, isLarge: false, mentioned: false, want: true}, + {name: "small channel mention", member: roomsubcache.Member{Account: "a"}, roomType: model.RoomTypeChannel, isLarge: false, mentioned: true, want: true}, + {name: "large channel non-mention dropped", member: roomsubcache.Member{Account: "a"}, roomType: model.RoomTypeChannel, isLarge: true, mentioned: false, want: false}, + {name: "large channel mention pushed", member: roomsubcache.Member{Account: "a"}, roomType: model.RoomTypeChannel, isLarge: true, mentioned: true, want: true}, + {name: "bot never", member: roomsubcache.Member{Account: "bot", IsBot: true}, roomType: model.RoomTypeDM, want: false}, + {name: "bot in mention dropped", member: roomsubcache.Member{Account: "bot", IsBot: true}, roomType: model.RoomTypeChannel, mentioned: true, want: false}, + {name: "discussion small non-mention", member: roomsubcache.Member{Account: "a"}, roomType: model.RoomTypeDiscussion, want: true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := EligibleForPush(&tt.member, tt.roomType, tt.isLarge, tt.mentioned) + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/notification-worker/threads.go b/notification-worker/threads.go new file mode 100644 index 000000000..62166b51f --- /dev/null +++ b/notification-worker/threads.go @@ -0,0 +1,50 @@ +package main + +import ( + "context" + "errors" + "fmt" + + "go.mongodb.org/mongo-driver/v2/bson" + "go.mongodb.org/mongo-driver/v2/mongo" + "go.mongodb.org/mongo-driver/v2/mongo/options" +) + +// ThreadFollowerLister returns the set of accounts following the thread rooted at parentMessageID. +// Backed by thread_rooms.replyAccounts (every replier + parent author seeded at creation) — matches +// the legacy notification rule. +type ThreadFollowerLister interface { + Followers(ctx context.Context, parentMessageID string) (map[string]struct{}, error) +} + +type mongoThreadFollowers struct { + col *mongo.Collection +} + +func newMongoThreadFollowers(col *mongo.Collection) *mongoThreadFollowers { + return &mongoThreadFollowers{col: col} +} + +func (m *mongoThreadFollowers) Followers(ctx context.Context, parentMessageID string) (map[string]struct{}, error) { + if parentMessageID == "" { + return map[string]struct{}{}, nil + } + var doc struct { + ReplyAccounts []string `bson:"replyAccounts"` + } + opts := options.FindOne().SetProjection(bson.M{"replyAccounts": 1, "_id": 0}) + err := m.col.FindOne(ctx, bson.M{"parentMessageId": parentMessageID}, opts).Decode(&doc) + if err != nil { + if errors.Is(err, mongo.ErrNoDocuments) { + return map[string]struct{}{}, nil + } + return nil, fmt.Errorf("find thread room by parent %s: %w", parentMessageID, err) + } + out := make(map[string]struct{}, len(doc.ReplyAccounts)) + for _, a := range doc.ReplyAccounts { + if a != "" { + out[a] = struct{}{} + } + } + return out, nil +} diff --git a/notification-worker/threads_test.go b/notification-worker/threads_test.go new file mode 100644 index 000000000..68fe11096 --- /dev/null +++ b/notification-worker/threads_test.go @@ -0,0 +1,41 @@ +package main + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type stubThreadLookup struct { + out []string + err error +} + +func (s *stubThreadLookup) Followers(_ context.Context, _ string) (map[string]struct{}, error) { + if s.err != nil { + return nil, s.err + } + set := make(map[string]struct{}, len(s.out)) + for _, a := range s.out { + set[a] = struct{}{} + } + return set, nil +} + +func TestThreadFollowers_Resolve(t *testing.T) { + s := &stubThreadLookup{out: []string{"alice", "bob"}} + got, err := s.Followers(context.Background(), "parent-1") + require.NoError(t, err) + assert.Contains(t, got, "alice") + assert.Contains(t, got, "bob") + assert.NotContains(t, got, "carol") +} + +func TestThreadFollowers_PropagatesError(t *testing.T) { + s := &stubThreadLookup{err: errors.New("mongo down")} + _, err := s.Followers(context.Background(), "parent-1") + assert.Error(t, err) +} diff --git a/pkg/mention/mention.go b/pkg/mention/mention.go index 3370e7fd1..db34143c0 100644 --- a/pkg/mention/mention.go +++ b/pkg/mention/mention.go @@ -20,8 +20,7 @@ type ParseResult struct { MentionAll bool // true if @all was mentioned (case-insensitive) } -// Parse extracts @mention tokens from content and returns the unique -// mentioned accounts along with whether @all was present. +// Parse extracts @mention tokens from content, returning unique accounts and whether @all appears. func Parse(content string) ParseResult { matches := mentionRe.FindAllStringSubmatch(content, -1) if len(matches) == 0 { @@ -52,7 +51,7 @@ type LookupFunc func(ctx context.Context, accounts []string) ([]model.User, erro // ResolveResult holds mention resolution output. type ResolveResult struct { Participants []model.Participant // enriched mentioned users + @all entry if present - MentionAll bool // true if @all or @here was mentioned + MentionAll bool // true if @all was mentioned (case-insensitive) Accounts []string // raw parsed accounts (for caller use outside resolution) } @@ -62,15 +61,20 @@ type ResolveResult struct { func Resolve(ctx context.Context, content string, lookupFn LookupFunc) (*ResolveResult, error) { parsed := Parse(content) if len(parsed.Accounts) == 0 && !parsed.MentionAll { - return &ResolveResult{Accounts: parsed.Accounts, MentionAll: parsed.MentionAll}, nil + return &ResolveResult{ + Accounts: parsed.Accounts, + MentionAll: parsed.MentionAll, + }, nil } users := map[string]model.User{} if len(parsed.Accounts) > 0 { fetched, err := lookupFn(ctx, parsed.Accounts) if err != nil { - return &ResolveResult{Accounts: parsed.Accounts, MentionAll: parsed.MentionAll}, - fmt.Errorf("find mentioned users: %w", err) + return &ResolveResult{ + Accounts: parsed.Accounts, + MentionAll: parsed.MentionAll, + }, fmt.Errorf("find mentioned users: %w", err) } users = make(map[string]model.User, len(fetched)) for i := range fetched { @@ -80,12 +84,8 @@ func Resolve(ctx context.Context, content string, lookupFn LookupFunc) (*Resolve return ResolveFromParsed(parsed, users), nil } -// ResolveFromParsed builds the ResolveResult from a pre-parsed input and a -// caller-supplied account→user map. Use this when the caller has already -// done the user lookup (e.g. broadcast-worker fetches sender + mentions in a -// single round-trip, then resolves without re-parsing or re-querying). -// Missing accounts in users are silently omitted from Participants — same -// semantics as Resolve, which omits unknown accounts returned by lookupFn. +// ResolveFromParsed builds a ResolveResult from pre-parsed input and a caller-supplied user map. +// Use when the caller has already done the lookup. Unknown accounts are silently omitted. func ResolveFromParsed(parsed ParseResult, users map[string]model.User) *ResolveResult { result := &ResolveResult{ MentionAll: parsed.MentionAll, diff --git a/pkg/mention/mention_test.go b/pkg/mention/mention_test.go index 40765da49..92c90fb95 100644 --- a/pkg/mention/mention_test.go +++ b/pkg/mention/mention_test.go @@ -44,10 +44,6 @@ func TestParse(t *testing.T) { // Email-style suffix no longer captured: only the leading @user matches. {name: "email-style suffix dropped", content: "ping @user@domain.com", accounts: []string{"user"}, mentionAll: false}, - - // @here is no longer a mentionAll alias — it parses as a regular account. - {name: "@here lowercase no longer mentionAll", content: "look @here please", accounts: []string{"here"}, mentionAll: false}, - {name: "@HERE uppercase no longer mentionAll", content: "look @HERE please", accounts: []string{"here"}, mentionAll: false}, } for _, tt := range tests { @@ -125,7 +121,7 @@ func TestResolve(t *testing.T) { }, }, { - name: "lookup error — partial result", + name: "lookup error — partial result returned", content: "hey @bob", lookupErr: errors.New("db error"), wantAccounts: []string{"bob"}, diff --git a/pkg/model/event.go b/pkg/model/event.go index aa2fa51d3..914438130 100644 --- a/pkg/model/event.go +++ b/pkg/model/event.go @@ -38,6 +38,18 @@ type SubscriptionUpdateEvent struct { Timestamp int64 `json:"timestamp" bson:"timestamp"` } +// CanonicalMemberEventMuted is the only event type currently published on this stream. +const CanonicalMemberEventMuted = "muted" + +// CanonicalMemberEvent is the room-scoped post-mutation event for roomsubcache invalidation (mute-only today). +type CanonicalMemberEvent struct { + Type string `json:"type"` + RoomID string `json:"roomId"` + Account string `json:"account"` + Muted bool `json:"muted"` // post-toggle state; false is a valid (unmuted) value, so no omitempty. + Timestamp int64 `json:"timestamp"` +} + type UpdateRoleRequest struct { RoomID string `json:"roomId" bson:"roomId"` Account string `json:"account" bson:"account"` @@ -69,13 +81,6 @@ type InboxMemberEvent struct { Timestamp int64 `json:"timestamp" bson:"timestamp"` } -type NotificationEvent struct { - Type string `json:"type"` // "new_message" - RoomID string `json:"roomId"` - Message Message `json:"message"` - Timestamp int64 `json:"timestamp" bson:"timestamp"` -} - // OutboxEventType is the type tag on an OutboxEvent used to route it to the // correct handler on the destination site. type OutboxEventType = string @@ -140,12 +145,17 @@ type MemberAddEvent struct { } // Participant represents a user with display name info for client rendering. +// DisplayName is the render-ready composed name (see pkg/displayfmt.CombineWithFallback) +// and is the field push-service uses to render notifications; it is populated only +// where pre-composition is meaningful (push event senders), left empty in +// fan-out shapes that carry raw EngName/ChineseName (mentions, ClientMessage). type Participant struct { - UserID string `json:"userId,omitempty" bson:"userId,omitempty"` - Account string `json:"account" bson:"account"` - SiteID string `json:"siteId,omitempty" bson:"siteId,omitempty"` - ChineseName string `json:"chineseName" bson:"chineseName"` - EngName string `json:"engName" bson:"engName"` + UserID string `json:"userId,omitempty" bson:"userId,omitempty"` + Account string `json:"account" bson:"account"` + SiteID string `json:"siteId,omitempty" bson:"siteId,omitempty"` + ChineseName string `json:"chineseName" bson:"chineseName"` + EngName string `json:"engName" bson:"engName"` + DisplayName string `json:"displayName,omitempty" bson:"displayName,omitempty"` } // ClientMessage wraps Message with enriched sender info for client consumption. diff --git a/pkg/model/message.go b/pkg/model/message.go index 19e6bf77d..58c8bd49f 100644 --- a/pkg/model/message.go +++ b/pkg/model/message.go @@ -7,10 +7,18 @@ import ( ) type Message struct { - ID string `json:"id" bson:"_id"` - RoomID string `json:"roomId" bson:"roomId"` - UserID string `json:"userId" bson:"userId"` - UserAccount string `json:"userAccount" bson:"userAccount"` + ID string `json:"id" bson:"_id"` + RoomID string `json:"roomId" bson:"roomId"` + UserID string `json:"userId" bson:"userId"` + UserAccount string `json:"userAccount" bson:"userAccount"` + // UserDisplayName is the render-ready sender name, composed once at canonical-message + // write time by message-gatekeeper via pkg/displayfmt.CombineWithFallback(engName, + // chineseName, account) — the same helper used by room-worker/sysmsg.go and + // pkg/model/cassandra/reactions.go so display formatting stays uniform system-wide. + // Downstream consumers (notification-worker, future search-sync-worker) read this + // verbatim; omitempty keeps pre-rollout canonical messages decoding cleanly (consumers + // fall back to UserAccount when the field is empty). + UserDisplayName string `json:"userDisplayName,omitempty" bson:"userDisplayName,omitempty"` Content string `json:"content" bson:"content"` Attachments [][]byte `json:"attachments,omitempty" bson:"attachments,omitempty"` Card *cassandra.Card `json:"card,omitempty" bson:"card,omitempty"` @@ -55,3 +63,14 @@ type SendMessageRequest struct { ThreadParentMessageCreatedAt *int64 `json:"threadParentMessageCreatedAt,omitempty"` QuotedParentMessageID string `json:"quotedParentMessageId,omitempty"` } + +// SenderDisplayName returns the canonical render-ready name for the message's +// sender: UserDisplayName when populated (the message-gatekeeper-composed value +// described on the field), UserAccount otherwise. The fallback handles legacy +// in-flight canonical messages that predate UserDisplayName. +func (m *Message) SenderDisplayName() string { + if m.UserDisplayName != "" { + return m.UserDisplayName + } + return m.UserAccount +} diff --git a/pkg/model/model_test.go b/pkg/model/model_test.go index 2c42800ed..6f8fa4da1 100644 --- a/pkg/model/model_test.go +++ b/pkg/model/model_test.go @@ -940,23 +940,6 @@ func TestRoomKeyGetResponseJSON(t *testing.T) { roundTrip(t, &src, &dst) } -func TestNotificationEventJSON(t *testing.T) { - src := model.NotificationEvent{ - Type: "new_message", - RoomID: "room-1", - Message: model.Message{ - ID: "m1", RoomID: "room-1", UserID: "u1", UserAccount: "alice", - Content: "hello", CreatedAt: time.Date(2026, 1, 1, 12, 0, 0, 0, time.UTC), - }, - Timestamp: 1735689600000, - } - data, err := json.Marshal(&src) - require.NoError(t, err) - var dst model.NotificationEvent - require.NoError(t, json.Unmarshal(data, &dst)) - assert.Equal(t, src, dst) -} - func TestUpdateRoleRequestJSON(t *testing.T) { src := model.UpdateRoleRequest{RoomID: "r1", Account: "bob", NewRole: model.RoleOwner, Timestamp: 1735689600000} roundTrip(t, &src, &model.UpdateRoleRequest{}) @@ -2749,3 +2732,64 @@ func TestMessageAndOutboxAndAsyncOpConstants(t *testing.T) { assert.Equal(t, model.RoomEventType("room_renamed"), model.RoomEventRoomRenamed) assert.Equal(t, model.RoomEventType("room_restricted"), model.RoomEventRoomRestricted) } + +func TestPushNotificationEvent_RoundTrip(t *testing.T) { + in := model.PushNotificationEvent{ + ID: "m1-b0", + Accounts: []string{"alice", "bob"}, + Title: "general", + Body: "hello", + RoomID: "r1", + Data: model.PushNotificationData{ + RoomID: "r1", + MessageID: "m1", + Type: "c", + Sender: &model.Participant{Account: "bob", ChineseName: "張三", EngName: "Bob"}, + PushTime: "2026-05-27T00:00:00Z", + }, + Timestamp: 1700000000000, + } + data, err := json.Marshal(in) + require.NoError(t, err) + var out model.PushNotificationEvent + require.NoError(t, json.Unmarshal(data, &out)) + assert.Equal(t, in, out) +} + +func TestMessage_SenderDisplayName(t *testing.T) { + tests := []struct { + name string + msg model.Message + want string + }{ + { + name: "uses UserDisplayName when populated", + msg: model.Message{UserAccount: "alice", UserDisplayName: "Alice Wang 愛麗絲"}, + want: "Alice Wang 愛麗絲", + }, + { + name: "falls back to UserAccount on legacy in-flight message", + msg: model.Message{UserAccount: "alice", UserDisplayName: ""}, + want: "alice", + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + assert.Equal(t, tc.want, tc.msg.SenderDisplayName()) + }) + } +} + +func TestPresenceSnapshotReply_RoundTrip(t *testing.T) { + in := model.PresenceSnapshotReply{ + Presences: map[string]model.Presence{ + "alice": {AggregatedStatus: "online"}, + "bob": {AggregatedStatus: "busy"}, + }, + } + data, err := json.Marshal(in) + require.NoError(t, err) + var out model.PresenceSnapshotReply + require.NoError(t, json.Unmarshal(data, &out)) + assert.Equal(t, in, out) +} diff --git a/pkg/model/presence.go b/pkg/model/presence.go new file mode 100644 index 000000000..9b24b4b9c --- /dev/null +++ b/pkg/model/presence.go @@ -0,0 +1,16 @@ +package model + +// PresenceSnapshotRequest is the bulk presence RPC request payload. +type PresenceSnapshotRequest struct { + Accounts []string `json:"accounts" bson:"accounts"` +} + +// PresenceSnapshotReply is the bulk presence RPC reply; absent accounts are treated fail-open. +type PresenceSnapshotReply struct { + Presences map[string]Presence `json:"presences" bson:"presences"` +} + +// Presence is an account's aggregated status. Only "busy" and "in-call" suppress push. +type Presence struct { + AggregatedStatus string `json:"aggregatedStatus" bson:"aggregatedStatus"` +} diff --git a/pkg/model/push.go b/pkg/model/push.go new file mode 100644 index 000000000..7069dbfda --- /dev/null +++ b/pkg/model/push.go @@ -0,0 +1,29 @@ +package model + +// PushNotificationEvent is the batched push envelope published to PUSH_NOTIFICATIONS_{siteID}. +// Accounts carries up to PUSH_RECIPIENT_BATCH_SIZE recipients sharing one canonical message; +// ID is "{messageID}-b{batchIndex}" and doubles as the Nats-Msg-Id dedup key. +type PushNotificationEvent struct { + ID string `json:"id" bson:"id"` + Accounts []string `json:"accounts" bson:"accounts"` + Title string `json:"title" bson:"title"` + Body string `json:"body" bson:"body"` + Data PushNotificationData `json:"data" bson:"data"` + RoomID string `json:"roomId" bson:"roomId"` + Timestamp int64 `json:"timestamp" bson:"timestamp"` +} + +// PushNotificationData is the push payload; short legacy tag names (rid/tmid/prid) are spelled +// out to camelCase, and chineseName/engName are collapsed into *Participant Sender. +type PushNotificationData struct { + RoomID string `json:"roomId" bson:"roomId"` + MessageID string `json:"messageId" bson:"messageId"` + Type string `json:"type" bson:"type"` + Sender *Participant `json:"sender,omitempty" bson:"sender,omitempty"` + ThreadMessageID string `json:"threadMessageId,omitempty" bson:"threadMessageId,omitempty"` + FileName string `json:"fileName,omitempty" bson:"fileName,omitempty"` + FileType string `json:"fileType,omitempty" bson:"fileType,omitempty"` + ParentRoomID string `json:"parentRoomId,omitempty" bson:"parentRoomId,omitempty"` + PushTime string `json:"pushTime" bson:"pushTime"` + AlsoSendToChannel bool `json:"alsoSendToChannel,omitempty" bson:"alsoSendToChannel,omitempty"` +} diff --git a/pkg/natsutil/gzip.go b/pkg/natsutil/gzip.go new file mode 100644 index 000000000..480fcd8d8 --- /dev/null +++ b/pkg/natsutil/gzip.go @@ -0,0 +1,126 @@ +package natsutil + +import ( + "bytes" + "compress/gzip" + "fmt" + "io" + "sync" + + "github.com/nats-io/nats.go" +) + +// HeaderContentEncoding and HeaderContentType mirror the HTTP header names so operators +// inspecting NATS payloads can apply familiar conventions. +const ( + HeaderContentEncoding = "Content-Encoding" + HeaderContentType = "Content-Type" + + ContentEncodingGzip = "gzip" + + // MaxDecodedPayloadSize is the default decompressed-size cap used by + // DecodePayload. Aligned with the typical operator-tuned NATS max_payload + // (256 KiB) — tighter than the upstream 1 MiB default. Realistic push events + // decompress to ≤ ~25 KB given the 20 KiB body cap enforced by + // message-gatekeeper, so 256 KiB leaves ~10× headroom for legitimate growth + // while keeping the gzip-bomb amplification ceiling reasonable. + // + // Callers who want a different cap (e.g. a service whose operator pinned + // max_payload to 1 MiB, or one that needs a tighter cap on small events) + // use DecodePayloadWithLimit(msg, maxBytes). + MaxDecodedPayloadSize = 256 << 10 // 256 KiB +) + +// gzipWriterPool amortises gzip.Writer allocations across publishers; the writer holds +// a ~64 KB internal buffer that would otherwise churn the GC under sustained publish load. +var gzipWriterPool = sync.Pool{ + New: func() any { return gzip.NewWriter(nil) }, +} + +// GzipPayload returns a gzip-compressed copy of payload. Allocates a fresh slice +// so the caller may reuse the input buffer without aliasing the output. +func GzipPayload(payload []byte) ([]byte, error) { + var buf bytes.Buffer + buf.Grow(len(payload) / 2) + gz, _ := gzipWriterPool.Get().(*gzip.Writer) + gz.Reset(&buf) + // Reset to io.Discard before returning to the pool so the writer does not + // retain a reference to the buffer (which may be large on big payloads). + defer func() { + gz.Reset(io.Discard) + gzipWriterPool.Put(gz) + }() + if _, err := gz.Write(payload); err != nil { + return nil, fmt.Errorf("gzip write: %w", err) + } + if err := gz.Close(); err != nil { + return nil, fmt.Errorf("gzip close: %w", err) + } + return buf.Bytes(), nil +} + +// NewGzipMsg builds a *nats.Msg with payload gzipped and Content-Encoding/Content-Type +// headers set so a consumer using DecodePayload can transparently decompress. +// contentType may be empty; the helper sets "application/json" by default for payload-encoded events. +func NewGzipMsg(subject string, payload []byte, contentType string) (*nats.Msg, error) { + encoded, err := GzipPayload(payload) + if err != nil { + return nil, err + } + if contentType == "" { + contentType = "application/json" + } + msg := &nats.Msg{ + Subject: subject, + Header: nats.Header{}, + Data: encoded, + } + msg.Header.Set(HeaderContentEncoding, ContentEncodingGzip) + msg.Header.Set(HeaderContentType, contentType) + return msg, nil +} + +// DecodePayload decodes using the default MaxDecodedPayloadSize cap. For a +// configurable cap (e.g. wired from a service's env var) use DecodePayloadWithLimit. +func DecodePayload(msg *nats.Msg) ([]byte, error) { + return DecodePayloadWithLimit(msg, MaxDecodedPayloadSize) +} + +// DecodePayloadWithLimit returns msg.Data verbatim when uncompressed, or the +// gunzipped bytes when Content-Encoding is "gzip". maxBytes caps the post-gunzip +// size so a gzip bomb can't blow up the consumer; the wire-side NATS max_payload +// is independent (typically 256 KiB - 1 MiB depending on operator config) and +// must be configured at the server level. Unknown encodings produce an error so +// consumers fail loudly rather than silently mis-parsing. A maxBytes of zero or +// negative falls back to MaxDecodedPayloadSize. +func DecodePayloadWithLimit(msg *nats.Msg, maxBytes int) ([]byte, error) { + if maxBytes <= 0 { + maxBytes = MaxDecodedPayloadSize + } + enc := "" + if msg.Header != nil { + enc = msg.Header.Get(HeaderContentEncoding) + } + switch enc { + case "", "identity": + return msg.Data, nil + case ContentEncodingGzip: + r, err := gzip.NewReader(bytes.NewReader(msg.Data)) + if err != nil { + return nil, fmt.Errorf("gzip reader: %w", err) + } + defer r.Close() + // Read up to maxBytes+1 so we can detect overflow without allocating + // beyond the cap. Bounds gzip-bomb amplification (~1000× on pathological inputs). + out, err := io.ReadAll(io.LimitReader(r, int64(maxBytes)+1)) + if err != nil { + return nil, fmt.Errorf("gzip read: %w", err) + } + if len(out) > maxBytes { + return nil, fmt.Errorf("gzip payload exceeds %d bytes", maxBytes) + } + return out, nil + default: + return nil, fmt.Errorf("unsupported Content-Encoding %q", enc) + } +} diff --git a/pkg/natsutil/gzip_test.go b/pkg/natsutil/gzip_test.go new file mode 100644 index 000000000..33ecddd4f --- /dev/null +++ b/pkg/natsutil/gzip_test.go @@ -0,0 +1,138 @@ +package natsutil + +import ( + "bytes" + "compress/gzip" + "strings" + "testing" + + "github.com/nats-io/nats.go" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestGzipPayload_RoundTrip(t *testing.T) { + in := []byte(strings.Repeat(`{"k":"v"}`, 100)) + encoded, err := GzipPayload(in) + require.NoError(t, err) + assert.NotEqual(t, in, encoded, "encoded payload differs from input") + assert.Less(t, len(encoded), len(in), "highly repetitive JSON should shrink under gzip") + + r, err := gzip.NewReader(bytes.NewReader(encoded)) + require.NoError(t, err) + defer r.Close() + var buf bytes.Buffer + _, err = buf.ReadFrom(r) + require.NoError(t, err) + assert.Equal(t, in, buf.Bytes()) +} + +func TestGzipPayload_EmptyInput(t *testing.T) { + encoded, err := GzipPayload(nil) + require.NoError(t, err) + // gzip framing means even empty input produces a non-empty (header+trailer) output. + assert.NotEmpty(t, encoded) + r, err := gzip.NewReader(bytes.NewReader(encoded)) + require.NoError(t, err) + defer r.Close() + var buf bytes.Buffer + _, err = buf.ReadFrom(r) + require.NoError(t, err) + assert.Empty(t, buf.Bytes()) +} + +func TestNewGzipMsg_SetsHeadersAndCompresses(t *testing.T) { + in := []byte(`{"hello":"world"}`) + msg, err := NewGzipMsg("foo.bar", in, "") + require.NoError(t, err) + + assert.Equal(t, "foo.bar", msg.Subject) + assert.Equal(t, ContentEncodingGzip, msg.Header.Get(HeaderContentEncoding)) + assert.Equal(t, "application/json", msg.Header.Get(HeaderContentType), "default content type") + + decoded, err := DecodePayload(msg) + require.NoError(t, err) + assert.Equal(t, in, decoded) +} + +func TestNewGzipMsg_CustomContentType(t *testing.T) { + msg, err := NewGzipMsg("foo.bar", []byte("hi"), "text/plain") + require.NoError(t, err) + assert.Equal(t, "text/plain", msg.Header.Get(HeaderContentType)) +} + +func TestDecodePayload_NoEncoding_PassesThrough(t *testing.T) { + msg := &nats.Msg{Data: []byte("raw bytes"), Header: nats.Header{}} + out, err := DecodePayload(msg) + require.NoError(t, err) + assert.Equal(t, []byte("raw bytes"), out) +} + +func TestDecodePayload_IdentityEncoding_PassesThrough(t *testing.T) { + msg := &nats.Msg{Data: []byte("raw"), Header: nats.Header{}} + msg.Header.Set(HeaderContentEncoding, "identity") + out, err := DecodePayload(msg) + require.NoError(t, err) + assert.Equal(t, []byte("raw"), out) +} + +func TestDecodePayload_NilHeader(t *testing.T) { + msg := &nats.Msg{Data: []byte("raw"), Header: nil} + out, err := DecodePayload(msg) + require.NoError(t, err) + assert.Equal(t, []byte("raw"), out) +} + +func TestDecodePayload_UnsupportedEncoding(t *testing.T) { + msg := &nats.Msg{Data: []byte("x"), Header: nats.Header{}} + msg.Header.Set(HeaderContentEncoding, "br") + _, err := DecodePayload(msg) + require.Error(t, err) + assert.Contains(t, err.Error(), "unsupported Content-Encoding") +} + +func TestDecodePayload_GzipCorrupt(t *testing.T) { + msg := &nats.Msg{Data: []byte("not gzip"), Header: nats.Header{}} + msg.Header.Set(HeaderContentEncoding, ContentEncodingGzip) + _, err := DecodePayload(msg) + require.Error(t, err) +} + +func TestDecodePayload_GzipExceedsMaxDecodedSize(t *testing.T) { + // Build a payload that decompresses to MaxDecodedPayloadSize+1 bytes. gzip on a + // constant byte runs at ~1000× compression so the wire bytes stay tiny — exactly + // the gzip-bomb shape we want the cap to reject. + oversized := bytes.Repeat([]byte{'a'}, MaxDecodedPayloadSize+1) + encoded, err := GzipPayload(oversized) + require.NoError(t, err) + assert.Less(t, len(encoded), MaxDecodedPayloadSize/100, + "sanity: highly repetitive input should compress small enough to fit in a single NATS msg") + + msg := &nats.Msg{Data: encoded, Header: nats.Header{}} + msg.Header.Set(HeaderContentEncoding, ContentEncodingGzip) + _, err = DecodePayload(msg) + require.Error(t, err) + assert.Contains(t, err.Error(), "exceeds") +} + +func TestDecodePayload_GzipAtMaxDecodedSize(t *testing.T) { + // Boundary check: a payload that exactly hits the cap must still succeed. + atLimit := bytes.Repeat([]byte{'b'}, MaxDecodedPayloadSize) + encoded, err := GzipPayload(atLimit) + require.NoError(t, err) + + msg := &nats.Msg{Data: encoded, Header: nats.Header{}} + msg.Header.Set(HeaderContentEncoding, ContentEncodingGzip) + out, err := DecodePayload(msg) + require.NoError(t, err) + assert.Len(t, out, MaxDecodedPayloadSize) +} + +func TestDecodePayload_GzipTruncated(t *testing.T) { + full, err := GzipPayload([]byte("hello world")) + require.NoError(t, err) + msg := &nats.Msg{Data: full[:len(full)-3], Header: nats.Header{}} + msg.Header.Set(HeaderContentEncoding, ContentEncodingGzip) + _, err = DecodePayload(msg) + require.Error(t, err) +} diff --git a/pkg/roomsubcache/roomsubcache.go b/pkg/roomsubcache/roomsubcache.go index c1808875a..d01fb51c9 100644 --- a/pkg/roomsubcache/roomsubcache.go +++ b/pkg/roomsubcache/roomsubcache.go @@ -2,12 +2,9 @@ // fan-out workers (e.g. notification-worker) can avoid a Mongo round-trip // for every published message. // -// The cache stores only the fields a fan-out path actually needs — -// {ID, Account} per member — not the full model.Subscription document. -// Entries are written with a caller-supplied TTL and are not actively -// invalidated; staleness is bounded by the TTL. An Invalidate method is -// provided so a future room-membership event listener can evict eagerly -// without changing this package. +// The cache stores the fan-out path's per-member input set — see Member. +// Entries are written with a caller-supplied TTL and may be eagerly +// invalidated via Invalidate; staleness is otherwise bounded by the TTL. package roomsubcache import ( @@ -17,6 +14,7 @@ import ( "fmt" "time" + "github.com/hmchangw/chat/pkg/model" "github.com/hmchangw/chat/pkg/valkeyutil" ) @@ -26,11 +24,15 @@ import ( // the reader. Configurable per-instance via WithMaxValueBytes. const DefaultMaxValueBytes = 16 * 1024 * 1024 -// Member is the projection of model.Subscription that fan-out callers need: -// the user's stable ID (for sender-skip checks) and account (for routing). +// Member is the model.Subscription projection needed by the fan-out path. +// Extra fields use omitempty so a plain member's JSON stays {id, account}. type Member struct { - ID string `json:"id"` - Account string `json:"account"` + ID string `json:"id"` + Account string `json:"account"` + RoomType model.RoomType `json:"roomType,omitempty"` + IsBot bool `json:"isBot,omitempty"` + Muted bool `json:"muted,omitempty"` + HistorySharedSince *int64 `json:"historySharedSince,omitempty"` } // Cache stores and retrieves a room's member list. diff --git a/pkg/roomsubcache/roomsubcache_test.go b/pkg/roomsubcache/roomsubcache_test.go index 821cd6676..47facbe68 100644 --- a/pkg/roomsubcache/roomsubcache_test.go +++ b/pkg/roomsubcache/roomsubcache_test.go @@ -2,6 +2,7 @@ package roomsubcache_test import ( "context" + "encoding/json" "errors" "strings" "testing" @@ -10,6 +11,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/hmchangw/chat/pkg/model" "github.com/hmchangw/chat/pkg/roomsubcache" "github.com/hmchangw/chat/pkg/valkeyutil" ) @@ -239,3 +241,47 @@ func keysOf(m map[string]string) []string { } return out } + +func TestMember_JSONRoundTrip_NewFields(t *testing.T) { + hss := int64(1700000000000) + in := roomsubcache.Member{ + ID: "u1", + Account: "alice", + RoomType: model.RoomTypeChannel, + IsBot: true, + Muted: true, + HistorySharedSince: &hss, + } + data, err := json.Marshal(in) + require.NoError(t, err) + + var out roomsubcache.Member + require.NoError(t, json.Unmarshal(data, &out)) + assert.Equal(t, in, out) +} + +func TestMember_RoomType_RoundTrip(t *testing.T) { + for _, rt := range []model.RoomType{ + model.RoomTypeChannel, + model.RoomTypeDM, + model.RoomTypeBotDM, + model.RoomTypeDiscussion, + } { + m := roomsubcache.Member{ID: "u1", Account: "alice", RoomType: rt} + data, err := json.Marshal(m) + require.NoError(t, err) + var out roomsubcache.Member + require.NoError(t, json.Unmarshal(data, &out)) + assert.Equal(t, rt, out.RoomType, "RoomType %q should round-trip", rt) + } +} + +func TestMember_OmitemptyOnZeroValues(t *testing.T) { + in := roomsubcache.Member{ID: "u1", Account: "alice"} + data, err := json.Marshal(in) + require.NoError(t, err) + got := string(data) + + // Only id + account on the wire; no zero-valued booleans / strings / pointers. + assert.JSONEq(t, `{"id":"u1","account":"alice"}`, got) +} diff --git a/pkg/stream/stream.go b/pkg/stream/stream.go index cddd7ed09..1b9229fa1 100644 --- a/pkg/stream/stream.go +++ b/pkg/stream/stream.go @@ -40,6 +40,15 @@ func Outbox(siteID string) Config { } } +// PushNotifications returns the PUSH_NOTIFICATIONS_{siteID} stream config. +// Owned by ops in production; notification-worker bootstraps it in dev only. +func PushNotifications(siteID string) Config { + return Config{ + Name: fmt.Sprintf("PUSH_NOTIFICATIONS_%s", siteID), + Subjects: []string{subject.PushNotificationFilter(siteID)}, + } +} + // Inbox returns the canonical config for the `INBOX_{siteID}` stream that // carries subscription lifecycle events (member_added, member_removed) // plus any other aggregated events federated in from other sites. diff --git a/pkg/stream/stream_test.go b/pkg/stream/stream_test.go index b52d71e1d..b54d86249 100644 --- a/pkg/stream/stream_test.go +++ b/pkg/stream/stream_test.go @@ -25,6 +25,7 @@ func TestStreamConfigs(t *testing.T) { {"MessagesCanonical", stream.MessagesCanonical(siteID), "MESSAGES_CANONICAL_site-a", "chat.msg.canonical.site-a.>"}, {"Rooms", stream.Rooms(siteID), "ROOMS_site-a", "chat.room.canonical.site-a.>"}, {"Outbox", stream.Outbox(siteID), "OUTBOX_site-a", "outbox.site-a.>"}, + {"PushNotifications", stream.PushNotifications(siteID), "PUSH_NOTIFICATIONS_site-a", "chat.server.notification.push.site-a.>"}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/pkg/subject/subject.go b/pkg/subject/subject.go index 01854344a..f59d6e850 100644 --- a/pkg/subject/subject.go +++ b/pkg/subject/subject.go @@ -101,6 +101,11 @@ func RoomCanonical(siteID, operation string) string { return fmt.Sprintf("chat.room.canonical.%s.%s", siteID, operation) } +// RoomCanonicalMemberEvent returns the post-mutation member-event subject (mute-only today). +func RoomCanonicalMemberEvent(siteID, eventType string) string { + return fmt.Sprintf("chat.room.canonical.%s.event.member.%s", siteID, eventType) +} + func SubscriptionUpdate(account string) string { return fmt.Sprintf("chat.user.%s.event.subscription.update", account) } @@ -802,3 +807,40 @@ func UserRoomWildCard(siteID string) string { func UserAppsWildCard(siteID string) string { return fmt.Sprintf("chat.user.*.request.user.%s.apps.>", siteID) } + +// PushNotification is the per-recipient mobile-push subject. Lives under chat.server.* so +// client JWTs cannot subscribe. The stream filter covers the .send leaf and future siblings. +func PushNotification(siteID string) string { + return fmt.Sprintf("chat.server.notification.push.%s.send", siteID) +} + +// PushNotificationFilter is the stream-binding wildcard covering .send and any future siblings. +func PushNotificationFilter(siteID string) string { + return fmt.Sprintf("chat.server.notification.push.%s.>", siteID) +} + +// PresenceSnapshot is the bulk presence RPC subject (request/reply). +func PresenceSnapshot(siteID string) string { + return fmt.Sprintf("chat.presence.%s.request.snapshot", siteID) +} + +// SubscriptionUpdateWildcard matches every subscription.update fanout event. +func SubscriptionUpdateWildcard() string { + return "chat.user.*.event.subscription.update" +} + +// ParseSubscriptionUpdateAccount extracts the account from a subscription.update subject; ok=false on malformed input. +func ParseSubscriptionUpdateAccount(s string) (account string, ok bool) { + parts := strings.Split(s, ".") + if len(parts) != 6 { + return "", false + } + if parts[0] != "chat" || parts[1] != "user" || parts[3] != "event" || + parts[4] != "subscription" || parts[5] != "update" { + return "", false + } + if !isValidAccountToken(parts[2]) { + return "", false + } + return parts[2], true +} diff --git a/pkg/subject/subject_test.go b/pkg/subject/subject_test.go index 0e4a6401e..868df0075 100644 --- a/pkg/subject/subject_test.go +++ b/pkg/subject/subject_test.go @@ -147,6 +147,12 @@ func TestSubjectBuilders(t *testing.T) { } }) + t.Run("PushNotificationFilter", func(t *testing.T) { + assert.Equal(t, + "chat.server.notification.push.site-a.>", + subject.PushNotificationFilter("site-a")) + }) + t.Run("InboxMemberEventSubjects", func(t *testing.T) { got := subject.InboxMemberEventSubjects("site-a") want := []string{ @@ -698,3 +704,33 @@ func TestUserServicePatternBuilders(t *testing.T) { }) } } + +func TestPushNotification(t *testing.T) { + assert.Equal(t, + "chat.server.notification.push.site-a.send", + subject.PushNotification("site-a")) +} + +func TestPresenceSnapshot(t *testing.T) { + assert.Equal(t, + "chat.presence.site-a.request.snapshot", + subject.PresenceSnapshot("site-a")) +} + +func TestSubscriptionUpdateWildcard(t *testing.T) { + assert.Equal(t, + "chat.user.*.event.subscription.update", + subject.SubscriptionUpdateWildcard()) +} + +func TestParseSubscriptionUpdateAccount(t *testing.T) { + acct, ok := subject.ParseSubscriptionUpdateAccount("chat.user.alice.event.subscription.update") + assert.True(t, ok) + assert.Equal(t, "alice", acct) + + _, ok = subject.ParseSubscriptionUpdateAccount("chat.user.alice.event.room.update") + assert.False(t, ok) + + _, ok = subject.ParseSubscriptionUpdateAccount("chat.user.*.event.subscription.update") + assert.False(t, ok) // wildcard token rejected +} diff --git a/pkg/userstore/cache.go b/pkg/userstore/cache.go new file mode 100644 index 000000000..b085d846e --- /dev/null +++ b/pkg/userstore/cache.go @@ -0,0 +1,165 @@ +package userstore + +import ( + "context" + "errors" + "fmt" + "sync/atomic" + "time" + + lru "github.com/hashicorp/golang-lru/v2/expirable" + "golang.org/x/sync/singleflight" + + "github.com/hmchangw/chat/pkg/model" +) + +// Cache is an in-process LRU+TTL cache fronting a UserStore. Shared by +// message-gatekeeper (sender display-name resolution), broadcast-worker (mention +// enrichment + sender lookup), and message-worker (mention resolution + sender +// lookup) so all three pay the same Mongo cost once per warm entry. +// +// Both lookups are cached. Every populate writes the user under both the by-ID +// and by-account prefixes so a hit on either path satisfies the other; the +// two LRUs share value pointers so a single User lives once in memory. +// Singleflight collapses concurrent FindUserByID misses for the same id. +// +// Pod-local in-memory is fine here: entries are tiny (~500 B/user), per-pod +// working set caps at a few MB for 10K warm users, writes are rare (display-name +// changes are admin events). Valkey overhead (network hop, serialization, +// error handling) buys nothing at this size. +type Cache struct { + byID *lru.LRU[string, *model.User] + byAccount *lru.LRU[string, *model.User] + store UserStore + sf singleflight.Group + + hits atomic.Uint64 + misses atomic.Uint64 + loadErrs atomic.Uint64 +} + +// Stats is a snapshot of the cache's counters. +type Stats struct { + Hits, Misses, LoadErrors uint64 + SizeByID, SizeByAccount int +} + +// NewCache returns a Cache fronting the given UserStore. size applies to each +// of the by-ID and by-account LRUs independently; ttl applies to both. +func NewCache(store UserStore, size int, ttl time.Duration) (*Cache, error) { + if store == nil { + return nil, fmt.Errorf("userstore: store must not be nil") + } + if size <= 0 { + return nil, fmt.Errorf("userstore: cache size must be positive, got %d", size) + } + if ttl <= 0 { + return nil, fmt.Errorf("userstore: cache ttl must be positive, got %v", ttl) + } + return &Cache{ + byID: lru.NewLRU[string, *model.User](size, nil, ttl), + byAccount: lru.NewLRU[string, *model.User](size, nil, ttl), + store: store, + }, nil +} + +// FindUserByID serves from the by-ID LRU when hot, falls through to the store +// on miss. ErrUserNotFound propagates unwrapped; missing entries are NOT +// negatively cached. +func (c *Cache) FindUserByID(ctx context.Context, id string) (*model.User, error) { + if v, ok := c.byID.Get(id); ok { + c.hits.Add(1) + return v, nil + } + c.misses.Add(1) + v, err, _ := c.sf.Do(id, func() (interface{}, error) { + if cached, ok := c.byID.Get(id); ok { + return cached, nil + } + u, err := c.store.FindUserByID(ctx, id) + if err != nil { + return nil, err + } + c.populate(u) + return u, nil + }) + if err != nil { + c.loadErrs.Add(1) + if errors.Is(err, ErrUserNotFound) { + return nil, err + } + return nil, fmt.Errorf("find cached user %q: %w", id, err) + } + return v.(*model.User), nil +} + +// FindUsersByAccounts serves cache hits from the by-account LRU and forwards +// the missing set to the store in one batched call. Input duplicates are +// deduped; result order is not guaranteed to match input. A store error +// returns partial hits + a wrapped error. +func (c *Cache) FindUsersByAccounts(ctx context.Context, accounts []string) ([]model.User, error) { + if len(accounts) == 0 { + return nil, nil + } + seen := make(map[string]struct{}, len(accounts)) + hits := make([]model.User, 0, len(accounts)) + missing := make([]string, 0, len(accounts)) + for _, a := range accounts { + if _, dup := seen[a]; dup { + continue + } + seen[a] = struct{}{} + if u, ok := c.byAccount.Get(a); ok { + c.hits.Add(1) + hits = append(hits, *u) + continue + } + c.misses.Add(1) + missing = append(missing, a) + } + if len(missing) == 0 { + return hits, nil + } + fresh, err := c.store.FindUsersByAccounts(ctx, missing) + if err != nil { + return hits, fmt.Errorf("cached find users by accounts: %w", err) + } + for i := range fresh { + c.populate(&fresh[i]) + } + return append(hits, fresh...), nil +} + +// populate writes the user under both prefixes so a hit on either path +// satisfies the other. The same pointer is stored in both LRUs. +func (c *Cache) populate(u *model.User) { + if u == nil { + return + } + c.byID.Add(u.ID, u) + if u.Account != "" { + c.byAccount.Add(u.Account, u) + } +} + +// Stats returns a snapshot of cache counters. +func (c *Cache) Stats() Stats { + return Stats{ + Hits: c.hits.Load(), + Misses: c.misses.Load(), + LoadErrors: c.loadErrs.Load(), + SizeByID: c.byID.Len(), + SizeByAccount: c.byAccount.Len(), + } +} + +// Invalidate drops any cached entries for the given user. Empty userID or +// account skips that prefix. +func (c *Cache) Invalidate(userID, account string) { + if userID != "" { + c.byID.Remove(userID) + } + if account != "" { + c.byAccount.Remove(account) + } +} diff --git a/pkg/userstore/cache_test.go b/pkg/userstore/cache_test.go new file mode 100644 index 000000000..993db403c --- /dev/null +++ b/pkg/userstore/cache_test.go @@ -0,0 +1,211 @@ +package userstore_test + +import ( + "context" + "errors" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/hmchangw/chat/pkg/model" + "github.com/hmchangw/chat/pkg/userstore" +) + +type fakeStore struct { + mu sync.Mutex + usersByID map[string]*model.User + usersByAccount map[string]*model.User + byIDCalls int + byAccountCalls int + err error +} + +func newFakeStore(users ...model.User) *fakeStore { + s := &fakeStore{ + usersByID: make(map[string]*model.User, len(users)), + usersByAccount: make(map[string]*model.User, len(users)), + } + for i := range users { + u := users[i] + s.usersByID[u.ID] = &u + s.usersByAccount[u.Account] = &u + } + return s +} + +func (f *fakeStore) FindUserByID(_ context.Context, id string) (*model.User, error) { + f.mu.Lock() + defer f.mu.Unlock() + f.byIDCalls++ + if f.err != nil { + return nil, f.err + } + if u, ok := f.usersByID[id]; ok { + return u, nil + } + return nil, userstore.ErrUserNotFound +} + +func (f *fakeStore) FindUsersByAccounts(_ context.Context, accounts []string) ([]model.User, error) { + f.mu.Lock() + defer f.mu.Unlock() + f.byAccountCalls++ + if f.err != nil { + return nil, f.err + } + out := make([]model.User, 0, len(accounts)) + for _, a := range accounts { + if u, ok := f.usersByAccount[a]; ok { + out = append(out, *u) + } + } + return out, nil +} + +func TestNewCache_RejectsInvalidArgs(t *testing.T) { + _, err := userstore.NewCache(nil, 10, time.Minute) + require.Error(t, err) + + _, err = userstore.NewCache(newFakeStore(), 0, time.Minute) + require.Error(t, err) + + _, err = userstore.NewCache(newFakeStore(), 10, 0) + require.Error(t, err) +} + +func TestCache_FindUserByID_MissThenHit(t *testing.T) { + store := newFakeStore(model.User{ID: "u1", Account: "alice"}) + cache, err := userstore.NewCache(store, 10, time.Minute) + require.NoError(t, err) + + u, err := cache.FindUserByID(context.Background(), "u1") + require.NoError(t, err) + assert.Equal(t, "alice", u.Account) + assert.Equal(t, 1, store.byIDCalls) + + _, err = cache.FindUserByID(context.Background(), "u1") + require.NoError(t, err) + assert.Equal(t, 1, store.byIDCalls, "second lookup must hit cache") +} + +func TestCache_FindUserByID_NotFoundIsUnwrapped(t *testing.T) { + cache, _ := userstore.NewCache(newFakeStore(), 10, time.Minute) + _, err := cache.FindUserByID(context.Background(), "ghost") + require.Error(t, err) + assert.ErrorIs(t, err, userstore.ErrUserNotFound) +} + +func TestCache_FindUserByID_StoreErrorWrapped(t *testing.T) { + store := &fakeStore{err: errors.New("mongo down"), usersByID: map[string]*model.User{}, usersByAccount: map[string]*model.User{}} + cache, _ := userstore.NewCache(store, 10, time.Minute) + _, err := cache.FindUserByID(context.Background(), "u1") + require.Error(t, err) + assert.NotErrorIs(t, err, userstore.ErrUserNotFound) + assert.Contains(t, err.Error(), "find cached user") +} + +func TestCache_FindUsersByAccounts_BatchPartialHit(t *testing.T) { + ctx := context.Background() + store := newFakeStore( + model.User{ID: "u1", Account: "alice"}, + model.User{ID: "u2", Account: "bob"}, + model.User{ID: "u3", Account: "carol"}, + ) + cache, _ := userstore.NewCache(store, 10, time.Minute) + + got, err := cache.FindUsersByAccounts(ctx, []string{"alice", "bob"}) + require.NoError(t, err) + assert.Len(t, got, 2) + assert.Equal(t, 1, store.byAccountCalls) + + got, err = cache.FindUsersByAccounts(ctx, []string{"alice", "carol"}) + require.NoError(t, err) + assert.Len(t, got, 2) + assert.Equal(t, 2, store.byAccountCalls) + stats := cache.Stats() + assert.GreaterOrEqual(t, stats.Hits, uint64(1)) +} + +func TestCache_FindUsersByAccounts_CrossPopulatesByID(t *testing.T) { + ctx := context.Background() + store := newFakeStore(model.User{ID: "u1", Account: "alice"}) + cache, _ := userstore.NewCache(store, 10, time.Minute) + + _, err := cache.FindUsersByAccounts(ctx, []string{"alice"}) + require.NoError(t, err) + + _, err = cache.FindUserByID(ctx, "u1") + require.NoError(t, err) + assert.Equal(t, 0, store.byIDCalls, "FindUserByID must hit cache via cross-populated key") +} + +func TestCache_FindUserByID_CrossPopulatesByAccount(t *testing.T) { + ctx := context.Background() + store := newFakeStore(model.User{ID: "u1", Account: "alice"}) + cache, _ := userstore.NewCache(store, 10, time.Minute) + + _, err := cache.FindUserByID(ctx, "u1") + require.NoError(t, err) + + got, err := cache.FindUsersByAccounts(ctx, []string{"alice"}) + require.NoError(t, err) + assert.Len(t, got, 1) + assert.Equal(t, 0, store.byAccountCalls, "FindUsersByAccounts must hit cache via cross-populated key") +} + +func TestCache_FindUsersByAccounts_DedupesInput(t *testing.T) { + store := newFakeStore(model.User{ID: "u1", Account: "alice"}) + cache, _ := userstore.NewCache(store, 10, time.Minute) + + got, err := cache.FindUsersByAccounts(context.Background(), []string{"alice", "alice", "alice"}) + require.NoError(t, err) + assert.Len(t, got, 1) +} + +func TestCache_FindUsersByAccounts_EmptyInput(t *testing.T) { + cache, _ := userstore.NewCache(newFakeStore(), 10, time.Minute) + got, err := cache.FindUsersByAccounts(context.Background(), nil) + require.NoError(t, err) + assert.Nil(t, got) +} + +func TestCache_FindUsersByAccounts_StoreErrorReturnsPartialHits(t *testing.T) { + ctx := context.Background() + store := newFakeStore(model.User{ID: "u1", Account: "alice"}) + cache, _ := userstore.NewCache(store, 10, time.Minute) + + _, err := cache.FindUsersByAccounts(ctx, []string{"alice"}) + require.NoError(t, err) + + store.mu.Lock() + store.err = errors.New("mongo down") + store.mu.Unlock() + got, err := cache.FindUsersByAccounts(ctx, []string{"alice", "ghost"}) + require.Error(t, err) + assert.Len(t, got, 1, "alice hit must be returned alongside the error") + assert.Equal(t, "alice", got[0].Account) +} + +func TestCache_Invalidate(t *testing.T) { + ctx := context.Background() + store := newFakeStore(model.User{ID: "u1", Account: "alice"}) + cache, _ := userstore.NewCache(store, 10, time.Minute) + + _, err := cache.FindUserByID(ctx, "u1") + require.NoError(t, err) + + cache.Invalidate("u1", "alice") + + _, err = cache.FindUserByID(ctx, "u1") + require.NoError(t, err) + assert.Equal(t, 2, store.byIDCalls, "post-invalidate lookup must re-hit store") + + _, err = cache.FindUsersByAccounts(ctx, []string{"alice"}) + require.NoError(t, err) + // Note: u1 was repopulated by the FindUserByID above (cross-populates the account prefix), + // so this call should be a cache hit and not increment byAccountCalls. + assert.Equal(t, 0, store.byAccountCalls) +} diff --git a/room-service/handler.go b/room-service/handler.go index 0134bba34..b3714f0a6 100644 --- a/room-service/handler.go +++ b/room-service/handler.go @@ -1874,6 +1874,21 @@ func (h *Handler) handleMuteToggle(ctx context.Context, subj string, _ []byte) ( // Non-fatal — the DB write is the source of truth; clients will reconcile on next refetch. } + // Canonical room-stream event consumed by notification-worker for cache invalidation. + // One event per mutation, room-scoped (not per-user). Non-fatal: TTL reconciles on miss. + canonEvt := model.CanonicalMemberEvent{ + Type: model.CanonicalMemberEventMuted, + RoomID: sub.RoomID, + Account: account, + Muted: sub.Muted, + Timestamp: now.UnixMilli(), + } + if canonData, err := json.Marshal(canonEvt); err == nil { + if err := h.publishToStream(ctx, subject.RoomCanonicalMemberEvent(h.siteID, model.CanonicalMemberEventMuted), canonData, ""); err != nil { + slog.Error("canonical member event publish failed", "error", err, "type", "muted", "roomID", sub.RoomID, "account", account) + } + } + userSiteID, err := h.store.GetUserSiteID(ctx, account) if err != nil { return nil, fmt.Errorf("get user siteId: %w", err) diff --git a/room-service/handler_test.go b/room-service/handler_test.go index cf5e847b9..e058b9956 100644 --- a/room-service/handler_test.go +++ b/room-service/handler_test.go @@ -3861,11 +3861,14 @@ func TestHandler_MuteToggle_Success(t *testing.T) { var coreSubjects []string var coreBodies [][]byte + var streamSubjects []string + var streamBodies [][]byte h := &Handler{ store: store, siteID: "site-a", - publishToStream: func(_ context.Context, _ string, _ []byte, _ string) error { - t.Fatal("publishToStream must not be called for same-site mute toggle") + publishToStream: func(_ context.Context, subj string, data []byte, _ string) error { + streamSubjects = append(streamSubjects, subj) + streamBodies = append(streamBodies, data) return nil }, publishCore: func(_ context.Context, subj string, data []byte) error { @@ -3892,6 +3895,16 @@ func TestHandler_MuteToggle_Success(t *testing.T) { assert.Equal(t, "mute_toggled", evt.Action) assert.True(t, evt.Subscription.Muted) assert.Equal(t, "alice", evt.Subscription.User.Account) + + // Canonical room-stream event for notification-worker cache invalidation. + require.Len(t, streamSubjects, 1) + assert.Equal(t, subject.RoomCanonicalMemberEvent("site-a", model.CanonicalMemberEventMuted), streamSubjects[0]) + var canon model.CanonicalMemberEvent + require.NoError(t, json.Unmarshal(streamBodies[0], &canon)) + assert.Equal(t, model.CanonicalMemberEventMuted, canon.Type) + assert.Equal(t, "r1", canon.RoomID) + assert.Equal(t, "alice", canon.Account) + assert.True(t, canon.Muted) } func TestHandler_MuteToggle_CrossSitePublishesOutbox(t *testing.T) { @@ -4006,11 +4019,10 @@ func TestHandler_MuteToggle_GetUserSiteIDError(t *testing.T) { h := &Handler{ store: store, siteID: "site-a", - publishToStream: func(_ context.Context, _ string, _ []byte, _ string) error { - t.Fatal("publishToStream must not be called when GetUserSiteID fails") - return nil - }, - publishCore: func(_ context.Context, _ string, _ []byte) error { return nil }, + // Canonical member event publish happens before GetUserSiteID and is + // independent of the outbox path — it represents the successful DB mutation. + publishToStream: func(_ context.Context, _ string, _ []byte, _ string) error { return nil }, + publishCore: func(_ context.Context, _ string, _ []byte) error { return nil }, } subj := subject.MuteToggle("alice", "r1", "site-a") @@ -4492,10 +4504,7 @@ func TestHandler_MuteToggle_CorePublishFailureIsNonFatal(t *testing.T) { publishCore: func(_ context.Context, _ string, _ []byte) error { return fmt.Errorf("core nats down") }, - publishToStream: func(_ context.Context, _ string, _ []byte, _ string) error { - t.Fatal("publishToStream must not be called for same-site mute toggle") - return nil - }, + publishToStream: func(_ context.Context, _ string, _ []byte, _ string) error { return nil }, } subj := subject.MuteToggle("alice", "r1", "site-a") diff --git a/room-worker/handler.go b/room-worker/handler.go index a3846934e..4b1be0767 100644 --- a/room-worker/handler.go +++ b/room-worker/handler.go @@ -73,6 +73,13 @@ func NewHandler(store SubscriptionStore, siteID string, publish PublishFunc, key } } +// publishSubscriptionUpdate fans out the per-user subscription.update event for the FE; best-effort. +func (h *Handler) publishSubscriptionUpdate(ctx context.Context, account string, subEvtData []byte) { + if err := h.publish(ctx, subject.SubscriptionUpdate(account), subEvtData, ""); err != nil { + slog.Error("subscription update publish failed", "error", err, "account", account) + } +} + // SetKeyFanoutWorkers overrides the bounded-worker pool size used by // fanOutKey. Values <= 0 are ignored so partial-deployment misconfig can't // disable the cap. main wires this from KEY_FANOUT_WORKERS at startup. @@ -438,9 +445,7 @@ func (h *Handler) processRemoveIndividual(ctx context.Context, req *model.Remove Timestamp: now.UnixMilli(), } subEvtData, _ := json.Marshal(subEvt) - if err := h.publish(ctx, subject.SubscriptionUpdate(req.Account), subEvtData, ""); err != nil { - slog.ErrorContext(ctx, "subscription update publish failed", "error", err, "account", req.Account) - } + h.publishSubscriptionUpdate(ctx, req.Account, subEvtData) // Member change event evtType := model.MessageTypeMemberLeft @@ -649,9 +654,7 @@ func (h *Handler) processRemoveOrg(ctx context.Context, req *model.RemoveMemberR Timestamp: now.UnixMilli(), } subEvtData, _ := json.Marshal(subEvt) - if err := h.publish(ctx, subject.SubscriptionUpdate(m.Account), subEvtData, ""); err != nil { - slog.ErrorContext(ctx, "subscription update publish failed", "error", err, "account", m.Account) - } + h.publishSubscriptionUpdate(ctx, m.Account, subEvtData) } // Member change event with all removed accounts @@ -1011,9 +1014,7 @@ func (h *Handler) processAddMembers(ctx context.Context, data []byte) (err error Timestamp: now.UnixMilli(), } subEvtData, _ := json.Marshal(subEvt) - if err := h.publish(ctx, subject.SubscriptionUpdate(sub.User.Account), subEvtData, ""); err != nil { - slog.ErrorContext(ctx, "subscription update publish failed", "error", err, "account", sub.User.Account) - } + h.publishSubscriptionUpdate(ctx, sub.User.Account, subEvtData) } // Fan out the room key only to newly-subscribed accounts. Accounts in @@ -1445,9 +1446,7 @@ func (h *Handler) finishCreateRoom(ctx context.Context, req *model.CreateRoomReq slog.ErrorContext(ctx, "marshal subscription.update failed", "error", err, "account", sub.User.Account) continue } - if err := h.publish(ctx, subject.SubscriptionUpdate(sub.User.Account), data, ""); err != nil { - slog.ErrorContext(ctx, "publish subscription.update failed", "error", err, "account", sub.User.Account) - } + h.publishSubscriptionUpdate(ctx, sub.User.Account, data) } // Task 36: channel-only sys-messages @@ -1817,10 +1816,7 @@ func (h *Handler) publishSubscriptionUpdates(ctx context.Context, subs []*model. "error", err, "account", sub.User.Account, "request_id", requestID) continue } - if err := h.publish(ctx, subject.SubscriptionUpdate(sub.User.Account), data, ""); err != nil { - slog.ErrorContext(ctx, "sync DM: publish subscription.update failed", - "error", err, "account", sub.User.Account, "request_id", requestID) - } + h.publishSubscriptionUpdate(ctx, sub.User.Account, data) } } diff --git a/room-worker/handler_test.go b/room-worker/handler_test.go index 029c2c3ef..0822f2155 100644 --- a/room-worker/handler_test.go +++ b/room-worker/handler_test.go @@ -1196,7 +1196,7 @@ func TestHandler_ProcessRemoveMember_OwnerRemovesOrg(t *testing.T) { err := h.processRemoveMember(context.Background(), data) require.NoError(t, err) - // Expect: 2 sub updates (carol, dave) + 1 member event + 1 local INBOX + 1 sys msg = 5 publishes + // Expect: 2 sub updates + 1 member event + 1 local INBOX + 1 sys msg = 5 publishes assert.Len(t, published, 5, "expected 5 publishes: 2 sub updates, member event, local INBOX, sys msg") subjSet := make(map[string]bool) @@ -2863,7 +2863,7 @@ func TestHandleSyncCreateDM_SelfDM(t *testing.T) { // Reply returns the in-memory sub directly (no read-back round-trip). assert.Equal(t, *captured[0], reply.Subscription) - // One subscription.update; no outbox (same-site by definition). + // subscription.update only — same-site self-DM; no outbox and no canonical event (Option C). require.Len(t, capture.captured, 1) assert.Equal(t, subject.SubscriptionUpdate("alice"), capture.captured[0].subject) } diff --git a/room-worker/integration_test.go b/room-worker/integration_test.go index a946807f3..47defa8ce 100644 --- a/room-worker/integration_test.go +++ b/room-worker/integration_test.go @@ -1131,7 +1131,7 @@ func TestSyncCreateDM_SelfDM_PersistsSingleFavoritedSub(t *testing.T) { } cap.mu.Unlock() assert.Equal(t, 1, subjects[subject.SubscriptionUpdate("alice")]) - assert.Equal(t, 1, total, "only the subscription.update; no outbox") + assert.Equal(t, 1, total, "subscription.update only; no outbox (same-site) and no canonical member event (Option C)") } func TestSyncCreateDM_BotDM_CrossSiteOutbox(t *testing.T) {