gdmcode 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gdmcode-0.1.0.dist-info/METADATA +240 -0
- gdmcode-0.1.0.dist-info/RECORD +131 -0
- gdmcode-0.1.0.dist-info/WHEEL +4 -0
- gdmcode-0.1.0.dist-info/entry_points.txt +2 -0
- src/__init__.py +1 -0
- src/_internal/__init__.py +0 -0
- src/_internal/constants.py +244 -0
- src/_internal/domain_skills.py +339 -0
- src/agent/__init__.py +0 -0
- src/agent/commit_classifier.py +91 -0
- src/agent/context_budget.py +391 -0
- src/agent/daemon.py +681 -0
- src/agent/dag_validator.py +153 -0
- src/agent/debug_loop.py +473 -0
- src/agent/impact_analyzer.py +149 -0
- src/agent/impact_graph.py +117 -0
- src/agent/loop.py +1410 -0
- src/agent/orchestrator.py +141 -0
- src/agent/regression_guard.py +251 -0
- src/agent/review_gate.py +648 -0
- src/agent/risk_scorer.py +169 -0
- src/agent/self_healing.py +145 -0
- src/agent/smart_test_selector.py +89 -0
- src/agent/system_prompt.py +226 -0
- src/agent/task_tracker.py +320 -0
- src/agent/test_validator.py +210 -0
- src/agent/tool_orchestrator.py +402 -0
- src/agent/transcript.py +230 -0
- src/agent/verification_loop.py +133 -0
- src/agent/work_director.py +136 -0
- src/agent/worktree_manager.py +53 -0
- src/artifacts/__init__.py +16 -0
- src/artifacts/artifact_store.py +456 -0
- src/artifacts/verification_graph.py +75 -0
- src/auth.py +411 -0
- src/cli.py +1290 -0
- src/commands.py +1398 -0
- src/config.py +762 -0
- src/cost_tracker.py +348 -0
- src/db/__init__.py +4 -0
- src/db/migrations.py +337 -0
- src/enterprise/__init__.py +3 -0
- src/enterprise/audit_log.py +182 -0
- src/enterprise/identity.py +90 -0
- src/enterprise/rbac.py +100 -0
- src/enterprise/team_config.py +125 -0
- src/enterprise/usage_analytics.py +261 -0
- src/exceptions.py +207 -0
- src/git_workflow.py +651 -0
- src/integrations/__init__.py +6 -0
- src/integrations/github_actions.py +106 -0
- src/integrations/mcp_server.py +333 -0
- src/integrations/sentry_integration.py +100 -0
- src/integrations/sentry_server.py +82 -0
- src/integrations/webhook_security.py +19 -0
- src/main.py +27 -0
- src/memory/__init__.py +0 -0
- src/memory/code_index.py +376 -0
- src/memory/compressor.py +378 -0
- src/memory/context_memory.py +135 -0
- src/memory/continuous_memory.py +234 -0
- src/memory/conventions.py +495 -0
- src/memory/db.py +1119 -0
- src/memory/document_index.py +205 -0
- src/memory/file_cache.py +128 -0
- src/memory/project_scanner.py +178 -0
- src/memory/session_store.py +201 -0
- src/models/__init__.py +0 -0
- src/models/client.py +715 -0
- src/models/definitions.py +459 -0
- src/models/router.py +418 -0
- src/models/schemas.py +389 -0
- src/permissions.py +294 -0
- src/remote/__init__.py +5 -0
- src/remote/command_filter.py +33 -0
- src/remote/models.py +31 -0
- src/remote/permission_handler.py +79 -0
- src/remote/phone_ui.py +48 -0
- src/remote/protocol.py +59 -0
- src/remote/qr.py +65 -0
- src/remote/server.py +586 -0
- src/remote/token_manager.py +61 -0
- src/remote/tunnel.py +212 -0
- src/repl.py +475 -0
- src/runtime/__init__.py +1 -0
- src/runtime/branch_farm.py +372 -0
- src/runtime/replay.py +351 -0
- src/sandbox/__init__.py +2 -0
- src/sandbox/hermetic.py +214 -0
- src/sandbox/policy.py +44 -0
- src/sdk/__init__.py +3 -0
- src/sdk/plugin_base.py +39 -0
- src/sdk/plugin_host.py +100 -0
- src/sdk/plugin_loader.py +101 -0
- src/security.py +409 -0
- src/server/__init__.py +7 -0
- src/server/bridge.py +427 -0
- src/server/bridge_cli.py +103 -0
- src/server/bridge_client.py +170 -0
- src/server/protocol_version.py +103 -0
- src/session/__init__.py +10 -0
- src/session/event_fanout.py +46 -0
- src/session/input_broker.py +38 -0
- src/session/permission_bridge.py +100 -0
- src/tools/__init__.py +160 -0
- src/tools/_atomic.py +72 -0
- src/tools/agent_tools.py +423 -0
- src/tools/ask_user_tool.py +83 -0
- src/tools/bash_tool.py +384 -0
- src/tools/browser_tool.py +352 -0
- src/tools/browser_tools.py +179 -0
- src/tools/dep_tools.py +210 -0
- src/tools/document_reader.py +167 -0
- src/tools/document_tool.py +240 -0
- src/tools/document_writer.py +171 -0
- src/tools/impact_tools.py +240 -0
- src/tools/playwright_tool.py +172 -0
- src/tools/quality_tools.py +366 -0
- src/tools/read_tools.py +318 -0
- src/tools/result_cache.py +157 -0
- src/tools/search_tools.py +310 -0
- src/tools/shell_tools.py +311 -0
- src/tools/write_tools.py +337 -0
- src/voice/__init__.py +25 -0
- src/voice/audio_capture.py +92 -0
- src/voice/audio_playback.py +68 -0
- src/voice/errors.py +14 -0
- src/voice/models.py +35 -0
- src/voice/providers.py +143 -0
- src/voice/vad.py +55 -0
- src/voice/voice_loop.py +156 -0
src/models/client.py
ADDED
|
@@ -0,0 +1,715 @@
|
|
|
1
|
+
"""GdmClient — structured-output-aware wrapper around openai.OpenAI.
|
|
2
|
+
|
|
3
|
+
Provides retry with exponential backoff, structured parsing via
|
|
4
|
+
client.beta.chat.completions.parse(), and provider-specific configuration.
|
|
5
|
+
"""
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
import io
|
|
9
|
+
import json
|
|
10
|
+
import logging
|
|
11
|
+
import random
|
|
12
|
+
import time
|
|
13
|
+
import uuid
|
|
14
|
+
from collections.abc import Callable
|
|
15
|
+
from dataclasses import dataclass, field, replace as _dc_replace
|
|
16
|
+
from typing import Any, TypeVar
|
|
17
|
+
|
|
18
|
+
import openai
|
|
19
|
+
from pydantic import BaseModel
|
|
20
|
+
|
|
21
|
+
from src.config import GdmConfig
|
|
22
|
+
from src.exceptions import ApiError, ApiRateLimitError, ConfigError, SchemaError
|
|
23
|
+
from src.models.definitions import PROVIDER_BASE_URLS, Provider
|
|
24
|
+
|
|
25
|
+
__all__ = [
|
|
26
|
+
"GdmClient",
|
|
27
|
+
"BaseModelT",
|
|
28
|
+
"LocalProviderConfig",
|
|
29
|
+
"LocalModelCapabilities",
|
|
30
|
+
"OllamaProvider",
|
|
31
|
+
"VLLMProvider",
|
|
32
|
+
"_validate_local_base_url",
|
|
33
|
+
"should_use_batch",
|
|
34
|
+
"OpenAIBatchBackend",
|
|
35
|
+
"GeminiBatchBackend",
|
|
36
|
+
"GrokSequentialBackend",
|
|
37
|
+
"BatchClient",
|
|
38
|
+
"_map_status",
|
|
39
|
+
]
|
|
40
|
+
|
|
41
|
+
log = logging.getLogger(__name__)
|
|
42
|
+
|
|
43
|
+
BaseModelT = TypeVar("BaseModelT", bound=BaseModel)
|
|
44
|
+
|
|
45
|
+
_MAX_RETRIES: int = 3
|
|
46
|
+
_BACKOFF_BASE: float = 2.0
|
|
47
|
+
_PARSE_MAX_RETRIES: int = 2
|
|
48
|
+
_SCHEMA_CORRECTION_PROMPT: str = (
|
|
49
|
+
"Your previous response did not conform to the required JSON schema. "
|
|
50
|
+
"Please try again and ensure your response matches the schema exactly."
|
|
51
|
+
)
|
|
52
|
+
_RETRYABLE_ERRORS: tuple[type[Exception], ...] = (
|
|
53
|
+
openai.RateLimitError,
|
|
54
|
+
openai.InternalServerError,
|
|
55
|
+
openai.APIConnectionError,
|
|
56
|
+
openai.APITimeoutError,
|
|
57
|
+
)
|
|
58
|
+
_FINISH_REASON_TRUNCATED: str = "length"
|
|
59
|
+
_FINISH_REASON_FILTERED: str = "content_filter"
|
|
60
|
+
|
|
61
|
+
# ---------------------------------------------------------------------------
|
|
62
|
+
# Local provider support
|
|
63
|
+
# ---------------------------------------------------------------------------
|
|
64
|
+
|
|
65
|
+
_ALLOWED_LOCAL_HOSTS: frozenset[str] = frozenset({"localhost", "127.0.0.1", "::1"})
|
|
66
|
+
_LOCAL_PROVIDERS: frozenset[str] = frozenset({"ollama", "vllm"})
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
@dataclass
|
|
70
|
+
class LocalProviderConfig:
|
|
71
|
+
"""Configuration for a local (non-cloud) model provider."""
|
|
72
|
+
|
|
73
|
+
provider: str # "ollama" or "vllm"
|
|
74
|
+
base_url: str
|
|
75
|
+
model: str
|
|
76
|
+
local_only: bool = True
|
|
77
|
+
allow_cloud_fallback: bool = False
|
|
78
|
+
allow_external: bool = False
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
@dataclass
|
|
82
|
+
class LocalModelCapabilities:
|
|
83
|
+
"""Capabilities reported by a local model provider."""
|
|
84
|
+
|
|
85
|
+
supports_tools: bool = False
|
|
86
|
+
supports_json_mode: bool = False
|
|
87
|
+
context_window: int = 4096
|
|
88
|
+
streaming: bool = True
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def _validate_local_base_url(url: str, allow_external: bool = False) -> None:
|
|
92
|
+
"""Raise ConfigError if url is non-local and allow_external is False."""
|
|
93
|
+
from urllib.parse import urlparse
|
|
94
|
+
parsed = urlparse(url)
|
|
95
|
+
if not allow_external and parsed.hostname not in _ALLOWED_LOCAL_HOSTS:
|
|
96
|
+
raise ConfigError(
|
|
97
|
+
f"Local provider base_url '{url}' resolves to non-local host "
|
|
98
|
+
f"'{parsed.hostname}'. Set allow_external=true to override."
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
class OllamaProvider:
|
|
103
|
+
"""Native Ollama /api/chat client (NOT OpenAI-compatible)."""
|
|
104
|
+
|
|
105
|
+
def __init__(self, config: LocalProviderConfig) -> None:
|
|
106
|
+
_validate_local_base_url(config.base_url, config.allow_external)
|
|
107
|
+
self.config = config
|
|
108
|
+
self._capabilities: LocalModelCapabilities | None = None
|
|
109
|
+
import urllib.request
|
|
110
|
+
self._http = urllib.request # stdlib only — no requests dep
|
|
111
|
+
|
|
112
|
+
def complete(self, messages: list[dict], *, tools: list | None = None, **kw: Any) -> dict:
|
|
113
|
+
"""POST to /api/chat and return adapted OpenAI-shaped dict."""
|
|
114
|
+
import json
|
|
115
|
+
|
|
116
|
+
if tools is not None:
|
|
117
|
+
caps = self.probe_capabilities()
|
|
118
|
+
if not caps.supports_tools:
|
|
119
|
+
raise ConfigError(
|
|
120
|
+
f"Model '{self.config.model}' does not support tool calls "
|
|
121
|
+
"(supports_tools=False). Remove tools or use a capable model."
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
payload: dict[str, Any] = {
|
|
125
|
+
"model": self.config.model,
|
|
126
|
+
"messages": messages,
|
|
127
|
+
"stream": False,
|
|
128
|
+
}
|
|
129
|
+
if tools is not None:
|
|
130
|
+
payload["tools"] = tools
|
|
131
|
+
|
|
132
|
+
url = f"{self.config.base_url.rstrip('/')}/api/chat"
|
|
133
|
+
data = json.dumps(payload).encode("utf-8")
|
|
134
|
+
req = self._http.Request(
|
|
135
|
+
url,
|
|
136
|
+
data=data,
|
|
137
|
+
headers={"Content-Type": "application/json"},
|
|
138
|
+
method="POST",
|
|
139
|
+
)
|
|
140
|
+
with self._http.urlopen(req, timeout=60) as resp:
|
|
141
|
+
raw: dict = json.loads(resp.read())
|
|
142
|
+
return self.to_openai_response(raw)
|
|
143
|
+
|
|
144
|
+
def to_openai_response(self, raw: dict) -> dict:
|
|
145
|
+
"""Adapt Ollama /api/chat response to OpenAI ChatCompletion shape.
|
|
146
|
+
|
|
147
|
+
raw["message"]["content"] -> choices[0].message.content
|
|
148
|
+
raw["message"]["tool_calls"] -> choices[0].message.tool_calls (if present)
|
|
149
|
+
"""
|
|
150
|
+
message = raw.get("message", {})
|
|
151
|
+
content = message.get("content", "")
|
|
152
|
+
tool_calls = message.get("tool_calls")
|
|
153
|
+
|
|
154
|
+
msg: dict[str, Any] = {"role": "assistant", "content": content}
|
|
155
|
+
if tool_calls is not None:
|
|
156
|
+
msg["tool_calls"] = tool_calls
|
|
157
|
+
|
|
158
|
+
return {
|
|
159
|
+
"id": f"ollama-{raw.get('created_at', 'response')}",
|
|
160
|
+
"object": "chat.completion",
|
|
161
|
+
"model": raw.get("model", self.config.model),
|
|
162
|
+
"choices": [
|
|
163
|
+
{
|
|
164
|
+
"index": 0,
|
|
165
|
+
"message": msg,
|
|
166
|
+
"finish_reason": "stop",
|
|
167
|
+
}
|
|
168
|
+
],
|
|
169
|
+
"usage": {
|
|
170
|
+
"prompt_tokens": raw.get("prompt_eval_count", 0),
|
|
171
|
+
"completion_tokens": raw.get("eval_count", 0),
|
|
172
|
+
"total_tokens": (
|
|
173
|
+
raw.get("prompt_eval_count", 0) + raw.get("eval_count", 0)
|
|
174
|
+
),
|
|
175
|
+
},
|
|
176
|
+
}
|
|
177
|
+
|
|
178
|
+
def probe_capabilities(self) -> LocalModelCapabilities:
|
|
179
|
+
"""Lazy-probe /api/show for model capabilities. Cached per instance."""
|
|
180
|
+
if self._capabilities is not None:
|
|
181
|
+
return self._capabilities
|
|
182
|
+
|
|
183
|
+
import json
|
|
184
|
+
|
|
185
|
+
try:
|
|
186
|
+
url = f"{self.config.base_url.rstrip('/')}/api/show"
|
|
187
|
+
data = json.dumps({"name": self.config.model}).encode("utf-8")
|
|
188
|
+
req = self._http.Request(
|
|
189
|
+
url,
|
|
190
|
+
data=data,
|
|
191
|
+
headers={"Content-Type": "application/json"},
|
|
192
|
+
method="POST",
|
|
193
|
+
)
|
|
194
|
+
with self._http.urlopen(req, timeout=10) as resp:
|
|
195
|
+
show_data: dict = json.loads(resp.read())
|
|
196
|
+
|
|
197
|
+
modelinfo = show_data.get("modelinfo", {})
|
|
198
|
+
context_window = (
|
|
199
|
+
modelinfo.get("llama.context_length")
|
|
200
|
+
or modelinfo.get("context_length")
|
|
201
|
+
or 4096
|
|
202
|
+
)
|
|
203
|
+
self._capabilities = LocalModelCapabilities(
|
|
204
|
+
supports_tools=bool(show_data.get("supports_tools", False)),
|
|
205
|
+
supports_json_mode=bool(show_data.get("supports_json_mode", False)),
|
|
206
|
+
context_window=int(context_window),
|
|
207
|
+
streaming=True,
|
|
208
|
+
)
|
|
209
|
+
except Exception:
|
|
210
|
+
self._capabilities = LocalModelCapabilities()
|
|
211
|
+
|
|
212
|
+
return self._capabilities
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
class VLLMProvider:
|
|
216
|
+
"""vLLM provider -- OpenAI-compatible, reuses openai.OpenAI with custom base_url."""
|
|
217
|
+
|
|
218
|
+
def __init__(self, config: LocalProviderConfig) -> None:
|
|
219
|
+
_validate_local_base_url(config.base_url, config.allow_external)
|
|
220
|
+
self.config = config
|
|
221
|
+
|
|
222
|
+
def get_client(self) -> "openai.OpenAI":
|
|
223
|
+
"""Return openai.OpenAI(base_url=..., api_key='not-needed')."""
|
|
224
|
+
return openai.OpenAI(
|
|
225
|
+
base_url=self.config.base_url,
|
|
226
|
+
api_key="not-needed",
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
class GdmClient:
|
|
232
|
+
"""OpenAI-compatible client with retry, structured output, and provider abstraction.
|
|
233
|
+
|
|
234
|
+
Wraps openai.OpenAI with:
|
|
235
|
+
- Structured output via client.beta.chat.completions.parse() with Pydantic models
|
|
236
|
+
- Exponential backoff retry (3 retries, 2x backoff) on rate limit / server errors
|
|
237
|
+
- Provider-specific configuration (base URL, API key)
|
|
238
|
+
- Sync interface only (async in Phase 3)
|
|
239
|
+
"""
|
|
240
|
+
|
|
241
|
+
def __init__(self, cfg: GdmConfig) -> None:
|
|
242
|
+
"""Initialise the client using provider config from GdmConfig.
|
|
243
|
+
|
|
244
|
+
Raises:
|
|
245
|
+
ConfigError: if cfg.provider is not a recognised provider.
|
|
246
|
+
"""
|
|
247
|
+
if cfg.provider not in PROVIDER_BASE_URLS:
|
|
248
|
+
raise ConfigError(
|
|
249
|
+
f"Unknown provider '{cfg.provider}'. "
|
|
250
|
+
f"Valid providers: {list(PROVIDER_BASE_URLS)}"
|
|
251
|
+
)
|
|
252
|
+
self._client = openai.OpenAI(
|
|
253
|
+
api_key=cfg.api_key,
|
|
254
|
+
base_url=PROVIDER_BASE_URLS[cfg.provider],
|
|
255
|
+
)
|
|
256
|
+
self._cfg = cfg
|
|
257
|
+
|
|
258
|
+
@classmethod
|
|
259
|
+
def for_provider(cls, provider: str, cfg: GdmConfig) -> "GdmClient":
|
|
260
|
+
"""Create a GdmClient configured for *provider* using keys from *cfg*.
|
|
261
|
+
|
|
262
|
+
Used to switch providers at runtime (e.g. API fallback) without
|
|
263
|
+
reloading config. The returned client has a modified cfg snapshot
|
|
264
|
+
reflecting the fallback provider/key.
|
|
265
|
+
|
|
266
|
+
Raises:
|
|
267
|
+
ConfigError: if no API key is available for *provider*, or if
|
|
268
|
+
*provider* is not a recognised provider.
|
|
269
|
+
"""
|
|
270
|
+
if provider not in PROVIDER_BASE_URLS:
|
|
271
|
+
raise ConfigError(
|
|
272
|
+
f"Unknown provider '{provider}'. "
|
|
273
|
+
f"Valid providers: {list(PROVIDER_BASE_URLS)}"
|
|
274
|
+
)
|
|
275
|
+
key_map: dict[str, str | None] = {
|
|
276
|
+
Provider.GROK: cfg.xai_api_key,
|
|
277
|
+
Provider.GEMINI: cfg.gemini_api_key,
|
|
278
|
+
Provider.CODEX: cfg.openai_api_key,
|
|
279
|
+
}
|
|
280
|
+
api_key = key_map.get(provider)
|
|
281
|
+
if not api_key:
|
|
282
|
+
raise ConfigError(
|
|
283
|
+
f"No API key configured for fallback provider '{provider}'. "
|
|
284
|
+
f"Set XAI_API_KEY / GEMINI_API_KEY / OPENAI_API_KEY as appropriate."
|
|
285
|
+
)
|
|
286
|
+
instance = cls.__new__(cls)
|
|
287
|
+
instance._client = openai.OpenAI(
|
|
288
|
+
api_key=api_key,
|
|
289
|
+
base_url=PROVIDER_BASE_URLS[provider],
|
|
290
|
+
)
|
|
291
|
+
instance._cfg = _dc_replace(cfg, provider=provider, api_key=api_key)
|
|
292
|
+
return instance
|
|
293
|
+
|
|
294
|
+
@classmethod
|
|
295
|
+
def for_proxy(cls, proxy_url: str, proxy_token: str) -> "GdmClient":
|
|
296
|
+
"""Route all LLM calls through a proxy server.
|
|
297
|
+
|
|
298
|
+
Used when direct API access is geo-restricted (e.g. Grok unavailable
|
|
299
|
+
in DRC). The proxy must expose an OpenAI-compatible /v1 endpoint.
|
|
300
|
+
``proxy_token`` is sent as the Bearer auth header; HTTPS is strongly
|
|
301
|
+
recommended so the token is protected in transit.
|
|
302
|
+
"""
|
|
303
|
+
if not proxy_url.startswith("https://") and not proxy_url.startswith("http://localhost"):
|
|
304
|
+
log.warning(
|
|
305
|
+
"Proxy URL '%s' does not use HTTPS — token will be sent unencrypted",
|
|
306
|
+
proxy_url,
|
|
307
|
+
)
|
|
308
|
+
instance = cls.__new__(cls)
|
|
309
|
+
instance._client = openai.OpenAI(
|
|
310
|
+
api_key=proxy_token,
|
|
311
|
+
base_url=proxy_url.rstrip("/"),
|
|
312
|
+
)
|
|
313
|
+
# _cfg = None is safe: _dispatch() reads it via getattr(self._cfg, ..., None)
|
|
314
|
+
instance._cfg = None
|
|
315
|
+
return instance
|
|
316
|
+
|
|
317
|
+
def _call_with_retry(self, fn: Callable[[], Any]) -> Any:
|
|
318
|
+
"""Execute fn with exponential backoff retry on transient API errors.
|
|
319
|
+
|
|
320
|
+
Retries up to _MAX_RETRIES times on rate-limit, server, connection,
|
|
321
|
+
and timeout errors. Non-transient APIErrors are re-raised immediately
|
|
322
|
+
as ApiError.
|
|
323
|
+
"""
|
|
324
|
+
last_exc: BaseException = RuntimeError("unreachable sentinel")
|
|
325
|
+
for attempt in range(_MAX_RETRIES + 1):
|
|
326
|
+
try:
|
|
327
|
+
return fn()
|
|
328
|
+
except _RETRYABLE_ERRORS as exc:
|
|
329
|
+
last_exc = exc
|
|
330
|
+
if attempt < _MAX_RETRIES:
|
|
331
|
+
wait = _BACKOFF_BASE ** attempt
|
|
332
|
+
log.warning(
|
|
333
|
+
"Retryable API error (%s), backing off %.0fs (retry %d/%d)",
|
|
334
|
+
type(exc).__name__,
|
|
335
|
+
wait,
|
|
336
|
+
attempt + 1,
|
|
337
|
+
_MAX_RETRIES,
|
|
338
|
+
)
|
|
339
|
+
time.sleep(wait)
|
|
340
|
+
except openai.APIError as exc:
|
|
341
|
+
raise ApiError(
|
|
342
|
+
str(exc), status_code=getattr(exc, "status_code", None)
|
|
343
|
+
) from exc
|
|
344
|
+
status = getattr(last_exc, "status_code", None)
|
|
345
|
+
if isinstance(last_exc, openai.RateLimitError):
|
|
346
|
+
raise ApiRateLimitError(
|
|
347
|
+
f"Rate limit exceeded after {_MAX_RETRIES} retries: {last_exc}",
|
|
348
|
+
status_code=status,
|
|
349
|
+
) from last_exc
|
|
350
|
+
raise ApiError(
|
|
351
|
+
f"API call failed after {_MAX_RETRIES} retries: {last_exc}",
|
|
352
|
+
status_code=status,
|
|
353
|
+
) from last_exc
|
|
354
|
+
|
|
355
|
+
def complete(
|
|
356
|
+
self,
|
|
357
|
+
messages: list[dict[str, Any]],
|
|
358
|
+
*,
|
|
359
|
+
model: str,
|
|
360
|
+
tools: list[dict[str, Any]] | None = None,
|
|
361
|
+
tool_choice: str = "auto",
|
|
362
|
+
max_tokens: int | None = None,
|
|
363
|
+
temperature: float | None = None,
|
|
364
|
+
) -> Any:
|
|
365
|
+
"""Make a non-structured completion request. Used by the agent loop.
|
|
366
|
+
|
|
367
|
+
Returns an openai.types.chat.ChatCompletion object.
|
|
368
|
+
"""
|
|
369
|
+
_tools = tools if tools is not None else openai.NOT_GIVEN
|
|
370
|
+
_tool_choice: Any = tool_choice if tools is not None else openai.NOT_GIVEN
|
|
371
|
+
_max_tokens: Any = max_tokens if max_tokens is not None else openai.NOT_GIVEN
|
|
372
|
+
_temperature: Any = temperature if temperature is not None else openai.NOT_GIVEN
|
|
373
|
+
|
|
374
|
+
def _call() -> Any:
|
|
375
|
+
return self._client.chat.completions.create(
|
|
376
|
+
model=model,
|
|
377
|
+
messages=messages, # type: ignore[arg-type] # SDK type is narrower
|
|
378
|
+
tools=_tools,
|
|
379
|
+
tool_choice=_tool_choice,
|
|
380
|
+
max_tokens=_max_tokens,
|
|
381
|
+
temperature=_temperature,
|
|
382
|
+
)
|
|
383
|
+
|
|
384
|
+
return self._call_with_retry(_call)
|
|
385
|
+
|
|
386
|
+
def parse(
|
|
387
|
+
self,
|
|
388
|
+
messages: list[dict[str, Any]],
|
|
389
|
+
*,
|
|
390
|
+
model: str,
|
|
391
|
+
response_format: type[BaseModelT],
|
|
392
|
+
max_tokens: int | None = None,
|
|
393
|
+
) -> BaseModelT:
|
|
394
|
+
"""Make a structured output request via client.beta.chat.completions.parse().
|
|
395
|
+
|
|
396
|
+
Returns the parsed Pydantic model instance.
|
|
397
|
+
Retries up to _PARSE_MAX_RETRIES times on schema parse failure with an
|
|
398
|
+
explicit correction prompt.
|
|
399
|
+
|
|
400
|
+
Raises:
|
|
401
|
+
ApiError: on API failure, refusal, or content filter.
|
|
402
|
+
SchemaError: on repeated schema parse failure or response truncation.
|
|
403
|
+
"""
|
|
404
|
+
msgs: list[dict[str, Any]] = list(messages)
|
|
405
|
+
_max_tokens: Any = max_tokens if max_tokens is not None else openai.NOT_GIVEN
|
|
406
|
+
|
|
407
|
+
for attempt in range(_PARSE_MAX_RETRIES + 1):
|
|
408
|
+
def _call(m: list[dict[str, Any]] = msgs) -> Any:
|
|
409
|
+
return self._client.beta.chat.completions.parse(
|
|
410
|
+
model=model,
|
|
411
|
+
messages=m, # type: ignore[arg-type]
|
|
412
|
+
response_format=response_format,
|
|
413
|
+
max_tokens=_max_tokens,
|
|
414
|
+
)
|
|
415
|
+
|
|
416
|
+
result = self._call_with_retry(_call)
|
|
417
|
+
if not result.choices:
|
|
418
|
+
raise ApiError("API returned no completion choices")
|
|
419
|
+
|
|
420
|
+
choice = result.choices[0]
|
|
421
|
+
refusal = getattr(choice.message, "refusal", None)
|
|
422
|
+
if refusal:
|
|
423
|
+
raise ApiError(f"Model refused request: {refusal}")
|
|
424
|
+
if choice.finish_reason == _FINISH_REASON_TRUNCATED:
|
|
425
|
+
raise SchemaError(
|
|
426
|
+
f"Response truncated (max_tokens too low) for {response_format.__name__}"
|
|
427
|
+
)
|
|
428
|
+
if choice.finish_reason == _FINISH_REASON_FILTERED:
|
|
429
|
+
raise ApiError("Response blocked by content policy")
|
|
430
|
+
if choice.message.parsed is not None:
|
|
431
|
+
return choice.message.parsed # type: ignore[return-value]
|
|
432
|
+
|
|
433
|
+
if attempt < _PARSE_MAX_RETRIES:
|
|
434
|
+
log.warning(
|
|
435
|
+
"Schema parse failed (attempt %d/%d), retrying with correction",
|
|
436
|
+
attempt + 1,
|
|
437
|
+
_PARSE_MAX_RETRIES + 1,
|
|
438
|
+
)
|
|
439
|
+
msgs = msgs + [
|
|
440
|
+
{"role": "assistant", "content": choice.message.content or ""},
|
|
441
|
+
{"role": "user", "content": _SCHEMA_CORRECTION_PROMPT},
|
|
442
|
+
]
|
|
443
|
+
|
|
444
|
+
raise SchemaError(
|
|
445
|
+
f"Failed to parse {response_format.__name__} "
|
|
446
|
+
f"after {_PARSE_MAX_RETRIES + 1} attempts"
|
|
447
|
+
)
|
|
448
|
+
|
|
449
|
+
def _dispatch(
|
|
450
|
+
self,
|
|
451
|
+
messages: list[dict[str, Any]],
|
|
452
|
+
*,
|
|
453
|
+
model: str,
|
|
454
|
+
tools: list[dict[str, Any]] | None = None,
|
|
455
|
+
**kwargs: Any,
|
|
456
|
+
) -> Any:
|
|
457
|
+
"""Route request to a local provider if configured, else return None.
|
|
458
|
+
|
|
459
|
+
Returns the provider response when a local provider handles the
|
|
460
|
+
request, or None to signal the caller should use the cloud path.
|
|
461
|
+
|
|
462
|
+
1. If a local provider is configured and reachable, use it.
|
|
463
|
+
2. If local_only=True and local is unreachable -> raise ConfigError.
|
|
464
|
+
3. Otherwise fall back to cloud with a WARNING log and return None.
|
|
465
|
+
|
|
466
|
+
Raises:
|
|
467
|
+
ConfigError: if local_only=True and local provider is unreachable.
|
|
468
|
+
ConfigError: if local provider rejects the request (e.g. no tools).
|
|
469
|
+
"""
|
|
470
|
+
local_providers = getattr(self._cfg, "local_providers", None) or []
|
|
471
|
+
if not local_providers:
|
|
472
|
+
return None
|
|
473
|
+
|
|
474
|
+
local_cfg: LocalProviderConfig = local_providers[0]
|
|
475
|
+
|
|
476
|
+
try:
|
|
477
|
+
if local_cfg.provider == "ollama":
|
|
478
|
+
provider = OllamaProvider(local_cfg)
|
|
479
|
+
return provider.complete(messages, tools=tools)
|
|
480
|
+
elif local_cfg.provider == "vllm":
|
|
481
|
+
vllm_provider = VLLMProvider(local_cfg)
|
|
482
|
+
client = vllm_provider.get_client()
|
|
483
|
+
_tools: Any = tools if tools is not None else openai.NOT_GIVEN
|
|
484
|
+
return client.chat.completions.create(
|
|
485
|
+
model=model,
|
|
486
|
+
messages=messages, # type: ignore[arg-type]
|
|
487
|
+
tools=_tools,
|
|
488
|
+
)
|
|
489
|
+
except ConfigError:
|
|
490
|
+
raise # Re-raise config errors (e.g. tool support rejection)
|
|
491
|
+
except Exception as exc:
|
|
492
|
+
if local_cfg.local_only:
|
|
493
|
+
raise ConfigError(
|
|
494
|
+
"Local provider unreachable and local_only=true"
|
|
495
|
+
" -- refusing cloud fallback"
|
|
496
|
+
) from exc
|
|
497
|
+
log.warning(
|
|
498
|
+
"Local provider unreachable (%s), falling back to cloud: %s",
|
|
499
|
+
local_cfg.provider,
|
|
500
|
+
exc,
|
|
501
|
+
)
|
|
502
|
+
return None
|
|
503
|
+
|
|
504
|
+
|
|
505
|
+
def complete_text(
|
|
506
|
+
self,
|
|
507
|
+
prompt: str,
|
|
508
|
+
*,
|
|
509
|
+
model: str,
|
|
510
|
+
max_tokens: int = 200,
|
|
511
|
+
system: str | None = None,
|
|
512
|
+
) -> str:
|
|
513
|
+
"""Convenience wrapper for simple text completion. Returns response text.
|
|
514
|
+
|
|
515
|
+
Raises:
|
|
516
|
+
ApiError: if the API returns no choices.
|
|
517
|
+
"""
|
|
518
|
+
messages: list[dict[str, Any]] = []
|
|
519
|
+
if system:
|
|
520
|
+
messages.append({"role": "system", "content": system})
|
|
521
|
+
messages.append({"role": "user", "content": prompt})
|
|
522
|
+
result = self.complete(messages, model=model, max_tokens=max_tokens)
|
|
523
|
+
if not result.choices:
|
|
524
|
+
raise ApiError("API returned no completion choices")
|
|
525
|
+
return result.choices[0].message.content or ""
|
|
526
|
+
|
|
527
|
+
|
|
528
|
+
# ---------------------------------------------------------------------------
|
|
529
|
+
# Batch API
|
|
530
|
+
# ---------------------------------------------------------------------------
|
|
531
|
+
|
|
532
|
+
_BATCH_DISCOUNT_RATE: dict[str, float] = {"gemini": 0.50, "codex": 0.50}
|
|
533
|
+
_BATCH_MIN_SAVINGS_USD: float = 0.01
|
|
534
|
+
|
|
535
|
+
|
|
536
|
+
def should_use_batch(requests: list[dict], provider: str) -> bool:
|
|
537
|
+
"""Return True when batch submission is worthwhile for this provider/request set."""
|
|
538
|
+
if provider == "grok":
|
|
539
|
+
return False
|
|
540
|
+
if len(requests) < 5:
|
|
541
|
+
return False
|
|
542
|
+
estimated_tokens = sum(len(json.dumps(r)) // 4 for r in requests)
|
|
543
|
+
savings = estimated_tokens / 1_000_000 * _BATCH_DISCOUNT_RATE.get(provider, 0)
|
|
544
|
+
return savings > _BATCH_MIN_SAVINGS_USD
|
|
545
|
+
|
|
546
|
+
|
|
547
|
+
class OpenAIBatchBackend:
|
|
548
|
+
"""OpenAI Batch API backend — uploads JSONL, creates batch, polls, retrieves."""
|
|
549
|
+
|
|
550
|
+
def __init__(self, client: Any) -> None:
|
|
551
|
+
self._client = client
|
|
552
|
+
|
|
553
|
+
def submit(self, requests: list[dict]) -> str:
|
|
554
|
+
buf = io.BytesIO("\n".join(json.dumps(r) for r in requests).encode())
|
|
555
|
+
f = self._client.files.create(file=("batch.jsonl", buf), purpose="batch")
|
|
556
|
+
batch = self._client.batches.create(
|
|
557
|
+
input_file_id=f.id,
|
|
558
|
+
endpoint="/v1/chat/completions",
|
|
559
|
+
completion_window="24h",
|
|
560
|
+
)
|
|
561
|
+
return batch.id
|
|
562
|
+
|
|
563
|
+
def poll(self, batch_id: str) -> str:
|
|
564
|
+
return self._client.batches.retrieve(batch_id).status
|
|
565
|
+
|
|
566
|
+
def retrieve(self, batch_id: str) -> list[dict]:
|
|
567
|
+
batch = self._client.batches.retrieve(batch_id)
|
|
568
|
+
content = self._client.files.content(batch.output_file_id).text
|
|
569
|
+
return [json.loads(line) for line in content.splitlines() if line.strip()]
|
|
570
|
+
|
|
571
|
+
|
|
572
|
+
class GeminiBatchBackend:
|
|
573
|
+
"""Gemini batchGenerateContent backend."""
|
|
574
|
+
|
|
575
|
+
def __init__(self, api_key: str, base_url: str = "https://generativelanguage.googleapis.com") -> None:
|
|
576
|
+
self._api_key = api_key
|
|
577
|
+
self._base_url = base_url
|
|
578
|
+
|
|
579
|
+
def submit(self, requests: list[dict]) -> str:
|
|
580
|
+
import urllib.request as _ur
|
|
581
|
+
payload = json.dumps({"requests": requests}).encode()
|
|
582
|
+
req = _ur.Request(
|
|
583
|
+
f"{self._base_url}/v1beta/models:batchGenerateContent?key={self._api_key}",
|
|
584
|
+
data=payload,
|
|
585
|
+
headers={"Content-Type": "application/json"},
|
|
586
|
+
method="POST",
|
|
587
|
+
)
|
|
588
|
+
with _ur.urlopen(req, timeout=30) as r:
|
|
589
|
+
body = json.loads(r.read())
|
|
590
|
+
return body.get("name", str(uuid.uuid4()))
|
|
591
|
+
|
|
592
|
+
def poll(self, batch_id: str) -> str:
|
|
593
|
+
import urllib.request as _ur
|
|
594
|
+
url = f"{self._base_url}/v1beta/{batch_id}?key={self._api_key}"
|
|
595
|
+
with _ur.urlopen(url, timeout=30) as r:
|
|
596
|
+
body = json.loads(r.read())
|
|
597
|
+
return "completed" if body.get("done") else "in_progress"
|
|
598
|
+
|
|
599
|
+
def retrieve(self, batch_id: str) -> list[dict]:
|
|
600
|
+
import urllib.request as _ur
|
|
601
|
+
url = f"{self._base_url}/v1beta/{batch_id}/responses?key={self._api_key}"
|
|
602
|
+
with _ur.urlopen(url, timeout=30) as r:
|
|
603
|
+
return json.loads(r.read()).get("responses", [])
|
|
604
|
+
|
|
605
|
+
|
|
606
|
+
class GrokSequentialBackend:
|
|
607
|
+
"""Grok has no batch API — sequential calls with jittered exponential backoff."""
|
|
608
|
+
|
|
609
|
+
def __init__(self, client: Any) -> None:
|
|
610
|
+
self._client = client
|
|
611
|
+
self._results: dict[str, list[dict]] = {}
|
|
612
|
+
|
|
613
|
+
def submit(self, requests: list[dict]) -> str:
|
|
614
|
+
results = []
|
|
615
|
+
for i, req in enumerate(requests):
|
|
616
|
+
for attempt in range(5):
|
|
617
|
+
try:
|
|
618
|
+
result = self._client.chat.completions.create(**req)
|
|
619
|
+
results.append(
|
|
620
|
+
result.model_dump() if hasattr(result, "model_dump") else dict(result)
|
|
621
|
+
)
|
|
622
|
+
break
|
|
623
|
+
except Exception as e:
|
|
624
|
+
if "rate" in str(e).lower():
|
|
625
|
+
wait = (2 ** attempt) + random.uniform(0, 1)
|
|
626
|
+
log.warning("Rate limit on request %d, backing off %.1fs", i, wait)
|
|
627
|
+
time.sleep(wait)
|
|
628
|
+
else:
|
|
629
|
+
raise
|
|
630
|
+
batch_id = f"grok-seq-{uuid.uuid4().hex[:8]}"
|
|
631
|
+
self._results[batch_id] = results
|
|
632
|
+
return batch_id
|
|
633
|
+
|
|
634
|
+
def poll(self, batch_id: str) -> str:
|
|
635
|
+
return "completed"
|
|
636
|
+
|
|
637
|
+
def retrieve(self, batch_id: str) -> list[dict]:
|
|
638
|
+
return self._results.get(batch_id, [])
|
|
639
|
+
|
|
640
|
+
|
|
641
|
+
class BatchClient:
|
|
642
|
+
"""Provider-agnostic batch submission client with DB persistence."""
|
|
643
|
+
|
|
644
|
+
def __init__(self, provider: str, backend: Any) -> None:
|
|
645
|
+
self.provider = provider
|
|
646
|
+
self._backend = backend
|
|
647
|
+
self._db_conn: Any = None # set externally if crash-recovery needed
|
|
648
|
+
|
|
649
|
+
def submit(self, requests: list[dict], job_type: str = "index") -> str:
|
|
650
|
+
"""Redact secrets, submit batch, persist to DB. Returns job UUID."""
|
|
651
|
+
import datetime
|
|
652
|
+
from src.security import redact
|
|
653
|
+
safe = [json.loads(redact(json.dumps(r))) for r in requests]
|
|
654
|
+
log.info("Submitting batch of %d requests (%s/%s)", len(safe), self.provider, job_type)
|
|
655
|
+
batch_id = self._backend.submit(safe)
|
|
656
|
+
job_id = str(uuid.uuid4())
|
|
657
|
+
if self._db_conn is not None:
|
|
658
|
+
now = datetime.datetime.utcnow().isoformat()
|
|
659
|
+
self._db_conn.execute(
|
|
660
|
+
"INSERT INTO batch_jobs "
|
|
661
|
+
"(id, provider, batch_id, job_type, status, request_count, created_at, submitted_at) "
|
|
662
|
+
"VALUES (?, ?, ?, ?, 'submitted', ?, ?, ?)",
|
|
663
|
+
(job_id, self.provider, batch_id, job_type, len(safe), now, now),
|
|
664
|
+
)
|
|
665
|
+
self._db_conn.commit()
|
|
666
|
+
log.info("Batch submitted: job_id=%s batch_id=%s", job_id, batch_id)
|
|
667
|
+
return job_id
|
|
668
|
+
|
|
669
|
+
def poll(self, job_id: str) -> str:
|
|
670
|
+
"""Poll provider status, update DB, return internal status string."""
|
|
671
|
+
import datetime
|
|
672
|
+
if self._db_conn is None:
|
|
673
|
+
return "unknown"
|
|
674
|
+
row = self._db_conn.execute(
|
|
675
|
+
"SELECT batch_id, status FROM batch_jobs WHERE id=?", (job_id,)
|
|
676
|
+
).fetchone()
|
|
677
|
+
if not row:
|
|
678
|
+
return "unknown"
|
|
679
|
+
batch_id, current_status = row[0], row[1]
|
|
680
|
+
if current_status in ("completed", "failed", "cancelled"):
|
|
681
|
+
return current_status
|
|
682
|
+
provider_status = self._backend.poll(batch_id)
|
|
683
|
+
status = _map_status(provider_status, self.provider)
|
|
684
|
+
updates: dict[str, str] = {"status": status}
|
|
685
|
+
if status in ("completed", "failed"):
|
|
686
|
+
updates["completed_at"] = datetime.datetime.utcnow().isoformat()
|
|
687
|
+
for k, v in updates.items():
|
|
688
|
+
self._db_conn.execute(f"UPDATE batch_jobs SET {k}=? WHERE id=?", (v, job_id))
|
|
689
|
+
self._db_conn.commit()
|
|
690
|
+
return status
|
|
691
|
+
|
|
692
|
+
def retrieve(self, job_id: str) -> list[dict]:
|
|
693
|
+
"""Retrieve completed batch results."""
|
|
694
|
+
if self._db_conn is not None:
|
|
695
|
+
row = self._db_conn.execute(
|
|
696
|
+
"SELECT batch_id FROM batch_jobs WHERE id=?", (job_id,)
|
|
697
|
+
).fetchone()
|
|
698
|
+
if row:
|
|
699
|
+
return self._backend.retrieve(row[0])
|
|
700
|
+
return []
|
|
701
|
+
|
|
702
|
+
|
|
703
|
+
def _map_status(provider_status: str, provider: str) -> str:
|
|
704
|
+
"""Normalise provider-specific status strings to internal status."""
|
|
705
|
+
_completed = {"completed", "succeeded", "done"}
|
|
706
|
+
_failed = {"failed", "error", "expired"}
|
|
707
|
+
_cancelled = {"cancelled", "canceled"}
|
|
708
|
+
if provider_status in _completed:
|
|
709
|
+
return "completed"
|
|
710
|
+
if provider_status in _failed:
|
|
711
|
+
return "failed"
|
|
712
|
+
if provider_status in _cancelled:
|
|
713
|
+
return "cancelled"
|
|
714
|
+
return "polling"
|
|
715
|
+
|