devcopilot 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- api/__init__.py +17 -0
- api/admin_config.py +1303 -0
- api/admin_routes.py +287 -0
- api/admin_static/admin.css +459 -0
- api/admin_static/admin.js +497 -0
- api/admin_static/index.html +77 -0
- api/admin_urls.py +34 -0
- api/app.py +194 -0
- api/command_utils.py +164 -0
- api/dependencies.py +144 -0
- api/detection.py +152 -0
- api/gateway_model_ids.py +54 -0
- api/model_catalog.py +133 -0
- api/model_router.py +125 -0
- api/models/__init__.py +45 -0
- api/models/anthropic.py +234 -0
- api/models/openai_responses.py +28 -0
- api/models/responses.py +60 -0
- api/optimization_handlers.py +154 -0
- api/request_pipeline.py +424 -0
- api/routes.py +156 -0
- api/runtime.py +334 -0
- api/validation_log.py +48 -0
- api/web_server_tools.py +22 -0
- api/web_tools/__init__.py +17 -0
- api/web_tools/constants.py +15 -0
- api/web_tools/egress.py +99 -0
- api/web_tools/outbound.py +278 -0
- api/web_tools/parsers.py +104 -0
- api/web_tools/request.py +87 -0
- api/web_tools/streaming.py +206 -0
- cli/__init__.py +5 -0
- cli/claude_env.py +12 -0
- cli/entrypoints.py +166 -0
- cli/env.example +209 -0
- cli/launchers/__init__.py +1 -0
- cli/launchers/claude.py +84 -0
- cli/launchers/codex.py +204 -0
- cli/launchers/codex_model_catalog.py +186 -0
- cli/launchers/common.py +93 -0
- cli/managed/__init__.py +6 -0
- cli/managed/claude.py +215 -0
- cli/managed/manager.py +157 -0
- cli/managed/session.py +260 -0
- cli/process_registry.py +78 -0
- config/__init__.py +5 -0
- config/constants.py +13 -0
- config/logging_config.py +159 -0
- config/nim.py +118 -0
- config/paths.py +91 -0
- config/provider_catalog.py +259 -0
- config/provider_ids.py +7 -0
- config/settings.py +538 -0
- core/__init__.py +1 -0
- core/anthropic/__init__.py +46 -0
- core/anthropic/content.py +31 -0
- core/anthropic/conversion.py +587 -0
- core/anthropic/emitted_sse_tracker.py +346 -0
- core/anthropic/errors.py +70 -0
- core/anthropic/native_messages_request.py +280 -0
- core/anthropic/native_sse_block_policy.py +313 -0
- core/anthropic/provider_stream_error.py +34 -0
- core/anthropic/server_tool_sse.py +14 -0
- core/anthropic/sse.py +440 -0
- core/anthropic/stream_contracts.py +205 -0
- core/anthropic/stream_recovery.py +346 -0
- core/anthropic/stream_recovery_session.py +133 -0
- core/anthropic/thinking.py +140 -0
- core/anthropic/tokens.py +117 -0
- core/anthropic/tools.py +212 -0
- core/anthropic/utils.py +9 -0
- core/openai_responses/__init__.py +5 -0
- core/openai_responses/adapter.py +31 -0
- core/openai_responses/anthropic_sse.py +59 -0
- core/openai_responses/errors.py +22 -0
- core/openai_responses/events.py +19 -0
- core/openai_responses/ids.py +21 -0
- core/openai_responses/input.py +258 -0
- core/openai_responses/items.py +37 -0
- core/openai_responses/reasoning.py +52 -0
- core/openai_responses/stream.py +25 -0
- core/openai_responses/stream_state.py +654 -0
- core/openai_responses/tools.py +374 -0
- core/openai_responses/usage.py +37 -0
- core/rate_limit.py +60 -0
- core/trace.py +216 -0
- devcopilot-0.2.0.dist-info/METADATA +687 -0
- devcopilot-0.2.0.dist-info/RECORD +189 -0
- devcopilot-0.2.0.dist-info/WHEEL +4 -0
- devcopilot-0.2.0.dist-info/entry_points.txt +6 -0
- devcopilot-0.2.0.dist-info/licenses/LICENSE +21 -0
- messaging/__init__.py +26 -0
- messaging/cli_event_constants.py +67 -0
- messaging/command_context.py +66 -0
- messaging/command_dispatcher.py +37 -0
- messaging/commands.py +275 -0
- messaging/event_parser.py +181 -0
- messaging/limiter.py +300 -0
- messaging/models.py +36 -0
- messaging/node_event_pipeline.py +127 -0
- messaging/node_runner.py +342 -0
- messaging/platforms/__init__.py +15 -0
- messaging/platforms/base.py +228 -0
- messaging/platforms/discord.py +567 -0
- messaging/platforms/factory.py +103 -0
- messaging/platforms/outbox.py +144 -0
- messaging/platforms/telegram.py +688 -0
- messaging/platforms/voice_flow.py +295 -0
- messaging/rendering/__init__.py +3 -0
- messaging/rendering/discord_markdown.py +318 -0
- messaging/rendering/markdown_tables.py +49 -0
- messaging/rendering/profiles.py +55 -0
- messaging/rendering/telegram_markdown.py +327 -0
- messaging/safe_diagnostics.py +17 -0
- messaging/session.py +334 -0
- messaging/transcript.py +581 -0
- messaging/transcription.py +164 -0
- messaging/trees/__init__.py +15 -0
- messaging/trees/data.py +482 -0
- messaging/trees/manager.py +433 -0
- messaging/trees/processor.py +179 -0
- messaging/trees/repository.py +177 -0
- messaging/turn_intake.py +235 -0
- messaging/ui_updates.py +101 -0
- messaging/voice.py +76 -0
- messaging/workflow.py +200 -0
- providers/__init__.py +31 -0
- providers/base.py +152 -0
- providers/cerebras/__init__.py +7 -0
- providers/cerebras/client.py +31 -0
- providers/cerebras/request.py +55 -0
- providers/codestral/__init__.py +7 -0
- providers/codestral/client.py +34 -0
- providers/deepseek/__init__.py +11 -0
- providers/deepseek/client.py +51 -0
- providers/deepseek/request.py +475 -0
- providers/defaults.py +41 -0
- providers/error_mapping.py +309 -0
- providers/exceptions.py +113 -0
- providers/fireworks/__init__.py +5 -0
- providers/fireworks/client.py +45 -0
- providers/fireworks/request.py +48 -0
- providers/gemini/__init__.py +7 -0
- providers/gemini/client.py +49 -0
- providers/gemini/request.py +199 -0
- providers/groq/__init__.py +7 -0
- providers/groq/client.py +31 -0
- providers/groq/request.py +83 -0
- providers/kimi/__init__.py +10 -0
- providers/kimi/client.py +53 -0
- providers/kimi/request.py +42 -0
- providers/llamacpp/__init__.py +3 -0
- providers/llamacpp/client.py +16 -0
- providers/lmstudio/__init__.py +5 -0
- providers/lmstudio/client.py +16 -0
- providers/mistral/__init__.py +7 -0
- providers/mistral/client.py +31 -0
- providers/mistral/request.py +37 -0
- providers/model_listing.py +133 -0
- providers/nvidia_nim/__init__.py +7 -0
- providers/nvidia_nim/client.py +91 -0
- providers/nvidia_nim/request.py +430 -0
- providers/nvidia_nim/voice.py +95 -0
- providers/ollama/__init__.py +7 -0
- providers/ollama/client.py +39 -0
- providers/open_router/__init__.py +7 -0
- providers/open_router/client.py +124 -0
- providers/open_router/request.py +42 -0
- providers/opencode/__init__.py +11 -0
- providers/opencode/client.py +31 -0
- providers/opencode/request.py +35 -0
- providers/rate_limit.py +300 -0
- providers/registry.py +527 -0
- providers/transports/__init__.py +1 -0
- providers/transports/anthropic_messages/__init__.py +5 -0
- providers/transports/anthropic_messages/http.py +118 -0
- providers/transports/anthropic_messages/recovery.py +206 -0
- providers/transports/anthropic_messages/stream.py +295 -0
- providers/transports/anthropic_messages/transport.py +236 -0
- providers/transports/openai_chat/__init__.py +5 -0
- providers/transports/openai_chat/recovery.py +217 -0
- providers/transports/openai_chat/stream.py +384 -0
- providers/transports/openai_chat/tool_calls.py +293 -0
- providers/transports/openai_chat/transport.py +156 -0
- providers/wafer/__init__.py +10 -0
- providers/wafer/client.py +50 -0
- providers/zai/__init__.py +10 -0
- providers/zai/client.py +46 -0
- providers/zai/request.py +42 -0
providers/registry.py
ADDED
|
@@ -0,0 +1,527 @@
|
|
|
1
|
+
"""Provider descriptors, factory, and runtime registry."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
from collections import defaultdict
|
|
7
|
+
from collections.abc import Callable, Iterable, MutableMapping
|
|
8
|
+
from contextlib import suppress
|
|
9
|
+
|
|
10
|
+
import httpx
|
|
11
|
+
from loguru import logger
|
|
12
|
+
|
|
13
|
+
from config.provider_catalog import (
|
|
14
|
+
PROVIDER_CATALOG,
|
|
15
|
+
SUPPORTED_PROVIDER_IDS,
|
|
16
|
+
ProviderDescriptor,
|
|
17
|
+
)
|
|
18
|
+
from config.settings import ConfiguredChatModelRef, Settings
|
|
19
|
+
from providers.base import BaseProvider, ProviderConfig
|
|
20
|
+
from providers.exceptions import (
|
|
21
|
+
AuthenticationError,
|
|
22
|
+
ModelListResponseError,
|
|
23
|
+
ProviderError,
|
|
24
|
+
ServiceUnavailableError,
|
|
25
|
+
UnknownProviderTypeError,
|
|
26
|
+
)
|
|
27
|
+
from providers.model_listing import ProviderModelInfo, model_infos_from_ids
|
|
28
|
+
|
|
29
|
+
ProviderFactory = Callable[[ProviderConfig, Settings], BaseProvider]
|
|
30
|
+
|
|
31
|
+
# Backwards-compatible name for the catalog (single source: ``config.provider_catalog``).
|
|
32
|
+
PROVIDER_DESCRIPTORS: dict[str, ProviderDescriptor] = PROVIDER_CATALOG
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def _create_nvidia_nim(config: ProviderConfig, settings: Settings) -> BaseProvider:
|
|
36
|
+
from providers.nvidia_nim import NvidiaNimProvider
|
|
37
|
+
|
|
38
|
+
return NvidiaNimProvider(config, nim_settings=settings.nim)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def _create_open_router(config: ProviderConfig, _settings: Settings) -> BaseProvider:
|
|
42
|
+
from providers.open_router import OpenRouterProvider
|
|
43
|
+
|
|
44
|
+
return OpenRouterProvider(config)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def _create_mistral(config: ProviderConfig, _settings: Settings) -> BaseProvider:
|
|
48
|
+
from providers.mistral import MistralProvider
|
|
49
|
+
|
|
50
|
+
return MistralProvider(config)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def _create_mistral_codestral(
|
|
54
|
+
config: ProviderConfig, _settings: Settings
|
|
55
|
+
) -> BaseProvider:
|
|
56
|
+
from providers.codestral import CodestralProvider
|
|
57
|
+
|
|
58
|
+
return CodestralProvider(config)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def _create_deepseek(config: ProviderConfig, _settings: Settings) -> BaseProvider:
|
|
62
|
+
from providers.deepseek import DeepSeekProvider
|
|
63
|
+
|
|
64
|
+
return DeepSeekProvider(config)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def _create_lmstudio(config: ProviderConfig, _settings: Settings) -> BaseProvider:
|
|
68
|
+
from providers.lmstudio import LMStudioProvider
|
|
69
|
+
|
|
70
|
+
return LMStudioProvider(config)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def _create_llamacpp(config: ProviderConfig, _settings: Settings) -> BaseProvider:
|
|
74
|
+
from providers.llamacpp import LlamaCppProvider
|
|
75
|
+
|
|
76
|
+
return LlamaCppProvider(config)
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def _create_ollama(config: ProviderConfig, _settings: Settings) -> BaseProvider:
|
|
80
|
+
from providers.ollama import OllamaProvider
|
|
81
|
+
|
|
82
|
+
return OllamaProvider(config)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def _create_kimi(config: ProviderConfig, _settings: Settings) -> BaseProvider:
|
|
86
|
+
from providers.kimi import KimiProvider
|
|
87
|
+
|
|
88
|
+
return KimiProvider(config)
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def _create_wafer(config: ProviderConfig, _settings: Settings) -> BaseProvider:
|
|
92
|
+
from providers.wafer import WaferProvider
|
|
93
|
+
|
|
94
|
+
return WaferProvider(config)
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def _create_opencode(config: ProviderConfig, _settings: Settings) -> BaseProvider:
|
|
98
|
+
from providers.opencode import OpenCodeProvider
|
|
99
|
+
|
|
100
|
+
return OpenCodeProvider(config)
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def _create_opencode_go(config: ProviderConfig, _settings: Settings) -> BaseProvider:
|
|
104
|
+
from providers.opencode import OpenCodeProvider
|
|
105
|
+
|
|
106
|
+
return OpenCodeProvider(config, provider_name="OPENCODE_GO")
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def _create_zai(config: ProviderConfig, _settings: Settings) -> BaseProvider:
|
|
110
|
+
from providers.zai import ZaiProvider
|
|
111
|
+
|
|
112
|
+
return ZaiProvider(config)
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def _create_fireworks(config: ProviderConfig, _settings: Settings) -> BaseProvider:
|
|
116
|
+
from providers.fireworks import FireworksProvider
|
|
117
|
+
|
|
118
|
+
return FireworksProvider(config)
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def _create_gemini(config: ProviderConfig, _settings: Settings) -> BaseProvider:
|
|
122
|
+
from providers.gemini import GeminiProvider
|
|
123
|
+
|
|
124
|
+
return GeminiProvider(config)
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def _create_groq(config: ProviderConfig, _settings: Settings) -> BaseProvider:
|
|
128
|
+
from providers.groq import GroqProvider
|
|
129
|
+
|
|
130
|
+
return GroqProvider(config)
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
def _create_cerebras(config: ProviderConfig, _settings: Settings) -> BaseProvider:
|
|
134
|
+
from providers.cerebras import CerebrasProvider
|
|
135
|
+
|
|
136
|
+
return CerebrasProvider(config)
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
PROVIDER_FACTORIES: dict[str, ProviderFactory] = {
|
|
140
|
+
"nvidia_nim": _create_nvidia_nim,
|
|
141
|
+
"open_router": _create_open_router,
|
|
142
|
+
"gemini": _create_gemini,
|
|
143
|
+
"deepseek": _create_deepseek,
|
|
144
|
+
"mistral": _create_mistral,
|
|
145
|
+
"mistral_codestral": _create_mistral_codestral,
|
|
146
|
+
"opencode": _create_opencode,
|
|
147
|
+
"opencode_go": _create_opencode_go,
|
|
148
|
+
"wafer": _create_wafer,
|
|
149
|
+
"kimi": _create_kimi,
|
|
150
|
+
"cerebras": _create_cerebras,
|
|
151
|
+
"groq": _create_groq,
|
|
152
|
+
"fireworks": _create_fireworks,
|
|
153
|
+
"zai": _create_zai,
|
|
154
|
+
"lmstudio": _create_lmstudio,
|
|
155
|
+
"llamacpp": _create_llamacpp,
|
|
156
|
+
"ollama": _create_ollama,
|
|
157
|
+
}
|
|
158
|
+
|
|
159
|
+
if set(PROVIDER_DESCRIPTORS) != set(SUPPORTED_PROVIDER_IDS) or set(
|
|
160
|
+
PROVIDER_FACTORIES
|
|
161
|
+
) != set(SUPPORTED_PROVIDER_IDS):
|
|
162
|
+
raise AssertionError(
|
|
163
|
+
"PROVIDER_DESCRIPTORS, PROVIDER_FACTORIES, and SUPPORTED_PROVIDER_IDS are out of sync: "
|
|
164
|
+
f"descriptors={set(PROVIDER_DESCRIPTORS)!r} factories={set(PROVIDER_FACTORIES)!r} "
|
|
165
|
+
f"ids={set(SUPPORTED_PROVIDER_IDS)!r}"
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def _string_attr(settings: Settings, attr_name: str | None, default: str = "") -> str:
|
|
170
|
+
if attr_name is None:
|
|
171
|
+
return default
|
|
172
|
+
value = getattr(settings, attr_name, default)
|
|
173
|
+
return value if isinstance(value, str) else default
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def _credential_for(descriptor: ProviderDescriptor, settings: Settings) -> str:
|
|
177
|
+
if descriptor.static_credential is not None:
|
|
178
|
+
return descriptor.static_credential
|
|
179
|
+
if descriptor.credential_attr:
|
|
180
|
+
return _string_attr(settings, descriptor.credential_attr)
|
|
181
|
+
return ""
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
def _require_credential(descriptor: ProviderDescriptor, credential: str) -> None:
|
|
185
|
+
if descriptor.credential_env is None:
|
|
186
|
+
return
|
|
187
|
+
if credential and credential.strip():
|
|
188
|
+
return
|
|
189
|
+
message = f"{descriptor.credential_env} is not set. Add it to your .env file."
|
|
190
|
+
if descriptor.credential_url:
|
|
191
|
+
message = f"{message} Get a key at {descriptor.credential_url}"
|
|
192
|
+
raise AuthenticationError(message)
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
def build_provider_config(
|
|
196
|
+
descriptor: ProviderDescriptor, settings: Settings
|
|
197
|
+
) -> ProviderConfig:
|
|
198
|
+
credential = _credential_for(descriptor, settings)
|
|
199
|
+
_require_credential(descriptor, credential)
|
|
200
|
+
base_url = _string_attr(
|
|
201
|
+
settings, descriptor.base_url_attr, descriptor.default_base_url or ""
|
|
202
|
+
)
|
|
203
|
+
proxy = _string_attr(settings, descriptor.proxy_attr)
|
|
204
|
+
return ProviderConfig(
|
|
205
|
+
api_key=credential,
|
|
206
|
+
base_url=base_url or descriptor.default_base_url,
|
|
207
|
+
rate_limit=settings.provider_rate_limit,
|
|
208
|
+
rate_window=settings.provider_rate_window,
|
|
209
|
+
max_concurrency=settings.provider_max_concurrency,
|
|
210
|
+
http_read_timeout=settings.http_read_timeout,
|
|
211
|
+
http_write_timeout=settings.http_write_timeout,
|
|
212
|
+
http_connect_timeout=settings.http_connect_timeout,
|
|
213
|
+
enable_thinking=settings.enable_model_thinking,
|
|
214
|
+
proxy=proxy,
|
|
215
|
+
log_raw_sse_events=settings.log_raw_sse_events,
|
|
216
|
+
log_api_error_tracebacks=settings.log_api_error_tracebacks,
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
def create_provider(provider_id: str, settings: Settings) -> BaseProvider:
|
|
221
|
+
descriptor = PROVIDER_DESCRIPTORS.get(provider_id)
|
|
222
|
+
if descriptor is None:
|
|
223
|
+
supported = "', '".join(PROVIDER_DESCRIPTORS)
|
|
224
|
+
raise UnknownProviderTypeError(
|
|
225
|
+
f"Unknown provider_type: '{provider_id}'. Supported: '{supported}'"
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
config = build_provider_config(descriptor, settings)
|
|
229
|
+
factory = PROVIDER_FACTORIES.get(provider_id)
|
|
230
|
+
if factory is None:
|
|
231
|
+
raise AssertionError(f"Unhandled provider descriptor: {provider_id}")
|
|
232
|
+
return factory(config, settings)
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
def _format_provider_query_failures(
|
|
236
|
+
refs: list[ConfiguredChatModelRef],
|
|
237
|
+
exc: BaseException,
|
|
238
|
+
settings: Settings,
|
|
239
|
+
) -> list[str]:
|
|
240
|
+
reason = _provider_query_failure_reason(exc, settings)
|
|
241
|
+
return [_format_model_validation_failure(ref, reason) for ref in refs]
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
def _format_missing_model_failure(ref: ConfiguredChatModelRef) -> str:
|
|
245
|
+
return _format_model_validation_failure(ref, "missing model")
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
def _format_model_validation_failure(ref: ConfiguredChatModelRef, problem: str) -> str:
|
|
249
|
+
return (
|
|
250
|
+
f"sources={','.join(ref.sources)} provider={ref.provider_id} "
|
|
251
|
+
f"model={ref.model_id} problem={problem}"
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
def _provider_query_failure_reason(
|
|
256
|
+
exc: BaseException,
|
|
257
|
+
settings: Settings,
|
|
258
|
+
) -> str:
|
|
259
|
+
if isinstance(exc, ModelListResponseError):
|
|
260
|
+
return f"malformed model-list response: {exc.message}"
|
|
261
|
+
if isinstance(exc, httpx.HTTPStatusError):
|
|
262
|
+
return f"query failure: HTTP {exc.response.status_code}"
|
|
263
|
+
if isinstance(exc, AuthenticationError):
|
|
264
|
+
return f"query failure: {exc.message}"
|
|
265
|
+
if isinstance(exc, ProviderError) and settings.log_api_error_tracebacks:
|
|
266
|
+
return f"query failure: {exc.message}"
|
|
267
|
+
return f"query failure: {type(exc).__name__}"
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
def _referenced_provider_ids(settings: Settings) -> frozenset[str]:
|
|
271
|
+
return frozenset(ref.provider_id for ref in settings.configured_chat_model_refs())
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
def _model_list_provider_ids_for_settings(settings: Settings) -> tuple[str, ...]:
|
|
275
|
+
"""Return providers worth discovering for this process configuration."""
|
|
276
|
+
referenced_provider_ids = _referenced_provider_ids(settings)
|
|
277
|
+
provider_ids: list[str] = []
|
|
278
|
+
for provider_id, descriptor in PROVIDER_DESCRIPTORS.items():
|
|
279
|
+
if descriptor.static_credential is not None:
|
|
280
|
+
if provider_id in referenced_provider_ids:
|
|
281
|
+
provider_ids.append(provider_id)
|
|
282
|
+
continue
|
|
283
|
+
if (
|
|
284
|
+
descriptor.credential_env is not None
|
|
285
|
+
and _credential_for(descriptor, settings).strip()
|
|
286
|
+
):
|
|
287
|
+
provider_ids.append(provider_id)
|
|
288
|
+
return tuple(provider_ids)
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
def _log_model_discovery_failure(
|
|
292
|
+
provider_id: str, exc: BaseException, settings: Settings
|
|
293
|
+
) -> None:
|
|
294
|
+
logger.warning(
|
|
295
|
+
"Provider model discovery skipped: provider={} reason={}",
|
|
296
|
+
provider_id,
|
|
297
|
+
_provider_query_failure_reason(exc, settings),
|
|
298
|
+
)
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
class ProviderRegistry:
|
|
302
|
+
"""Cache and clean up provider instances by provider id."""
|
|
303
|
+
|
|
304
|
+
def __init__(self, providers: MutableMapping[str, BaseProvider] | None = None):
|
|
305
|
+
self._providers = providers if providers is not None else {}
|
|
306
|
+
self._model_ids_by_provider: dict[str, frozenset[str]] = {}
|
|
307
|
+
self._model_infos_by_provider: dict[str, dict[str, ProviderModelInfo]] = {}
|
|
308
|
+
self._model_list_refresh_task: asyncio.Task[None] | None = None
|
|
309
|
+
|
|
310
|
+
def is_cached(self, provider_id: str) -> bool:
|
|
311
|
+
"""Return whether a provider for this id is already in the cache."""
|
|
312
|
+
return provider_id in self._providers
|
|
313
|
+
|
|
314
|
+
def get(self, provider_id: str, settings: Settings) -> BaseProvider:
|
|
315
|
+
if provider_id not in self._providers:
|
|
316
|
+
self._providers[provider_id] = create_provider(provider_id, settings)
|
|
317
|
+
return self._providers[provider_id]
|
|
318
|
+
|
|
319
|
+
def cache_model_ids(self, provider_id: str, model_ids: Iterable[str]) -> None:
|
|
320
|
+
"""Store a provider model-list result for later instant API responses."""
|
|
321
|
+
self.cache_model_infos(provider_id, model_infos_from_ids(model_ids))
|
|
322
|
+
|
|
323
|
+
def cache_model_infos(
|
|
324
|
+
self, provider_id: str, model_infos: Iterable[ProviderModelInfo]
|
|
325
|
+
) -> None:
|
|
326
|
+
"""Store provider model metadata for later instant API responses."""
|
|
327
|
+
clean_infos = {
|
|
328
|
+
info.model_id: info for info in model_infos if info.model_id.strip()
|
|
329
|
+
}
|
|
330
|
+
self._model_infos_by_provider[provider_id] = clean_infos
|
|
331
|
+
self._model_ids_by_provider[provider_id] = frozenset(clean_infos)
|
|
332
|
+
|
|
333
|
+
def cached_model_ids(self) -> dict[str, frozenset[str]]:
|
|
334
|
+
"""Return a copy of cached raw provider model ids."""
|
|
335
|
+
return dict(self._model_ids_by_provider)
|
|
336
|
+
|
|
337
|
+
def cached_model_supports_thinking(
|
|
338
|
+
self, provider_id: str, model_id: str
|
|
339
|
+
) -> bool | None:
|
|
340
|
+
"""Return cached thinking support when a provider exposes it."""
|
|
341
|
+
info = self._model_infos_by_provider.get(provider_id, {}).get(model_id)
|
|
342
|
+
if info is None:
|
|
343
|
+
return None
|
|
344
|
+
return info.supports_thinking
|
|
345
|
+
|
|
346
|
+
def cached_prefixed_model_refs(self) -> tuple[str, ...]:
|
|
347
|
+
"""Return cached provider models in user-selectable ``provider/model`` form."""
|
|
348
|
+
return tuple(info.model_id for info in self.cached_prefixed_model_infos())
|
|
349
|
+
|
|
350
|
+
def cached_prefixed_model_infos(self) -> tuple[ProviderModelInfo, ...]:
|
|
351
|
+
"""Return cached provider models with user-selectable prefixed ids."""
|
|
352
|
+
infos: list[ProviderModelInfo] = []
|
|
353
|
+
for provider_id in SUPPORTED_PROVIDER_IDS:
|
|
354
|
+
provider_infos = self._model_infos_by_provider.get(provider_id, {})
|
|
355
|
+
infos.extend(
|
|
356
|
+
ProviderModelInfo(
|
|
357
|
+
model_id=f"{provider_id}/{info.model_id}",
|
|
358
|
+
supports_thinking=info.supports_thinking,
|
|
359
|
+
)
|
|
360
|
+
for info in sorted(
|
|
361
|
+
provider_infos.values(), key=lambda item: item.model_id
|
|
362
|
+
)
|
|
363
|
+
)
|
|
364
|
+
return tuple(infos)
|
|
365
|
+
|
|
366
|
+
async def refresh_model_list_cache(
|
|
367
|
+
self, settings: Settings, *, only_missing: bool = False
|
|
368
|
+
) -> None:
|
|
369
|
+
"""Best-effort refresh of model lists for providers usable in this process."""
|
|
370
|
+
provider_ids = _model_list_provider_ids_for_settings(settings)
|
|
371
|
+
if only_missing:
|
|
372
|
+
provider_ids = tuple(
|
|
373
|
+
provider_id
|
|
374
|
+
for provider_id in provider_ids
|
|
375
|
+
if provider_id not in self._model_ids_by_provider
|
|
376
|
+
)
|
|
377
|
+
await self._refresh_model_ids(settings, provider_ids)
|
|
378
|
+
|
|
379
|
+
def start_model_list_refresh(self, settings: Settings) -> None:
|
|
380
|
+
"""Start a non-blocking cache warmup for missing eligible provider lists."""
|
|
381
|
+
if (
|
|
382
|
+
self._model_list_refresh_task is not None
|
|
383
|
+
and not self._model_list_refresh_task.done()
|
|
384
|
+
):
|
|
385
|
+
return
|
|
386
|
+
|
|
387
|
+
provider_ids = tuple(
|
|
388
|
+
provider_id
|
|
389
|
+
for provider_id in _model_list_provider_ids_for_settings(settings)
|
|
390
|
+
if provider_id not in self._model_ids_by_provider
|
|
391
|
+
)
|
|
392
|
+
if not provider_ids:
|
|
393
|
+
logger.info(
|
|
394
|
+
"Provider model discovery cache already warm: providers={}",
|
|
395
|
+
len(self._model_ids_by_provider),
|
|
396
|
+
)
|
|
397
|
+
return
|
|
398
|
+
|
|
399
|
+
self._model_list_refresh_task = asyncio.create_task(
|
|
400
|
+
self._run_model_list_refresh(settings, provider_ids)
|
|
401
|
+
)
|
|
402
|
+
|
|
403
|
+
async def _run_model_list_refresh(
|
|
404
|
+
self, settings: Settings, provider_ids: tuple[str, ...]
|
|
405
|
+
) -> None:
|
|
406
|
+
try:
|
|
407
|
+
await self._refresh_model_ids(settings, provider_ids)
|
|
408
|
+
except asyncio.CancelledError:
|
|
409
|
+
raise
|
|
410
|
+
except Exception as exc:
|
|
411
|
+
logger.warning(
|
|
412
|
+
"Provider model discovery task failed: exc_type={}",
|
|
413
|
+
type(exc).__name__,
|
|
414
|
+
)
|
|
415
|
+
|
|
416
|
+
async def _refresh_model_ids(
|
|
417
|
+
self, settings: Settings, provider_ids: tuple[str, ...]
|
|
418
|
+
) -> None:
|
|
419
|
+
tasks: dict[str, asyncio.Task[frozenset[ProviderModelInfo]]] = {}
|
|
420
|
+
for provider_id in provider_ids:
|
|
421
|
+
try:
|
|
422
|
+
provider = self.get(provider_id, settings)
|
|
423
|
+
except Exception as exc:
|
|
424
|
+
_log_model_discovery_failure(provider_id, exc, settings)
|
|
425
|
+
continue
|
|
426
|
+
tasks[provider_id] = asyncio.create_task(provider.list_model_infos())
|
|
427
|
+
|
|
428
|
+
if not tasks:
|
|
429
|
+
return
|
|
430
|
+
|
|
431
|
+
results = await asyncio.gather(*tasks.values(), return_exceptions=True)
|
|
432
|
+
for (provider_id, _task), result in zip(tasks.items(), results, strict=True):
|
|
433
|
+
if isinstance(result, BaseException):
|
|
434
|
+
if isinstance(result, asyncio.CancelledError):
|
|
435
|
+
raise result
|
|
436
|
+
_log_model_discovery_failure(provider_id, result, settings)
|
|
437
|
+
continue
|
|
438
|
+
self.cache_model_infos(provider_id, result)
|
|
439
|
+
logger.info(
|
|
440
|
+
"Provider model discovery cached: provider={} models={}",
|
|
441
|
+
provider_id,
|
|
442
|
+
len(result),
|
|
443
|
+
)
|
|
444
|
+
|
|
445
|
+
async def validate_configured_models(self, settings: Settings) -> None:
|
|
446
|
+
"""Fail fast unless every configured chat model exists upstream."""
|
|
447
|
+
refs = settings.configured_chat_model_refs()
|
|
448
|
+
refs_by_provider: dict[str, list[ConfiguredChatModelRef]] = defaultdict(list)
|
|
449
|
+
for ref in refs:
|
|
450
|
+
refs_by_provider[ref.provider_id].append(ref)
|
|
451
|
+
|
|
452
|
+
failures: list[str] = []
|
|
453
|
+
tasks: dict[str, asyncio.Task[frozenset[ProviderModelInfo]]] = {}
|
|
454
|
+
for provider_id, provider_refs in refs_by_provider.items():
|
|
455
|
+
try:
|
|
456
|
+
provider = self.get(provider_id, settings)
|
|
457
|
+
except Exception as exc:
|
|
458
|
+
failures.extend(
|
|
459
|
+
_format_provider_query_failures(provider_refs, exc, settings)
|
|
460
|
+
)
|
|
461
|
+
continue
|
|
462
|
+
tasks[provider_id] = asyncio.create_task(provider.list_model_infos())
|
|
463
|
+
|
|
464
|
+
if tasks:
|
|
465
|
+
results = await asyncio.gather(*tasks.values(), return_exceptions=True)
|
|
466
|
+
for (provider_id, _task), result in zip(
|
|
467
|
+
tasks.items(), results, strict=True
|
|
468
|
+
):
|
|
469
|
+
provider_refs = refs_by_provider[provider_id]
|
|
470
|
+
if isinstance(result, BaseException):
|
|
471
|
+
if isinstance(result, asyncio.CancelledError):
|
|
472
|
+
raise result
|
|
473
|
+
failures.extend(
|
|
474
|
+
_format_provider_query_failures(provider_refs, result, settings)
|
|
475
|
+
)
|
|
476
|
+
continue
|
|
477
|
+
self.cache_model_infos(provider_id, result)
|
|
478
|
+
model_ids = self._model_ids_by_provider[provider_id]
|
|
479
|
+
failures.extend(
|
|
480
|
+
_format_missing_model_failure(ref)
|
|
481
|
+
for ref in provider_refs
|
|
482
|
+
if ref.model_id not in model_ids
|
|
483
|
+
)
|
|
484
|
+
|
|
485
|
+
if failures:
|
|
486
|
+
message = "Configured model validation failed:\n" + "\n".join(
|
|
487
|
+
f"- {failure}" for failure in failures
|
|
488
|
+
)
|
|
489
|
+
raise ServiceUnavailableError(message)
|
|
490
|
+
|
|
491
|
+
logger.info(
|
|
492
|
+
"Configured provider models validated: models={} providers={}",
|
|
493
|
+
len(refs),
|
|
494
|
+
len(refs_by_provider),
|
|
495
|
+
)
|
|
496
|
+
|
|
497
|
+
async def cleanup(self) -> None:
|
|
498
|
+
"""Call ``cleanup`` on every cached provider, then clear the cache.
|
|
499
|
+
|
|
500
|
+
Attempts all providers even if one fails. A single failure is re-raised
|
|
501
|
+
as-is; multiple failures are wrapped in :exc:`ExceptionGroup`.
|
|
502
|
+
"""
|
|
503
|
+
if (
|
|
504
|
+
self._model_list_refresh_task is not None
|
|
505
|
+
and not self._model_list_refresh_task.done()
|
|
506
|
+
):
|
|
507
|
+
self._model_list_refresh_task.cancel()
|
|
508
|
+
with suppress(asyncio.CancelledError):
|
|
509
|
+
await self._model_list_refresh_task
|
|
510
|
+
|
|
511
|
+
items = list(self._providers.items())
|
|
512
|
+
errors: list[Exception] = []
|
|
513
|
+
try:
|
|
514
|
+
for _pid, provider in items:
|
|
515
|
+
try:
|
|
516
|
+
await provider.cleanup()
|
|
517
|
+
except Exception as e:
|
|
518
|
+
errors.append(e)
|
|
519
|
+
finally:
|
|
520
|
+
self._providers.clear()
|
|
521
|
+
self._model_ids_by_provider.clear()
|
|
522
|
+
self._model_infos_by_provider.clear()
|
|
523
|
+
if len(errors) == 1:
|
|
524
|
+
raise errors[0]
|
|
525
|
+
if len(errors) > 1:
|
|
526
|
+
msg = "One or more provider cleanups failed"
|
|
527
|
+
raise ExceptionGroup(msg, errors)
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Provider transport families."""
|
|
@@ -0,0 +1,118 @@
|
|
|
1
|
+
"""HTTP helpers for native Anthropic Messages transports."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import inspect
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
import httpx
|
|
9
|
+
from loguru import logger
|
|
10
|
+
|
|
11
|
+
from config.constants import (
|
|
12
|
+
NATIVE_MESSAGES_ERROR_BODY_LOG_CAP_BYTES,
|
|
13
|
+
PROVIDER_ERROR_BODY_DISPLAY_CAP_BYTES,
|
|
14
|
+
)
|
|
15
|
+
from providers.error_mapping import attach_provider_error_body
|
|
16
|
+
from providers.exceptions import ModelListResponseError
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
async def maybe_await_aclose(response: Any) -> None:
|
|
20
|
+
"""Call ``aclose`` on httpx-like responses; ignore sync test doubles."""
|
|
21
|
+
close = getattr(response, "aclose", None)
|
|
22
|
+
if not callable(close):
|
|
23
|
+
return
|
|
24
|
+
result = close()
|
|
25
|
+
if inspect.isawaitable(result):
|
|
26
|
+
await result
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def model_list_json(response: httpx.Response, *, provider_name: str) -> Any:
|
|
30
|
+
"""Parse model-list JSON with a provider-specific malformed-body error."""
|
|
31
|
+
response.raise_for_status()
|
|
32
|
+
try:
|
|
33
|
+
return response.json()
|
|
34
|
+
except ValueError as exc:
|
|
35
|
+
raise ModelListResponseError(
|
|
36
|
+
f"{provider_name} model-list response is malformed: invalid JSON"
|
|
37
|
+
) from exc
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
async def read_error_body_preview(
|
|
41
|
+
response: httpx.Response, max_bytes: int
|
|
42
|
+
) -> tuple[bytes, bool]:
|
|
43
|
+
"""Read at most ``max_bytes`` from an error response body."""
|
|
44
|
+
if max_bytes <= 0:
|
|
45
|
+
return b"", False
|
|
46
|
+
received = 0
|
|
47
|
+
parts: list[bytes] = []
|
|
48
|
+
truncated = False
|
|
49
|
+
async for chunk in response.aiter_bytes(chunk_size=65_536):
|
|
50
|
+
if received >= max_bytes:
|
|
51
|
+
truncated = True
|
|
52
|
+
break
|
|
53
|
+
remaining = max_bytes - received
|
|
54
|
+
take = chunk if len(chunk) <= remaining else chunk[:remaining]
|
|
55
|
+
if take:
|
|
56
|
+
parts.append(take)
|
|
57
|
+
received += len(take)
|
|
58
|
+
if len(chunk) > len(take):
|
|
59
|
+
truncated = True
|
|
60
|
+
break
|
|
61
|
+
if received >= max_bytes:
|
|
62
|
+
break
|
|
63
|
+
return (b"".join(parts), truncated)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
async def raise_for_status_with_body(
|
|
67
|
+
response: httpx.Response,
|
|
68
|
+
*,
|
|
69
|
+
provider_name: str,
|
|
70
|
+
req_tag: str,
|
|
71
|
+
log_api_error_tracebacks: bool,
|
|
72
|
+
) -> None:
|
|
73
|
+
"""Raise for non-200 responses after attaching a safe body preview."""
|
|
74
|
+
try:
|
|
75
|
+
response.raise_for_status()
|
|
76
|
+
except httpx.HTTPStatusError as error:
|
|
77
|
+
preview, truncated = await read_error_body_preview(
|
|
78
|
+
response, PROVIDER_ERROR_BODY_DISPLAY_CAP_BYTES
|
|
79
|
+
)
|
|
80
|
+
attach_provider_error_body(error, preview, truncated=truncated)
|
|
81
|
+
if log_api_error_tracebacks:
|
|
82
|
+
log_preview = preview[:NATIVE_MESSAGES_ERROR_BODY_LOG_CAP_BYTES]
|
|
83
|
+
log_truncated = truncated or len(preview) > len(log_preview)
|
|
84
|
+
if log_preview:
|
|
85
|
+
text = log_preview.decode("utf-8", errors="replace")
|
|
86
|
+
logger.error(
|
|
87
|
+
"{}_ERROR:{} HTTP {} body_preview_bytes={} truncated={}: {}",
|
|
88
|
+
provider_name,
|
|
89
|
+
req_tag,
|
|
90
|
+
response.status_code,
|
|
91
|
+
len(log_preview),
|
|
92
|
+
log_truncated,
|
|
93
|
+
text,
|
|
94
|
+
)
|
|
95
|
+
else:
|
|
96
|
+
logger.error(
|
|
97
|
+
"{}_ERROR:{} HTTP {} (empty error body)",
|
|
98
|
+
provider_name,
|
|
99
|
+
req_tag,
|
|
100
|
+
response.status_code,
|
|
101
|
+
)
|
|
102
|
+
else:
|
|
103
|
+
cl = response.headers.get("content-length", "").strip()
|
|
104
|
+
extra = f" content_length_declared={cl}" if cl.isdigit() else ""
|
|
105
|
+
body_extra = (
|
|
106
|
+
" empty_error_body"
|
|
107
|
+
if not preview
|
|
108
|
+
else f" error_body_bytes_read={len(preview)}"
|
|
109
|
+
)
|
|
110
|
+
logger.error(
|
|
111
|
+
"{}_ERROR:{} HTTP {}{}{}",
|
|
112
|
+
provider_name,
|
|
113
|
+
req_tag,
|
|
114
|
+
response.status_code,
|
|
115
|
+
extra,
|
|
116
|
+
body_extra,
|
|
117
|
+
)
|
|
118
|
+
raise error
|