devcopilot 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- api/__init__.py +17 -0
- api/admin_config.py +1303 -0
- api/admin_routes.py +287 -0
- api/admin_static/admin.css +459 -0
- api/admin_static/admin.js +497 -0
- api/admin_static/index.html +77 -0
- api/admin_urls.py +34 -0
- api/app.py +194 -0
- api/command_utils.py +164 -0
- api/dependencies.py +144 -0
- api/detection.py +152 -0
- api/gateway_model_ids.py +54 -0
- api/model_catalog.py +133 -0
- api/model_router.py +125 -0
- api/models/__init__.py +45 -0
- api/models/anthropic.py +234 -0
- api/models/openai_responses.py +28 -0
- api/models/responses.py +60 -0
- api/optimization_handlers.py +154 -0
- api/request_pipeline.py +424 -0
- api/routes.py +156 -0
- api/runtime.py +334 -0
- api/validation_log.py +48 -0
- api/web_server_tools.py +22 -0
- api/web_tools/__init__.py +17 -0
- api/web_tools/constants.py +15 -0
- api/web_tools/egress.py +99 -0
- api/web_tools/outbound.py +278 -0
- api/web_tools/parsers.py +104 -0
- api/web_tools/request.py +87 -0
- api/web_tools/streaming.py +206 -0
- cli/__init__.py +5 -0
- cli/claude_env.py +12 -0
- cli/entrypoints.py +166 -0
- cli/env.example +209 -0
- cli/launchers/__init__.py +1 -0
- cli/launchers/claude.py +84 -0
- cli/launchers/codex.py +204 -0
- cli/launchers/codex_model_catalog.py +186 -0
- cli/launchers/common.py +93 -0
- cli/managed/__init__.py +6 -0
- cli/managed/claude.py +215 -0
- cli/managed/manager.py +157 -0
- cli/managed/session.py +260 -0
- cli/process_registry.py +78 -0
- config/__init__.py +5 -0
- config/constants.py +13 -0
- config/logging_config.py +159 -0
- config/nim.py +118 -0
- config/paths.py +91 -0
- config/provider_catalog.py +259 -0
- config/provider_ids.py +7 -0
- config/settings.py +538 -0
- core/__init__.py +1 -0
- core/anthropic/__init__.py +46 -0
- core/anthropic/content.py +31 -0
- core/anthropic/conversion.py +587 -0
- core/anthropic/emitted_sse_tracker.py +346 -0
- core/anthropic/errors.py +70 -0
- core/anthropic/native_messages_request.py +280 -0
- core/anthropic/native_sse_block_policy.py +313 -0
- core/anthropic/provider_stream_error.py +34 -0
- core/anthropic/server_tool_sse.py +14 -0
- core/anthropic/sse.py +440 -0
- core/anthropic/stream_contracts.py +205 -0
- core/anthropic/stream_recovery.py +346 -0
- core/anthropic/stream_recovery_session.py +133 -0
- core/anthropic/thinking.py +140 -0
- core/anthropic/tokens.py +117 -0
- core/anthropic/tools.py +212 -0
- core/anthropic/utils.py +9 -0
- core/openai_responses/__init__.py +5 -0
- core/openai_responses/adapter.py +31 -0
- core/openai_responses/anthropic_sse.py +59 -0
- core/openai_responses/errors.py +22 -0
- core/openai_responses/events.py +19 -0
- core/openai_responses/ids.py +21 -0
- core/openai_responses/input.py +258 -0
- core/openai_responses/items.py +37 -0
- core/openai_responses/reasoning.py +52 -0
- core/openai_responses/stream.py +25 -0
- core/openai_responses/stream_state.py +654 -0
- core/openai_responses/tools.py +374 -0
- core/openai_responses/usage.py +37 -0
- core/rate_limit.py +60 -0
- core/trace.py +216 -0
- devcopilot-0.2.0.dist-info/METADATA +687 -0
- devcopilot-0.2.0.dist-info/RECORD +189 -0
- devcopilot-0.2.0.dist-info/WHEEL +4 -0
- devcopilot-0.2.0.dist-info/entry_points.txt +6 -0
- devcopilot-0.2.0.dist-info/licenses/LICENSE +21 -0
- messaging/__init__.py +26 -0
- messaging/cli_event_constants.py +67 -0
- messaging/command_context.py +66 -0
- messaging/command_dispatcher.py +37 -0
- messaging/commands.py +275 -0
- messaging/event_parser.py +181 -0
- messaging/limiter.py +300 -0
- messaging/models.py +36 -0
- messaging/node_event_pipeline.py +127 -0
- messaging/node_runner.py +342 -0
- messaging/platforms/__init__.py +15 -0
- messaging/platforms/base.py +228 -0
- messaging/platforms/discord.py +567 -0
- messaging/platforms/factory.py +103 -0
- messaging/platforms/outbox.py +144 -0
- messaging/platforms/telegram.py +688 -0
- messaging/platforms/voice_flow.py +295 -0
- messaging/rendering/__init__.py +3 -0
- messaging/rendering/discord_markdown.py +318 -0
- messaging/rendering/markdown_tables.py +49 -0
- messaging/rendering/profiles.py +55 -0
- messaging/rendering/telegram_markdown.py +327 -0
- messaging/safe_diagnostics.py +17 -0
- messaging/session.py +334 -0
- messaging/transcript.py +581 -0
- messaging/transcription.py +164 -0
- messaging/trees/__init__.py +15 -0
- messaging/trees/data.py +482 -0
- messaging/trees/manager.py +433 -0
- messaging/trees/processor.py +179 -0
- messaging/trees/repository.py +177 -0
- messaging/turn_intake.py +235 -0
- messaging/ui_updates.py +101 -0
- messaging/voice.py +76 -0
- messaging/workflow.py +200 -0
- providers/__init__.py +31 -0
- providers/base.py +152 -0
- providers/cerebras/__init__.py +7 -0
- providers/cerebras/client.py +31 -0
- providers/cerebras/request.py +55 -0
- providers/codestral/__init__.py +7 -0
- providers/codestral/client.py +34 -0
- providers/deepseek/__init__.py +11 -0
- providers/deepseek/client.py +51 -0
- providers/deepseek/request.py +475 -0
- providers/defaults.py +41 -0
- providers/error_mapping.py +309 -0
- providers/exceptions.py +113 -0
- providers/fireworks/__init__.py +5 -0
- providers/fireworks/client.py +45 -0
- providers/fireworks/request.py +48 -0
- providers/gemini/__init__.py +7 -0
- providers/gemini/client.py +49 -0
- providers/gemini/request.py +199 -0
- providers/groq/__init__.py +7 -0
- providers/groq/client.py +31 -0
- providers/groq/request.py +83 -0
- providers/kimi/__init__.py +10 -0
- providers/kimi/client.py +53 -0
- providers/kimi/request.py +42 -0
- providers/llamacpp/__init__.py +3 -0
- providers/llamacpp/client.py +16 -0
- providers/lmstudio/__init__.py +5 -0
- providers/lmstudio/client.py +16 -0
- providers/mistral/__init__.py +7 -0
- providers/mistral/client.py +31 -0
- providers/mistral/request.py +37 -0
- providers/model_listing.py +133 -0
- providers/nvidia_nim/__init__.py +7 -0
- providers/nvidia_nim/client.py +91 -0
- providers/nvidia_nim/request.py +430 -0
- providers/nvidia_nim/voice.py +95 -0
- providers/ollama/__init__.py +7 -0
- providers/ollama/client.py +39 -0
- providers/open_router/__init__.py +7 -0
- providers/open_router/client.py +124 -0
- providers/open_router/request.py +42 -0
- providers/opencode/__init__.py +11 -0
- providers/opencode/client.py +31 -0
- providers/opencode/request.py +35 -0
- providers/rate_limit.py +300 -0
- providers/registry.py +527 -0
- providers/transports/__init__.py +1 -0
- providers/transports/anthropic_messages/__init__.py +5 -0
- providers/transports/anthropic_messages/http.py +118 -0
- providers/transports/anthropic_messages/recovery.py +206 -0
- providers/transports/anthropic_messages/stream.py +295 -0
- providers/transports/anthropic_messages/transport.py +236 -0
- providers/transports/openai_chat/__init__.py +5 -0
- providers/transports/openai_chat/recovery.py +217 -0
- providers/transports/openai_chat/stream.py +384 -0
- providers/transports/openai_chat/tool_calls.py +293 -0
- providers/transports/openai_chat/transport.py +156 -0
- providers/wafer/__init__.py +10 -0
- providers/wafer/client.py +50 -0
- providers/zai/__init__.py +10 -0
- providers/zai/client.py +46 -0
- providers/zai/request.py +42 -0
|
@@ -0,0 +1,164 @@
|
|
|
1
|
+
"""Voice note transcription for messaging platforms.
|
|
2
|
+
|
|
3
|
+
Supports:
|
|
4
|
+
- Local Whisper (cpu/cuda): Hugging Face transformers pipeline
|
|
5
|
+
- NVIDIA NIM: NVIDIA NIM Whisper/Parakeet
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from typing import Any
|
|
10
|
+
|
|
11
|
+
from loguru import logger
|
|
12
|
+
|
|
13
|
+
from providers.nvidia_nim.voice import (
|
|
14
|
+
transcribe_audio_file as transcribe_nvidia_nim_audio,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
# Max file size in bytes (25 MB)
|
|
18
|
+
MAX_AUDIO_SIZE_BYTES = 25 * 1024 * 1024
|
|
19
|
+
|
|
20
|
+
# Short model names -> full Hugging Face model IDs (for local Whisper)
|
|
21
|
+
_MODEL_MAP: dict[str, str] = {
|
|
22
|
+
"tiny": "openai/whisper-tiny",
|
|
23
|
+
"base": "openai/whisper-base",
|
|
24
|
+
"small": "openai/whisper-small",
|
|
25
|
+
"medium": "openai/whisper-medium",
|
|
26
|
+
"large-v2": "openai/whisper-large-v2",
|
|
27
|
+
"large-v3": "openai/whisper-large-v3",
|
|
28
|
+
"large-v3-turbo": "openai/whisper-large-v3-turbo",
|
|
29
|
+
}
|
|
30
|
+
|
|
31
|
+
# Lazy-loaded pipelines: (model_id, device, hf_token_fingerprint) -> pipeline
|
|
32
|
+
_pipeline_cache: dict[tuple[str, str, str], Any] = {}
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def _resolve_model_id(whisper_model: str) -> str:
|
|
36
|
+
"""Resolve short name to full Hugging Face model ID."""
|
|
37
|
+
return _MODEL_MAP.get(whisper_model, whisper_model)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def _get_pipeline(model_id: str, device: str, hf_token: str = "") -> Any:
|
|
41
|
+
"""Lazy-load transformers Whisper pipeline. Raises ImportError if not installed."""
|
|
42
|
+
global _pipeline_cache
|
|
43
|
+
if device not in ("cpu", "cuda"):
|
|
44
|
+
raise ValueError(f"whisper_device must be 'cpu' or 'cuda', got {device!r}")
|
|
45
|
+
resolved_token = hf_token or ""
|
|
46
|
+
cache_key = (model_id, device, resolved_token)
|
|
47
|
+
if cache_key not in _pipeline_cache:
|
|
48
|
+
try:
|
|
49
|
+
import torch
|
|
50
|
+
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
|
|
51
|
+
|
|
52
|
+
hf_auth_token = resolved_token or None
|
|
53
|
+
|
|
54
|
+
use_cuda = device == "cuda" and torch.cuda.is_available()
|
|
55
|
+
pipe_device = "cuda:0" if use_cuda else "cpu"
|
|
56
|
+
model_dtype = torch.float16 if use_cuda else torch.float32
|
|
57
|
+
|
|
58
|
+
model = AutoModelForSpeechSeq2Seq.from_pretrained(
|
|
59
|
+
model_id,
|
|
60
|
+
dtype=model_dtype,
|
|
61
|
+
low_cpu_mem_usage=True,
|
|
62
|
+
attn_implementation="sdpa",
|
|
63
|
+
token=hf_auth_token,
|
|
64
|
+
)
|
|
65
|
+
model = model.to(pipe_device)
|
|
66
|
+
processor = AutoProcessor.from_pretrained(model_id, token=hf_auth_token)
|
|
67
|
+
|
|
68
|
+
pipe = pipeline(
|
|
69
|
+
"automatic-speech-recognition",
|
|
70
|
+
model=model,
|
|
71
|
+
tokenizer=processor.tokenizer,
|
|
72
|
+
feature_extractor=processor.feature_extractor,
|
|
73
|
+
device=pipe_device,
|
|
74
|
+
)
|
|
75
|
+
_pipeline_cache[cache_key] = pipe
|
|
76
|
+
logger.debug(
|
|
77
|
+
f"Loaded Whisper pipeline: model={model_id} device={pipe_device}"
|
|
78
|
+
)
|
|
79
|
+
except ImportError as e:
|
|
80
|
+
raise ImportError(
|
|
81
|
+
"Local Whisper requires the voice_local extra. Install with: uv sync --extra voice_local"
|
|
82
|
+
) from e
|
|
83
|
+
return _pipeline_cache[cache_key]
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def transcribe_audio(
|
|
87
|
+
file_path: Path,
|
|
88
|
+
mime_type: str,
|
|
89
|
+
*,
|
|
90
|
+
whisper_model: str = "base",
|
|
91
|
+
whisper_device: str = "cpu",
|
|
92
|
+
hf_token: str = "",
|
|
93
|
+
nvidia_nim_api_key: str = "",
|
|
94
|
+
) -> str:
|
|
95
|
+
"""
|
|
96
|
+
Transcribe audio file to text.
|
|
97
|
+
|
|
98
|
+
Supports:
|
|
99
|
+
- whisper_device="cpu"/"cuda": local Whisper (requires voice_local extra)
|
|
100
|
+
- whisper_device="nvidia_nim": NVIDIA NIM Whisper API (requires voice extra)
|
|
101
|
+
|
|
102
|
+
Args:
|
|
103
|
+
file_path: Path to audio file (OGG, MP3, MP4, WAV, M4A supported)
|
|
104
|
+
mime_type: MIME type of the audio (e.g. "audio/ogg")
|
|
105
|
+
whisper_model: Model ID or short name (local) or NVIDIA NIM model
|
|
106
|
+
whisper_device: "cpu" | "cuda" | "nvidia_nim"
|
|
107
|
+
|
|
108
|
+
Returns:
|
|
109
|
+
Transcribed text
|
|
110
|
+
|
|
111
|
+
Raises:
|
|
112
|
+
FileNotFoundError: If file does not exist
|
|
113
|
+
ValueError: If file too large
|
|
114
|
+
ImportError: If voice_local extra not installed (for local Whisper)
|
|
115
|
+
"""
|
|
116
|
+
|
|
117
|
+
if not file_path.exists():
|
|
118
|
+
raise FileNotFoundError(f"Audio file not found: {file_path}")
|
|
119
|
+
|
|
120
|
+
size = file_path.stat().st_size
|
|
121
|
+
if size > MAX_AUDIO_SIZE_BYTES:
|
|
122
|
+
raise ValueError(
|
|
123
|
+
f"Audio file too large ({size} bytes). Max {MAX_AUDIO_SIZE_BYTES} bytes."
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
if whisper_device == "nvidia_nim":
|
|
127
|
+
return transcribe_nvidia_nim_audio(
|
|
128
|
+
file_path, whisper_model, api_key=nvidia_nim_api_key
|
|
129
|
+
)
|
|
130
|
+
return _transcribe_local(
|
|
131
|
+
file_path, whisper_model, whisper_device, hf_token=hf_token
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
# Whisper expects 16 kHz sample rate
|
|
136
|
+
_WHISPER_SAMPLE_RATE = 16000
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def _load_audio(file_path: Path) -> dict[str, Any]:
|
|
140
|
+
"""Load audio file to waveform dict. No ffmpeg required."""
|
|
141
|
+
import librosa
|
|
142
|
+
|
|
143
|
+
waveform, sr = librosa.load(str(file_path), sr=_WHISPER_SAMPLE_RATE, mono=True)
|
|
144
|
+
return {"array": waveform, "sampling_rate": sr}
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def _transcribe_local(
|
|
148
|
+
file_path: Path,
|
|
149
|
+
whisper_model: str,
|
|
150
|
+
whisper_device: str,
|
|
151
|
+
*,
|
|
152
|
+
hf_token: str = "",
|
|
153
|
+
) -> str:
|
|
154
|
+
"""Transcribe using transformers Whisper pipeline."""
|
|
155
|
+
model_id = _resolve_model_id(whisper_model)
|
|
156
|
+
pipe = _get_pipeline(model_id, whisper_device, hf_token=hf_token)
|
|
157
|
+
audio = _load_audio(file_path)
|
|
158
|
+
result = pipe(audio, generate_kwargs={"language": "en", "task": "transcribe"})
|
|
159
|
+
text = result.get("text", "") or ""
|
|
160
|
+
if isinstance(text, list):
|
|
161
|
+
text = " ".join(text) if text else ""
|
|
162
|
+
result_text = text.strip()
|
|
163
|
+
logger.debug(f"Local transcription: {len(result_text)} chars")
|
|
164
|
+
return result_text or "(no speech detected)"
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
"""Message tree data structures and queue management."""
|
|
2
|
+
|
|
3
|
+
from .data import MessageNode, MessageState, MessageTree
|
|
4
|
+
from .manager import TreeQueueManager
|
|
5
|
+
from .processor import TreeQueueProcessor
|
|
6
|
+
from .repository import TreeRepository
|
|
7
|
+
|
|
8
|
+
__all__ = [
|
|
9
|
+
"MessageNode",
|
|
10
|
+
"MessageState",
|
|
11
|
+
"MessageTree",
|
|
12
|
+
"TreeQueueManager",
|
|
13
|
+
"TreeQueueProcessor",
|
|
14
|
+
"TreeRepository",
|
|
15
|
+
]
|
messaging/trees/data.py
ADDED
|
@@ -0,0 +1,482 @@
|
|
|
1
|
+
"""Tree data structures for message queue.
|
|
2
|
+
|
|
3
|
+
Contains MessageState, MessageNode, and MessageTree classes.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import asyncio
|
|
7
|
+
from collections import deque
|
|
8
|
+
from contextlib import asynccontextmanager
|
|
9
|
+
from dataclasses import dataclass, field
|
|
10
|
+
from datetime import UTC, datetime
|
|
11
|
+
from enum import Enum
|
|
12
|
+
from typing import Any
|
|
13
|
+
|
|
14
|
+
from loguru import logger
|
|
15
|
+
|
|
16
|
+
from ..models import IncomingMessage
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class _SnapshotQueue:
|
|
20
|
+
"""Queue with snapshot/remove helpers, backed by a deque and a set index."""
|
|
21
|
+
|
|
22
|
+
def __init__(self) -> None:
|
|
23
|
+
self._deque: deque[str] = deque()
|
|
24
|
+
self._set: set[str] = set()
|
|
25
|
+
|
|
26
|
+
async def put(self, item: str) -> None:
|
|
27
|
+
self._deque.append(item)
|
|
28
|
+
self._set.add(item)
|
|
29
|
+
|
|
30
|
+
def put_nowait(self, item: str) -> None:
|
|
31
|
+
self._deque.append(item)
|
|
32
|
+
self._set.add(item)
|
|
33
|
+
|
|
34
|
+
def get_nowait(self) -> str:
|
|
35
|
+
if not self._deque:
|
|
36
|
+
raise asyncio.QueueEmpty()
|
|
37
|
+
item = self._deque.popleft()
|
|
38
|
+
self._set.discard(item)
|
|
39
|
+
return item
|
|
40
|
+
|
|
41
|
+
def qsize(self) -> int:
|
|
42
|
+
return len(self._deque)
|
|
43
|
+
|
|
44
|
+
def get_snapshot(self) -> list[str]:
|
|
45
|
+
"""Return current queue contents in FIFO order (read-only copy)."""
|
|
46
|
+
return list(self._deque)
|
|
47
|
+
|
|
48
|
+
def remove_if_present(self, item: str) -> bool:
|
|
49
|
+
"""Remove item from queue if present (O(1) membership check). Returns True if removed."""
|
|
50
|
+
if item not in self._set:
|
|
51
|
+
return False
|
|
52
|
+
self._set.discard(item)
|
|
53
|
+
self._deque = deque(x for x in self._deque if x != item)
|
|
54
|
+
return True
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class MessageState(Enum):
|
|
58
|
+
"""State of a message node in the tree."""
|
|
59
|
+
|
|
60
|
+
PENDING = "pending" # Queued, waiting to be processed
|
|
61
|
+
IN_PROGRESS = "in_progress" # Currently being processed by Claude
|
|
62
|
+
COMPLETED = "completed" # Processing finished successfully
|
|
63
|
+
ERROR = "error" # Processing failed
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
@dataclass
|
|
67
|
+
class MessageNode:
|
|
68
|
+
"""
|
|
69
|
+
A node in the message tree.
|
|
70
|
+
|
|
71
|
+
Each node represents a single message and tracks:
|
|
72
|
+
- Its relationship to parent/children
|
|
73
|
+
- Its processing state
|
|
74
|
+
- Claude session information
|
|
75
|
+
"""
|
|
76
|
+
|
|
77
|
+
node_id: str # Unique ID (typically message_id)
|
|
78
|
+
incoming: IncomingMessage # The original message
|
|
79
|
+
status_message_id: str # Bot's status message ID
|
|
80
|
+
state: MessageState = MessageState.PENDING
|
|
81
|
+
parent_id: str | None = None # Parent node ID (None for root)
|
|
82
|
+
session_id: str | None = None # Claude session ID (forked from parent)
|
|
83
|
+
children_ids: list[str] = field(default_factory=list)
|
|
84
|
+
created_at: datetime = field(default_factory=lambda: datetime.now(UTC))
|
|
85
|
+
completed_at: datetime | None = None
|
|
86
|
+
error_message: str | None = None
|
|
87
|
+
context: Any = None # Additional context if needed
|
|
88
|
+
|
|
89
|
+
def set_context(self, context: Any) -> None:
|
|
90
|
+
self.context = context
|
|
91
|
+
|
|
92
|
+
def to_dict(self) -> dict:
|
|
93
|
+
"""Convert to dictionary for JSON serialization."""
|
|
94
|
+
return {
|
|
95
|
+
"node_id": self.node_id,
|
|
96
|
+
"incoming": {
|
|
97
|
+
"text": self.incoming.text,
|
|
98
|
+
"chat_id": self.incoming.chat_id,
|
|
99
|
+
"user_id": self.incoming.user_id,
|
|
100
|
+
"message_id": self.incoming.message_id,
|
|
101
|
+
"platform": self.incoming.platform,
|
|
102
|
+
"reply_to_message_id": self.incoming.reply_to_message_id,
|
|
103
|
+
"message_thread_id": self.incoming.message_thread_id,
|
|
104
|
+
"username": self.incoming.username,
|
|
105
|
+
},
|
|
106
|
+
"status_message_id": self.status_message_id,
|
|
107
|
+
"state": self.state.value,
|
|
108
|
+
"parent_id": self.parent_id,
|
|
109
|
+
"session_id": self.session_id,
|
|
110
|
+
"children_ids": self.children_ids,
|
|
111
|
+
"created_at": self.created_at.isoformat(),
|
|
112
|
+
"completed_at": self.completed_at.isoformat()
|
|
113
|
+
if self.completed_at
|
|
114
|
+
else None,
|
|
115
|
+
"error_message": self.error_message,
|
|
116
|
+
}
|
|
117
|
+
|
|
118
|
+
@classmethod
|
|
119
|
+
def from_dict(cls, data: dict) -> MessageNode:
|
|
120
|
+
"""Create from dictionary (JSON deserialization)."""
|
|
121
|
+
incoming_data = data["incoming"]
|
|
122
|
+
incoming = IncomingMessage(
|
|
123
|
+
text=incoming_data["text"],
|
|
124
|
+
chat_id=incoming_data["chat_id"],
|
|
125
|
+
user_id=incoming_data["user_id"],
|
|
126
|
+
message_id=incoming_data["message_id"],
|
|
127
|
+
platform=incoming_data["platform"],
|
|
128
|
+
reply_to_message_id=incoming_data.get("reply_to_message_id"),
|
|
129
|
+
message_thread_id=incoming_data.get("message_thread_id"),
|
|
130
|
+
username=incoming_data.get("username"),
|
|
131
|
+
)
|
|
132
|
+
return cls(
|
|
133
|
+
node_id=data["node_id"],
|
|
134
|
+
incoming=incoming,
|
|
135
|
+
status_message_id=data["status_message_id"],
|
|
136
|
+
state=MessageState(data["state"]),
|
|
137
|
+
parent_id=data.get("parent_id"),
|
|
138
|
+
session_id=data.get("session_id"),
|
|
139
|
+
children_ids=data.get("children_ids", []),
|
|
140
|
+
created_at=datetime.fromisoformat(data["created_at"]),
|
|
141
|
+
completed_at=datetime.fromisoformat(data["completed_at"])
|
|
142
|
+
if data.get("completed_at")
|
|
143
|
+
else None,
|
|
144
|
+
error_message=data.get("error_message"),
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
class MessageTree:
|
|
149
|
+
"""
|
|
150
|
+
A tree of message nodes with queue functionality.
|
|
151
|
+
|
|
152
|
+
Provides:
|
|
153
|
+
- O(1) node lookup via hashmap
|
|
154
|
+
- Per-tree message queue
|
|
155
|
+
- Thread-safe operations via asyncio.Lock
|
|
156
|
+
"""
|
|
157
|
+
|
|
158
|
+
def __init__(self, root_node: MessageNode):
|
|
159
|
+
"""
|
|
160
|
+
Initialize tree with a root node.
|
|
161
|
+
|
|
162
|
+
Args:
|
|
163
|
+
root_node: The root message node
|
|
164
|
+
"""
|
|
165
|
+
self.root_id = root_node.node_id
|
|
166
|
+
self._nodes: dict[str, MessageNode] = {root_node.node_id: root_node}
|
|
167
|
+
self._status_to_node: dict[str, str] = {
|
|
168
|
+
root_node.status_message_id: root_node.node_id
|
|
169
|
+
}
|
|
170
|
+
self._queue: _SnapshotQueue = _SnapshotQueue()
|
|
171
|
+
self._lock = asyncio.Lock()
|
|
172
|
+
self._is_processing = False
|
|
173
|
+
self._current_node_id: str | None = None
|
|
174
|
+
self._current_task: asyncio.Task | None = None
|
|
175
|
+
|
|
176
|
+
logger.debug(f"Created MessageTree with root {self.root_id}")
|
|
177
|
+
|
|
178
|
+
def set_current_task(self, task: asyncio.Task | None) -> None:
|
|
179
|
+
"""Set the current processing task. Caller must hold lock."""
|
|
180
|
+
self._current_task = task
|
|
181
|
+
|
|
182
|
+
@property
|
|
183
|
+
def is_processing(self) -> bool:
|
|
184
|
+
"""Check if tree is currently processing a message."""
|
|
185
|
+
return self._is_processing
|
|
186
|
+
|
|
187
|
+
async def add_node(
|
|
188
|
+
self,
|
|
189
|
+
node_id: str,
|
|
190
|
+
incoming: IncomingMessage,
|
|
191
|
+
status_message_id: str,
|
|
192
|
+
parent_id: str,
|
|
193
|
+
) -> MessageNode:
|
|
194
|
+
"""
|
|
195
|
+
Add a child node to the tree.
|
|
196
|
+
|
|
197
|
+
Args:
|
|
198
|
+
node_id: Unique ID for the new node
|
|
199
|
+
incoming: The incoming message
|
|
200
|
+
status_message_id: Bot's status message ID
|
|
201
|
+
parent_id: Parent node ID
|
|
202
|
+
|
|
203
|
+
Returns:
|
|
204
|
+
The created MessageNode
|
|
205
|
+
"""
|
|
206
|
+
async with self._lock:
|
|
207
|
+
if parent_id not in self._nodes:
|
|
208
|
+
raise ValueError(f"Parent node {parent_id} not found in tree")
|
|
209
|
+
|
|
210
|
+
node = MessageNode(
|
|
211
|
+
node_id=node_id,
|
|
212
|
+
incoming=incoming,
|
|
213
|
+
status_message_id=status_message_id,
|
|
214
|
+
parent_id=parent_id,
|
|
215
|
+
state=MessageState.PENDING,
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
self._nodes[node_id] = node
|
|
219
|
+
self._status_to_node[status_message_id] = node_id
|
|
220
|
+
self._nodes[parent_id].children_ids.append(node_id)
|
|
221
|
+
|
|
222
|
+
logger.debug(f"Added node {node_id} as child of {parent_id}")
|
|
223
|
+
return node
|
|
224
|
+
|
|
225
|
+
def get_node(self, node_id: str) -> MessageNode | None:
|
|
226
|
+
"""Get a node by ID (O(1) lookup)."""
|
|
227
|
+
return self._nodes.get(node_id)
|
|
228
|
+
|
|
229
|
+
def get_root(self) -> MessageNode:
|
|
230
|
+
"""Get the root node."""
|
|
231
|
+
return self._nodes[self.root_id]
|
|
232
|
+
|
|
233
|
+
def get_children(self, node_id: str) -> list[MessageNode]:
|
|
234
|
+
"""Get all child nodes of a given node."""
|
|
235
|
+
node = self._nodes.get(node_id)
|
|
236
|
+
if not node:
|
|
237
|
+
return []
|
|
238
|
+
return [self._nodes[cid] for cid in node.children_ids if cid in self._nodes]
|
|
239
|
+
|
|
240
|
+
def get_parent(self, node_id: str) -> MessageNode | None:
|
|
241
|
+
"""Get the parent node."""
|
|
242
|
+
node = self._nodes.get(node_id)
|
|
243
|
+
if not node or not node.parent_id:
|
|
244
|
+
return None
|
|
245
|
+
return self._nodes.get(node.parent_id)
|
|
246
|
+
|
|
247
|
+
def get_parent_session_id(self, node_id: str) -> str | None:
|
|
248
|
+
"""
|
|
249
|
+
Get the parent's session ID for forking.
|
|
250
|
+
|
|
251
|
+
Returns None for root nodes.
|
|
252
|
+
"""
|
|
253
|
+
parent = self.get_parent(node_id)
|
|
254
|
+
return parent.session_id if parent else None
|
|
255
|
+
|
|
256
|
+
async def update_state(
|
|
257
|
+
self,
|
|
258
|
+
node_id: str,
|
|
259
|
+
state: MessageState,
|
|
260
|
+
session_id: str | None = None,
|
|
261
|
+
error_message: str | None = None,
|
|
262
|
+
) -> None:
|
|
263
|
+
"""Update a node's state."""
|
|
264
|
+
async with self._lock:
|
|
265
|
+
node = self._nodes.get(node_id)
|
|
266
|
+
if not node:
|
|
267
|
+
logger.warning(f"Node {node_id} not found for state update")
|
|
268
|
+
return
|
|
269
|
+
|
|
270
|
+
node.state = state
|
|
271
|
+
if session_id:
|
|
272
|
+
node.session_id = session_id
|
|
273
|
+
if error_message:
|
|
274
|
+
node.error_message = error_message
|
|
275
|
+
if state in (MessageState.COMPLETED, MessageState.ERROR):
|
|
276
|
+
node.completed_at = datetime.now(UTC)
|
|
277
|
+
|
|
278
|
+
logger.debug(f"Node {node_id} state -> {state.value}")
|
|
279
|
+
|
|
280
|
+
async def enqueue(self, node_id: str) -> int:
|
|
281
|
+
"""
|
|
282
|
+
Add a node to the processing queue.
|
|
283
|
+
|
|
284
|
+
Returns:
|
|
285
|
+
Queue position (1-indexed)
|
|
286
|
+
"""
|
|
287
|
+
async with self._lock:
|
|
288
|
+
await self._queue.put(node_id)
|
|
289
|
+
position = self._queue.qsize()
|
|
290
|
+
logger.debug(f"Enqueued node {node_id}, position {position}")
|
|
291
|
+
return position
|
|
292
|
+
|
|
293
|
+
async def dequeue(self) -> str | None:
|
|
294
|
+
"""
|
|
295
|
+
Get the next node ID from the queue.
|
|
296
|
+
|
|
297
|
+
Returns None if queue is empty.
|
|
298
|
+
"""
|
|
299
|
+
try:
|
|
300
|
+
return self._queue.get_nowait()
|
|
301
|
+
except asyncio.QueueEmpty:
|
|
302
|
+
return None
|
|
303
|
+
|
|
304
|
+
async def get_queue_snapshot(self) -> list[str]:
|
|
305
|
+
"""
|
|
306
|
+
Get a snapshot of the current queue order.
|
|
307
|
+
|
|
308
|
+
Returns:
|
|
309
|
+
List of node IDs in FIFO order.
|
|
310
|
+
"""
|
|
311
|
+
async with self._lock:
|
|
312
|
+
return self._queue.get_snapshot()
|
|
313
|
+
|
|
314
|
+
def get_queue_size(self) -> int:
|
|
315
|
+
"""Get number of messages waiting in queue."""
|
|
316
|
+
return self._queue.qsize()
|
|
317
|
+
|
|
318
|
+
def remove_from_queue(self, node_id: str) -> bool:
|
|
319
|
+
"""
|
|
320
|
+
Remove node_id from the internal queue if present.
|
|
321
|
+
|
|
322
|
+
Caller must hold the tree lock (e.g. via with_lock).
|
|
323
|
+
Returns True if node was removed, False if not in queue.
|
|
324
|
+
"""
|
|
325
|
+
return self._queue.remove_if_present(node_id)
|
|
326
|
+
|
|
327
|
+
@asynccontextmanager
|
|
328
|
+
async def with_lock(self):
|
|
329
|
+
"""Async context manager for tree lock. Use when multiple operations need atomicity."""
|
|
330
|
+
async with self._lock:
|
|
331
|
+
yield
|
|
332
|
+
|
|
333
|
+
def set_processing_state(self, node_id: str | None, is_processing: bool) -> None:
|
|
334
|
+
"""Set processing state. Caller must hold lock for consistency with queue operations."""
|
|
335
|
+
self._is_processing = is_processing
|
|
336
|
+
self._current_node_id = node_id if is_processing else None
|
|
337
|
+
|
|
338
|
+
def clear_current_node(self) -> None:
|
|
339
|
+
"""Clear the currently processing node ID. Caller must hold lock."""
|
|
340
|
+
self._current_node_id = None
|
|
341
|
+
|
|
342
|
+
def is_current_node(self, node_id: str) -> bool:
|
|
343
|
+
"""Check if node_id is the currently processing node."""
|
|
344
|
+
return self._current_node_id == node_id
|
|
345
|
+
|
|
346
|
+
def put_queue_unlocked(self, node_id: str) -> None:
|
|
347
|
+
"""Add node to queue. Caller must hold lock (e.g. via with_lock)."""
|
|
348
|
+
self._queue.put_nowait(node_id)
|
|
349
|
+
|
|
350
|
+
def cancel_current_task(self) -> bool:
|
|
351
|
+
"""Cancel the currently running task. Returns True if a task was cancelled."""
|
|
352
|
+
if self._current_task and not self._current_task.done():
|
|
353
|
+
self._current_task.cancel()
|
|
354
|
+
return True
|
|
355
|
+
return False
|
|
356
|
+
|
|
357
|
+
def set_node_error_sync(self, node: MessageNode, error_message: str) -> None:
|
|
358
|
+
"""Synchronously mark a node as ERROR. Caller must ensure no concurrent access."""
|
|
359
|
+
node.state = MessageState.ERROR
|
|
360
|
+
node.error_message = error_message
|
|
361
|
+
node.completed_at = datetime.now(UTC)
|
|
362
|
+
|
|
363
|
+
def drain_queue_and_mark_cancelled(
|
|
364
|
+
self, error_message: str = "Cancelled by user"
|
|
365
|
+
) -> list[MessageNode]:
|
|
366
|
+
"""
|
|
367
|
+
Drain the queue, mark each node as ERROR, and return affected nodes.
|
|
368
|
+
Does not acquire lock; caller must ensure no concurrent queue access.
|
|
369
|
+
"""
|
|
370
|
+
nodes: list[MessageNode] = []
|
|
371
|
+
while True:
|
|
372
|
+
try:
|
|
373
|
+
node_id = self._queue.get_nowait()
|
|
374
|
+
except asyncio.QueueEmpty:
|
|
375
|
+
break
|
|
376
|
+
node = self._nodes.get(node_id)
|
|
377
|
+
if node:
|
|
378
|
+
self.set_node_error_sync(node, error_message)
|
|
379
|
+
nodes.append(node)
|
|
380
|
+
return nodes
|
|
381
|
+
|
|
382
|
+
def reset_processing_state(self) -> None:
|
|
383
|
+
"""Reset processing flags after cancel/cleanup."""
|
|
384
|
+
self._is_processing = False
|
|
385
|
+
self._current_node_id = None
|
|
386
|
+
|
|
387
|
+
@property
|
|
388
|
+
def current_node_id(self) -> str | None:
|
|
389
|
+
"""Get the ID of the node currently being processed."""
|
|
390
|
+
return self._current_node_id
|
|
391
|
+
|
|
392
|
+
def to_dict(self) -> dict:
|
|
393
|
+
"""Serialize tree to dictionary."""
|
|
394
|
+
return {
|
|
395
|
+
"root_id": self.root_id,
|
|
396
|
+
"nodes": {nid: node.to_dict() for nid, node in self._nodes.items()},
|
|
397
|
+
}
|
|
398
|
+
|
|
399
|
+
def _add_node_from_dict(self, node: MessageNode) -> None:
|
|
400
|
+
"""Register a deserialized node into the tree's internal indices."""
|
|
401
|
+
self._nodes[node.node_id] = node
|
|
402
|
+
self._status_to_node[node.status_message_id] = node.node_id
|
|
403
|
+
|
|
404
|
+
@classmethod
|
|
405
|
+
def from_dict(cls, data: dict) -> MessageTree:
|
|
406
|
+
"""Deserialize tree from dictionary."""
|
|
407
|
+
root_id = data["root_id"]
|
|
408
|
+
nodes_data = data["nodes"]
|
|
409
|
+
|
|
410
|
+
# Create root node first
|
|
411
|
+
root_node = MessageNode.from_dict(nodes_data[root_id])
|
|
412
|
+
tree = cls(root_node)
|
|
413
|
+
|
|
414
|
+
# Add remaining nodes and build status->node index
|
|
415
|
+
for node_id, node_data in nodes_data.items():
|
|
416
|
+
if node_id != root_id:
|
|
417
|
+
node = MessageNode.from_dict(node_data)
|
|
418
|
+
tree._add_node_from_dict(node)
|
|
419
|
+
|
|
420
|
+
return tree
|
|
421
|
+
|
|
422
|
+
def all_nodes(self) -> list[MessageNode]:
|
|
423
|
+
"""Get all nodes in the tree."""
|
|
424
|
+
return list(self._nodes.values())
|
|
425
|
+
|
|
426
|
+
def has_node(self, node_id: str) -> bool:
|
|
427
|
+
"""Check if a node exists in this tree."""
|
|
428
|
+
return node_id in self._nodes
|
|
429
|
+
|
|
430
|
+
def find_node_by_status_message(self, status_msg_id: str) -> MessageNode | None:
|
|
431
|
+
"""Find the node that has this status message ID (O(1) lookup)."""
|
|
432
|
+
node_id = self._status_to_node.get(status_msg_id)
|
|
433
|
+
return self._nodes.get(node_id) if node_id else None
|
|
434
|
+
|
|
435
|
+
def get_descendants(self, node_id: str) -> list[str]:
|
|
436
|
+
"""
|
|
437
|
+
Get node_id and all descendant IDs (subtree).
|
|
438
|
+
|
|
439
|
+
Returns:
|
|
440
|
+
List of node IDs including the given node.
|
|
441
|
+
"""
|
|
442
|
+
if node_id not in self._nodes:
|
|
443
|
+
return []
|
|
444
|
+
result: list[str] = []
|
|
445
|
+
stack = [node_id]
|
|
446
|
+
while stack:
|
|
447
|
+
nid = stack.pop()
|
|
448
|
+
result.append(nid)
|
|
449
|
+
node = self._nodes.get(nid)
|
|
450
|
+
if node:
|
|
451
|
+
stack.extend(node.children_ids)
|
|
452
|
+
return result
|
|
453
|
+
|
|
454
|
+
def remove_branch(self, branch_root_id: str) -> list[MessageNode]:
|
|
455
|
+
"""
|
|
456
|
+
Remove a subtree (branch_root and all descendants) from the tree.
|
|
457
|
+
|
|
458
|
+
Updates parent's children_ids. Caller must hold lock for consistency.
|
|
459
|
+
Does not acquire lock internally.
|
|
460
|
+
|
|
461
|
+
Returns:
|
|
462
|
+
List of removed nodes.
|
|
463
|
+
"""
|
|
464
|
+
if branch_root_id not in self._nodes:
|
|
465
|
+
return []
|
|
466
|
+
|
|
467
|
+
parent = self.get_parent(branch_root_id)
|
|
468
|
+
removed = []
|
|
469
|
+
for nid in self.get_descendants(branch_root_id):
|
|
470
|
+
node = self._nodes.get(nid)
|
|
471
|
+
if node:
|
|
472
|
+
removed.append(node)
|
|
473
|
+
del self._nodes[nid]
|
|
474
|
+
del self._status_to_node[node.status_message_id]
|
|
475
|
+
|
|
476
|
+
if parent and branch_root_id in parent.children_ids:
|
|
477
|
+
parent.children_ids = [
|
|
478
|
+
c for c in parent.children_ids if c != branch_root_id
|
|
479
|
+
]
|
|
480
|
+
|
|
481
|
+
logger.debug(f"Removed branch {branch_root_id} ({len(removed)} nodes)")
|
|
482
|
+
return removed
|