beadhub 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. beadhub/__init__.py +12 -0
  2. beadhub/api.py +260 -0
  3. beadhub/auth.py +101 -0
  4. beadhub/aweb_context.py +65 -0
  5. beadhub/aweb_introspection.py +70 -0
  6. beadhub/beads_sync.py +514 -0
  7. beadhub/cli.py +330 -0
  8. beadhub/config.py +65 -0
  9. beadhub/db.py +129 -0
  10. beadhub/defaults/invariants/01-tracking-bdh-only.md +11 -0
  11. beadhub/defaults/invariants/02-communication-mail-first.md +36 -0
  12. beadhub/defaults/invariants/03-communication-chat.md +60 -0
  13. beadhub/defaults/invariants/04-identity-no-impersonation.md +17 -0
  14. beadhub/defaults/invariants/05-collaborate.md +12 -0
  15. beadhub/defaults/roles/backend.md +55 -0
  16. beadhub/defaults/roles/coordinator.md +44 -0
  17. beadhub/defaults/roles/frontend.md +77 -0
  18. beadhub/defaults/roles/implementer.md +73 -0
  19. beadhub/defaults/roles/reviewer.md +56 -0
  20. beadhub/defaults/roles/startup-expert.md +93 -0
  21. beadhub/defaults.py +262 -0
  22. beadhub/events.py +704 -0
  23. beadhub/internal_auth.py +121 -0
  24. beadhub/jsonl.py +68 -0
  25. beadhub/logging.py +62 -0
  26. beadhub/migrations/beads/001_initial.sql +70 -0
  27. beadhub/migrations/beads/002_search_indexes.sql +20 -0
  28. beadhub/migrations/server/001_initial.sql +279 -0
  29. beadhub/names.py +33 -0
  30. beadhub/notifications.py +275 -0
  31. beadhub/pagination.py +125 -0
  32. beadhub/presence.py +495 -0
  33. beadhub/rate_limit.py +152 -0
  34. beadhub/redis_client.py +11 -0
  35. beadhub/roles.py +35 -0
  36. beadhub/routes/__init__.py +1 -0
  37. beadhub/routes/agents.py +303 -0
  38. beadhub/routes/bdh.py +655 -0
  39. beadhub/routes/beads.py +778 -0
  40. beadhub/routes/claims.py +141 -0
  41. beadhub/routes/escalations.py +471 -0
  42. beadhub/routes/init.py +348 -0
  43. beadhub/routes/mcp.py +338 -0
  44. beadhub/routes/policies.py +833 -0
  45. beadhub/routes/repos.py +538 -0
  46. beadhub/routes/status.py +568 -0
  47. beadhub/routes/subscriptions.py +362 -0
  48. beadhub/routes/workspaces.py +1642 -0
  49. beadhub/workspace_config.py +202 -0
  50. beadhub-0.1.0.dist-info/METADATA +254 -0
  51. beadhub-0.1.0.dist-info/RECORD +54 -0
  52. beadhub-0.1.0.dist-info/WHEEL +4 -0
  53. beadhub-0.1.0.dist-info/entry_points.txt +2 -0
  54. beadhub-0.1.0.dist-info/licenses/LICENSE +21 -0
