gdmcode 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (131) hide show
  1. gdmcode-0.1.0.dist-info/METADATA +240 -0
  2. gdmcode-0.1.0.dist-info/RECORD +131 -0
  3. gdmcode-0.1.0.dist-info/WHEEL +4 -0
  4. gdmcode-0.1.0.dist-info/entry_points.txt +2 -0
  5. src/__init__.py +1 -0
  6. src/_internal/__init__.py +0 -0
  7. src/_internal/constants.py +244 -0
  8. src/_internal/domain_skills.py +339 -0
  9. src/agent/__init__.py +0 -0
  10. src/agent/commit_classifier.py +91 -0
  11. src/agent/context_budget.py +391 -0
  12. src/agent/daemon.py +681 -0
  13. src/agent/dag_validator.py +153 -0
  14. src/agent/debug_loop.py +473 -0
  15. src/agent/impact_analyzer.py +149 -0
  16. src/agent/impact_graph.py +117 -0
  17. src/agent/loop.py +1410 -0
  18. src/agent/orchestrator.py +141 -0
  19. src/agent/regression_guard.py +251 -0
  20. src/agent/review_gate.py +648 -0
  21. src/agent/risk_scorer.py +169 -0
  22. src/agent/self_healing.py +145 -0
  23. src/agent/smart_test_selector.py +89 -0
  24. src/agent/system_prompt.py +226 -0
  25. src/agent/task_tracker.py +320 -0
  26. src/agent/test_validator.py +210 -0
  27. src/agent/tool_orchestrator.py +402 -0
  28. src/agent/transcript.py +230 -0
  29. src/agent/verification_loop.py +133 -0
  30. src/agent/work_director.py +136 -0
  31. src/agent/worktree_manager.py +53 -0
  32. src/artifacts/__init__.py +16 -0
  33. src/artifacts/artifact_store.py +456 -0
  34. src/artifacts/verification_graph.py +75 -0
  35. src/auth.py +411 -0
  36. src/cli.py +1290 -0
  37. src/commands.py +1398 -0
  38. src/config.py +762 -0
  39. src/cost_tracker.py +348 -0
  40. src/db/__init__.py +4 -0
  41. src/db/migrations.py +337 -0
  42. src/enterprise/__init__.py +3 -0
  43. src/enterprise/audit_log.py +182 -0
  44. src/enterprise/identity.py +90 -0
  45. src/enterprise/rbac.py +100 -0
  46. src/enterprise/team_config.py +125 -0
  47. src/enterprise/usage_analytics.py +261 -0
  48. src/exceptions.py +207 -0
  49. src/git_workflow.py +651 -0
  50. src/integrations/__init__.py +6 -0
  51. src/integrations/github_actions.py +106 -0
  52. src/integrations/mcp_server.py +333 -0
  53. src/integrations/sentry_integration.py +100 -0
  54. src/integrations/sentry_server.py +82 -0
  55. src/integrations/webhook_security.py +19 -0
  56. src/main.py +27 -0
  57. src/memory/__init__.py +0 -0
  58. src/memory/code_index.py +376 -0
  59. src/memory/compressor.py +378 -0
  60. src/memory/context_memory.py +135 -0
  61. src/memory/continuous_memory.py +234 -0
  62. src/memory/conventions.py +495 -0
  63. src/memory/db.py +1119 -0
  64. src/memory/document_index.py +205 -0
  65. src/memory/file_cache.py +128 -0
  66. src/memory/project_scanner.py +178 -0
  67. src/memory/session_store.py +201 -0
  68. src/models/__init__.py +0 -0
  69. src/models/client.py +715 -0
  70. src/models/definitions.py +459 -0
  71. src/models/router.py +418 -0
  72. src/models/schemas.py +389 -0
  73. src/permissions.py +294 -0
  74. src/remote/__init__.py +5 -0
  75. src/remote/command_filter.py +33 -0
  76. src/remote/models.py +31 -0
  77. src/remote/permission_handler.py +79 -0
  78. src/remote/phone_ui.py +48 -0
  79. src/remote/protocol.py +59 -0
  80. src/remote/qr.py +65 -0
  81. src/remote/server.py +586 -0
  82. src/remote/token_manager.py +61 -0
  83. src/remote/tunnel.py +212 -0
  84. src/repl.py +475 -0
  85. src/runtime/__init__.py +1 -0
  86. src/runtime/branch_farm.py +372 -0
  87. src/runtime/replay.py +351 -0
  88. src/sandbox/__init__.py +2 -0
  89. src/sandbox/hermetic.py +214 -0
  90. src/sandbox/policy.py +44 -0
  91. src/sdk/__init__.py +3 -0
  92. src/sdk/plugin_base.py +39 -0
  93. src/sdk/plugin_host.py +100 -0
  94. src/sdk/plugin_loader.py +101 -0
  95. src/security.py +409 -0
  96. src/server/__init__.py +7 -0
  97. src/server/bridge.py +427 -0
  98. src/server/bridge_cli.py +103 -0
  99. src/server/bridge_client.py +170 -0
  100. src/server/protocol_version.py +103 -0
  101. src/session/__init__.py +10 -0
  102. src/session/event_fanout.py +46 -0
  103. src/session/input_broker.py +38 -0
  104. src/session/permission_bridge.py +100 -0
  105. src/tools/__init__.py +160 -0
  106. src/tools/_atomic.py +72 -0
  107. src/tools/agent_tools.py +423 -0
  108. src/tools/ask_user_tool.py +83 -0
  109. src/tools/bash_tool.py +384 -0
  110. src/tools/browser_tool.py +352 -0
  111. src/tools/browser_tools.py +179 -0
  112. src/tools/dep_tools.py +210 -0
  113. src/tools/document_reader.py +167 -0
  114. src/tools/document_tool.py +240 -0
  115. src/tools/document_writer.py +171 -0
  116. src/tools/impact_tools.py +240 -0
  117. src/tools/playwright_tool.py +172 -0
  118. src/tools/quality_tools.py +366 -0
  119. src/tools/read_tools.py +318 -0
  120. src/tools/result_cache.py +157 -0
  121. src/tools/search_tools.py +310 -0
  122. src/tools/shell_tools.py +311 -0
  123. src/tools/write_tools.py +337 -0
  124. src/voice/__init__.py +25 -0
  125. src/voice/audio_capture.py +92 -0
  126. src/voice/audio_playback.py +68 -0
  127. src/voice/errors.py +14 -0
  128. src/voice/models.py +35 -0
  129. src/voice/providers.py +143 -0
  130. src/voice/vad.py +55 -0
  131. src/voice/voice_loop.py +156 -0
