devcopilot 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (189) hide show
  1. api/__init__.py +17 -0
  2. api/admin_config.py +1303 -0
  3. api/admin_routes.py +287 -0
  4. api/admin_static/admin.css +459 -0
  5. api/admin_static/admin.js +497 -0
  6. api/admin_static/index.html +77 -0
  7. api/admin_urls.py +34 -0
  8. api/app.py +194 -0
  9. api/command_utils.py +164 -0
  10. api/dependencies.py +144 -0
  11. api/detection.py +152 -0
  12. api/gateway_model_ids.py +54 -0
  13. api/model_catalog.py +133 -0
  14. api/model_router.py +125 -0
  15. api/models/__init__.py +45 -0
  16. api/models/anthropic.py +234 -0
  17. api/models/openai_responses.py +28 -0
  18. api/models/responses.py +60 -0
  19. api/optimization_handlers.py +154 -0
  20. api/request_pipeline.py +424 -0
  21. api/routes.py +156 -0
  22. api/runtime.py +334 -0
  23. api/validation_log.py +48 -0
  24. api/web_server_tools.py +22 -0
  25. api/web_tools/__init__.py +17 -0
  26. api/web_tools/constants.py +15 -0
  27. api/web_tools/egress.py +99 -0
  28. api/web_tools/outbound.py +278 -0
  29. api/web_tools/parsers.py +104 -0
  30. api/web_tools/request.py +87 -0
  31. api/web_tools/streaming.py +206 -0
  32. cli/__init__.py +5 -0
  33. cli/claude_env.py +12 -0
  34. cli/entrypoints.py +166 -0
  35. cli/env.example +209 -0
  36. cli/launchers/__init__.py +1 -0
  37. cli/launchers/claude.py +84 -0
  38. cli/launchers/codex.py +204 -0
  39. cli/launchers/codex_model_catalog.py +186 -0
  40. cli/launchers/common.py +93 -0
  41. cli/managed/__init__.py +6 -0
  42. cli/managed/claude.py +215 -0
  43. cli/managed/manager.py +157 -0
  44. cli/managed/session.py +260 -0
  45. cli/process_registry.py +78 -0
  46. config/__init__.py +5 -0
  47. config/constants.py +13 -0
  48. config/logging_config.py +159 -0
  49. config/nim.py +118 -0
  50. config/paths.py +91 -0
  51. config/provider_catalog.py +259 -0
  52. config/provider_ids.py +7 -0
  53. config/settings.py +538 -0
  54. core/__init__.py +1 -0
  55. core/anthropic/__init__.py +46 -0
  56. core/anthropic/content.py +31 -0
  57. core/anthropic/conversion.py +587 -0
  58. core/anthropic/emitted_sse_tracker.py +346 -0
  59. core/anthropic/errors.py +70 -0
  60. core/anthropic/native_messages_request.py +280 -0
  61. core/anthropic/native_sse_block_policy.py +313 -0
  62. core/anthropic/provider_stream_error.py +34 -0
  63. core/anthropic/server_tool_sse.py +14 -0
  64. core/anthropic/sse.py +440 -0
  65. core/anthropic/stream_contracts.py +205 -0
  66. core/anthropic/stream_recovery.py +346 -0
  67. core/anthropic/stream_recovery_session.py +133 -0
  68. core/anthropic/thinking.py +140 -0
  69. core/anthropic/tokens.py +117 -0
  70. core/anthropic/tools.py +212 -0
  71. core/anthropic/utils.py +9 -0
  72. core/openai_responses/__init__.py +5 -0
  73. core/openai_responses/adapter.py +31 -0
  74. core/openai_responses/anthropic_sse.py +59 -0
  75. core/openai_responses/errors.py +22 -0
  76. core/openai_responses/events.py +19 -0
  77. core/openai_responses/ids.py +21 -0
  78. core/openai_responses/input.py +258 -0
  79. core/openai_responses/items.py +37 -0
  80. core/openai_responses/reasoning.py +52 -0
  81. core/openai_responses/stream.py +25 -0
  82. core/openai_responses/stream_state.py +654 -0
  83. core/openai_responses/tools.py +374 -0
  84. core/openai_responses/usage.py +37 -0
  85. core/rate_limit.py +60 -0
  86. core/trace.py +216 -0
  87. devcopilot-0.2.0.dist-info/METADATA +687 -0
  88. devcopilot-0.2.0.dist-info/RECORD +189 -0
  89. devcopilot-0.2.0.dist-info/WHEEL +4 -0
  90. devcopilot-0.2.0.dist-info/entry_points.txt +6 -0
  91. devcopilot-0.2.0.dist-info/licenses/LICENSE +21 -0
  92. messaging/__init__.py +26 -0
  93. messaging/cli_event_constants.py +67 -0
  94. messaging/command_context.py +66 -0
  95. messaging/command_dispatcher.py +37 -0
  96. messaging/commands.py +275 -0
  97. messaging/event_parser.py +181 -0
  98. messaging/limiter.py +300 -0
  99. messaging/models.py +36 -0
  100. messaging/node_event_pipeline.py +127 -0
  101. messaging/node_runner.py +342 -0
  102. messaging/platforms/__init__.py +15 -0
  103. messaging/platforms/base.py +228 -0
  104. messaging/platforms/discord.py +567 -0
  105. messaging/platforms/factory.py +103 -0
  106. messaging/platforms/outbox.py +144 -0
  107. messaging/platforms/telegram.py +688 -0
  108. messaging/platforms/voice_flow.py +295 -0
  109. messaging/rendering/__init__.py +3 -0
  110. messaging/rendering/discord_markdown.py +318 -0
  111. messaging/rendering/markdown_tables.py +49 -0
  112. messaging/rendering/profiles.py +55 -0
  113. messaging/rendering/telegram_markdown.py +327 -0
  114. messaging/safe_diagnostics.py +17 -0
  115. messaging/session.py +334 -0
  116. messaging/transcript.py +581 -0
  117. messaging/transcription.py +164 -0
  118. messaging/trees/__init__.py +15 -0
  119. messaging/trees/data.py +482 -0
  120. messaging/trees/manager.py +433 -0
  121. messaging/trees/processor.py +179 -0
  122. messaging/trees/repository.py +177 -0
  123. messaging/turn_intake.py +235 -0
  124. messaging/ui_updates.py +101 -0
  125. messaging/voice.py +76 -0
  126. messaging/workflow.py +200 -0
  127. providers/__init__.py +31 -0
  128. providers/base.py +152 -0
  129. providers/cerebras/__init__.py +7 -0
  130. providers/cerebras/client.py +31 -0
  131. providers/cerebras/request.py +55 -0
  132. providers/codestral/__init__.py +7 -0
  133. providers/codestral/client.py +34 -0
  134. providers/deepseek/__init__.py +11 -0
  135. providers/deepseek/client.py +51 -0
  136. providers/deepseek/request.py +475 -0
  137. providers/defaults.py +41 -0
  138. providers/error_mapping.py +309 -0
  139. providers/exceptions.py +113 -0
  140. providers/fireworks/__init__.py +5 -0
  141. providers/fireworks/client.py +45 -0
  142. providers/fireworks/request.py +48 -0
  143. providers/gemini/__init__.py +7 -0
  144. providers/gemini/client.py +49 -0
  145. providers/gemini/request.py +199 -0
  146. providers/groq/__init__.py +7 -0
  147. providers/groq/client.py +31 -0
  148. providers/groq/request.py +83 -0
  149. providers/kimi/__init__.py +10 -0
  150. providers/kimi/client.py +53 -0
  151. providers/kimi/request.py +42 -0
  152. providers/llamacpp/__init__.py +3 -0
  153. providers/llamacpp/client.py +16 -0
  154. providers/lmstudio/__init__.py +5 -0
  155. providers/lmstudio/client.py +16 -0
  156. providers/mistral/__init__.py +7 -0
  157. providers/mistral/client.py +31 -0
  158. providers/mistral/request.py +37 -0
  159. providers/model_listing.py +133 -0
  160. providers/nvidia_nim/__init__.py +7 -0
  161. providers/nvidia_nim/client.py +91 -0
  162. providers/nvidia_nim/request.py +430 -0
  163. providers/nvidia_nim/voice.py +95 -0
  164. providers/ollama/__init__.py +7 -0
  165. providers/ollama/client.py +39 -0
  166. providers/open_router/__init__.py +7 -0
  167. providers/open_router/client.py +124 -0
  168. providers/open_router/request.py +42 -0
  169. providers/opencode/__init__.py +11 -0
  170. providers/opencode/client.py +31 -0
  171. providers/opencode/request.py +35 -0
  172. providers/rate_limit.py +300 -0
  173. providers/registry.py +527 -0
  174. providers/transports/__init__.py +1 -0
  175. providers/transports/anthropic_messages/__init__.py +5 -0
  176. providers/transports/anthropic_messages/http.py +118 -0
  177. providers/transports/anthropic_messages/recovery.py +206 -0
  178. providers/transports/anthropic_messages/stream.py +295 -0
  179. providers/transports/anthropic_messages/transport.py +236 -0
  180. providers/transports/openai_chat/__init__.py +5 -0
  181. providers/transports/openai_chat/recovery.py +217 -0
  182. providers/transports/openai_chat/stream.py +384 -0
  183. providers/transports/openai_chat/tool_calls.py +293 -0
  184. providers/transports/openai_chat/transport.py +156 -0
  185. providers/wafer/__init__.py +10 -0
  186. providers/wafer/client.py +50 -0
  187. providers/zai/__init__.py +10 -0
  188. providers/zai/client.py +46 -0
  189. providers/zai/request.py +42 -0