beadhub/events.py ADDED
@@ -0,0 +1,704 @@
1
+ """Event publishing and streaming via Redis pub/sub.
2
+
3
+ This module provides the infrastructure for real-time event streaming:
4
+ - Event types for messages, escalations, and beads
5
+ - EventBus for publishing events to Redis pub/sub channels
6
+ - Helpers for SSE streaming
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import asyncio
12
+ import json
13
+ import logging
14
+ import time
15
+ from dataclasses import asdict, dataclass, field
16
+ from datetime import datetime, timedelta, timezone
17
+ from enum import Enum
18
+ from typing import Any, AsyncIterator, Awaitable, Callable, Optional
19
+
20
+ from redis.asyncio import Redis
21
+ from redis.asyncio.client import PubSub
22
+ from redis.exceptions import ResponseError
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+
27
+ class EventCategory(str, Enum):
28
+ """Categories of events that can be streamed."""
29
+
30
+ RESERVATION = "reservation"
31
+ MESSAGE = "message"
32
+ ESCALATION = "escalation"
33
+ BEAD = "bead"
34
+
35
+
36
+ @dataclass
37
+ class Event:
38
+ """Base class for all events."""
39
+
40
+ workspace_id: str
41
+ type: str = ""
42
+ timestamp: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
43
+ project_slug: str | None = None
44
+
45
+ def to_dict(self) -> dict[str, Any]:
46
+ return asdict(self)
47
+
48
+ def to_json(self) -> str:
49
+ return json.dumps(self.to_dict())
50
+
51
+ @property
52
+ def category(self) -> EventCategory:
53
+ """Extract category from event type (e.g., 'message.delivered' -> 'message')."""
54
+ return EventCategory(self.type.split(".")[0])
55
+
56
+
57
+ @dataclass
58
+ class ReservationAcquiredEvent(Event):
59
+ """Event emitted when reservations are acquired."""
60
+
61
+ type: str = field(default="reservation.acquired", init=False)
62
+ paths: list[str] = field(default_factory=list)
63
+ alias: str = ""
64
+ ttl_seconds: int = 0
65
+ bead_id: str | None = None
66
+ reason: str | None = None
67
+ exclusive: bool = True
68
+
69
+
70
+ @dataclass
71
+ class ReservationReleasedEvent(Event):
72
+ """Event emitted when reservations are released."""
73
+
74
+ type: str = field(default="reservation.released", init=False)
75
+ paths: list[str] = field(default_factory=list)
76
+ alias: str = ""
77
+
78
+
79
+ @dataclass
80
+ class ReservationRenewedEvent(Event):
81
+ """Event emitted when reservation TTLs are extended."""
82
+
83
+ type: str = field(default="reservation.renewed", init=False)
84
+ paths: list[str] = field(default_factory=list)
85
+ alias: str = ""
86
+ ttl_seconds: int = 0
87
+
88
+
89
+ @dataclass
90
+ class MessageDeliveredEvent(Event):
91
+ """Event emitted when a message is delivered to a workspace inbox."""
92
+
93
+ type: str = field(default="message.delivered", init=False)
94
+ message_id: str = ""
95
+ from_workspace: str = ""
96
+ from_alias: str = ""
97
+ subject: str = ""
98
+ priority: str = "normal"
99
+
100
+
101
+ @dataclass
102
+ class MessageAcknowledgedEvent(Event):
103
+ """Event emitted when a message is acknowledged."""
104
+
105
+ type: str = field(default="message.acknowledged", init=False)
106
+ message_id: str = ""
107
+
108
+
109
+ @dataclass
110
+ class EscalationCreatedEvent(Event):
111
+ """Event emitted when an escalation is created."""
112
+
113
+ type: str = field(default="escalation.created", init=False)
114
+ escalation_id: str = ""
115
+ alias: str = ""
116
+ subject: str = ""
117
+
118
+
119
+ @dataclass
120
+ class EscalationRespondedEvent(Event):
121
+ """Event emitted when an escalation receives a response."""
122
+
123
+ type: str = field(default="escalation.responded", init=False)
124
+ escalation_id: str = ""
125
+ response: str = ""
126
+
127
+
128
+ @dataclass
129
+ class BeadStatusChangedEvent(Event):
130
+ """Event emitted when a bead's status changes."""
131
+
132
+ type: str = field(default="bead.status_changed", init=False)
133
+ project_id: str = ""
134
+ bead_id: str = ""
135
+ repo: str = ""
136
+ old_status: str = ""
137
+ new_status: str = ""
138
+
139
+
140
+ # =============================================================================
141
+ # Chat Events (for real-time agent chat sessions)
142
+ # =============================================================================
143
+
144
+
145
+ @dataclass
146
+ class ChatEvent:
147
+ """Base class for chat session events.
148
+
149
+ Chat events are scoped to a session, not a workspace.
150
+ """
151
+
152
+ session_id: str
153
+ type: str = ""
154
+ timestamp: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
155
+
156
+ def to_dict(self) -> dict[str, Any]:
157
+ return asdict(self)
158
+
159
+ def to_json(self) -> str:
160
+ return json.dumps(self.to_dict())
161
+
162
+
163
+ @dataclass
164
+ class ChatMessageEvent(ChatEvent):
165
+ """Event emitted when a message is sent in a chat session.
166
+
167
+ Sessions are persistent and do not require explicit join/leave state.
168
+ """
169
+
170
+ type: str = field(default="message", init=False)
171
+ message_id: str = ""
172
+ from_agent: str = ""
173
+ body: str = ""
174
+ sender_leaving: bool = False # True when sender left the conversation
175
+ hang_on: bool = False # True when sender requests more time
176
+ extends_wait_seconds: int = 0 # How long to extend wait (for hang_on)
177
+
178
+
179
+ @dataclass
180
+ class ChatReadReceiptEvent(ChatEvent):
181
+ """Event emitted when a participant reads messages in a chat session.
182
+
183
+ Notifies the original sender that their message was read, allowing their
184
+ CLI to extend the wait timeout.
185
+ """
186
+
187
+ type: str = field(default="read_receipt", init=False)
188
+ reader: str = "" # workspace_id of the reader
189
+ reader_alias: str = "" # alias of the reader
190
+ up_to_message_id: str = "" # messages read up to this ID
191
+ extends_wait_seconds: int = 0 # How long to extend the sender's wait
192
+
193
+
194
+ def _channel_name(workspace_id: str) -> str:
195
+ """Generate Redis channel name for a workspace."""
196
+ return f"events:{workspace_id}"
197
+
198
+
199
+ def _chat_channel_name(session_id: str) -> str:
200
+ """Generate Redis channel name for a chat session."""
201
+ return f"chat:{session_id}"
202
+
203
+
204
+ async def publish_event(redis: Redis, event: Event) -> int:
205
+ """Publish an event to the workspace's Redis pub/sub channel.
206
+
207
+ Args:
208
+ redis: Redis client
209
+ event: Event to publish
210
+
211
+ Returns:
212
+ Number of subscribers that received the message
213
+ """
214
+ channel = _channel_name(event.workspace_id)
215
+ message = event.to_json()
216
+ count = await redis.publish(channel, message)
217
+ logger.debug(f"Published {event.type} to {channel}, {count} subscribers")
218
+ return count
219
+
220
+
221
+ async def stream_events(
222
+ redis: Redis,
223
+ workspace_id: str,
224
+ event_types: Optional[set[str]] = None,
225
+ keepalive_seconds: int = 30,
226
+ ) -> AsyncIterator[str]:
227
+ """Stream events for a workspace as SSE-formatted strings.
228
+
229
+ Args:
230
+ redis: Redis client
231
+ workspace_id: Workspace to stream events for
232
+ event_types: Optional set of event categories to filter (e.g., {'message', 'bead'})
233
+ If None, all events are streamed.
234
+ keepalive_seconds: Seconds between keepalive comments
235
+
236
+ Yields:
237
+ SSE-formatted event strings (e.g., "data: {...}\\n\\n")
238
+ """
239
+ async for event in stream_events_multi(redis, [workspace_id], event_types, keepalive_seconds):
240
+ yield event
241
+
242
+
243
+ async def stream_events_multi(
244
+ redis: Redis,
245
+ workspace_ids: list[str],
246
+ event_types: Optional[set[str]] = None,
247
+ keepalive_seconds: int = 30,
248
+ check_disconnected: Optional[Callable[[], Awaitable[bool]]] = None,
249
+ ) -> AsyncIterator[str]:
250
+ """Stream events for multiple workspaces as SSE-formatted strings.
251
+
252
+ Args:
253
+ redis: Redis client
254
+ workspace_ids: List of workspace IDs to stream events for
255
+ event_types: Optional set of event categories to filter (e.g., {'message', 'bead'})
256
+ If None, all events are streamed.
257
+ keepalive_seconds: Seconds between keepalive comments
258
+ check_disconnected: Optional async callback to check if client has disconnected.
259
+ When provided and returns True, the stream ends cleanly.
260
+
261
+ Yields:
262
+ SSE-formatted event strings (e.g., "data: {...}\\n\\n")
263
+ """
264
+ channels = [_channel_name(ws_id) for ws_id in workspace_ids]
265
+
266
+ # Empty workspace list: send keepalives for a limited time.
267
+ # This handles new projects with no workspaces yet while preventing
268
+ # resource leaks if disconnect detection fails.
269
+ if not channels:
270
+ max_duration_seconds = 5 * 60 # 5 minutes
271
+ max_keepalives = max_duration_seconds // keepalive_seconds
272
+ keepalive_count = 0
273
+
274
+ while keepalive_count < max_keepalives:
275
+ # Check for client disconnect
276
+ if check_disconnected and await check_disconnected():
277
+ logger.debug("Client disconnected (empty workspace list)")
278
+ return
279
+ await asyncio.sleep(keepalive_seconds)
280
+ yield ": keepalive\n\n"
281
+ keepalive_count += 1
282
+
283
+ logger.debug("Empty workspace stream reached max duration, closing")
284
+ return
285
+
286
+ pubsub: PubSub = redis.pubsub()
287
+
288
+ try:
289
+ await pubsub.subscribe(*channels)
290
+ logger.debug(f"Subscribed to {len(channels)} channels")
291
+
292
+ last_keepalive = asyncio.get_event_loop().time()
293
+
294
+ while True:
295
+ # Check for messages with timeout for keepalive
296
+ try:
297
+ message = await asyncio.wait_for(
298
+ pubsub.get_message(ignore_subscribe_messages=True, timeout=1.0),
299
+ timeout=keepalive_seconds,
300
+ )
301
+ except asyncio.TimeoutError:
302
+ message = None
303
+
304
+ # Check for client disconnect
305
+ if check_disconnected and await check_disconnected():
306
+ logger.debug(f"Client disconnected, ending stream for {len(channels)} channels")
307
+ return
308
+
309
+ current_time = asyncio.get_event_loop().time()
310
+
311
+ if message is not None and message["type"] == "message":
312
+ data = message["data"]
313
+ if isinstance(data, bytes):
314
+ data = data.decode("utf-8")
315
+
316
+ # Parse event to check category filter
317
+ try:
318
+ event_data = json.loads(data)
319
+ event_category = event_data.get("type", "").split(".")[0]
320
+
321
+ # Apply filter if specified
322
+ if event_types is None or event_category in event_types:
323
+ yield f"data: {data}\n\n"
324
+ last_keepalive = current_time
325
+ except json.JSONDecodeError:
326
+ logger.warning(f"Invalid JSON in event: {data}")
327
+ continue
328
+
329
+ # Send keepalive comment if needed
330
+ if current_time - last_keepalive >= keepalive_seconds:
331
+ yield ": keepalive\n\n"
332
+ last_keepalive = current_time
333
+
334
+ except asyncio.CancelledError:
335
+ logger.debug(f"Stream cancelled for {len(channels)} channels")
336
+ raise
337
+ finally:
338
+ await pubsub.unsubscribe(*channels)
339
+ await pubsub.aclose()
340
+ logger.debug(f"Unsubscribed from {len(channels)} channels")
341
+
342
+
343
+ async def publish_chat_event(redis: Redis, event: ChatEvent) -> int:
344
+ """Publish a chat event to the session's Redis pub/sub channel.
345
+
346
+ Args:
347
+ redis: Redis client
348
+ event: Chat event to publish
349
+
350
+ Returns:
351
+ Number of subscribers that received the message
352
+ """
353
+ channel = _chat_channel_name(event.session_id)
354
+ message = event.to_json()
355
+ count = await redis.publish(channel, message)
356
+ logger.debug(f"Published chat {event.type} to {channel}, {count} subscribers")
357
+ return count
358
+
359
+
360
+ def _chat_waiting_key(session_id: str) -> str:
361
+ """Redis key for tracking workspaces waiting in a chat session."""
362
+ return f"chat:waiting:{session_id}"
363
+
364
+
365
+ def _chat_deadline_key(session_id: str) -> str:
366
+ """Redis key for tracking workspace deadlines in a chat session.
367
+
368
+ Stores deadline timestamps so we can show time remaining to other participants.
369
+ """
370
+ return f"chat:deadline:{session_id}"
371
+
372
+
373
+ async def _ensure_waiting_key_zset(redis: Redis, waiting_key: str, keep_ttl: bool = True) -> None:
374
+ """Ensure the chat waiting key uses a ZSET (member -> last_seen timestamp).
375
+
376
+ If the key exists with an unexpected Redis type, ZSET operations will raise
377
+ WRONGTYPE. This function replaces the key with a ZSET (best-effort).
378
+
379
+ This function is best-effort and safe to call on hot paths.
380
+ """
381
+ try:
382
+ key_type = await redis.type(waiting_key)
383
+ except Exception:
384
+ logger.warning(
385
+ "Failed to read Redis key type for waiting_key=%s", waiting_key, exc_info=True
386
+ )
387
+ return
388
+
389
+ # redis-py returns bytes for TYPE.
390
+ if isinstance(key_type, bytes):
391
+ key_type = key_type.decode("utf-8", errors="ignore")
392
+
393
+ if key_type in ("none", ""):
394
+ return
395
+ if key_type == "zset":
396
+ return
397
+
398
+ ttl = None
399
+ if keep_ttl:
400
+ try:
401
+ ttl = await redis.ttl(waiting_key)
402
+ except Exception:
403
+ logger.warning(
404
+ "Failed to read Redis TTL for waiting_key=%s", waiting_key, exc_info=True
405
+ )
406
+ ttl = None
407
+
408
+ if key_type == "set":
409
+ try:
410
+ members = await redis.smembers(waiting_key)
411
+ except Exception:
412
+ logger.warning(
413
+ "Failed to read Redis set members for waiting_key=%s", waiting_key, exc_info=True
414
+ )
415
+ members = set()
416
+
417
+ try:
418
+ await redis.delete(waiting_key)
419
+ except Exception:
420
+ logger.warning(
421
+ "Failed to delete Redis key waiting_key=%s for type migration",
422
+ waiting_key,
423
+ exc_info=True,
424
+ )
425
+ return
426
+
427
+ mapping: dict[str, float] = {}
428
+ now = time.time()
429
+ for m in members:
430
+ if isinstance(m, bytes):
431
+ m = m.decode("utf-8", errors="ignore")
432
+ mapping[str(m)] = now
433
+
434
+ if mapping:
435
+ try:
436
+ await redis.zadd(waiting_key, mapping)
437
+ except Exception:
438
+ logger.warning(
439
+ "Failed to write Redis zset waiting_key=%s for type migration",
440
+ waiting_key,
441
+ exc_info=True,
442
+ )
443
+ return
444
+
445
+ if ttl is not None and ttl > 0:
446
+ try:
447
+ await redis.expire(waiting_key, ttl)
448
+ except Exception:
449
+ logger.warning(
450
+ "Failed to restore Redis TTL for waiting_key=%s after type migration",
451
+ waiting_key,
452
+ exc_info=True,
453
+ )
454
+ return
455
+
456
+ # Unknown type; safest is to delete it to avoid crashing endpoints.
457
+ try:
458
+ await redis.delete(waiting_key)
459
+ except Exception:
460
+ logger.warning(
461
+ "Failed to delete Redis key waiting_key=%s with unknown type",
462
+ waiting_key,
463
+ exc_info=True,
464
+ )
465
+
466
+
467
+ async def is_workspace_waiting(
468
+ redis: Redis,
469
+ session_id: str,
470
+ workspace_id: str,
471
+ max_age_seconds: int = 90,
472
+ ) -> bool:
473
+ """Check if a workspace is currently waiting (connected to SSE) in a session.
474
+
475
+ We track a per-workspace "last seen" timestamp in a Redis sorted set.
476
+ This avoids false positives when the Redis key TTL is refreshed by other
477
+ connected workspaces (stale members can no longer stick indefinitely).
478
+ """
479
+ waiting_key = _chat_waiting_key(session_id)
480
+ await _ensure_waiting_key_zset(redis, waiting_key)
481
+ try:
482
+ score = await redis.zscore(waiting_key, workspace_id)
483
+ except ResponseError as e:
484
+ if "WRONGTYPE" in str(e):
485
+ await _ensure_waiting_key_zset(redis, waiting_key)
486
+ score = await redis.zscore(waiting_key, workspace_id)
487
+ else:
488
+ raise
489
+ if score is None:
490
+ return False
491
+ now = time.time()
492
+ if float(score) < now - max_age_seconds:
493
+ await redis.zrem(waiting_key, workspace_id)
494
+ # Best-effort: clean up stale deadline too.
495
+ try:
496
+ await redis.hdel(_chat_deadline_key(session_id), workspace_id)
497
+ except Exception:
498
+ logger.warning(
499
+ "Failed to delete stale chat deadline session_id=%s workspace_id=%s",
500
+ session_id,
501
+ workspace_id,
502
+ exc_info=True,
503
+ )
504
+ return False
505
+ return True
506
+
507
+
508
+ async def get_workspace_deadline(
509
+ redis: Redis,
510
+ session_id: str,
511
+ workspace_id: str,
512
+ ) -> Optional[datetime]:
513
+ """Get the deadline for a workspace waiting in a chat session.
514
+
515
+ Returns the deadline timestamp if the workspace has one set, or None if no
516
+ deadline is stored.
517
+
518
+ Args:
519
+ redis: Redis client
520
+ session_id: Chat session ID
521
+ workspace_id: Workspace ID to check
522
+
523
+ Returns:
524
+ Deadline as datetime (UTC) or None if not set
525
+ """
526
+ deadline_key = _chat_deadline_key(session_id)
527
+ try:
528
+ deadline_str = await redis.hget(deadline_key, workspace_id)
529
+ if deadline_str:
530
+ if isinstance(deadline_str, bytes):
531
+ deadline_str = deadline_str.decode("utf-8")
532
+ return datetime.fromisoformat(deadline_str)
533
+ except Exception:
534
+ logger.warning(
535
+ "Failed to get workspace deadline session_id=%s workspace_id=%s",
536
+ session_id,
537
+ workspace_id,
538
+ exc_info=True,
539
+ )
540
+ return None
541
+
542
+
543
+ async def extend_workspace_deadline(
544
+ redis: Redis,
545
+ session_id: str,
546
+ workspace_id: str,
547
+ extends_seconds: int,
548
+ ) -> Optional[datetime]:
549
+ """Extend a workspace's stored wait deadline for a chat session.
550
+
551
+ This only updates Redis metadata used for time-remaining displays. The client
552
+ must still extend its own local wait timeout to avoid disconnecting.
553
+
554
+ Returns the new deadline when one exists and was updated, otherwise None.
555
+ """
556
+ if extends_seconds <= 0:
557
+ return None
558
+
559
+ deadline_key = _chat_deadline_key(session_id)
560
+ current = await get_workspace_deadline(redis, session_id, workspace_id)
561
+ if current is None:
562
+ return None
563
+
564
+ if current.tzinfo is None:
565
+ current = current.replace(tzinfo=timezone.utc)
566
+
567
+ new_deadline = current + timedelta(seconds=extends_seconds)
568
+
569
+ try:
570
+ await redis.hset(deadline_key, workspace_id, new_deadline.isoformat())
571
+ # Keep the deadline hash alive while the waiting zset is alive.
572
+ waiting_ttl = await redis.ttl(_chat_waiting_key(session_id))
573
+ if waiting_ttl and waiting_ttl > 0:
574
+ await redis.expire(deadline_key, waiting_ttl)
575
+ except Exception:
576
+ logger.warning(
577
+ "Failed to extend workspace deadline session_id=%s workspace_id=%s",
578
+ session_id,
579
+ workspace_id,
580
+ exc_info=True,
581
+ )
582
+ return None
583
+
584
+ return new_deadline
585
+
586
+
587
+ async def stream_chat_events(
588
+ redis: Redis,
589
+ session_id: str,
590
+ workspace_id: str,
591
+ keepalive_seconds: int = 30,
592
+ check_disconnected: Optional[Callable[[], Awaitable[bool]]] = None,
593
+ deadline: Optional[datetime] = None,
594
+ ) -> AsyncIterator[str]:
595
+ """Stream chat events for a session as SSE-formatted strings.
596
+
597
+ Tracks workspace connection state in Redis to enable "waiting" detection.
598
+ When workspace connects, it's added to the waiting set.
599
+ When workspace disconnects (timeout, reply received, crash), it's removed.
600
+
601
+ Args:
602
+ redis: Redis client
603
+ session_id: Chat session to stream events for
604
+ workspace_id: Workspace ID of the connecting client
605
+ keepalive_seconds: Seconds between keepalive comments
606
+ check_disconnected: Optional async callback to check if client has disconnected.
607
+ When provided and returns True, the stream ends cleanly.
608
+ deadline: Optional deadline timestamp for this workspace's wait.
609
+ If provided, stored in Redis so other participants can see time remaining.
610
+
611
+ Yields:
612
+ SSE-formatted event strings with event type (e.g., "event: message\\ndata: {...}\\n\\n")
613
+ """
614
+ channel = _chat_channel_name(session_id)
615
+ waiting_key = _chat_waiting_key(session_id)
616
+ deadline_key = _chat_deadline_key(session_id)
617
+ pubsub: PubSub = redis.pubsub()
618
+
619
+ # TTL for waiting set - 3x keepalive to handle missed refreshes
620
+ waiting_ttl = keepalive_seconds * 3
621
+
622
+ try:
623
+ await _ensure_waiting_key_zset(redis, waiting_key)
624
+
625
+ # Track that this workspace is waiting for a reply.
626
+ # Use a per-workspace heartbeat timestamp so stale entries decay even
627
+ # if other workspaces keep the key alive.
628
+ await redis.zadd(waiting_key, {workspace_id: time.time()})
629
+ await redis.expire(waiting_key, waiting_ttl)
630
+
631
+ # Store deadline if provided
632
+ if deadline:
633
+ await redis.hset(deadline_key, workspace_id, deadline.isoformat())
634
+ await redis.expire(deadline_key, waiting_ttl)
635
+
636
+ logger.debug(f"Workspace {workspace_id} now waiting in session {session_id}")
637
+
638
+ await pubsub.subscribe(channel)
639
+ logger.debug(f"Subscribed to chat channel {channel}")
640
+
641
+ last_keepalive = asyncio.get_event_loop().time()
642
+
643
+ while True:
644
+ try:
645
+ message = await asyncio.wait_for(
646
+ pubsub.get_message(ignore_subscribe_messages=True, timeout=1.0),
647
+ timeout=keepalive_seconds,
648
+ )
649
+ except asyncio.TimeoutError:
650
+ message = None
651
+
652
+ # Check for client disconnect
653
+ if check_disconnected and await check_disconnected():
654
+ logger.debug(f"Client disconnected, ending chat stream for {channel}")
655
+ return
656
+
657
+ current_time = asyncio.get_event_loop().time()
658
+
659
+ if message is not None and message["type"] == "message":
660
+ data = message["data"]
661
+ if isinstance(data, bytes):
662
+ data = data.decode("utf-8")
663
+
664
+ try:
665
+ event_data = json.loads(data)
666
+ event_type = event_data.get("type", "message")
667
+ # SSE format with event type for client-side event handling
668
+ yield f"event: {event_type}\ndata: {data}\n\n"
669
+ last_keepalive = current_time
670
+ except json.JSONDecodeError:
671
+ logger.warning(f"Invalid JSON in chat event: {data}")
672
+ continue
673
+
674
+ if current_time - last_keepalive >= keepalive_seconds:
675
+ yield ": keepalive\n\n"
676
+ last_keepalive = current_time
677
+ # Refresh heartbeat + TTL to prevent expiration while connected
678
+ await _ensure_waiting_key_zset(redis, waiting_key)
679
+ await redis.zadd(waiting_key, {workspace_id: time.time()})
680
+ await redis.expire(waiting_key, waiting_ttl)
681
+ # Also refresh deadline TTL (key may exist even if this stream didn't set it).
682
+ try:
683
+ await redis.expire(deadline_key, waiting_ttl)
684
+ except Exception:
685
+ logger.warning(
686
+ "Failed to refresh chat deadline TTL session_id=%s workspace_id=%s",
687
+ session_id,
688
+ workspace_id,
689
+ exc_info=True,
690
+ )
691
+
692
+ except asyncio.CancelledError:
693
+ logger.debug(f"Chat stream cancelled for {channel}")
694
+ raise
695
+ finally:
696
+ # Remove from waiting set - workspace is no longer waiting
697
+ await redis.zrem(waiting_key, workspace_id)
698
+ # Also remove deadline
699
+ await redis.hdel(deadline_key, workspace_id)
700
+ logger.debug(f"Workspace {workspace_id} no longer waiting in session {session_id}")
701
+
702
+ await pubsub.unsubscribe(channel)
703
+ await pubsub.aclose()
704
+ logger.debug(f"Unsubscribed from chat channel {channel}")