src/models/client.py ADDED
@@ -0,0 +1,715 @@
1
+ """GdmClient — structured-output-aware wrapper around openai.OpenAI.
2
+
3
+ Provides retry with exponential backoff, structured parsing via
4
+ client.beta.chat.completions.parse(), and provider-specific configuration.
5
+ """
6
+ from __future__ import annotations
7
+
8
+ import io
9
+ import json
10
+ import logging
11
+ import random
12
+ import time
13
+ import uuid
14
+ from collections.abc import Callable
15
+ from dataclasses import dataclass, field, replace as _dc_replace
16
+ from typing import Any, TypeVar
17
+
18
+ import openai
19
+ from pydantic import BaseModel
20
+
21
+ from src.config import GdmConfig
22
+ from src.exceptions import ApiError, ApiRateLimitError, ConfigError, SchemaError
23
+ from src.models.definitions import PROVIDER_BASE_URLS, Provider
24
+
25
+ __all__ = [
26
+ "GdmClient",
27
+ "BaseModelT",
28
+ "LocalProviderConfig",
29
+ "LocalModelCapabilities",
30
+ "OllamaProvider",
31
+ "VLLMProvider",
32
+ "_validate_local_base_url",
33
+ "should_use_batch",
34
+ "OpenAIBatchBackend",
35
+ "GeminiBatchBackend",
36
+ "GrokSequentialBackend",
37
+ "BatchClient",
38
+ "_map_status",
39
+ ]
40
+
41
+ log = logging.getLogger(__name__)
42
+
43
+ BaseModelT = TypeVar("BaseModelT", bound=BaseModel)
44
+
45
+ _MAX_RETRIES: int = 3
46
+ _BACKOFF_BASE: float = 2.0
47
+ _PARSE_MAX_RETRIES: int = 2
48
+ _SCHEMA_CORRECTION_PROMPT: str = (
49
+ "Your previous response did not conform to the required JSON schema. "
50
+ "Please try again and ensure your response matches the schema exactly."
51
+ )
52
+ _RETRYABLE_ERRORS: tuple[type[Exception], ...] = (
53
+ openai.RateLimitError,
54
+ openai.InternalServerError,
55
+ openai.APIConnectionError,
56
+ openai.APITimeoutError,
57
+ )
58
+ _FINISH_REASON_TRUNCATED: str = "length"
59
+ _FINISH_REASON_FILTERED: str = "content_filter"
60
+
61
+ # ---------------------------------------------------------------------------
62
+ # Local provider support
63
+ # ---------------------------------------------------------------------------
64
+
65
+ _ALLOWED_LOCAL_HOSTS: frozenset[str] = frozenset({"localhost", "127.0.0.1", "::1"})
66
+ _LOCAL_PROVIDERS: frozenset[str] = frozenset({"ollama", "vllm"})
67
+
68
+
69
+ @dataclass
70
+ class LocalProviderConfig:
71
+ """Configuration for a local (non-cloud) model provider."""
72
+
73
+ provider: str # "ollama" or "vllm"
74
+ base_url: str
75
+ model: str
76
+ local_only: bool = True
77
+ allow_cloud_fallback: bool = False
78
+ allow_external: bool = False
79
+
80
+
81
+ @dataclass
82
+ class LocalModelCapabilities:
83
+ """Capabilities reported by a local model provider."""
84
+
85
+ supports_tools: bool = False
86
+ supports_json_mode: bool = False
87
+ context_window: int = 4096
88
+ streaming: bool = True
89
+
90
+
91
+ def _validate_local_base_url(url: str, allow_external: bool = False) -> None:
92
+ """Raise ConfigError if url is non-local and allow_external is False."""
93
+ from urllib.parse import urlparse
94
+ parsed = urlparse(url)
95
+ if not allow_external and parsed.hostname not in _ALLOWED_LOCAL_HOSTS:
96
+ raise ConfigError(
97
+ f"Local provider base_url '{url}' resolves to non-local host "
98
+ f"'{parsed.hostname}'. Set allow_external=true to override."
99
+ )
100
+
101
+
102
+ class OllamaProvider:
103
+ """Native Ollama /api/chat client (NOT OpenAI-compatible)."""
104
+
105
+ def __init__(self, config: LocalProviderConfig) -> None:
106
+ _validate_local_base_url(config.base_url, config.allow_external)
107
+ self.config = config
108
+ self._capabilities: LocalModelCapabilities | None = None
109
+ import urllib.request
110
+ self._http = urllib.request # stdlib only — no requests dep
111
+
112
+ def complete(self, messages: list[dict], *, tools: list | None = None, **kw: Any) -> dict:
113
+ """POST to /api/chat and return adapted OpenAI-shaped dict."""
114
+ import json
115
+
116
+ if tools is not None:
117
+ caps = self.probe_capabilities()
118
+ if not caps.supports_tools:
119
+ raise ConfigError(
120
+ f"Model '{self.config.model}' does not support tool calls "
121
+ "(supports_tools=False). Remove tools or use a capable model."
122
+ )
123
+
124
+ payload: dict[str, Any] = {
125
+ "model": self.config.model,
126
+ "messages": messages,
127
+ "stream": False,
128
+ }
129
+ if tools is not None:
130
+ payload["tools"] = tools
131
+
132
+ url = f"{self.config.base_url.rstrip('/')}/api/chat"
133
+ data = json.dumps(payload).encode("utf-8")
134
+ req = self._http.Request(
135
+ url,
136
+ data=data,
137
+ headers={"Content-Type": "application/json"},
138
+ method="POST",
139
+ )
140
+ with self._http.urlopen(req, timeout=60) as resp:
141
+ raw: dict = json.loads(resp.read())
142
+ return self.to_openai_response(raw)
143
+
144
+ def to_openai_response(self, raw: dict) -> dict:
145
+ """Adapt Ollama /api/chat response to OpenAI ChatCompletion shape.
146
+
147
+ raw["message"]["content"] -> choices[0].message.content
148
+ raw["message"]["tool_calls"] -> choices[0].message.tool_calls (if present)
149
+ """
150
+ message = raw.get("message", {})
151
+ content = message.get("content", "")
152
+ tool_calls = message.get("tool_calls")
153
+
154
+ msg: dict[str, Any] = {"role": "assistant", "content": content}
155
+ if tool_calls is not None:
156
+ msg["tool_calls"] = tool_calls
157
+
158
+ return {
159
+ "id": f"ollama-{raw.get('created_at', 'response')}",
160
+ "object": "chat.completion",
161
+ "model": raw.get("model", self.config.model),
162
+ "choices": [
163
+ {
164
+ "index": 0,
165
+ "message": msg,
166
+ "finish_reason": "stop",
167
+ }
168
+ ],
169
+ "usage": {
170
+ "prompt_tokens": raw.get("prompt_eval_count", 0),
171
+ "completion_tokens": raw.get("eval_count", 0),
172
+ "total_tokens": (
173
+ raw.get("prompt_eval_count", 0) + raw.get("eval_count", 0)
174
+ ),
175
+ },
176
+ }
177
+
178
+ def probe_capabilities(self) -> LocalModelCapabilities:
179
+ """Lazy-probe /api/show for model capabilities. Cached per instance."""
180
+ if self._capabilities is not None:
181
+ return self._capabilities
182
+
183
+ import json
184
+
185
+ try:
186
+ url = f"{self.config.base_url.rstrip('/')}/api/show"
187
+ data = json.dumps({"name": self.config.model}).encode("utf-8")
188
+ req = self._http.Request(
189
+ url,
190
+ data=data,
191
+ headers={"Content-Type": "application/json"},
192
+ method="POST",
193
+ )
194
+ with self._http.urlopen(req, timeout=10) as resp:
195
+ show_data: dict = json.loads(resp.read())
196
+
197
+ modelinfo = show_data.get("modelinfo", {})
198
+ context_window = (
199
+ modelinfo.get("llama.context_length")
200
+ or modelinfo.get("context_length")
201
+ or 4096
202
+ )
203
+ self._capabilities = LocalModelCapabilities(
204
+ supports_tools=bool(show_data.get("supports_tools", False)),
205
+ supports_json_mode=bool(show_data.get("supports_json_mode", False)),
206
+ context_window=int(context_window),
207
+ streaming=True,
208
+ )
209
+ except Exception:
210
+ self._capabilities = LocalModelCapabilities()
211
+
212
+ return self._capabilities
213
+
214
+
215
+ class VLLMProvider:
216
+ """vLLM provider -- OpenAI-compatible, reuses openai.OpenAI with custom base_url."""
217
+
218
+ def __init__(self, config: LocalProviderConfig) -> None:
219
+ _validate_local_base_url(config.base_url, config.allow_external)
220
+ self.config = config
221
+
222
+ def get_client(self) -> "openai.OpenAI":
223
+ """Return openai.OpenAI(base_url=..., api_key='not-needed')."""
224
+ return openai.OpenAI(
225
+ base_url=self.config.base_url,
226
+ api_key="not-needed",
227
+ )
228
+
229
+
230
+
231
+ class GdmClient:
232
+ """OpenAI-compatible client with retry, structured output, and provider abstraction.
233
+
234
+ Wraps openai.OpenAI with:
235
+ - Structured output via client.beta.chat.completions.parse() with Pydantic models
236
+ - Exponential backoff retry (3 retries, 2x backoff) on rate limit / server errors
237
+ - Provider-specific configuration (base URL, API key)
238
+ - Sync interface only (async in Phase 3)
239
+ """
240
+
241
+ def __init__(self, cfg: GdmConfig) -> None:
242
+ """Initialise the client using provider config from GdmConfig.
243
+
244
+ Raises:
245
+ ConfigError: if cfg.provider is not a recognised provider.
246
+ """
247
+ if cfg.provider not in PROVIDER_BASE_URLS:
248
+ raise ConfigError(
249
+ f"Unknown provider '{cfg.provider}'. "
250
+ f"Valid providers: {list(PROVIDER_BASE_URLS)}"
251
+ )
252
+ self._client = openai.OpenAI(
253
+ api_key=cfg.api_key,
254
+ base_url=PROVIDER_BASE_URLS[cfg.provider],
255
+ )
256
+ self._cfg = cfg
257
+
258
+ @classmethod
259
+ def for_provider(cls, provider: str, cfg: GdmConfig) -> "GdmClient":
260
+ """Create a GdmClient configured for *provider* using keys from *cfg*.
261
+
262
+ Used to switch providers at runtime (e.g. API fallback) without
263
+ reloading config. The returned client has a modified cfg snapshot
264
+ reflecting the fallback provider/key.
265
+
266
+ Raises:
267
+ ConfigError: if no API key is available for *provider*, or if
268
+ *provider* is not a recognised provider.
269
+ """
270
+ if provider not in PROVIDER_BASE_URLS:
271
+ raise ConfigError(
272
+ f"Unknown provider '{provider}'. "
273
+ f"Valid providers: {list(PROVIDER_BASE_URLS)}"
274
+ )
275
+ key_map: dict[str, str | None] = {
276
+ Provider.GROK: cfg.xai_api_key,
277
+ Provider.GEMINI: cfg.gemini_api_key,
278
+ Provider.CODEX: cfg.openai_api_key,
279
+ }
280
+ api_key = key_map.get(provider)
281
+ if not api_key:
282
+ raise ConfigError(
283
+ f"No API key configured for fallback provider '{provider}'. "
284
+ f"Set XAI_API_KEY / GEMINI_API_KEY / OPENAI_API_KEY as appropriate."
285
+ )
286
+ instance = cls.__new__(cls)
287
+ instance._client = openai.OpenAI(
288
+ api_key=api_key,
289
+ base_url=PROVIDER_BASE_URLS[provider],
290
+ )
291
+ instance._cfg = _dc_replace(cfg, provider=provider, api_key=api_key)
292
+ return instance
293
+
294
+ @classmethod
295
+ def for_proxy(cls, proxy_url: str, proxy_token: str) -> "GdmClient":
296
+ """Route all LLM calls through a proxy server.
297
+
298
+ Used when direct API access is geo-restricted (e.g. Grok unavailable
299
+ in DRC). The proxy must expose an OpenAI-compatible /v1 endpoint.
300
+ ``proxy_token`` is sent as the Bearer auth header; HTTPS is strongly
301
+ recommended so the token is protected in transit.
302
+ """
303
+ if not proxy_url.startswith("https://") and not proxy_url.startswith("http://localhost"):
304
+ log.warning(
305
+ "Proxy URL '%s' does not use HTTPS — token will be sent unencrypted",
306
+ proxy_url,
307
+ )
308
+ instance = cls.__new__(cls)
309
+ instance._client = openai.OpenAI(
310
+ api_key=proxy_token,
311
+ base_url=proxy_url.rstrip("/"),
312
+ )
313
+ # _cfg = None is safe: _dispatch() reads it via getattr(self._cfg, ..., None)
314
+ instance._cfg = None
315
+ return instance
316
+
317
+ def _call_with_retry(self, fn: Callable[[], Any]) -> Any:
318
+ """Execute fn with exponential backoff retry on transient API errors.
319
+
320
+ Retries up to _MAX_RETRIES times on rate-limit, server, connection,
321
+ and timeout errors. Non-transient APIErrors are re-raised immediately
322
+ as ApiError.
323
+ """
324
+ last_exc: BaseException = RuntimeError("unreachable sentinel")
325
+ for attempt in range(_MAX_RETRIES + 1):
326
+ try:
327
+ return fn()
328
+ except _RETRYABLE_ERRORS as exc:
329
+ last_exc = exc
330
+ if attempt < _MAX_RETRIES:
331
+ wait = _BACKOFF_BASE ** attempt
332
+ log.warning(
333
+ "Retryable API error (%s), backing off %.0fs (retry %d/%d)",
334
+ type(exc).__name__,
335
+ wait,
336
+ attempt + 1,
337
+ _MAX_RETRIES,
338
+ )
339
+ time.sleep(wait)
340
+ except openai.APIError as exc:
341
+ raise ApiError(
342
+ str(exc), status_code=getattr(exc, "status_code", None)
343
+ ) from exc
344
+ status = getattr(last_exc, "status_code", None)
345
+ if isinstance(last_exc, openai.RateLimitError):
346
+ raise ApiRateLimitError(
347
+ f"Rate limit exceeded after {_MAX_RETRIES} retries: {last_exc}",
348
+ status_code=status,
349
+ ) from last_exc
350
+ raise ApiError(
351
+ f"API call failed after {_MAX_RETRIES} retries: {last_exc}",
352
+ status_code=status,
353
+ ) from last_exc
354
+
355
+ def complete(
356
+ self,
357
+ messages: list[dict[str, Any]],
358
+ *,
359
+ model: str,
360
+ tools: list[dict[str, Any]] | None = None,
361
+ tool_choice: str = "auto",
362
+ max_tokens: int | None = None,
363
+ temperature: float | None = None,
364
+ ) -> Any:
365
+ """Make a non-structured completion request. Used by the agent loop.
366
+
367
+ Returns an openai.types.chat.ChatCompletion object.
368
+ """
369
+ _tools = tools if tools is not None else openai.NOT_GIVEN
370
+ _tool_choice: Any = tool_choice if tools is not None else openai.NOT_GIVEN
371
+ _max_tokens: Any = max_tokens if max_tokens is not None else openai.NOT_GIVEN
372
+ _temperature: Any = temperature if temperature is not None else openai.NOT_GIVEN
373
+
374
+ def _call() -> Any:
375
+ return self._client.chat.completions.create(
376
+ model=model,
377
+ messages=messages, # type: ignore[arg-type] # SDK type is narrower
378
+ tools=_tools,
379
+ tool_choice=_tool_choice,
380
+ max_tokens=_max_tokens,
381
+ temperature=_temperature,
382
+ )
383
+
384
+ return self._call_with_retry(_call)
385
+
386
+ def parse(
387
+ self,
388
+ messages: list[dict[str, Any]],
389
+ *,
390
+ model: str,
391
+ response_format: type[BaseModelT],
392
+ max_tokens: int | None = None,
393
+ ) -> BaseModelT:
394
+ """Make a structured output request via client.beta.chat.completions.parse().
395
+
396
+ Returns the parsed Pydantic model instance.
397
+ Retries up to _PARSE_MAX_RETRIES times on schema parse failure with an
398
+ explicit correction prompt.
399
+
400
+ Raises:
401
+ ApiError: on API failure, refusal, or content filter.
402
+ SchemaError: on repeated schema parse failure or response truncation.
403
+ """
404
+ msgs: list[dict[str, Any]] = list(messages)
405
+ _max_tokens: Any = max_tokens if max_tokens is not None else openai.NOT_GIVEN
406
+
407
+ for attempt in range(_PARSE_MAX_RETRIES + 1):
408
+ def _call(m: list[dict[str, Any]] = msgs) -> Any:
409
+ return self._client.beta.chat.completions.parse(
410
+ model=model,
411
+ messages=m, # type: ignore[arg-type]
412
+ response_format=response_format,
413
+ max_tokens=_max_tokens,
414
+ )
415
+
416
+ result = self._call_with_retry(_call)
417
+ if not result.choices:
418
+ raise ApiError("API returned no completion choices")
419
+
420
+ choice = result.choices[0]
421
+ refusal = getattr(choice.message, "refusal", None)
422
+ if refusal:
423
+ raise ApiError(f"Model refused request: {refusal}")
424
+ if choice.finish_reason == _FINISH_REASON_TRUNCATED:
425
+ raise SchemaError(
426
+ f"Response truncated (max_tokens too low) for {response_format.__name__}"
427
+ )
428
+ if choice.finish_reason == _FINISH_REASON_FILTERED:
429
+ raise ApiError("Response blocked by content policy")
430
+ if choice.message.parsed is not None:
431
+ return choice.message.parsed # type: ignore[return-value]
432
+
433
+ if attempt < _PARSE_MAX_RETRIES:
434
+ log.warning(
435
+ "Schema parse failed (attempt %d/%d), retrying with correction",
436
+ attempt + 1,
437
+ _PARSE_MAX_RETRIES + 1,
438
+ )
439
+ msgs = msgs + [
440
+ {"role": "assistant", "content": choice.message.content or ""},
441
+ {"role": "user", "content": _SCHEMA_CORRECTION_PROMPT},
442
+ ]
443
+
444
+ raise SchemaError(
445
+ f"Failed to parse {response_format.__name__} "
446
+ f"after {_PARSE_MAX_RETRIES + 1} attempts"
447
+ )
448
+
449
+ def _dispatch(
450
+ self,
451
+ messages: list[dict[str, Any]],
452
+ *,
453
+ model: str,
454
+ tools: list[dict[str, Any]] | None = None,
455
+ **kwargs: Any,
456
+ ) -> Any:
457
+ """Route request to a local provider if configured, else return None.
458
+
459
+ Returns the provider response when a local provider handles the
460
+ request, or None to signal the caller should use the cloud path.
461
+
462
+ 1. If a local provider is configured and reachable, use it.
463
+ 2. If local_only=True and local is unreachable -> raise ConfigError.
464
+ 3. Otherwise fall back to cloud with a WARNING log and return None.
465
+
466
+ Raises:
467
+ ConfigError: if local_only=True and local provider is unreachable.
468
+ ConfigError: if local provider rejects the request (e.g. no tools).
469
+ """
470
+ local_providers = getattr(self._cfg, "local_providers", None) or []
471
+ if not local_providers:
472
+ return None
473
+
474
+ local_cfg: LocalProviderConfig = local_providers[0]
475
+
476
+ try:
477
+ if local_cfg.provider == "ollama":
478
+ provider = OllamaProvider(local_cfg)
479
+ return provider.complete(messages, tools=tools)
480
+ elif local_cfg.provider == "vllm":
481
+ vllm_provider = VLLMProvider(local_cfg)
482
+ client = vllm_provider.get_client()
483
+ _tools: Any = tools if tools is not None else openai.NOT_GIVEN
484
+ return client.chat.completions.create(
485
+ model=model,
486
+ messages=messages, # type: ignore[arg-type]
487
+ tools=_tools,
488
+ )
489
+ except ConfigError:
490
+ raise # Re-raise config errors (e.g. tool support rejection)
491
+ except Exception as exc:
492
+ if local_cfg.local_only:
493
+ raise ConfigError(
494
+ "Local provider unreachable and local_only=true"
495
+ " -- refusing cloud fallback"
496
+ ) from exc
497
+ log.warning(
498
+ "Local provider unreachable (%s), falling back to cloud: %s",
499
+ local_cfg.provider,
500
+ exc,
501
+ )
502
+ return None
503
+
504
+
505
+ def complete_text(
506
+ self,
507
+ prompt: str,
508
+ *,
509
+ model: str,
510
+ max_tokens: int = 200,
511
+ system: str | None = None,
512
+ ) -> str:
513
+ """Convenience wrapper for simple text completion. Returns response text.
514
+
515
+ Raises:
516
+ ApiError: if the API returns no choices.
517
+ """
518
+ messages: list[dict[str, Any]] = []
519
+ if system:
520
+ messages.append({"role": "system", "content": system})
521
+ messages.append({"role": "user", "content": prompt})
522
+ result = self.complete(messages, model=model, max_tokens=max_tokens)
523
+ if not result.choices:
524
+ raise ApiError("API returned no completion choices")
525
+ return result.choices[0].message.content or ""
526
+
527
+
528
+ # ---------------------------------------------------------------------------
529
+ # Batch API
530
+ # ---------------------------------------------------------------------------
531
+
532
+ _BATCH_DISCOUNT_RATE: dict[str, float] = {"gemini": 0.50, "codex": 0.50}
533
+ _BATCH_MIN_SAVINGS_USD: float = 0.01
534
+
535
+
536
+ def should_use_batch(requests: list[dict], provider: str) -> bool:
537
+ """Return True when batch submission is worthwhile for this provider/request set."""
538
+ if provider == "grok":
539
+ return False
540
+ if len(requests) < 5:
541
+ return False
542
+ estimated_tokens = sum(len(json.dumps(r)) // 4 for r in requests)
543
+ savings = estimated_tokens / 1_000_000 * _BATCH_DISCOUNT_RATE.get(provider, 0)
544
+ return savings > _BATCH_MIN_SAVINGS_USD
545
+
546
+
547
+ class OpenAIBatchBackend:
548
+ """OpenAI Batch API backend — uploads JSONL, creates batch, polls, retrieves."""
549
+
550
+ def __init__(self, client: Any) -> None:
551
+ self._client = client
552
+
553
+ def submit(self, requests: list[dict]) -> str:
554
+ buf = io.BytesIO("\n".join(json.dumps(r) for r in requests).encode())
555
+ f = self._client.files.create(file=("batch.jsonl", buf), purpose="batch")
556
+ batch = self._client.batches.create(
557
+ input_file_id=f.id,
558
+ endpoint="/v1/chat/completions",
559
+ completion_window="24h",
560
+ )
561
+ return batch.id
562
+
563
+ def poll(self, batch_id: str) -> str:
564
+ return self._client.batches.retrieve(batch_id).status
565
+
566
+ def retrieve(self, batch_id: str) -> list[dict]:
567
+ batch = self._client.batches.retrieve(batch_id)
568
+ content = self._client.files.content(batch.output_file_id).text
569
+ return [json.loads(line) for line in content.splitlines() if line.strip()]
570
+
571
+
572
+ class GeminiBatchBackend:
573
+ """Gemini batchGenerateContent backend."""
574
+
575
+ def __init__(self, api_key: str, base_url: str = "https://generativelanguage.googleapis.com") -> None:
576
+ self._api_key = api_key
577
+ self._base_url = base_url
578
+
579
+ def submit(self, requests: list[dict]) -> str:
580
+ import urllib.request as _ur
581
+ payload = json.dumps({"requests": requests}).encode()
582
+ req = _ur.Request(
583
+ f"{self._base_url}/v1beta/models:batchGenerateContent?key={self._api_key}",
584
+ data=payload,
585
+ headers={"Content-Type": "application/json"},
586
+ method="POST",
587
+ )
588
+ with _ur.urlopen(req, timeout=30) as r:
589
+ body = json.loads(r.read())
590
+ return body.get("name", str(uuid.uuid4()))
591
+
592
+ def poll(self, batch_id: str) -> str:
593
+ import urllib.request as _ur
594
+ url = f"{self._base_url}/v1beta/{batch_id}?key={self._api_key}"
595
+ with _ur.urlopen(url, timeout=30) as r:
596
+ body = json.loads(r.read())
597
+ return "completed" if body.get("done") else "in_progress"
598
+
599
+ def retrieve(self, batch_id: str) -> list[dict]:
600
+ import urllib.request as _ur
601
+ url = f"{self._base_url}/v1beta/{batch_id}/responses?key={self._api_key}"
602
+ with _ur.urlopen(url, timeout=30) as r:
603
+ return json.loads(r.read()).get("responses", [])
604
+
605
+
606
+ class GrokSequentialBackend:
607
+ """Grok has no batch API — sequential calls with jittered exponential backoff."""
608
+
609
+ def __init__(self, client: Any) -> None:
610
+ self._client = client
611
+ self._results: dict[str, list[dict]] = {}
612
+
613
+ def submit(self, requests: list[dict]) -> str:
614
+ results = []
615
+ for i, req in enumerate(requests):
616
+ for attempt in range(5):
617
+ try:
618
+ result = self._client.chat.completions.create(**req)
619
+ results.append(
620
+ result.model_dump() if hasattr(result, "model_dump") else dict(result)
621
+ )
622
+ break
623
+ except Exception as e:
624
+ if "rate" in str(e).lower():
625
+ wait = (2 ** attempt) + random.uniform(0, 1)
626
+ log.warning("Rate limit on request %d, backing off %.1fs", i, wait)
627
+ time.sleep(wait)
628
+ else:
629
+ raise
630
+ batch_id = f"grok-seq-{uuid.uuid4().hex[:8]}"
631
+ self._results[batch_id] = results
632
+ return batch_id
633
+
634
+ def poll(self, batch_id: str) -> str:
635
+ return "completed"
636
+
637
+ def retrieve(self, batch_id: str) -> list[dict]:
638
+ return self._results.get(batch_id, [])
639
+
640
+
641
+ class BatchClient:
642
+ """Provider-agnostic batch submission client with DB persistence."""
643
+
644
+ def __init__(self, provider: str, backend: Any) -> None:
645
+ self.provider = provider
646
+ self._backend = backend
647
+ self._db_conn: Any = None # set externally if crash-recovery needed
648
+
649
+ def submit(self, requests: list[dict], job_type: str = "index") -> str:
650
+ """Redact secrets, submit batch, persist to DB. Returns job UUID."""
651
+ import datetime
652
+ from src.security import redact
653
+ safe = [json.loads(redact(json.dumps(r))) for r in requests]
654
+ log.info("Submitting batch of %d requests (%s/%s)", len(safe), self.provider, job_type)
655
+ batch_id = self._backend.submit(safe)
656
+ job_id = str(uuid.uuid4())
657
+ if self._db_conn is not None:
658
+ now = datetime.datetime.utcnow().isoformat()
659
+ self._db_conn.execute(
660
+ "INSERT INTO batch_jobs "
661
+ "(id, provider, batch_id, job_type, status, request_count, created_at, submitted_at) "
662
+ "VALUES (?, ?, ?, ?, 'submitted', ?, ?, ?)",
663
+ (job_id, self.provider, batch_id, job_type, len(safe), now, now),
664
+ )
665
+ self._db_conn.commit()
666
+ log.info("Batch submitted: job_id=%s batch_id=%s", job_id, batch_id)
667
+ return job_id
668
+
669
+ def poll(self, job_id: str) -> str:
670
+ """Poll provider status, update DB, return internal status string."""
671
+ import datetime
672
+ if self._db_conn is None:
673
+ return "unknown"
674
+ row = self._db_conn.execute(
675
+ "SELECT batch_id, status FROM batch_jobs WHERE id=?", (job_id,)
676
+ ).fetchone()
677
+ if not row:
678
+ return "unknown"
679
+ batch_id, current_status = row[0], row[1]
680
+ if current_status in ("completed", "failed", "cancelled"):
681
+ return current_status
682
+ provider_status = self._backend.poll(batch_id)
683
+ status = _map_status(provider_status, self.provider)
684
+ updates: dict[str, str] = {"status": status}
685
+ if status in ("completed", "failed"):
686
+ updates["completed_at"] = datetime.datetime.utcnow().isoformat()
687
+ for k, v in updates.items():
688
+ self._db_conn.execute(f"UPDATE batch_jobs SET {k}=? WHERE id=?", (v, job_id))
689
+ self._db_conn.commit()
690
+ return status
691
+
692
+ def retrieve(self, job_id: str) -> list[dict]:
693
+ """Retrieve completed batch results."""
694
+ if self._db_conn is not None:
695
+ row = self._db_conn.execute(
696
+ "SELECT batch_id FROM batch_jobs WHERE id=?", (job_id,)
697
+ ).fetchone()
698
+ if row:
699
+ return self._backend.retrieve(row[0])
700
+ return []
701
+
702
+
703
+ def _map_status(provider_status: str, provider: str) -> str:
704
+ """Normalise provider-specific status strings to internal status."""
705
+ _completed = {"completed", "succeeded", "done"}
706
+ _failed = {"failed", "error", "expired"}
707
+ _cancelled = {"cancelled", "canceled"}
708
+ if provider_status in _completed:
709
+ return "completed"
710
+ if provider_status in _failed:
711
+ return "failed"
712
+ if provider_status in _cancelled:
713
+ return "cancelled"
714
+ return "polling"
715
+