@@ -0,0 +1,164 @@
1
+ """Voice note transcription for messaging platforms.
2
+
3
+ Supports:
4
+ - Local Whisper (cpu/cuda): Hugging Face transformers pipeline
5
+ - NVIDIA NIM: NVIDIA NIM Whisper/Parakeet
6
+ """
7
+
8
+ from pathlib import Path
9
+ from typing import Any
10
+
11
+ from loguru import logger
12
+
13
+ from providers.nvidia_nim.voice import (
14
+ transcribe_audio_file as transcribe_nvidia_nim_audio,
15
+ )
16
+
17
+ # Max file size in bytes (25 MB)
18
+ MAX_AUDIO_SIZE_BYTES = 25 * 1024 * 1024
19
+
20
+ # Short model names -> full Hugging Face model IDs (for local Whisper)
21
+ _MODEL_MAP: dict[str, str] = {
22
+ "tiny": "openai/whisper-tiny",
23
+ "base": "openai/whisper-base",
24
+ "small": "openai/whisper-small",
25
+ "medium": "openai/whisper-medium",
26
+ "large-v2": "openai/whisper-large-v2",
27
+ "large-v3": "openai/whisper-large-v3",
28
+ "large-v3-turbo": "openai/whisper-large-v3-turbo",
29
+ }
30
+
31
+ # Lazy-loaded pipelines: (model_id, device, hf_token_fingerprint) -> pipeline
32
+ _pipeline_cache: dict[tuple[str, str, str], Any] = {}
33
+
34
+
35
+ def _resolve_model_id(whisper_model: str) -> str:
36
+ """Resolve short name to full Hugging Face model ID."""
37
+ return _MODEL_MAP.get(whisper_model, whisper_model)
38
+
39
+
40
+ def _get_pipeline(model_id: str, device: str, hf_token: str = "") -> Any:
41
+ """Lazy-load transformers Whisper pipeline. Raises ImportError if not installed."""
42
+ global _pipeline_cache
43
+ if device not in ("cpu", "cuda"):
44
+ raise ValueError(f"whisper_device must be 'cpu' or 'cuda', got {device!r}")
45
+ resolved_token = hf_token or ""
46
+ cache_key = (model_id, device, resolved_token)
47
+ if cache_key not in _pipeline_cache:
48
+ try:
49
+ import torch
50
+ from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
51
+
52
+ hf_auth_token = resolved_token or None
53
+
54
+ use_cuda = device == "cuda" and torch.cuda.is_available()
55
+ pipe_device = "cuda:0" if use_cuda else "cpu"
56
+ model_dtype = torch.float16 if use_cuda else torch.float32
57
+
58
+ model = AutoModelForSpeechSeq2Seq.from_pretrained(
59
+ model_id,
60
+ dtype=model_dtype,
61
+ low_cpu_mem_usage=True,
62
+ attn_implementation="sdpa",
63
+ token=hf_auth_token,
64
+ )
65
+ model = model.to(pipe_device)
66
+ processor = AutoProcessor.from_pretrained(model_id, token=hf_auth_token)
67
+
68
+ pipe = pipeline(
69
+ "automatic-speech-recognition",
70
+ model=model,
71
+ tokenizer=processor.tokenizer,
72
+ feature_extractor=processor.feature_extractor,
73
+ device=pipe_device,
74
+ )
75
+ _pipeline_cache[cache_key] = pipe
76
+ logger.debug(
77
+ f"Loaded Whisper pipeline: model={model_id} device={pipe_device}"
78
+ )
79
+ except ImportError as e:
80
+ raise ImportError(
81
+ "Local Whisper requires the voice_local extra. Install with: uv sync --extra voice_local"
82
+ ) from e
83
+ return _pipeline_cache[cache_key]
84
+
85
+
86
+ def transcribe_audio(
87
+ file_path: Path,
88
+ mime_type: str,
89
+ *,
90
+ whisper_model: str = "base",
91
+ whisper_device: str = "cpu",
92
+ hf_token: str = "",
93
+ nvidia_nim_api_key: str = "",
94
+ ) -> str:
95
+ """
96
+ Transcribe audio file to text.
97
+
98
+ Supports:
99
+ - whisper_device="cpu"/"cuda": local Whisper (requires voice_local extra)
100
+ - whisper_device="nvidia_nim": NVIDIA NIM Whisper API (requires voice extra)
101
+
102
+ Args:
103
+ file_path: Path to audio file (OGG, MP3, MP4, WAV, M4A supported)
104
+ mime_type: MIME type of the audio (e.g. "audio/ogg")
105
+ whisper_model: Model ID or short name (local) or NVIDIA NIM model
106
+ whisper_device: "cpu" | "cuda" | "nvidia_nim"
107
+
108
+ Returns:
109
+ Transcribed text
110
+
111
+ Raises:
112
+ FileNotFoundError: If file does not exist
113
+ ValueError: If file too large
114
+ ImportError: If voice_local extra not installed (for local Whisper)
115
+ """
116
+
117
+ if not file_path.exists():
118
+ raise FileNotFoundError(f"Audio file not found: {file_path}")
119
+
120
+ size = file_path.stat().st_size
121
+ if size > MAX_AUDIO_SIZE_BYTES:
122
+ raise ValueError(
123
+ f"Audio file too large ({size} bytes). Max {MAX_AUDIO_SIZE_BYTES} bytes."
124
+ )
125
+
126
+ if whisper_device == "nvidia_nim":
127
+ return transcribe_nvidia_nim_audio(
128
+ file_path, whisper_model, api_key=nvidia_nim_api_key
129
+ )
130
+ return _transcribe_local(
131
+ file_path, whisper_model, whisper_device, hf_token=hf_token
132
+ )
133
+
134
+
135
+ # Whisper expects 16 kHz sample rate
136
+ _WHISPER_SAMPLE_RATE = 16000
137
+
138
+
139
+ def _load_audio(file_path: Path) -> dict[str, Any]:
140
+ """Load audio file to waveform dict. No ffmpeg required."""
141
+ import librosa
142
+
143
+ waveform, sr = librosa.load(str(file_path), sr=_WHISPER_SAMPLE_RATE, mono=True)
144
+ return {"array": waveform, "sampling_rate": sr}
145
+
146
+
147
+ def _transcribe_local(
148
+ file_path: Path,
149
+ whisper_model: str,
150
+ whisper_device: str,
151
+ *,
152
+ hf_token: str = "",
153
+ ) -> str:
154
+ """Transcribe using transformers Whisper pipeline."""
155
+ model_id = _resolve_model_id(whisper_model)
156
+ pipe = _get_pipeline(model_id, whisper_device, hf_token=hf_token)
157
+ audio = _load_audio(file_path)
158
+ result = pipe(audio, generate_kwargs={"language": "en", "task": "transcribe"})
159
+ text = result.get("text", "") or ""
160
+ if isinstance(text, list):
161
+ text = " ".join(text) if text else ""
162
+ result_text = text.strip()
163
+ logger.debug(f"Local transcription: {len(result_text)} chars")
164
+ return result_text or "(no speech detected)"
@@ -0,0 +1,15 @@
1
+ """Message tree data structures and queue management."""
2
+
3
+ from .data import MessageNode, MessageState, MessageTree
4
+ from .manager import TreeQueueManager
5
+ from .processor import TreeQueueProcessor
6
+ from .repository import TreeRepository
7
+
8
+ __all__ = [
9
+ "MessageNode",
10
+ "MessageState",
11
+ "MessageTree",
12
+ "TreeQueueManager",
13
+ "TreeQueueProcessor",
14
+ "TreeRepository",
15
+ ]
@@ -0,0 +1,482 @@
1
+ """Tree data structures for message queue.
2
+
3
+ Contains MessageState, MessageNode, and MessageTree classes.
4
+ """
5
+
6
+ import asyncio
7
+ from collections import deque
8
+ from contextlib import asynccontextmanager
9
+ from dataclasses import dataclass, field
10
+ from datetime import UTC, datetime
11
+ from enum import Enum
12
+ from typing import Any
13
+
14
+ from loguru import logger
15
+
16
+ from ..models import IncomingMessage
17
+
18
+
19
+ class _SnapshotQueue:
20
+ """Queue with snapshot/remove helpers, backed by a deque and a set index."""
21
+
22
+ def __init__(self) -> None:
23
+ self._deque: deque[str] = deque()
24
+ self._set: set[str] = set()
25
+
26
+ async def put(self, item: str) -> None:
27
+ self._deque.append(item)
28
+ self._set.add(item)
29
+
30
+ def put_nowait(self, item: str) -> None:
31
+ self._deque.append(item)
32
+ self._set.add(item)
33
+
34
+ def get_nowait(self) -> str:
35
+ if not self._deque:
36
+ raise asyncio.QueueEmpty()
37
+ item = self._deque.popleft()
38
+ self._set.discard(item)
39
+ return item
40
+
41
+ def qsize(self) -> int:
42
+ return len(self._deque)
43
+
44
+ def get_snapshot(self) -> list[str]:
45
+ """Return current queue contents in FIFO order (read-only copy)."""
46
+ return list(self._deque)
47
+
48
+ def remove_if_present(self, item: str) -> bool:
49
+ """Remove item from queue if present (O(1) membership check). Returns True if removed."""
50
+ if item not in self._set:
51
+ return False
52
+ self._set.discard(item)
53
+ self._deque = deque(x for x in self._deque if x != item)
54
+ return True
55
+
56
+
57
+ class MessageState(Enum):
58
+ """State of a message node in the tree."""
59
+
60
+ PENDING = "pending" # Queued, waiting to be processed
61
+ IN_PROGRESS = "in_progress" # Currently being processed by Claude
62
+ COMPLETED = "completed" # Processing finished successfully
63
+ ERROR = "error" # Processing failed
64
+
65
+
66
+ @dataclass
67
+ class MessageNode:
68
+ """
69
+ A node in the message tree.
70
+
71
+ Each node represents a single message and tracks:
72
+ - Its relationship to parent/children
73
+ - Its processing state
74
+ - Claude session information
75
+ """
76
+
77
+ node_id: str # Unique ID (typically message_id)
78
+ incoming: IncomingMessage # The original message
79
+ status_message_id: str # Bot's status message ID
80
+ state: MessageState = MessageState.PENDING
81
+ parent_id: str | None = None # Parent node ID (None for root)
82
+ session_id: str | None = None # Claude session ID (forked from parent)
83
+ children_ids: list[str] = field(default_factory=list)
84
+ created_at: datetime = field(default_factory=lambda: datetime.now(UTC))
85
+ completed_at: datetime | None = None
86
+ error_message: str | None = None
87
+ context: Any = None # Additional context if needed
88
+
89
+ def set_context(self, context: Any) -> None:
90
+ self.context = context
91
+
92
+ def to_dict(self) -> dict:
93
+ """Convert to dictionary for JSON serialization."""
94
+ return {
95
+ "node_id": self.node_id,
96
+ "incoming": {
97
+ "text": self.incoming.text,
98
+ "chat_id": self.incoming.chat_id,
99
+ "user_id": self.incoming.user_id,
100
+ "message_id": self.incoming.message_id,
101
+ "platform": self.incoming.platform,
102
+ "reply_to_message_id": self.incoming.reply_to_message_id,
103
+ "message_thread_id": self.incoming.message_thread_id,
104
+ "username": self.incoming.username,
105
+ },
106
+ "status_message_id": self.status_message_id,
107
+ "state": self.state.value,
108
+ "parent_id": self.parent_id,
109
+ "session_id": self.session_id,
110
+ "children_ids": self.children_ids,
111
+ "created_at": self.created_at.isoformat(),
112
+ "completed_at": self.completed_at.isoformat()
113
+ if self.completed_at
114
+ else None,
115
+ "error_message": self.error_message,
116
+ }
117
+
118
+ @classmethod
119
+ def from_dict(cls, data: dict) -> MessageNode:
120
+ """Create from dictionary (JSON deserialization)."""
121
+ incoming_data = data["incoming"]
122
+ incoming = IncomingMessage(
123
+ text=incoming_data["text"],
124
+ chat_id=incoming_data["chat_id"],
125
+ user_id=incoming_data["user_id"],
126
+ message_id=incoming_data["message_id"],
127
+ platform=incoming_data["platform"],
128
+ reply_to_message_id=incoming_data.get("reply_to_message_id"),
129
+ message_thread_id=incoming_data.get("message_thread_id"),
130
+ username=incoming_data.get("username"),
131
+ )
132
+ return cls(
133
+ node_id=data["node_id"],
134
+ incoming=incoming,
135
+ status_message_id=data["status_message_id"],
136
+ state=MessageState(data["state"]),
137
+ parent_id=data.get("parent_id"),
138
+ session_id=data.get("session_id"),
139
+ children_ids=data.get("children_ids", []),
140
+ created_at=datetime.fromisoformat(data["created_at"]),
141
+ completed_at=datetime.fromisoformat(data["completed_at"])
142
+ if data.get("completed_at")
143
+ else None,
144
+ error_message=data.get("error_message"),
145
+ )
146
+
147
+
148
+ class MessageTree:
149
+ """
150
+ A tree of message nodes with queue functionality.
151
+
152
+ Provides:
153
+ - O(1) node lookup via hashmap
154
+ - Per-tree message queue
155
+ - Thread-safe operations via asyncio.Lock
156
+ """
157
+
158
+ def __init__(self, root_node: MessageNode):
159
+ """
160
+ Initialize tree with a root node.
161
+
162
+ Args:
163
+ root_node: The root message node
164
+ """
165
+ self.root_id = root_node.node_id
166
+ self._nodes: dict[str, MessageNode] = {root_node.node_id: root_node}
167
+ self._status_to_node: dict[str, str] = {
168
+ root_node.status_message_id: root_node.node_id
169
+ }
170
+ self._queue: _SnapshotQueue = _SnapshotQueue()
171
+ self._lock = asyncio.Lock()
172
+ self._is_processing = False
173
+ self._current_node_id: str | None = None
174
+ self._current_task: asyncio.Task | None = None
175
+
176
+ logger.debug(f"Created MessageTree with root {self.root_id}")
177
+
178
+ def set_current_task(self, task: asyncio.Task | None) -> None:
179
+ """Set the current processing task. Caller must hold lock."""
180
+ self._current_task = task
181
+
182
+ @property
183
+ def is_processing(self) -> bool:
184
+ """Check if tree is currently processing a message."""
185
+ return self._is_processing
186
+
187
+ async def add_node(
188
+ self,
189
+ node_id: str,
190
+ incoming: IncomingMessage,
191
+ status_message_id: str,
192
+ parent_id: str,
193
+ ) -> MessageNode:
194
+ """
195
+ Add a child node to the tree.
196
+
197
+ Args:
198
+ node_id: Unique ID for the new node
199
+ incoming: The incoming message
200
+ status_message_id: Bot's status message ID
201
+ parent_id: Parent node ID
202
+
203
+ Returns:
204
+ The created MessageNode
205
+ """
206
+ async with self._lock:
207
+ if parent_id not in self._nodes:
208
+ raise ValueError(f"Parent node {parent_id} not found in tree")
209
+
210
+ node = MessageNode(
211
+ node_id=node_id,
212
+ incoming=incoming,
213
+ status_message_id=status_message_id,
214
+ parent_id=parent_id,
215
+ state=MessageState.PENDING,
216
+ )
217
+
218
+ self._nodes[node_id] = node
219
+ self._status_to_node[status_message_id] = node_id
220
+ self._nodes[parent_id].children_ids.append(node_id)
221
+
222
+ logger.debug(f"Added node {node_id} as child of {parent_id}")
223
+ return node
224
+
225
+ def get_node(self, node_id: str) -> MessageNode | None:
226
+ """Get a node by ID (O(1) lookup)."""
227
+ return self._nodes.get(node_id)
228
+
229
+ def get_root(self) -> MessageNode:
230
+ """Get the root node."""
231
+ return self._nodes[self.root_id]
232
+
233
+ def get_children(self, node_id: str) -> list[MessageNode]:
234
+ """Get all child nodes of a given node."""
235
+ node = self._nodes.get(node_id)
236
+ if not node:
237
+ return []
238
+ return [self._nodes[cid] for cid in node.children_ids if cid in self._nodes]
239
+
240
+ def get_parent(self, node_id: str) -> MessageNode | None:
241
+ """Get the parent node."""
242
+ node = self._nodes.get(node_id)
243
+ if not node or not node.parent_id:
244
+ return None
245
+ return self._nodes.get(node.parent_id)
246
+
247
+ def get_parent_session_id(self, node_id: str) -> str | None:
248
+ """
249
+ Get the parent's session ID for forking.
250
+
251
+ Returns None for root nodes.
252
+ """
253
+ parent = self.get_parent(node_id)
254
+ return parent.session_id if parent else None
255
+
256
+ async def update_state(
257
+ self,
258
+ node_id: str,
259
+ state: MessageState,
260
+ session_id: str | None = None,
261
+ error_message: str | None = None,
262
+ ) -> None:
263
+ """Update a node's state."""
264
+ async with self._lock:
265
+ node = self._nodes.get(node_id)
266
+ if not node:
267
+ logger.warning(f"Node {node_id} not found for state update")
268
+ return
269
+
270
+ node.state = state
271
+ if session_id:
272
+ node.session_id = session_id
273
+ if error_message:
274
+ node.error_message = error_message
275
+ if state in (MessageState.COMPLETED, MessageState.ERROR):
276
+ node.completed_at = datetime.now(UTC)
277
+
278
+ logger.debug(f"Node {node_id} state -> {state.value}")
279
+
280
+ async def enqueue(self, node_id: str) -> int:
281
+ """
282
+ Add a node to the processing queue.
283
+
284
+ Returns:
285
+ Queue position (1-indexed)
286
+ """
287
+ async with self._lock:
288
+ await self._queue.put(node_id)
289
+ position = self._queue.qsize()
290
+ logger.debug(f"Enqueued node {node_id}, position {position}")
291
+ return position
292
+
293
+ async def dequeue(self) -> str | None:
294
+ """
295
+ Get the next node ID from the queue.
296
+
297
+ Returns None if queue is empty.
298
+ """
299
+ try:
300
+ return self._queue.get_nowait()
301
+ except asyncio.QueueEmpty:
302
+ return None
303
+
304
+ async def get_queue_snapshot(self) -> list[str]:
305
+ """
306
+ Get a snapshot of the current queue order.
307
+
308
+ Returns:
309
+ List of node IDs in FIFO order.
310
+ """
311
+ async with self._lock:
312
+ return self._queue.get_snapshot()
313
+
314
+ def get_queue_size(self) -> int:
315
+ """Get number of messages waiting in queue."""
316
+ return self._queue.qsize()
317
+
318
+ def remove_from_queue(self, node_id: str) -> bool:
319
+ """
320
+ Remove node_id from the internal queue if present.
321
+
322
+ Caller must hold the tree lock (e.g. via with_lock).
323
+ Returns True if node was removed, False if not in queue.
324
+ """
325
+ return self._queue.remove_if_present(node_id)
326
+
327
+ @asynccontextmanager
328
+ async def with_lock(self):
329
+ """Async context manager for tree lock. Use when multiple operations need atomicity."""
330
+ async with self._lock:
331
+ yield
332
+
333
+ def set_processing_state(self, node_id: str | None, is_processing: bool) -> None:
334
+ """Set processing state. Caller must hold lock for consistency with queue operations."""
335
+ self._is_processing = is_processing
336
+ self._current_node_id = node_id if is_processing else None
337
+
338
+ def clear_current_node(self) -> None:
339
+ """Clear the currently processing node ID. Caller must hold lock."""
340
+ self._current_node_id = None
341
+
342
+ def is_current_node(self, node_id: str) -> bool:
343
+ """Check if node_id is the currently processing node."""
344
+ return self._current_node_id == node_id
345
+
346
+ def put_queue_unlocked(self, node_id: str) -> None:
347
+ """Add node to queue. Caller must hold lock (e.g. via with_lock)."""
348
+ self._queue.put_nowait(node_id)
349
+
350
+ def cancel_current_task(self) -> bool:
351
+ """Cancel the currently running task. Returns True if a task was cancelled."""
352
+ if self._current_task and not self._current_task.done():
353
+ self._current_task.cancel()
354
+ return True
355
+ return False
356
+
357
+ def set_node_error_sync(self, node: MessageNode, error_message: str) -> None:
358
+ """Synchronously mark a node as ERROR. Caller must ensure no concurrent access."""
359
+ node.state = MessageState.ERROR
360
+ node.error_message = error_message
361
+ node.completed_at = datetime.now(UTC)
362
+
363
+ def drain_queue_and_mark_cancelled(
364
+ self, error_message: str = "Cancelled by user"
365
+ ) -> list[MessageNode]:
366
+ """
367
+ Drain the queue, mark each node as ERROR, and return affected nodes.
368
+ Does not acquire lock; caller must ensure no concurrent queue access.
369
+ """
370
+ nodes: list[MessageNode] = []
371
+ while True:
372
+ try:
373
+ node_id = self._queue.get_nowait()
374
+ except asyncio.QueueEmpty:
375
+ break
376
+ node = self._nodes.get(node_id)
377
+ if node:
378
+ self.set_node_error_sync(node, error_message)
379
+ nodes.append(node)
380
+ return nodes
381
+
382
+ def reset_processing_state(self) -> None:
383
+ """Reset processing flags after cancel/cleanup."""
384
+ self._is_processing = False
385
+ self._current_node_id = None
386
+
387
+ @property
388
+ def current_node_id(self) -> str | None:
389
+ """Get the ID of the node currently being processed."""
390
+ return self._current_node_id
391
+
392
+ def to_dict(self) -> dict:
393
+ """Serialize tree to dictionary."""
394
+ return {
395
+ "root_id": self.root_id,
396
+ "nodes": {nid: node.to_dict() for nid, node in self._nodes.items()},
397
+ }
398
+
399
+ def _add_node_from_dict(self, node: MessageNode) -> None:
400
+ """Register a deserialized node into the tree's internal indices."""
401
+ self._nodes[node.node_id] = node
402
+ self._status_to_node[node.status_message_id] = node.node_id
403
+
404
+ @classmethod
405
+ def from_dict(cls, data: dict) -> MessageTree:
406
+ """Deserialize tree from dictionary."""
407
+ root_id = data["root_id"]
408
+ nodes_data = data["nodes"]
409
+
410
+ # Create root node first
411
+ root_node = MessageNode.from_dict(nodes_data[root_id])
412
+ tree = cls(root_node)
413
+
414
+ # Add remaining nodes and build status->node index
415
+ for node_id, node_data in nodes_data.items():
416
+ if node_id != root_id:
417
+ node = MessageNode.from_dict(node_data)
418
+ tree._add_node_from_dict(node)
419
+
420
+ return tree
421
+
422
+ def all_nodes(self) -> list[MessageNode]:
423
+ """Get all nodes in the tree."""
424
+ return list(self._nodes.values())
425
+
426
+ def has_node(self, node_id: str) -> bool:
427
+ """Check if a node exists in this tree."""
428
+ return node_id in self._nodes
429
+
430
+ def find_node_by_status_message(self, status_msg_id: str) -> MessageNode | None:
431
+ """Find the node that has this status message ID (O(1) lookup)."""
432
+ node_id = self._status_to_node.get(status_msg_id)
433
+ return self._nodes.get(node_id) if node_id else None
434
+
435
+ def get_descendants(self, node_id: str) -> list[str]:
436
+ """
437
+ Get node_id and all descendant IDs (subtree).
438
+
439
+ Returns:
440
+ List of node IDs including the given node.
441
+ """
442
+ if node_id not in self._nodes:
443
+ return []
444
+ result: list[str] = []
445
+ stack = [node_id]
446
+ while stack:
447
+ nid = stack.pop()
448
+ result.append(nid)
449
+ node = self._nodes.get(nid)
450
+ if node:
451
+ stack.extend(node.children_ids)
452
+ return result
453
+
454
+ def remove_branch(self, branch_root_id: str) -> list[MessageNode]:
455
+ """
456
+ Remove a subtree (branch_root and all descendants) from the tree.
457
+
458
+ Updates parent's children_ids. Caller must hold lock for consistency.
459
+ Does not acquire lock internally.
460
+
461
+ Returns:
462
+ List of removed nodes.
463
+ """
464
+ if branch_root_id not in self._nodes:
465
+ return []
466
+
467
+ parent = self.get_parent(branch_root_id)
468
+ removed = []
469
+ for nid in self.get_descendants(branch_root_id):
470
+ node = self._nodes.get(nid)
471
+ if node:
472
+ removed.append(node)
473
+ del self._nodes[nid]
474
+ del self._status_to_node[node.status_message_id]
475
+
476
+ if parent and branch_root_id in parent.children_ids:
477
+ parent.children_ids = [
478
+ c for c in parent.children_ids if c != branch_root_id
479
+ ]
480
+
481
+ logger.debug(f"Removed branch {branch_root_id} ({len(removed)} nodes)")
482
+ return removed