hindsight-api 0.4.6__py3-none-any.whl → 0.4.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hindsight_api/__init__.py +1 -1
- hindsight_api/alembic/versions/5a366d414dce_initial_schema.py +16 -2
- hindsight_api/api/http.py +83 -1
- hindsight_api/banner.py +3 -0
- hindsight_api/config.py +44 -6
- hindsight_api/daemon.py +18 -112
- hindsight_api/engine/llm_interface.py +146 -0
- hindsight_api/engine/llm_wrapper.py +304 -1327
- hindsight_api/engine/memory_engine.py +125 -41
- hindsight_api/engine/providers/__init__.py +14 -0
- hindsight_api/engine/providers/anthropic_llm.py +434 -0
- hindsight_api/engine/providers/claude_code_llm.py +352 -0
- hindsight_api/engine/providers/codex_llm.py +527 -0
- hindsight_api/engine/providers/gemini_llm.py +502 -0
- hindsight_api/engine/providers/mock_llm.py +234 -0
- hindsight_api/engine/providers/openai_compatible_llm.py +745 -0
- hindsight_api/engine/retain/fact_extraction.py +13 -9
- hindsight_api/engine/retain/fact_storage.py +5 -3
- hindsight_api/extensions/__init__.py +10 -0
- hindsight_api/extensions/builtin/tenant.py +36 -0
- hindsight_api/extensions/operation_validator.py +129 -0
- hindsight_api/main.py +6 -21
- hindsight_api/migrations.py +75 -0
- hindsight_api/worker/main.py +41 -11
- hindsight_api/worker/poller.py +26 -14
- {hindsight_api-0.4.6.dist-info → hindsight_api-0.4.8.dist-info}/METADATA +2 -1
- {hindsight_api-0.4.6.dist-info → hindsight_api-0.4.8.dist-info}/RECORD +29 -21
- {hindsight_api-0.4.6.dist-info → hindsight_api-0.4.8.dist-info}/WHEEL +0 -0
- {hindsight_api-0.4.6.dist-info → hindsight_api-0.4.8.dist-info}/entry_points.txt +0 -0
|
@@ -8,15 +8,14 @@ import logging
|
|
|
8
8
|
import os
|
|
9
9
|
import re
|
|
10
10
|
import time
|
|
11
|
+
import uuid
|
|
12
|
+
from pathlib import Path
|
|
11
13
|
from typing import Any
|
|
12
14
|
|
|
13
15
|
import httpx
|
|
14
|
-
from google import genai
|
|
15
|
-
from google.genai import errors as genai_errors
|
|
16
|
-
from google.genai import types as genai_types
|
|
17
16
|
from openai import APIConnectionError, APIStatusError, AsyncOpenAI, LengthFinishReasonError
|
|
18
17
|
|
|
19
|
-
# Vertex AI imports (conditional)
|
|
18
|
+
# Vertex AI imports (conditional - for LLMProvider to pass credentials to GeminiLLM)
|
|
20
19
|
try:
|
|
21
20
|
import google.auth
|
|
22
21
|
from google.oauth2 import service_account
|
|
@@ -61,6 +60,108 @@ class OutputTooLongError(Exception):
|
|
|
61
60
|
pass
|
|
62
61
|
|
|
63
62
|
|
|
63
|
+
def create_llm_provider(
|
|
64
|
+
provider: str,
|
|
65
|
+
api_key: str,
|
|
66
|
+
base_url: str,
|
|
67
|
+
model: str,
|
|
68
|
+
reasoning_effort: str,
|
|
69
|
+
groq_service_tier: str | None = None,
|
|
70
|
+
vertexai_project_id: str | None = None,
|
|
71
|
+
vertexai_region: str | None = None,
|
|
72
|
+
vertexai_credentials: Any = None,
|
|
73
|
+
) -> Any: # Returns LLMInterface
|
|
74
|
+
"""
|
|
75
|
+
Factory function to create the appropriate LLM provider implementation.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
provider: Provider name ("openai", "groq", "ollama", "gemini", "anthropic", etc.).
|
|
79
|
+
api_key: API key (may be None for local providers or OAuth providers).
|
|
80
|
+
base_url: Base URL for the API.
|
|
81
|
+
model: Model name.
|
|
82
|
+
reasoning_effort: Reasoning effort level for supported providers.
|
|
83
|
+
groq_service_tier: Groq service tier (for Groq provider).
|
|
84
|
+
vertexai_project_id: Vertex AI project ID (for VertexAI provider).
|
|
85
|
+
vertexai_region: Vertex AI region (for VertexAI provider).
|
|
86
|
+
vertexai_credentials: Vertex AI credentials object (for VertexAI provider).
|
|
87
|
+
|
|
88
|
+
Returns:
|
|
89
|
+
LLMInterface implementation for the specified provider.
|
|
90
|
+
"""
|
|
91
|
+
from .llm_interface import LLMInterface
|
|
92
|
+
from .providers import (
|
|
93
|
+
AnthropicLLM,
|
|
94
|
+
ClaudeCodeLLM,
|
|
95
|
+
CodexLLM,
|
|
96
|
+
GeminiLLM,
|
|
97
|
+
MockLLM,
|
|
98
|
+
OpenAICompatibleLLM,
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
provider_lower = provider.lower()
|
|
102
|
+
|
|
103
|
+
if provider_lower == "openai-codex":
|
|
104
|
+
return CodexLLM(
|
|
105
|
+
provider=provider,
|
|
106
|
+
api_key=api_key,
|
|
107
|
+
base_url=base_url,
|
|
108
|
+
model=model,
|
|
109
|
+
reasoning_effort=reasoning_effort,
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
elif provider_lower == "claude-code":
|
|
113
|
+
return ClaudeCodeLLM(
|
|
114
|
+
provider=provider,
|
|
115
|
+
api_key=api_key,
|
|
116
|
+
base_url=base_url,
|
|
117
|
+
model=model,
|
|
118
|
+
reasoning_effort=reasoning_effort,
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
elif provider_lower == "mock":
|
|
122
|
+
return MockLLM(
|
|
123
|
+
provider=provider,
|
|
124
|
+
api_key=api_key,
|
|
125
|
+
base_url=base_url,
|
|
126
|
+
model=model,
|
|
127
|
+
reasoning_effort=reasoning_effort,
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
elif provider_lower in ("gemini", "vertexai"):
|
|
131
|
+
return GeminiLLM(
|
|
132
|
+
provider=provider,
|
|
133
|
+
api_key=api_key,
|
|
134
|
+
base_url=base_url,
|
|
135
|
+
model=model,
|
|
136
|
+
reasoning_effort=reasoning_effort,
|
|
137
|
+
vertexai_project_id=vertexai_project_id,
|
|
138
|
+
vertexai_region=vertexai_region,
|
|
139
|
+
vertexai_credentials=vertexai_credentials,
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
elif provider_lower == "anthropic":
|
|
143
|
+
return AnthropicLLM(
|
|
144
|
+
provider=provider,
|
|
145
|
+
api_key=api_key,
|
|
146
|
+
base_url=base_url,
|
|
147
|
+
model=model,
|
|
148
|
+
reasoning_effort=reasoning_effort,
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
elif provider_lower in ("openai", "groq", "ollama", "lmstudio"):
|
|
152
|
+
return OpenAICompatibleLLM(
|
|
153
|
+
provider=provider,
|
|
154
|
+
api_key=api_key,
|
|
155
|
+
base_url=base_url,
|
|
156
|
+
model=model,
|
|
157
|
+
reasoning_effort=reasoning_effort,
|
|
158
|
+
groq_service_tier=groq_service_tier,
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
else:
|
|
162
|
+
raise ValueError(f"Unknown provider: {provider}")
|
|
163
|
+
|
|
164
|
+
|
|
64
165
|
class LLMProvider:
|
|
65
166
|
"""
|
|
66
167
|
Unified LLM provider.
|
|
@@ -97,14 +198,21 @@ class LLMProvider:
|
|
|
97
198
|
self.groq_service_tier = groq_service_tier or os.getenv(ENV_LLM_GROQ_SERVICE_TIER, "auto")
|
|
98
199
|
|
|
99
200
|
# Validate provider
|
|
100
|
-
valid_providers = [
|
|
201
|
+
valid_providers = [
|
|
202
|
+
"openai",
|
|
203
|
+
"groq",
|
|
204
|
+
"ollama",
|
|
205
|
+
"gemini",
|
|
206
|
+
"anthropic",
|
|
207
|
+
"lmstudio",
|
|
208
|
+
"vertexai",
|
|
209
|
+
"openai-codex",
|
|
210
|
+
"claude-code",
|
|
211
|
+
"mock",
|
|
212
|
+
]
|
|
101
213
|
if self.provider not in valid_providers:
|
|
102
214
|
raise ValueError(f"Invalid LLM provider: {self.provider}. Must be one of: {', '.join(valid_providers)}")
|
|
103
215
|
|
|
104
|
-
# Mock provider tracking (for testing)
|
|
105
|
-
self._mock_calls: list[dict] = []
|
|
106
|
-
self._mock_response: Any = None
|
|
107
|
-
|
|
108
216
|
# Set default base URLs
|
|
109
217
|
if not self.base_url:
|
|
110
218
|
if self.provider == "groq":
|
|
@@ -114,24 +222,24 @@ class LLMProvider:
|
|
|
114
222
|
elif self.provider == "lmstudio":
|
|
115
223
|
self.base_url = "http://localhost:1234/v1"
|
|
116
224
|
|
|
117
|
-
# Vertex AI config
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
225
|
+
# Prepare Vertex AI config (if applicable)
|
|
226
|
+
vertexai_project_id = None
|
|
227
|
+
vertexai_region = None
|
|
228
|
+
vertexai_credentials = None
|
|
121
229
|
|
|
122
230
|
if self.provider == "vertexai":
|
|
123
231
|
from ..config import get_config
|
|
124
232
|
|
|
125
233
|
config = get_config()
|
|
126
234
|
|
|
127
|
-
|
|
128
|
-
if not
|
|
235
|
+
vertexai_project_id = config.llm_vertexai_project_id
|
|
236
|
+
if not vertexai_project_id:
|
|
129
237
|
raise ValueError(
|
|
130
238
|
"HINDSIGHT_API_LLM_VERTEXAI_PROJECT_ID is required for Vertex AI provider. "
|
|
131
239
|
"Set it to your GCP project ID."
|
|
132
240
|
)
|
|
133
241
|
|
|
134
|
-
|
|
242
|
+
vertexai_region = config.llm_vertexai_region or "us-central1"
|
|
135
243
|
service_account_key = config.llm_vertexai_service_account_key
|
|
136
244
|
|
|
137
245
|
# Load explicit service account credentials if provided
|
|
@@ -141,75 +249,71 @@ class LLMProvider:
|
|
|
141
249
|
"Vertex AI service account auth requires 'google-auth' package. "
|
|
142
250
|
"Install with: pip install google-auth"
|
|
143
251
|
)
|
|
144
|
-
|
|
252
|
+
vertexai_credentials = service_account.Credentials.from_service_account_file(
|
|
145
253
|
service_account_key,
|
|
146
254
|
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
|
147
255
|
)
|
|
148
256
|
logger.info(f"Vertex AI: Using service account key: {service_account_key}")
|
|
149
257
|
|
|
150
258
|
# Strip google/ prefix from model name — native SDK uses bare names
|
|
151
|
-
# e.g. "google/gemini-2.0-flash-lite-001" -> "gemini-2.0-flash-lite-001"
|
|
152
259
|
if self.model.startswith("google/"):
|
|
153
260
|
self.model = self.model[len("google/") :]
|
|
154
261
|
|
|
155
262
|
logger.info(
|
|
156
|
-
f"Vertex AI: project={
|
|
263
|
+
f"Vertex AI: project={vertexai_project_id}, region={vertexai_region}, "
|
|
157
264
|
f"model={self.model}, auth={'service_account' if service_account_key else 'ADC'}"
|
|
158
265
|
)
|
|
159
266
|
|
|
160
|
-
#
|
|
161
|
-
|
|
162
|
-
|
|
267
|
+
# Create provider implementation using factory
|
|
268
|
+
self._provider_impl = create_llm_provider(
|
|
269
|
+
provider=self.provider,
|
|
270
|
+
api_key=self.api_key,
|
|
271
|
+
base_url=self.base_url,
|
|
272
|
+
model=self.model,
|
|
273
|
+
reasoning_effort=self.reasoning_effort,
|
|
274
|
+
groq_service_tier=self.groq_service_tier,
|
|
275
|
+
vertexai_project_id=vertexai_project_id,
|
|
276
|
+
vertexai_region=vertexai_region,
|
|
277
|
+
vertexai_credentials=vertexai_credentials,
|
|
278
|
+
)
|
|
279
|
+
|
|
280
|
+
# Backward compatibility: Keep mock provider properties
|
|
281
|
+
self._mock_calls: list[dict] = []
|
|
282
|
+
self._mock_response: Any = None
|
|
163
283
|
|
|
164
|
-
|
|
165
|
-
|
|
284
|
+
@property
|
|
285
|
+
def _client(self) -> Any:
|
|
286
|
+
"""
|
|
287
|
+
Get the OpenAI client for OpenAI-compatible providers.
|
|
166
288
|
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
self._gemini_client = None
|
|
170
|
-
self._anthropic_client = None
|
|
289
|
+
This property provides backward compatibility for code that directly accesses
|
|
290
|
+
the _client attribute (e.g., benchmarks, memory_engine).
|
|
171
291
|
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
self._gemini_client = genai.Client(**client_kwargs)
|
|
198
|
-
elif self.provider in ("ollama", "lmstudio"):
|
|
199
|
-
# Use dummy key if not provided for local
|
|
200
|
-
api_key = self.api_key or "local"
|
|
201
|
-
client_kwargs = {"api_key": api_key, "base_url": self.base_url, "max_retries": 0}
|
|
202
|
-
if self.timeout:
|
|
203
|
-
client_kwargs["timeout"] = self.timeout
|
|
204
|
-
self._client = AsyncOpenAI(**client_kwargs)
|
|
205
|
-
else:
|
|
206
|
-
# Only pass base_url if it's set (OpenAI uses default URL otherwise)
|
|
207
|
-
client_kwargs = {"api_key": self.api_key, "max_retries": 0}
|
|
208
|
-
if self.base_url:
|
|
209
|
-
client_kwargs["base_url"] = self.base_url
|
|
210
|
-
if self.timeout:
|
|
211
|
-
client_kwargs["timeout"] = self.timeout
|
|
212
|
-
self._client = AsyncOpenAI(**client_kwargs)
|
|
292
|
+
Returns:
|
|
293
|
+
AsyncOpenAI client instance for OpenAI-compatible providers, or None for other providers.
|
|
294
|
+
"""
|
|
295
|
+
from .providers.openai_compatible_llm import OpenAICompatibleLLM
|
|
296
|
+
|
|
297
|
+
if isinstance(self._provider_impl, OpenAICompatibleLLM):
|
|
298
|
+
return self._provider_impl._client
|
|
299
|
+
return None
|
|
300
|
+
|
|
301
|
+
@property
|
|
302
|
+
def _gemini_client(self) -> Any:
|
|
303
|
+
"""
|
|
304
|
+
Get the Gemini client for Gemini/VertexAI providers.
|
|
305
|
+
|
|
306
|
+
This property provides backward compatibility for code that directly accesses
|
|
307
|
+
the _gemini_client attribute.
|
|
308
|
+
|
|
309
|
+
Returns:
|
|
310
|
+
genai.Client instance for Gemini/VertexAI providers, or None for other providers.
|
|
311
|
+
"""
|
|
312
|
+
from .providers.gemini_llm import GeminiLLM
|
|
313
|
+
|
|
314
|
+
if isinstance(self._provider_impl, GeminiLLM):
|
|
315
|
+
return self._provider_impl._client
|
|
316
|
+
return None
|
|
213
317
|
|
|
214
318
|
async def verify_connection(self) -> None:
|
|
215
319
|
"""
|
|
@@ -218,21 +322,7 @@ class LLMProvider:
|
|
|
218
322
|
Raises:
|
|
219
323
|
RuntimeError: If the connection test fails.
|
|
220
324
|
"""
|
|
221
|
-
|
|
222
|
-
logger.info(
|
|
223
|
-
f"Verifying LLM: provider={self.provider}, model={self.model}, base_url={self.base_url or 'default'}..."
|
|
224
|
-
)
|
|
225
|
-
await self.call(
|
|
226
|
-
messages=[{"role": "user", "content": "Say 'ok'"}],
|
|
227
|
-
max_completion_tokens=100,
|
|
228
|
-
max_retries=2,
|
|
229
|
-
initial_backoff=0.5,
|
|
230
|
-
max_backoff=2.0,
|
|
231
|
-
)
|
|
232
|
-
# If we get here without exception, the connection is working
|
|
233
|
-
logger.info(f"LLM verified: {self.provider}/{self.model}")
|
|
234
|
-
except Exception as e:
|
|
235
|
-
raise RuntimeError(f"LLM connection verification failed for {self.provider}/{self.model}: {e}") from e
|
|
325
|
+
await self._provider_impl.verify_connection()
|
|
236
326
|
|
|
237
327
|
async def call(
|
|
238
328
|
self,
|
|
@@ -272,340 +362,32 @@ class LLMProvider:
|
|
|
272
362
|
OutputTooLongError: If output exceeds token limits.
|
|
273
363
|
Exception: Re-raises API errors after retries exhausted.
|
|
274
364
|
"""
|
|
275
|
-
semaphore_start = time.time()
|
|
276
365
|
async with _global_llm_semaphore:
|
|
277
|
-
|
|
278
|
-
|
|
366
|
+
# Delegate to provider implementation
|
|
367
|
+
result = await self._provider_impl.call(
|
|
368
|
+
messages=messages,
|
|
369
|
+
response_format=response_format,
|
|
370
|
+
max_completion_tokens=max_completion_tokens,
|
|
371
|
+
temperature=temperature,
|
|
372
|
+
scope=scope,
|
|
373
|
+
max_retries=max_retries,
|
|
374
|
+
initial_backoff=initial_backoff,
|
|
375
|
+
max_backoff=max_backoff,
|
|
376
|
+
skip_validation=skip_validation,
|
|
377
|
+
strict_schema=strict_schema,
|
|
378
|
+
return_usage=return_usage,
|
|
379
|
+
)
|
|
279
380
|
|
|
280
|
-
#
|
|
381
|
+
# Backward compatibility: Update mock call tracking for mock provider
|
|
382
|
+
# This allows existing tests using LLMProvider._mock_calls to continue working
|
|
281
383
|
if self.provider == "mock":
|
|
282
|
-
|
|
283
|
-
messages,
|
|
284
|
-
response_format,
|
|
285
|
-
scope,
|
|
286
|
-
return_usage,
|
|
287
|
-
)
|
|
288
|
-
|
|
289
|
-
# Handle Gemini and Vertex AI providers (both use native genai SDK)
|
|
290
|
-
if self.provider in ("gemini", "vertexai"):
|
|
291
|
-
return await self._call_gemini(
|
|
292
|
-
messages,
|
|
293
|
-
response_format,
|
|
294
|
-
max_retries,
|
|
295
|
-
initial_backoff,
|
|
296
|
-
max_backoff,
|
|
297
|
-
skip_validation,
|
|
298
|
-
start_time,
|
|
299
|
-
scope,
|
|
300
|
-
return_usage,
|
|
301
|
-
semaphore_wait_time,
|
|
302
|
-
)
|
|
303
|
-
|
|
304
|
-
# Handle Anthropic provider separately
|
|
305
|
-
if self.provider == "anthropic":
|
|
306
|
-
return await self._call_anthropic(
|
|
307
|
-
messages,
|
|
308
|
-
response_format,
|
|
309
|
-
max_completion_tokens,
|
|
310
|
-
max_retries,
|
|
311
|
-
initial_backoff,
|
|
312
|
-
max_backoff,
|
|
313
|
-
skip_validation,
|
|
314
|
-
start_time,
|
|
315
|
-
scope,
|
|
316
|
-
return_usage,
|
|
317
|
-
semaphore_wait_time,
|
|
318
|
-
)
|
|
384
|
+
from .providers.mock_llm import MockLLM
|
|
319
385
|
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
messages,
|
|
324
|
-
response_format,
|
|
325
|
-
max_completion_tokens,
|
|
326
|
-
temperature,
|
|
327
|
-
max_retries,
|
|
328
|
-
initial_backoff,
|
|
329
|
-
max_backoff,
|
|
330
|
-
skip_validation,
|
|
331
|
-
start_time,
|
|
332
|
-
scope,
|
|
333
|
-
return_usage,
|
|
334
|
-
semaphore_wait_time,
|
|
335
|
-
)
|
|
386
|
+
if isinstance(self._provider_impl, MockLLM):
|
|
387
|
+
# Sync the mock calls from provider implementation to wrapper
|
|
388
|
+
self._mock_calls = self._provider_impl.get_mock_calls()
|
|
336
389
|
|
|
337
|
-
|
|
338
|
-
"model": self.model,
|
|
339
|
-
"messages": messages,
|
|
340
|
-
}
|
|
341
|
-
|
|
342
|
-
# Check if model supports reasoning parameter (o1, o3, gpt-5 families)
|
|
343
|
-
model_lower = self.model.lower()
|
|
344
|
-
is_reasoning_model = any(x in model_lower for x in ["gpt-5", "o1", "o3", "deepseek"])
|
|
345
|
-
|
|
346
|
-
# For GPT-4 and GPT-4.1 models, cap max_completion_tokens to 32000
|
|
347
|
-
# For GPT-4o models, cap to 16384
|
|
348
|
-
is_gpt4_model = any(x in model_lower for x in ["gpt-4.1", "gpt-4-"])
|
|
349
|
-
is_gpt4o_model = "gpt-4o" in model_lower
|
|
350
|
-
if max_completion_tokens is not None:
|
|
351
|
-
if is_gpt4o_model and max_completion_tokens > 16384:
|
|
352
|
-
max_completion_tokens = 16384
|
|
353
|
-
elif is_gpt4_model and max_completion_tokens > 32000:
|
|
354
|
-
max_completion_tokens = 32000
|
|
355
|
-
# For reasoning models, max_completion_tokens includes reasoning + output tokens
|
|
356
|
-
# Enforce minimum of 16000 to ensure enough space for both
|
|
357
|
-
if is_reasoning_model and max_completion_tokens < 16000:
|
|
358
|
-
max_completion_tokens = 16000
|
|
359
|
-
call_params["max_completion_tokens"] = max_completion_tokens
|
|
360
|
-
|
|
361
|
-
# GPT-5/o1/o3 family doesn't support custom temperature (only default 1)
|
|
362
|
-
if temperature is not None and not is_reasoning_model:
|
|
363
|
-
call_params["temperature"] = temperature
|
|
364
|
-
|
|
365
|
-
# Set reasoning_effort for reasoning models (OpenAI gpt-5, o1, o3)
|
|
366
|
-
if is_reasoning_model:
|
|
367
|
-
call_params["reasoning_effort"] = self.reasoning_effort
|
|
368
|
-
|
|
369
|
-
# Provider-specific parameters
|
|
370
|
-
if self.provider == "groq":
|
|
371
|
-
call_params["seed"] = DEFAULT_LLM_SEED
|
|
372
|
-
extra_body: dict[str, Any] = {}
|
|
373
|
-
# Add service_tier if configured (requires paid plan for flex/auto)
|
|
374
|
-
if self.groq_service_tier:
|
|
375
|
-
extra_body["service_tier"] = self.groq_service_tier
|
|
376
|
-
# Add reasoning parameters for reasoning models
|
|
377
|
-
if is_reasoning_model:
|
|
378
|
-
extra_body["include_reasoning"] = False
|
|
379
|
-
if extra_body:
|
|
380
|
-
call_params["extra_body"] = extra_body
|
|
381
|
-
|
|
382
|
-
last_exception = None
|
|
383
|
-
|
|
384
|
-
# Prepare response format ONCE before the retry loop
|
|
385
|
-
# (to avoid appending schema to messages on every retry)
|
|
386
|
-
if response_format is not None:
|
|
387
|
-
schema = None
|
|
388
|
-
if hasattr(response_format, "model_json_schema"):
|
|
389
|
-
schema = response_format.model_json_schema()
|
|
390
|
-
|
|
391
|
-
if strict_schema and schema is not None:
|
|
392
|
-
# Use OpenAI's strict JSON schema enforcement
|
|
393
|
-
# This guarantees all required fields are returned
|
|
394
|
-
call_params["response_format"] = {
|
|
395
|
-
"type": "json_schema",
|
|
396
|
-
"json_schema": {
|
|
397
|
-
"name": "response",
|
|
398
|
-
"strict": True,
|
|
399
|
-
"schema": schema,
|
|
400
|
-
},
|
|
401
|
-
}
|
|
402
|
-
else:
|
|
403
|
-
# Soft enforcement: add schema to prompt and use json_object mode
|
|
404
|
-
if schema is not None:
|
|
405
|
-
schema_msg = f"\n\nYou must respond with valid JSON matching this schema:\n{json.dumps(schema, indent=2)}"
|
|
406
|
-
|
|
407
|
-
if call_params["messages"] and call_params["messages"][0].get("role") == "system":
|
|
408
|
-
first_msg = call_params["messages"][0]
|
|
409
|
-
if isinstance(first_msg, dict) and isinstance(first_msg.get("content"), str):
|
|
410
|
-
first_msg["content"] += schema_msg
|
|
411
|
-
elif call_params["messages"]:
|
|
412
|
-
first_msg = call_params["messages"][0]
|
|
413
|
-
if isinstance(first_msg, dict) and isinstance(first_msg.get("content"), str):
|
|
414
|
-
first_msg["content"] = schema_msg + "\n\n" + first_msg["content"]
|
|
415
|
-
if self.provider not in ("lmstudio", "ollama"):
|
|
416
|
-
# LM Studio and Ollama don't support json_object response format reliably
|
|
417
|
-
# We rely on the schema in the system message instead
|
|
418
|
-
call_params["response_format"] = {"type": "json_object"}
|
|
419
|
-
|
|
420
|
-
for attempt in range(max_retries + 1):
|
|
421
|
-
try:
|
|
422
|
-
if response_format is not None:
|
|
423
|
-
response = await self._client.chat.completions.create(**call_params)
|
|
424
|
-
|
|
425
|
-
content = response.choices[0].message.content
|
|
426
|
-
|
|
427
|
-
# Strip reasoning model thinking tags
|
|
428
|
-
# Supports: <think>, <thinking>, <reasoning>, |startthink|/|endthink|
|
|
429
|
-
# for reasoning models that embed thinking in their output (e.g., Qwen3, DeepSeek)
|
|
430
|
-
if content:
|
|
431
|
-
original_len = len(content)
|
|
432
|
-
content = re.sub(r"<think>.*?</think>", "", content, flags=re.DOTALL)
|
|
433
|
-
content = re.sub(r"<thinking>.*?</thinking>", "", content, flags=re.DOTALL)
|
|
434
|
-
content = re.sub(r"<reasoning>.*?</reasoning>", "", content, flags=re.DOTALL)
|
|
435
|
-
content = re.sub(r"\|startthink\|.*?\|endthink\|", "", content, flags=re.DOTALL)
|
|
436
|
-
content = content.strip()
|
|
437
|
-
if len(content) < original_len:
|
|
438
|
-
logger.debug(f"Stripped {original_len - len(content)} chars of reasoning tokens")
|
|
439
|
-
|
|
440
|
-
# For local models, they may wrap JSON in markdown code blocks
|
|
441
|
-
if self.provider in ("lmstudio", "ollama"):
|
|
442
|
-
clean_content = content
|
|
443
|
-
if "```json" in content:
|
|
444
|
-
clean_content = content.split("```json")[1].split("```")[0].strip()
|
|
445
|
-
elif "```" in content:
|
|
446
|
-
clean_content = content.split("```")[1].split("```")[0].strip()
|
|
447
|
-
try:
|
|
448
|
-
json_data = json.loads(clean_content)
|
|
449
|
-
except json.JSONDecodeError:
|
|
450
|
-
# Fallback to parsing raw content
|
|
451
|
-
json_data = json.loads(content)
|
|
452
|
-
else:
|
|
453
|
-
# Log raw LLM response for debugging JSON parse issues
|
|
454
|
-
try:
|
|
455
|
-
json_data = json.loads(content)
|
|
456
|
-
except json.JSONDecodeError as json_err:
|
|
457
|
-
# Truncate content for logging (first 500 and last 200 chars)
|
|
458
|
-
content_preview = content[:500] if content else "<empty>"
|
|
459
|
-
if content and len(content) > 700:
|
|
460
|
-
content_preview = f"{content[:500]}...TRUNCATED...{content[-200:]}"
|
|
461
|
-
logger.warning(
|
|
462
|
-
f"JSON parse error from LLM response (attempt {attempt + 1}/{max_retries + 1}): {json_err}\n"
|
|
463
|
-
f" Model: {self.provider}/{self.model}\n"
|
|
464
|
-
f" Content length: {len(content) if content else 0} chars\n"
|
|
465
|
-
f" Content preview: {content_preview!r}\n"
|
|
466
|
-
f" Finish reason: {response.choices[0].finish_reason if response.choices else 'unknown'}"
|
|
467
|
-
)
|
|
468
|
-
# Retry on JSON parse errors - LLM may return valid JSON on next attempt
|
|
469
|
-
if attempt < max_retries:
|
|
470
|
-
backoff = min(initial_backoff * (2**attempt), max_backoff)
|
|
471
|
-
await asyncio.sleep(backoff)
|
|
472
|
-
last_exception = json_err
|
|
473
|
-
continue
|
|
474
|
-
else:
|
|
475
|
-
logger.error(f"JSON parse error after {max_retries + 1} attempts, giving up")
|
|
476
|
-
raise
|
|
477
|
-
|
|
478
|
-
if skip_validation:
|
|
479
|
-
result = json_data
|
|
480
|
-
else:
|
|
481
|
-
result = response_format.model_validate(json_data)
|
|
482
|
-
else:
|
|
483
|
-
response = await self._client.chat.completions.create(**call_params)
|
|
484
|
-
result = response.choices[0].message.content
|
|
485
|
-
|
|
486
|
-
# Record token usage metrics
|
|
487
|
-
duration = time.time() - start_time
|
|
488
|
-
usage = response.usage
|
|
489
|
-
input_tokens = usage.prompt_tokens or 0 if usage else 0
|
|
490
|
-
output_tokens = usage.completion_tokens or 0 if usage else 0
|
|
491
|
-
total_tokens = usage.total_tokens or 0 if usage else 0
|
|
492
|
-
|
|
493
|
-
# Record LLM metrics
|
|
494
|
-
metrics = get_metrics_collector()
|
|
495
|
-
metrics.record_llm_call(
|
|
496
|
-
provider=self.provider,
|
|
497
|
-
model=self.model,
|
|
498
|
-
scope=scope,
|
|
499
|
-
duration=duration,
|
|
500
|
-
input_tokens=input_tokens,
|
|
501
|
-
output_tokens=output_tokens,
|
|
502
|
-
success=True,
|
|
503
|
-
)
|
|
504
|
-
|
|
505
|
-
# Log slow calls
|
|
506
|
-
if duration > 10.0 and usage:
|
|
507
|
-
ratio = max(1, output_tokens) / max(1, input_tokens)
|
|
508
|
-
cached_tokens = 0
|
|
509
|
-
if hasattr(usage, "prompt_tokens_details") and usage.prompt_tokens_details:
|
|
510
|
-
cached_tokens = getattr(usage.prompt_tokens_details, "cached_tokens", 0) or 0
|
|
511
|
-
cache_info = f", cached_tokens={cached_tokens}" if cached_tokens > 0 else ""
|
|
512
|
-
wait_info = f", wait={semaphore_wait_time:.3f}s" if semaphore_wait_time > 0.1 else ""
|
|
513
|
-
logger.info(
|
|
514
|
-
f"slow llm call: scope={scope}, model={self.provider}/{self.model}, "
|
|
515
|
-
f"input_tokens={input_tokens}, output_tokens={output_tokens}, "
|
|
516
|
-
f"total_tokens={total_tokens}{cache_info}, time={duration:.3f}s{wait_info}, ratio out/in={ratio:.2f}"
|
|
517
|
-
)
|
|
518
|
-
|
|
519
|
-
if return_usage:
|
|
520
|
-
token_usage = TokenUsage(
|
|
521
|
-
input_tokens=input_tokens,
|
|
522
|
-
output_tokens=output_tokens,
|
|
523
|
-
total_tokens=total_tokens,
|
|
524
|
-
)
|
|
525
|
-
return result, token_usage
|
|
526
|
-
return result
|
|
527
|
-
|
|
528
|
-
except LengthFinishReasonError as e:
|
|
529
|
-
logger.warning(f"LLM output exceeded token limits: {str(e)}")
|
|
530
|
-
raise OutputTooLongError(
|
|
531
|
-
"LLM output exceeded token limits. Input may need to be split into smaller chunks."
|
|
532
|
-
) from e
|
|
533
|
-
|
|
534
|
-
except APIConnectionError as e:
|
|
535
|
-
last_exception = e
|
|
536
|
-
status_code = getattr(e, "status_code", None) or getattr(
|
|
537
|
-
getattr(e, "response", None), "status_code", None
|
|
538
|
-
)
|
|
539
|
-
logger.warning(f"APIConnectionError (HTTP {status_code}), attempt {attempt + 1}: {str(e)[:200]}")
|
|
540
|
-
if attempt < max_retries:
|
|
541
|
-
backoff = min(initial_backoff * (2**attempt), max_backoff)
|
|
542
|
-
await asyncio.sleep(backoff)
|
|
543
|
-
continue
|
|
544
|
-
else:
|
|
545
|
-
logger.error(f"Connection error after {max_retries + 1} attempts: {str(e)}")
|
|
546
|
-
raise
|
|
547
|
-
|
|
548
|
-
except APIStatusError as e:
|
|
549
|
-
# Fast fail only on 401 (unauthorized) and 403 (forbidden) - these won't recover with retries
|
|
550
|
-
if e.status_code in (401, 403):
|
|
551
|
-
logger.error(f"Auth error (HTTP {e.status_code}), not retrying: {str(e)}")
|
|
552
|
-
raise
|
|
553
|
-
|
|
554
|
-
# Handle tool_use_failed error - model outputted in tool call format
|
|
555
|
-
# Convert to expected JSON format and continue
|
|
556
|
-
if e.status_code == 400 and response_format is not None:
|
|
557
|
-
try:
|
|
558
|
-
error_body = e.body if hasattr(e, "body") else {}
|
|
559
|
-
if isinstance(error_body, dict):
|
|
560
|
-
error_info: dict[str, Any] = error_body.get("error") or {}
|
|
561
|
-
if error_info.get("code") == "tool_use_failed":
|
|
562
|
-
failed_gen = error_info.get("failed_generation", "")
|
|
563
|
-
if failed_gen:
|
|
564
|
-
# Parse the tool call format and convert to actions format
|
|
565
|
-
tool_call = json.loads(failed_gen)
|
|
566
|
-
tool_name = tool_call.get("name", "")
|
|
567
|
-
tool_args = tool_call.get("arguments", {})
|
|
568
|
-
# Convert to actions format: {"actions": [{"tool": "name", ...args}]}
|
|
569
|
-
converted = {"actions": [{"tool": tool_name, **tool_args}]}
|
|
570
|
-
if skip_validation:
|
|
571
|
-
result = converted
|
|
572
|
-
else:
|
|
573
|
-
result = response_format.model_validate(converted)
|
|
574
|
-
|
|
575
|
-
# Record metrics for this successful recovery
|
|
576
|
-
duration = time.time() - start_time
|
|
577
|
-
metrics = get_metrics_collector()
|
|
578
|
-
metrics.record_llm_call(
|
|
579
|
-
provider=self.provider,
|
|
580
|
-
model=self.model,
|
|
581
|
-
scope=scope,
|
|
582
|
-
duration=duration,
|
|
583
|
-
input_tokens=0,
|
|
584
|
-
output_tokens=0,
|
|
585
|
-
success=True,
|
|
586
|
-
)
|
|
587
|
-
if return_usage:
|
|
588
|
-
return result, TokenUsage(input_tokens=0, output_tokens=0, total_tokens=0)
|
|
589
|
-
return result
|
|
590
|
-
except (json.JSONDecodeError, KeyError, TypeError):
|
|
591
|
-
pass # Failed to parse tool_use_failed, continue with normal retry
|
|
592
|
-
|
|
593
|
-
last_exception = e
|
|
594
|
-
if attempt < max_retries:
|
|
595
|
-
backoff = min(initial_backoff * (2**attempt), max_backoff)
|
|
596
|
-
jitter = backoff * 0.2 * (2 * (time.time() % 1) - 1)
|
|
597
|
-
sleep_time = backoff + jitter
|
|
598
|
-
await asyncio.sleep(sleep_time)
|
|
599
|
-
else:
|
|
600
|
-
logger.error(f"API error after {max_retries + 1} attempts: {str(e)}")
|
|
601
|
-
raise
|
|
602
|
-
|
|
603
|
-
except Exception:
|
|
604
|
-
raise
|
|
605
|
-
|
|
606
|
-
if last_exception:
|
|
607
|
-
raise last_exception
|
|
608
|
-
raise RuntimeError("LLM call failed after all retries with no exception captured")
|
|
390
|
+
return result
|
|
609
391
|
|
|
610
392
|
async def call_with_tools(
|
|
611
393
|
self,
|
|
@@ -636,940 +418,122 @@ class LLMProvider:
|
|
|
636
418
|
Returns:
|
|
637
419
|
LLMToolCallResult with content and/or tool_calls.
|
|
638
420
|
"""
|
|
639
|
-
from .response_models import LLMToolCall, LLMToolCallResult
|
|
640
|
-
|
|
641
421
|
async with _global_llm_semaphore:
|
|
642
|
-
|
|
422
|
+
# Delegate to provider implementation
|
|
423
|
+
result = await self._provider_impl.call_with_tools(
|
|
424
|
+
messages=messages,
|
|
425
|
+
tools=tools,
|
|
426
|
+
max_completion_tokens=max_completion_tokens,
|
|
427
|
+
temperature=temperature,
|
|
428
|
+
scope=scope,
|
|
429
|
+
max_retries=max_retries,
|
|
430
|
+
initial_backoff=initial_backoff,
|
|
431
|
+
max_backoff=max_backoff,
|
|
432
|
+
tool_choice=tool_choice,
|
|
433
|
+
)
|
|
643
434
|
|
|
644
|
-
#
|
|
435
|
+
# Backward compatibility: Update mock call tracking for mock provider
|
|
436
|
+
# This allows existing tests using LLMProvider._mock_calls to continue working
|
|
645
437
|
if self.provider == "mock":
|
|
646
|
-
|
|
647
|
-
|
|
648
|
-
# Handle Anthropic separately (uses different tool format)
|
|
649
|
-
if self.provider == "anthropic":
|
|
650
|
-
return await self._call_with_tools_anthropic(
|
|
651
|
-
messages, tools, max_completion_tokens, max_retries, initial_backoff, max_backoff, start_time, scope
|
|
652
|
-
)
|
|
653
|
-
|
|
654
|
-
# Handle Gemini and Vertex AI (convert to Gemini tool format)
|
|
655
|
-
if self.provider in ("gemini", "vertexai"):
|
|
656
|
-
return await self._call_with_tools_gemini(
|
|
657
|
-
messages, tools, max_retries, initial_backoff, max_backoff, start_time, scope
|
|
658
|
-
)
|
|
438
|
+
from .providers.mock_llm import MockLLM
|
|
659
439
|
|
|
660
|
-
|
|
661
|
-
|
|
662
|
-
|
|
663
|
-
"messages": messages,
|
|
664
|
-
"tools": tools,
|
|
665
|
-
"tool_choice": tool_choice,
|
|
666
|
-
}
|
|
440
|
+
if isinstance(self._provider_impl, MockLLM):
|
|
441
|
+
# Sync the mock calls from provider implementation to wrapper
|
|
442
|
+
self._mock_calls = self._provider_impl.get_mock_calls()
|
|
667
443
|
|
|
668
|
-
|
|
669
|
-
call_params["max_completion_tokens"] = max_completion_tokens
|
|
670
|
-
if temperature is not None:
|
|
671
|
-
call_params["temperature"] = temperature
|
|
444
|
+
return result
|
|
672
445
|
|
|
673
|
-
|
|
674
|
-
|
|
675
|
-
|
|
676
|
-
|
|
677
|
-
|
|
678
|
-
|
|
679
|
-
for attempt in range(max_retries + 1):
|
|
680
|
-
try:
|
|
681
|
-
response = await self._client.chat.completions.create(**call_params)
|
|
682
|
-
|
|
683
|
-
message = response.choices[0].message
|
|
684
|
-
finish_reason = response.choices[0].finish_reason
|
|
685
|
-
|
|
686
|
-
# Extract tool calls if present
|
|
687
|
-
tool_calls: list[LLMToolCall] = []
|
|
688
|
-
if message.tool_calls:
|
|
689
|
-
for tc in message.tool_calls:
|
|
690
|
-
try:
|
|
691
|
-
args = json.loads(tc.function.arguments) if tc.function.arguments else {}
|
|
692
|
-
except json.JSONDecodeError:
|
|
693
|
-
args = {"_raw": tc.function.arguments}
|
|
694
|
-
tool_calls.append(LLMToolCall(id=tc.id, name=tc.function.name, arguments=args))
|
|
695
|
-
|
|
696
|
-
content = message.content
|
|
697
|
-
|
|
698
|
-
# Record metrics
|
|
699
|
-
duration = time.time() - start_time
|
|
700
|
-
usage = response.usage
|
|
701
|
-
input_tokens = usage.prompt_tokens or 0 if usage else 0
|
|
702
|
-
output_tokens = usage.completion_tokens or 0 if usage else 0
|
|
703
|
-
|
|
704
|
-
metrics = get_metrics_collector()
|
|
705
|
-
metrics.record_llm_call(
|
|
706
|
-
provider=self.provider,
|
|
707
|
-
model=self.model,
|
|
708
|
-
scope=scope,
|
|
709
|
-
duration=duration,
|
|
710
|
-
input_tokens=input_tokens,
|
|
711
|
-
output_tokens=output_tokens,
|
|
712
|
-
success=True,
|
|
713
|
-
)
|
|
446
|
+
def set_mock_response(self, response: Any) -> None:
|
|
447
|
+
"""Set the response to return from mock calls."""
|
|
448
|
+
# Backward compatibility: Store in both wrapper and provider implementation
|
|
449
|
+
self._mock_response = response
|
|
450
|
+
if self.provider == "mock":
|
|
451
|
+
from .providers.mock_llm import MockLLM
|
|
714
452
|
|
|
715
|
-
|
|
716
|
-
|
|
717
|
-
tool_calls=tool_calls,
|
|
718
|
-
finish_reason=finish_reason,
|
|
719
|
-
input_tokens=input_tokens,
|
|
720
|
-
output_tokens=output_tokens,
|
|
721
|
-
)
|
|
453
|
+
if isinstance(self._provider_impl, MockLLM):
|
|
454
|
+
self._provider_impl.set_mock_response(response)
|
|
722
455
|
|
|
723
|
-
|
|
724
|
-
|
|
725
|
-
|
|
726
|
-
|
|
727
|
-
|
|
728
|
-
raise
|
|
729
|
-
|
|
730
|
-
except APIStatusError as e:
|
|
731
|
-
if e.status_code in (401, 403):
|
|
732
|
-
raise
|
|
733
|
-
last_exception = e
|
|
734
|
-
if attempt < max_retries:
|
|
735
|
-
await asyncio.sleep(min(initial_backoff * (2**attempt), max_backoff))
|
|
736
|
-
continue
|
|
737
|
-
raise
|
|
738
|
-
|
|
739
|
-
except Exception:
|
|
740
|
-
raise
|
|
741
|
-
|
|
742
|
-
if last_exception:
|
|
743
|
-
raise last_exception
|
|
744
|
-
raise RuntimeError("Tool call failed after all retries")
|
|
745
|
-
|
|
746
|
-
async def _call_with_tools_mock(
|
|
747
|
-
self,
|
|
748
|
-
messages: list[dict[str, Any]],
|
|
749
|
-
tools: list[dict[str, Any]],
|
|
750
|
-
scope: str,
|
|
751
|
-
) -> "LLMToolCallResult":
|
|
752
|
-
"""Handle mock tool calls for testing."""
|
|
753
|
-
from .response_models import LLMToolCallResult
|
|
754
|
-
|
|
755
|
-
call_record = {
|
|
756
|
-
"provider": self.provider,
|
|
757
|
-
"model": self.model,
|
|
758
|
-
"messages": messages,
|
|
759
|
-
"tools": [t.get("function", {}).get("name") for t in tools],
|
|
760
|
-
"scope": scope,
|
|
761
|
-
}
|
|
762
|
-
self._mock_calls.append(call_record)
|
|
763
|
-
|
|
764
|
-
if self._mock_response is not None:
|
|
765
|
-
if isinstance(self._mock_response, LLMToolCallResult):
|
|
766
|
-
return self._mock_response
|
|
767
|
-
# Allow setting just tool calls as a list
|
|
768
|
-
if isinstance(self._mock_response, list):
|
|
769
|
-
from .response_models import LLMToolCall
|
|
770
|
-
|
|
771
|
-
return LLMToolCallResult(
|
|
772
|
-
tool_calls=[
|
|
773
|
-
LLMToolCall(id=f"mock_{i}", name=tc["name"], arguments=tc.get("arguments", {}))
|
|
774
|
-
for i, tc in enumerate(self._mock_response)
|
|
775
|
-
],
|
|
776
|
-
finish_reason="tool_calls",
|
|
777
|
-
)
|
|
456
|
+
def get_mock_calls(self) -> list[dict]:
|
|
457
|
+
"""Get the list of recorded mock calls."""
|
|
458
|
+
# Backward compatibility: Read from provider implementation if mock provider
|
|
459
|
+
if self.provider == "mock":
|
|
460
|
+
from .providers.mock_llm import MockLLM
|
|
778
461
|
|
|
779
|
-
|
|
462
|
+
if isinstance(self._provider_impl, MockLLM):
|
|
463
|
+
return self._provider_impl.get_mock_calls()
|
|
464
|
+
return self._mock_calls
|
|
780
465
|
|
|
781
|
-
|
|
782
|
-
|
|
783
|
-
|
|
784
|
-
|
|
785
|
-
|
|
786
|
-
|
|
787
|
-
initial_backoff: float,
|
|
788
|
-
max_backoff: float,
|
|
789
|
-
start_time: float,
|
|
790
|
-
scope: str,
|
|
791
|
-
) -> "LLMToolCallResult":
|
|
792
|
-
"""Handle Anthropic tool calling."""
|
|
793
|
-
from anthropic import APIConnectionError, APIStatusError
|
|
794
|
-
|
|
795
|
-
from .response_models import LLMToolCall, LLMToolCallResult
|
|
796
|
-
|
|
797
|
-
# Convert OpenAI tool format to Anthropic format
|
|
798
|
-
anthropic_tools = []
|
|
799
|
-
for tool in tools:
|
|
800
|
-
func = tool.get("function", {})
|
|
801
|
-
anthropic_tools.append(
|
|
802
|
-
{
|
|
803
|
-
"name": func.get("name", ""),
|
|
804
|
-
"description": func.get("description", ""),
|
|
805
|
-
"input_schema": func.get("parameters", {"type": "object", "properties": {}}),
|
|
806
|
-
}
|
|
807
|
-
)
|
|
466
|
+
def clear_mock_calls(self) -> None:
|
|
467
|
+
"""Clear the recorded mock calls."""
|
|
468
|
+
# Backward compatibility: Clear in both wrapper and provider implementation
|
|
469
|
+
self._mock_calls = []
|
|
470
|
+
if self.provider == "mock":
|
|
471
|
+
from .providers.mock_llm import MockLLM
|
|
808
472
|
|
|
809
|
-
|
|
810
|
-
|
|
811
|
-
anthropic_messages = []
|
|
812
|
-
for msg in messages:
|
|
813
|
-
role = msg.get("role", "user")
|
|
814
|
-
content = msg.get("content", "")
|
|
815
|
-
|
|
816
|
-
if role == "system":
|
|
817
|
-
system_prompt = (system_prompt + "\n\n" + content) if system_prompt else content
|
|
818
|
-
elif role == "tool":
|
|
819
|
-
# Anthropic uses tool_result blocks
|
|
820
|
-
anthropic_messages.append(
|
|
821
|
-
{
|
|
822
|
-
"role": "user",
|
|
823
|
-
"content": [
|
|
824
|
-
{"type": "tool_result", "tool_use_id": msg.get("tool_call_id", ""), "content": content}
|
|
825
|
-
],
|
|
826
|
-
}
|
|
827
|
-
)
|
|
828
|
-
elif role == "assistant" and msg.get("tool_calls"):
|
|
829
|
-
# Convert assistant tool calls
|
|
830
|
-
tool_use_blocks = []
|
|
831
|
-
for tc in msg["tool_calls"]:
|
|
832
|
-
tool_use_blocks.append(
|
|
833
|
-
{
|
|
834
|
-
"type": "tool_use",
|
|
835
|
-
"id": tc.get("id", ""),
|
|
836
|
-
"name": tc.get("function", {}).get("name", ""),
|
|
837
|
-
"input": json.loads(tc.get("function", {}).get("arguments", "{}")),
|
|
838
|
-
}
|
|
839
|
-
)
|
|
840
|
-
anthropic_messages.append({"role": "assistant", "content": tool_use_blocks})
|
|
841
|
-
else:
|
|
842
|
-
anthropic_messages.append({"role": role, "content": content})
|
|
843
|
-
|
|
844
|
-
call_params: dict[str, Any] = {
|
|
845
|
-
"model": self.model,
|
|
846
|
-
"messages": anthropic_messages,
|
|
847
|
-
"tools": anthropic_tools,
|
|
848
|
-
"max_tokens": max_completion_tokens or 4096,
|
|
849
|
-
}
|
|
850
|
-
if system_prompt:
|
|
851
|
-
call_params["system"] = system_prompt
|
|
852
|
-
|
|
853
|
-
last_exception = None
|
|
854
|
-
for attempt in range(max_retries + 1):
|
|
855
|
-
try:
|
|
856
|
-
response = await self._anthropic_client.messages.create(**call_params)
|
|
857
|
-
|
|
858
|
-
# Extract content and tool calls
|
|
859
|
-
content_parts = []
|
|
860
|
-
tool_calls: list[LLMToolCall] = []
|
|
861
|
-
|
|
862
|
-
for block in response.content:
|
|
863
|
-
if block.type == "text":
|
|
864
|
-
content_parts.append(block.text)
|
|
865
|
-
elif block.type == "tool_use":
|
|
866
|
-
tool_calls.append(LLMToolCall(id=block.id, name=block.name, arguments=block.input or {}))
|
|
867
|
-
|
|
868
|
-
content = "".join(content_parts) if content_parts else None
|
|
869
|
-
finish_reason = "tool_calls" if tool_calls else "stop"
|
|
870
|
-
|
|
871
|
-
# Extract token usage
|
|
872
|
-
input_tokens = response.usage.input_tokens or 0
|
|
873
|
-
output_tokens = response.usage.output_tokens or 0
|
|
874
|
-
|
|
875
|
-
# Record metrics
|
|
876
|
-
metrics = get_metrics_collector()
|
|
877
|
-
metrics.record_llm_call(
|
|
878
|
-
provider=self.provider,
|
|
879
|
-
model=self.model,
|
|
880
|
-
scope=scope,
|
|
881
|
-
duration=time.time() - start_time,
|
|
882
|
-
input_tokens=input_tokens,
|
|
883
|
-
output_tokens=output_tokens,
|
|
884
|
-
success=True,
|
|
885
|
-
)
|
|
473
|
+
if isinstance(self._provider_impl, MockLLM):
|
|
474
|
+
self._provider_impl.clear_mock_calls()
|
|
886
475
|
|
|
887
|
-
|
|
888
|
-
|
|
889
|
-
|
|
890
|
-
finish_reason=finish_reason,
|
|
891
|
-
input_tokens=input_tokens,
|
|
892
|
-
output_tokens=output_tokens,
|
|
893
|
-
)
|
|
476
|
+
def _load_codex_auth(self) -> tuple[str, str]:
|
|
477
|
+
"""
|
|
478
|
+
Load OAuth credentials from ~/.codex/auth.json.
|
|
894
479
|
|
|
895
|
-
|
|
896
|
-
|
|
897
|
-
raise
|
|
898
|
-
last_exception = e
|
|
899
|
-
if attempt < max_retries:
|
|
900
|
-
await asyncio.sleep(min(initial_backoff * (2**attempt), max_backoff))
|
|
901
|
-
continue
|
|
902
|
-
raise
|
|
480
|
+
Returns:
|
|
481
|
+
Tuple of (access_token, account_id).
|
|
903
482
|
|
|
904
|
-
|
|
905
|
-
|
|
906
|
-
|
|
483
|
+
Raises:
|
|
484
|
+
FileNotFoundError: If auth file doesn't exist.
|
|
485
|
+
ValueError: If auth file is invalid.
|
|
486
|
+
"""
|
|
487
|
+
auth_file = Path.home() / ".codex" / "auth.json"
|
|
907
488
|
|
|
908
|
-
|
|
909
|
-
|
|
910
|
-
|
|
911
|
-
tools: list[dict[str, Any]],
|
|
912
|
-
max_retries: int,
|
|
913
|
-
initial_backoff: float,
|
|
914
|
-
max_backoff: float,
|
|
915
|
-
start_time: float,
|
|
916
|
-
scope: str,
|
|
917
|
-
) -> "LLMToolCallResult":
|
|
918
|
-
"""Handle Gemini tool calling."""
|
|
919
|
-
from .response_models import LLMToolCall, LLMToolCallResult
|
|
920
|
-
|
|
921
|
-
# Convert tools to Gemini format
|
|
922
|
-
gemini_tools = []
|
|
923
|
-
for tool in tools:
|
|
924
|
-
func = tool.get("function", {})
|
|
925
|
-
gemini_tools.append(
|
|
926
|
-
genai_types.Tool(
|
|
927
|
-
function_declarations=[
|
|
928
|
-
genai_types.FunctionDeclaration(
|
|
929
|
-
name=func.get("name", ""),
|
|
930
|
-
description=func.get("description", ""),
|
|
931
|
-
parameters=func.get("parameters"),
|
|
932
|
-
)
|
|
933
|
-
]
|
|
934
|
-
)
|
|
489
|
+
if not auth_file.exists():
|
|
490
|
+
raise FileNotFoundError(
|
|
491
|
+
f"Codex auth file not found: {auth_file}\nRun 'codex auth login' to authenticate with ChatGPT Plus/Pro."
|
|
935
492
|
)
|
|
936
493
|
|
|
937
|
-
|
|
938
|
-
|
|
939
|
-
gemini_contents = []
|
|
940
|
-
for msg in messages:
|
|
941
|
-
role = msg.get("role", "user")
|
|
942
|
-
content = msg.get("content", "")
|
|
943
|
-
|
|
944
|
-
if role == "system":
|
|
945
|
-
system_instruction = (system_instruction + "\n\n" + content) if system_instruction else content
|
|
946
|
-
elif role == "tool":
|
|
947
|
-
# Gemini uses function_response
|
|
948
|
-
gemini_contents.append(
|
|
949
|
-
genai_types.Content(
|
|
950
|
-
role="user",
|
|
951
|
-
parts=[
|
|
952
|
-
genai_types.Part(
|
|
953
|
-
function_response=genai_types.FunctionResponse(
|
|
954
|
-
name=msg.get("name", ""),
|
|
955
|
-
response={"result": content},
|
|
956
|
-
)
|
|
957
|
-
)
|
|
958
|
-
],
|
|
959
|
-
)
|
|
960
|
-
)
|
|
961
|
-
elif role == "assistant":
|
|
962
|
-
gemini_contents.append(genai_types.Content(role="model", parts=[genai_types.Part(text=content)]))
|
|
963
|
-
else:
|
|
964
|
-
gemini_contents.append(genai_types.Content(role="user", parts=[genai_types.Part(text=content)]))
|
|
965
|
-
|
|
966
|
-
config = genai_types.GenerateContentConfig(
|
|
967
|
-
system_instruction=system_instruction,
|
|
968
|
-
tools=gemini_tools,
|
|
969
|
-
)
|
|
970
|
-
|
|
971
|
-
last_exception = None
|
|
972
|
-
for attempt in range(max_retries + 1):
|
|
973
|
-
try:
|
|
974
|
-
response = await self._gemini_client.aio.models.generate_content(
|
|
975
|
-
model=self.model,
|
|
976
|
-
contents=gemini_contents,
|
|
977
|
-
config=config,
|
|
978
|
-
)
|
|
979
|
-
|
|
980
|
-
# Extract content and tool calls
|
|
981
|
-
content = None
|
|
982
|
-
tool_calls: list[LLMToolCall] = []
|
|
983
|
-
|
|
984
|
-
if response.candidates and response.candidates[0].content:
|
|
985
|
-
parts = response.candidates[0].content.parts
|
|
986
|
-
if parts:
|
|
987
|
-
for part in parts:
|
|
988
|
-
if hasattr(part, "text") and part.text:
|
|
989
|
-
content = part.text
|
|
990
|
-
if hasattr(part, "function_call") and part.function_call:
|
|
991
|
-
fc = part.function_call
|
|
992
|
-
tool_calls.append(
|
|
993
|
-
LLMToolCall(
|
|
994
|
-
id=f"gemini_{len(tool_calls)}",
|
|
995
|
-
name=fc.name,
|
|
996
|
-
arguments=dict(fc.args) if fc.args else {},
|
|
997
|
-
)
|
|
998
|
-
)
|
|
999
|
-
|
|
1000
|
-
finish_reason = "tool_calls" if tool_calls else "stop"
|
|
1001
|
-
|
|
1002
|
-
# Record metrics
|
|
1003
|
-
metrics = get_metrics_collector()
|
|
1004
|
-
input_tokens = response.usage_metadata.prompt_token_count if response.usage_metadata else 0
|
|
1005
|
-
output_tokens = response.usage_metadata.candidates_token_count if response.usage_metadata else 0
|
|
1006
|
-
metrics.record_llm_call(
|
|
1007
|
-
provider=self.provider,
|
|
1008
|
-
model=self.model,
|
|
1009
|
-
scope=scope,
|
|
1010
|
-
duration=time.time() - start_time,
|
|
1011
|
-
input_tokens=input_tokens,
|
|
1012
|
-
output_tokens=output_tokens,
|
|
1013
|
-
success=True,
|
|
1014
|
-
)
|
|
1015
|
-
|
|
1016
|
-
return LLMToolCallResult(
|
|
1017
|
-
content=content,
|
|
1018
|
-
tool_calls=tool_calls,
|
|
1019
|
-
finish_reason=finish_reason,
|
|
1020
|
-
input_tokens=input_tokens,
|
|
1021
|
-
output_tokens=output_tokens,
|
|
1022
|
-
)
|
|
1023
|
-
|
|
1024
|
-
except genai_errors.APIError as e:
|
|
1025
|
-
if e.code in (401, 403):
|
|
1026
|
-
raise
|
|
1027
|
-
last_exception = e
|
|
1028
|
-
if attempt < max_retries:
|
|
1029
|
-
await asyncio.sleep(min(initial_backoff * (2**attempt), max_backoff))
|
|
1030
|
-
continue
|
|
1031
|
-
raise
|
|
1032
|
-
|
|
1033
|
-
if last_exception:
|
|
1034
|
-
raise last_exception
|
|
1035
|
-
raise RuntimeError("Gemini tool call failed")
|
|
1036
|
-
|
|
1037
|
-
async def _call_anthropic(
|
|
1038
|
-
self,
|
|
1039
|
-
messages: list[dict[str, str]],
|
|
1040
|
-
response_format: Any | None,
|
|
1041
|
-
max_completion_tokens: int | None,
|
|
1042
|
-
max_retries: int,
|
|
1043
|
-
initial_backoff: float,
|
|
1044
|
-
max_backoff: float,
|
|
1045
|
-
skip_validation: bool,
|
|
1046
|
-
start_time: float,
|
|
1047
|
-
scope: str = "memory",
|
|
1048
|
-
return_usage: bool = False,
|
|
1049
|
-
semaphore_wait_time: float = 0.0,
|
|
1050
|
-
) -> Any:
|
|
1051
|
-
"""Handle Anthropic-specific API calls."""
|
|
1052
|
-
from anthropic import APIConnectionError, APIStatusError, RateLimitError
|
|
1053
|
-
|
|
1054
|
-
# Convert OpenAI-style messages to Anthropic format
|
|
1055
|
-
system_prompt = None
|
|
1056
|
-
anthropic_messages = []
|
|
1057
|
-
|
|
1058
|
-
for msg in messages:
|
|
1059
|
-
role = msg.get("role", "user")
|
|
1060
|
-
content = msg.get("content", "")
|
|
1061
|
-
|
|
1062
|
-
if role == "system":
|
|
1063
|
-
if system_prompt:
|
|
1064
|
-
system_prompt += "\n\n" + content
|
|
1065
|
-
else:
|
|
1066
|
-
system_prompt = content
|
|
1067
|
-
else:
|
|
1068
|
-
anthropic_messages.append({"role": role, "content": content})
|
|
1069
|
-
|
|
1070
|
-
# Add JSON schema instruction if response_format is provided
|
|
1071
|
-
if response_format is not None and hasattr(response_format, "model_json_schema"):
|
|
1072
|
-
schema = response_format.model_json_schema()
|
|
1073
|
-
schema_msg = f"\n\nYou must respond with valid JSON matching this schema:\n{json.dumps(schema, indent=2)}"
|
|
1074
|
-
if system_prompt:
|
|
1075
|
-
system_prompt += schema_msg
|
|
1076
|
-
else:
|
|
1077
|
-
system_prompt = schema_msg
|
|
1078
|
-
|
|
1079
|
-
# Prepare parameters
|
|
1080
|
-
call_params = {
|
|
1081
|
-
"model": self.model,
|
|
1082
|
-
"messages": anthropic_messages,
|
|
1083
|
-
"max_tokens": max_completion_tokens if max_completion_tokens is not None else 4096,
|
|
1084
|
-
}
|
|
1085
|
-
|
|
1086
|
-
if system_prompt:
|
|
1087
|
-
call_params["system"] = system_prompt
|
|
1088
|
-
|
|
1089
|
-
last_exception = None
|
|
1090
|
-
|
|
1091
|
-
for attempt in range(max_retries + 1):
|
|
1092
|
-
try:
|
|
1093
|
-
response = await self._anthropic_client.messages.create(**call_params)
|
|
1094
|
-
|
|
1095
|
-
# Anthropic response content is a list of blocks
|
|
1096
|
-
content = ""
|
|
1097
|
-
for block in response.content:
|
|
1098
|
-
if block.type == "text":
|
|
1099
|
-
content += block.text
|
|
1100
|
-
|
|
1101
|
-
if response_format is not None:
|
|
1102
|
-
# Models may wrap JSON in markdown code blocks
|
|
1103
|
-
clean_content = content
|
|
1104
|
-
if "```json" in content:
|
|
1105
|
-
clean_content = content.split("```json")[1].split("```")[0].strip()
|
|
1106
|
-
elif "```" in content:
|
|
1107
|
-
clean_content = content.split("```")[1].split("```")[0].strip()
|
|
1108
|
-
|
|
1109
|
-
try:
|
|
1110
|
-
json_data = json.loads(clean_content)
|
|
1111
|
-
except json.JSONDecodeError:
|
|
1112
|
-
# Fallback to parsing raw content if markdown stripping failed
|
|
1113
|
-
json_data = json.loads(content)
|
|
1114
|
-
|
|
1115
|
-
if skip_validation:
|
|
1116
|
-
result = json_data
|
|
1117
|
-
else:
|
|
1118
|
-
result = response_format.model_validate(json_data)
|
|
1119
|
-
else:
|
|
1120
|
-
result = content
|
|
1121
|
-
|
|
1122
|
-
# Record metrics and log slow calls
|
|
1123
|
-
duration = time.time() - start_time
|
|
1124
|
-
input_tokens = response.usage.input_tokens or 0 if response.usage else 0
|
|
1125
|
-
output_tokens = response.usage.output_tokens or 0 if response.usage else 0
|
|
1126
|
-
total_tokens = input_tokens + output_tokens
|
|
1127
|
-
|
|
1128
|
-
# Record LLM metrics
|
|
1129
|
-
metrics = get_metrics_collector()
|
|
1130
|
-
metrics.record_llm_call(
|
|
1131
|
-
provider=self.provider,
|
|
1132
|
-
model=self.model,
|
|
1133
|
-
scope=scope,
|
|
1134
|
-
duration=duration,
|
|
1135
|
-
input_tokens=input_tokens,
|
|
1136
|
-
output_tokens=output_tokens,
|
|
1137
|
-
success=True,
|
|
1138
|
-
)
|
|
1139
|
-
|
|
1140
|
-
# Log slow calls
|
|
1141
|
-
if duration > 10.0:
|
|
1142
|
-
wait_info = f", wait={semaphore_wait_time:.3f}s" if semaphore_wait_time > 0.1 else ""
|
|
1143
|
-
logger.info(
|
|
1144
|
-
f"slow llm call: scope={scope}, model={self.provider}/{self.model}, "
|
|
1145
|
-
f"input_tokens={input_tokens}, output_tokens={output_tokens}, "
|
|
1146
|
-
f"time={duration:.3f}s{wait_info}"
|
|
1147
|
-
)
|
|
1148
|
-
|
|
1149
|
-
if return_usage:
|
|
1150
|
-
token_usage = TokenUsage(
|
|
1151
|
-
input_tokens=input_tokens,
|
|
1152
|
-
output_tokens=output_tokens,
|
|
1153
|
-
total_tokens=total_tokens,
|
|
1154
|
-
)
|
|
1155
|
-
return result, token_usage
|
|
1156
|
-
return result
|
|
1157
|
-
|
|
1158
|
-
except json.JSONDecodeError as e:
|
|
1159
|
-
last_exception = e
|
|
1160
|
-
if attempt < max_retries:
|
|
1161
|
-
logger.warning("Anthropic returned invalid JSON, retrying...")
|
|
1162
|
-
backoff = min(initial_backoff * (2**attempt), max_backoff)
|
|
1163
|
-
await asyncio.sleep(backoff)
|
|
1164
|
-
continue
|
|
1165
|
-
else:
|
|
1166
|
-
logger.error(f"Anthropic returned invalid JSON after {max_retries + 1} attempts")
|
|
1167
|
-
raise
|
|
1168
|
-
|
|
1169
|
-
except (APIConnectionError, RateLimitError, APIStatusError) as e:
|
|
1170
|
-
# Fast fail on 401/403
|
|
1171
|
-
if isinstance(e, APIStatusError) and e.status_code in (401, 403):
|
|
1172
|
-
logger.error(f"Anthropic auth error (HTTP {e.status_code}), not retrying: {str(e)}")
|
|
1173
|
-
raise
|
|
1174
|
-
|
|
1175
|
-
last_exception = e
|
|
1176
|
-
if attempt < max_retries:
|
|
1177
|
-
# Check if it's a rate limit or server error
|
|
1178
|
-
should_retry = isinstance(e, (APIConnectionError, RateLimitError)) or (
|
|
1179
|
-
isinstance(e, APIStatusError) and e.status_code >= 500
|
|
1180
|
-
)
|
|
1181
|
-
|
|
1182
|
-
if should_retry:
|
|
1183
|
-
backoff = min(initial_backoff * (2**attempt), max_backoff)
|
|
1184
|
-
jitter = backoff * 0.2 * (2 * (time.time() % 1) - 1)
|
|
1185
|
-
await asyncio.sleep(backoff + jitter)
|
|
1186
|
-
continue
|
|
494
|
+
with open(auth_file) as f:
|
|
495
|
+
data = json.load(f)
|
|
1187
496
|
|
|
1188
|
-
|
|
1189
|
-
|
|
497
|
+
# Validate auth structure
|
|
498
|
+
auth_mode = data.get("auth_mode")
|
|
499
|
+
if auth_mode != "chatgpt":
|
|
500
|
+
raise ValueError(f"Expected auth_mode='chatgpt', got: {auth_mode}")
|
|
1190
501
|
|
|
1191
|
-
|
|
1192
|
-
|
|
1193
|
-
|
|
502
|
+
tokens = data.get("tokens", {})
|
|
503
|
+
access_token = tokens.get("access_token")
|
|
504
|
+
account_id = tokens.get("account_id")
|
|
1194
505
|
|
|
1195
|
-
if
|
|
1196
|
-
raise
|
|
1197
|
-
raise RuntimeError("Anthropic call failed after all retries")
|
|
506
|
+
if not access_token:
|
|
507
|
+
raise ValueError("No access_token found in Codex auth file. Run 'codex auth login' again.")
|
|
1198
508
|
|
|
1199
|
-
|
|
1200
|
-
self,
|
|
1201
|
-
messages: list[dict[str, str]],
|
|
1202
|
-
response_format: Any,
|
|
1203
|
-
max_completion_tokens: int | None,
|
|
1204
|
-
temperature: float | None,
|
|
1205
|
-
max_retries: int,
|
|
1206
|
-
initial_backoff: float,
|
|
1207
|
-
max_backoff: float,
|
|
1208
|
-
skip_validation: bool,
|
|
1209
|
-
start_time: float,
|
|
1210
|
-
scope: str = "memory",
|
|
1211
|
-
return_usage: bool = False,
|
|
1212
|
-
semaphore_wait_time: float = 0.0,
|
|
1213
|
-
) -> Any:
|
|
1214
|
-
"""
|
|
1215
|
-
Call Ollama using native API with JSON schema enforcement.
|
|
509
|
+
return access_token, account_id
|
|
1216
510
|
|
|
1217
|
-
|
|
1218
|
-
which provides better structured output control than the OpenAI-compatible API.
|
|
511
|
+
def _verify_claude_code_available(self) -> None:
|
|
1219
512
|
"""
|
|
1220
|
-
|
|
1221
|
-
schema = response_format.model_json_schema() if hasattr(response_format, "model_json_schema") else None
|
|
1222
|
-
|
|
1223
|
-
# Build the base URL for Ollama's native API
|
|
1224
|
-
# Default OpenAI-compatible URL is http://localhost:11434/v1
|
|
1225
|
-
# Native API is at http://localhost:11434/api/chat
|
|
1226
|
-
base_url = self.base_url or "http://localhost:11434/v1"
|
|
1227
|
-
if base_url.endswith("/v1"):
|
|
1228
|
-
native_url = base_url[:-3] + "/api/chat"
|
|
1229
|
-
else:
|
|
1230
|
-
native_url = base_url.rstrip("/") + "/api/chat"
|
|
1231
|
-
|
|
1232
|
-
# Build request payload
|
|
1233
|
-
payload = {
|
|
1234
|
-
"model": self.model,
|
|
1235
|
-
"messages": messages,
|
|
1236
|
-
"stream": False,
|
|
1237
|
-
}
|
|
1238
|
-
|
|
1239
|
-
# Add schema as format parameter for structured output
|
|
1240
|
-
if schema:
|
|
1241
|
-
payload["format"] = schema
|
|
1242
|
-
|
|
1243
|
-
# Add optional parameters with optimized defaults for Ollama
|
|
1244
|
-
# Benchmarking shows num_ctx=16384 + num_batch=512 is optimal
|
|
1245
|
-
options = {
|
|
1246
|
-
"num_ctx": 16384, # 16k context window for larger prompts
|
|
1247
|
-
"num_batch": 512, # Optimal batch size for prompt processing
|
|
1248
|
-
}
|
|
1249
|
-
if max_completion_tokens:
|
|
1250
|
-
options["num_predict"] = max_completion_tokens
|
|
1251
|
-
if temperature is not None:
|
|
1252
|
-
options["temperature"] = temperature
|
|
1253
|
-
payload["options"] = options
|
|
1254
|
-
|
|
1255
|
-
last_exception = None
|
|
1256
|
-
|
|
1257
|
-
async with httpx.AsyncClient(timeout=300.0) as client:
|
|
1258
|
-
for attempt in range(max_retries + 1):
|
|
1259
|
-
try:
|
|
1260
|
-
response = await client.post(native_url, json=payload)
|
|
1261
|
-
response.raise_for_status()
|
|
1262
|
-
|
|
1263
|
-
result = response.json()
|
|
1264
|
-
content = result.get("message", {}).get("content", "")
|
|
1265
|
-
|
|
1266
|
-
# Parse JSON response
|
|
1267
|
-
try:
|
|
1268
|
-
json_data = json.loads(content)
|
|
1269
|
-
except json.JSONDecodeError as json_err:
|
|
1270
|
-
content_preview = content[:500] if content else "<empty>"
|
|
1271
|
-
if content and len(content) > 700:
|
|
1272
|
-
content_preview = f"{content[:500]}...TRUNCATED...{content[-200:]}"
|
|
1273
|
-
logger.warning(
|
|
1274
|
-
f"Ollama JSON parse error (attempt {attempt + 1}/{max_retries + 1}): {json_err}\n"
|
|
1275
|
-
f" Model: ollama/{self.model}\n"
|
|
1276
|
-
f" Content length: {len(content) if content else 0} chars\n"
|
|
1277
|
-
f" Content preview: {content_preview!r}"
|
|
1278
|
-
)
|
|
1279
|
-
if attempt < max_retries:
|
|
1280
|
-
backoff = min(initial_backoff * (2**attempt), max_backoff)
|
|
1281
|
-
await asyncio.sleep(backoff)
|
|
1282
|
-
last_exception = json_err
|
|
1283
|
-
continue
|
|
1284
|
-
else:
|
|
1285
|
-
raise
|
|
1286
|
-
|
|
1287
|
-
# Extract token usage from Ollama response
|
|
1288
|
-
# Ollama returns prompt_eval_count (input) and eval_count (output)
|
|
1289
|
-
duration = time.time() - start_time
|
|
1290
|
-
input_tokens = result.get("prompt_eval_count", 0) or 0
|
|
1291
|
-
output_tokens = result.get("eval_count", 0) or 0
|
|
1292
|
-
total_tokens = input_tokens + output_tokens
|
|
1293
|
-
|
|
1294
|
-
# Record LLM metrics
|
|
1295
|
-
metrics = get_metrics_collector()
|
|
1296
|
-
metrics.record_llm_call(
|
|
1297
|
-
provider=self.provider,
|
|
1298
|
-
model=self.model,
|
|
1299
|
-
scope=scope,
|
|
1300
|
-
duration=duration,
|
|
1301
|
-
input_tokens=input_tokens,
|
|
1302
|
-
output_tokens=output_tokens,
|
|
1303
|
-
success=True,
|
|
1304
|
-
)
|
|
1305
|
-
|
|
1306
|
-
# Validate against Pydantic model or return raw JSON
|
|
1307
|
-
if skip_validation:
|
|
1308
|
-
validated_result = json_data
|
|
1309
|
-
else:
|
|
1310
|
-
validated_result = response_format.model_validate(json_data)
|
|
1311
|
-
|
|
1312
|
-
if return_usage:
|
|
1313
|
-
token_usage = TokenUsage(
|
|
1314
|
-
input_tokens=input_tokens,
|
|
1315
|
-
output_tokens=output_tokens,
|
|
1316
|
-
total_tokens=total_tokens,
|
|
1317
|
-
)
|
|
1318
|
-
return validated_result, token_usage
|
|
1319
|
-
return validated_result
|
|
1320
|
-
|
|
1321
|
-
except httpx.HTTPStatusError as e:
|
|
1322
|
-
last_exception = e
|
|
1323
|
-
if attempt < max_retries:
|
|
1324
|
-
logger.warning(
|
|
1325
|
-
f"Ollama HTTP error (attempt {attempt + 1}/{max_retries + 1}): {e.response.status_code}"
|
|
1326
|
-
)
|
|
1327
|
-
backoff = min(initial_backoff * (2**attempt), max_backoff)
|
|
1328
|
-
await asyncio.sleep(backoff)
|
|
1329
|
-
continue
|
|
1330
|
-
else:
|
|
1331
|
-
logger.error(f"Ollama HTTP error after {max_retries + 1} attempts: {e}")
|
|
1332
|
-
raise
|
|
1333
|
-
|
|
1334
|
-
except httpx.RequestError as e:
|
|
1335
|
-
last_exception = e
|
|
1336
|
-
if attempt < max_retries:
|
|
1337
|
-
logger.warning(f"Ollama connection error (attempt {attempt + 1}/{max_retries + 1}): {e}")
|
|
1338
|
-
backoff = min(initial_backoff * (2**attempt), max_backoff)
|
|
1339
|
-
await asyncio.sleep(backoff)
|
|
1340
|
-
continue
|
|
1341
|
-
else:
|
|
1342
|
-
logger.error(f"Ollama connection error after {max_retries + 1} attempts: {e}")
|
|
1343
|
-
raise
|
|
1344
|
-
|
|
1345
|
-
except Exception as e:
|
|
1346
|
-
logger.error(f"Unexpected error during Ollama call: {type(e).__name__}: {e}")
|
|
1347
|
-
raise
|
|
1348
|
-
|
|
1349
|
-
if last_exception:
|
|
1350
|
-
raise last_exception
|
|
1351
|
-
raise RuntimeError("Ollama call failed after all retries")
|
|
1352
|
-
|
|
1353
|
-
async def _call_gemini(
|
|
1354
|
-
self,
|
|
1355
|
-
messages: list[dict[str, str]],
|
|
1356
|
-
response_format: Any | None,
|
|
1357
|
-
max_retries: int,
|
|
1358
|
-
initial_backoff: float,
|
|
1359
|
-
max_backoff: float,
|
|
1360
|
-
skip_validation: bool,
|
|
1361
|
-
start_time: float,
|
|
1362
|
-
scope: str = "memory",
|
|
1363
|
-
return_usage: bool = False,
|
|
1364
|
-
semaphore_wait_time: float = 0.0,
|
|
1365
|
-
) -> Any:
|
|
1366
|
-
"""Handle Gemini-specific API calls."""
|
|
1367
|
-
# Convert OpenAI-style messages to Gemini format
|
|
1368
|
-
system_instruction = None
|
|
1369
|
-
gemini_contents = []
|
|
1370
|
-
|
|
1371
|
-
for msg in messages:
|
|
1372
|
-
role = msg.get("role", "user")
|
|
1373
|
-
content = msg.get("content", "")
|
|
1374
|
-
|
|
1375
|
-
if role == "system":
|
|
1376
|
-
if system_instruction:
|
|
1377
|
-
system_instruction += "\n\n" + content
|
|
1378
|
-
else:
|
|
1379
|
-
system_instruction = content
|
|
1380
|
-
elif role == "assistant":
|
|
1381
|
-
gemini_contents.append(genai_types.Content(role="model", parts=[genai_types.Part(text=content)]))
|
|
1382
|
-
else:
|
|
1383
|
-
gemini_contents.append(genai_types.Content(role="user", parts=[genai_types.Part(text=content)]))
|
|
1384
|
-
|
|
1385
|
-
# Add JSON schema instruction if response_format is provided
|
|
1386
|
-
if response_format is not None and hasattr(response_format, "model_json_schema"):
|
|
1387
|
-
schema = response_format.model_json_schema()
|
|
1388
|
-
schema_msg = f"\n\nYou must respond with valid JSON matching this schema:\n{json.dumps(schema, indent=2)}"
|
|
1389
|
-
if system_instruction:
|
|
1390
|
-
system_instruction += schema_msg
|
|
1391
|
-
else:
|
|
1392
|
-
system_instruction = schema_msg
|
|
1393
|
-
|
|
1394
|
-
# Build generation config
|
|
1395
|
-
config_kwargs = {}
|
|
1396
|
-
if system_instruction:
|
|
1397
|
-
config_kwargs["system_instruction"] = system_instruction
|
|
1398
|
-
if response_format is not None:
|
|
1399
|
-
config_kwargs["response_mime_type"] = "application/json"
|
|
1400
|
-
config_kwargs["response_schema"] = response_format
|
|
1401
|
-
|
|
1402
|
-
generation_config = genai_types.GenerateContentConfig(**config_kwargs) if config_kwargs else None
|
|
1403
|
-
|
|
1404
|
-
last_exception = None
|
|
1405
|
-
|
|
1406
|
-
for attempt in range(max_retries + 1):
|
|
1407
|
-
try:
|
|
1408
|
-
response = await self._gemini_client.aio.models.generate_content(
|
|
1409
|
-
model=self.model,
|
|
1410
|
-
contents=gemini_contents,
|
|
1411
|
-
config=generation_config,
|
|
1412
|
-
)
|
|
1413
|
-
|
|
1414
|
-
content = response.text
|
|
1415
|
-
|
|
1416
|
-
# Handle empty response
|
|
1417
|
-
if content is None:
|
|
1418
|
-
block_reason = None
|
|
1419
|
-
if hasattr(response, "candidates") and response.candidates:
|
|
1420
|
-
candidate = response.candidates[0]
|
|
1421
|
-
if hasattr(candidate, "finish_reason"):
|
|
1422
|
-
block_reason = candidate.finish_reason
|
|
1423
|
-
|
|
1424
|
-
if attempt < max_retries:
|
|
1425
|
-
logger.warning(f"Gemini returned empty response (reason: {block_reason}), retrying...")
|
|
1426
|
-
backoff = min(initial_backoff * (2**attempt), max_backoff)
|
|
1427
|
-
await asyncio.sleep(backoff)
|
|
1428
|
-
continue
|
|
1429
|
-
else:
|
|
1430
|
-
raise RuntimeError(f"Gemini returned empty response after {max_retries + 1} attempts")
|
|
1431
|
-
|
|
1432
|
-
if response_format is not None:
|
|
1433
|
-
json_data = json.loads(content)
|
|
1434
|
-
if skip_validation:
|
|
1435
|
-
result = json_data
|
|
1436
|
-
else:
|
|
1437
|
-
result = response_format.model_validate(json_data)
|
|
1438
|
-
else:
|
|
1439
|
-
result = content
|
|
1440
|
-
|
|
1441
|
-
# Record metrics and log slow calls
|
|
1442
|
-
duration = time.time() - start_time
|
|
1443
|
-
input_tokens = 0
|
|
1444
|
-
output_tokens = 0
|
|
1445
|
-
if hasattr(response, "usage_metadata") and response.usage_metadata:
|
|
1446
|
-
usage = response.usage_metadata
|
|
1447
|
-
input_tokens = usage.prompt_token_count or 0
|
|
1448
|
-
output_tokens = usage.candidates_token_count or 0
|
|
1449
|
-
|
|
1450
|
-
# Record LLM metrics
|
|
1451
|
-
metrics = get_metrics_collector()
|
|
1452
|
-
metrics.record_llm_call(
|
|
1453
|
-
provider=self.provider,
|
|
1454
|
-
model=self.model,
|
|
1455
|
-
scope=scope,
|
|
1456
|
-
duration=duration,
|
|
1457
|
-
input_tokens=input_tokens,
|
|
1458
|
-
output_tokens=output_tokens,
|
|
1459
|
-
success=True,
|
|
1460
|
-
)
|
|
1461
|
-
|
|
1462
|
-
# Log slow calls
|
|
1463
|
-
if duration > 10.0 and input_tokens > 0:
|
|
1464
|
-
wait_info = f", wait={semaphore_wait_time:.3f}s" if semaphore_wait_time > 0.1 else ""
|
|
1465
|
-
logger.info(
|
|
1466
|
-
f"slow llm call: scope={scope}, model={self.provider}/{self.model}, "
|
|
1467
|
-
f"input_tokens={input_tokens}, output_tokens={output_tokens}, "
|
|
1468
|
-
f"time={duration:.3f}s{wait_info}"
|
|
1469
|
-
)
|
|
513
|
+
Verify that Claude Agent SDK can be imported and is properly configured.
|
|
1470
514
|
|
|
1471
|
-
|
|
1472
|
-
|
|
1473
|
-
|
|
1474
|
-
output_tokens=output_tokens,
|
|
1475
|
-
total_tokens=input_tokens + output_tokens,
|
|
1476
|
-
)
|
|
1477
|
-
return result, token_usage
|
|
1478
|
-
return result
|
|
1479
|
-
|
|
1480
|
-
except json.JSONDecodeError as e:
|
|
1481
|
-
last_exception = e
|
|
1482
|
-
if attempt < max_retries:
|
|
1483
|
-
logger.warning("Gemini returned invalid JSON, retrying...")
|
|
1484
|
-
backoff = min(initial_backoff * (2**attempt), max_backoff)
|
|
1485
|
-
await asyncio.sleep(backoff)
|
|
1486
|
-
continue
|
|
1487
|
-
else:
|
|
1488
|
-
logger.error(f"Gemini returned invalid JSON after {max_retries + 1} attempts")
|
|
1489
|
-
raise
|
|
1490
|
-
|
|
1491
|
-
except genai_errors.APIError as e:
|
|
1492
|
-
# Fast fail only on 401 (unauthorized) and 403 (forbidden) - these won't recover with retries
|
|
1493
|
-
if e.code in (401, 403):
|
|
1494
|
-
logger.error(f"Gemini auth error (HTTP {e.code}), not retrying: {str(e)}")
|
|
1495
|
-
raise
|
|
1496
|
-
|
|
1497
|
-
# Retry on retryable errors (rate limits, server errors, and other client errors like 400)
|
|
1498
|
-
if e.code in (400, 429, 500, 502, 503, 504) or (e.code and e.code >= 500):
|
|
1499
|
-
last_exception = e
|
|
1500
|
-
if attempt < max_retries:
|
|
1501
|
-
backoff = min(initial_backoff * (2**attempt), max_backoff)
|
|
1502
|
-
jitter = backoff * 0.2 * (2 * (time.time() % 1) - 1)
|
|
1503
|
-
await asyncio.sleep(backoff + jitter)
|
|
1504
|
-
else:
|
|
1505
|
-
logger.error(f"Gemini API error after {max_retries + 1} attempts: {str(e)}")
|
|
1506
|
-
raise
|
|
1507
|
-
else:
|
|
1508
|
-
logger.error(f"Gemini API error: {type(e).__name__}: {str(e)}")
|
|
1509
|
-
raise
|
|
1510
|
-
|
|
1511
|
-
except Exception as e:
|
|
1512
|
-
logger.error(f"Unexpected error during Gemini call: {type(e).__name__}: {str(e)}")
|
|
1513
|
-
raise
|
|
1514
|
-
|
|
1515
|
-
if last_exception:
|
|
1516
|
-
raise last_exception
|
|
1517
|
-
raise RuntimeError("Gemini call failed after all retries")
|
|
1518
|
-
|
|
1519
|
-
async def _call_mock(
|
|
1520
|
-
self,
|
|
1521
|
-
messages: list[dict[str, str]],
|
|
1522
|
-
response_format: Any | None,
|
|
1523
|
-
scope: str,
|
|
1524
|
-
return_usage: bool,
|
|
1525
|
-
) -> Any:
|
|
515
|
+
Raises:
|
|
516
|
+
ImportError: If Claude Agent SDK is not installed.
|
|
517
|
+
RuntimeError: If Claude Code is not authenticated.
|
|
1526
518
|
"""
|
|
1527
|
-
|
|
519
|
+
try:
|
|
520
|
+
# Import Claude Agent SDK
|
|
521
|
+
# Reduce Claude Agent SDK logging verbosity
|
|
522
|
+
import logging as sdk_logging
|
|
1528
523
|
|
|
1529
|
-
|
|
1530
|
-
"""
|
|
1531
|
-
# Record the call for test verification
|
|
1532
|
-
call_record = {
|
|
1533
|
-
"provider": self.provider,
|
|
1534
|
-
"model": self.model,
|
|
1535
|
-
"messages": messages,
|
|
1536
|
-
"response_format": response_format.__name__
|
|
1537
|
-
if response_format and hasattr(response_format, "__name__")
|
|
1538
|
-
else str(response_format),
|
|
1539
|
-
"scope": scope,
|
|
1540
|
-
}
|
|
1541
|
-
self._mock_calls.append(call_record)
|
|
1542
|
-
logger.debug(f"Mock LLM call recorded: scope={scope}, model={self.model}")
|
|
1543
|
-
|
|
1544
|
-
# Return mock response
|
|
1545
|
-
if self._mock_response is not None:
|
|
1546
|
-
result = self._mock_response
|
|
1547
|
-
elif response_format is not None:
|
|
1548
|
-
# Try to create a minimal valid instance of the response format
|
|
1549
|
-
try:
|
|
1550
|
-
# For Pydantic models, try to create with minimal valid data
|
|
1551
|
-
result = {"mock": True}
|
|
1552
|
-
except Exception:
|
|
1553
|
-
result = {"mock": True}
|
|
1554
|
-
else:
|
|
1555
|
-
result = "mock response"
|
|
1556
|
-
|
|
1557
|
-
if return_usage:
|
|
1558
|
-
token_usage = TokenUsage(input_tokens=10, output_tokens=5, total_tokens=15)
|
|
1559
|
-
return result, token_usage
|
|
1560
|
-
return result
|
|
524
|
+
from claude_agent_sdk import query # noqa: F401
|
|
1561
525
|
|
|
1562
|
-
|
|
1563
|
-
|
|
1564
|
-
self._mock_response = response
|
|
526
|
+
sdk_logging.getLogger("claude_agent_sdk").setLevel(sdk_logging.WARNING)
|
|
527
|
+
sdk_logging.getLogger("claude_agent_sdk._internal").setLevel(sdk_logging.WARNING)
|
|
1565
528
|
|
|
1566
|
-
|
|
1567
|
-
|
|
1568
|
-
|
|
529
|
+
logger.debug("Claude Agent SDK imported successfully")
|
|
530
|
+
except ImportError as e:
|
|
531
|
+
raise ImportError(
|
|
532
|
+
"Claude Agent SDK not installed. Run: uv add claude-agent-sdk or pip install claude-agent-sdk"
|
|
533
|
+
) from e
|
|
1569
534
|
|
|
1570
|
-
|
|
1571
|
-
|
|
1572
|
-
self._mock_calls = []
|
|
535
|
+
# SDK will automatically check for authentication when first used
|
|
536
|
+
# No need to verify here - let it fail gracefully on first call with helpful error
|
|
1573
537
|
|
|
1574
538
|
async def cleanup(self) -> None:
|
|
1575
539
|
"""Clean up resources."""
|
|
@@ -1579,9 +543,14 @@ class LLMProvider:
|
|
|
1579
543
|
def for_memory(cls) -> "LLMProvider":
|
|
1580
544
|
"""Create provider for memory operations from environment variables."""
|
|
1581
545
|
provider = os.getenv("HINDSIGHT_API_LLM_PROVIDER", "groq")
|
|
1582
|
-
api_key = os.getenv("HINDSIGHT_API_LLM_API_KEY")
|
|
1583
|
-
|
|
1584
|
-
|
|
546
|
+
api_key = os.getenv("HINDSIGHT_API_LLM_API_KEY", "")
|
|
547
|
+
|
|
548
|
+
# API key not needed for openai-codex (uses OAuth) or claude-code (uses Keychain OAuth)
|
|
549
|
+
if not api_key and provider not in ("openai-codex", "claude-code"):
|
|
550
|
+
raise ValueError(
|
|
551
|
+
"HINDSIGHT_API_LLM_API_KEY environment variable is required (unless using openai-codex or claude-code)"
|
|
552
|
+
)
|
|
553
|
+
|
|
1585
554
|
base_url = os.getenv("HINDSIGHT_API_LLM_BASE_URL", "")
|
|
1586
555
|
model = os.getenv("HINDSIGHT_API_LLM_MODEL", "openai/gpt-oss-120b")
|
|
1587
556
|
|
|
@@ -1591,11 +560,15 @@ class LLMProvider:
|
|
|
1591
560
|
def for_answer_generation(cls) -> "LLMProvider":
|
|
1592
561
|
"""Create provider for answer generation. Falls back to memory config if not set."""
|
|
1593
562
|
provider = os.getenv("HINDSIGHT_API_ANSWER_LLM_PROVIDER", os.getenv("HINDSIGHT_API_LLM_PROVIDER", "groq"))
|
|
1594
|
-
api_key = os.getenv("HINDSIGHT_API_ANSWER_LLM_API_KEY", os.getenv("HINDSIGHT_API_LLM_API_KEY"))
|
|
1595
|
-
|
|
563
|
+
api_key = os.getenv("HINDSIGHT_API_ANSWER_LLM_API_KEY", os.getenv("HINDSIGHT_API_LLM_API_KEY", ""))
|
|
564
|
+
|
|
565
|
+
# API key not needed for openai-codex (uses OAuth) or claude-code (uses Keychain OAuth)
|
|
566
|
+
if not api_key and provider not in ("openai-codex", "claude-code"):
|
|
1596
567
|
raise ValueError(
|
|
1597
|
-
"HINDSIGHT_API_LLM_API_KEY or HINDSIGHT_API_ANSWER_LLM_API_KEY environment variable is required"
|
|
568
|
+
"HINDSIGHT_API_LLM_API_KEY or HINDSIGHT_API_ANSWER_LLM_API_KEY environment variable is required "
|
|
569
|
+
"(unless using openai-codex or claude-code)"
|
|
1598
570
|
)
|
|
571
|
+
|
|
1599
572
|
base_url = os.getenv("HINDSIGHT_API_ANSWER_LLM_BASE_URL", os.getenv("HINDSIGHT_API_LLM_BASE_URL", ""))
|
|
1600
573
|
model = os.getenv("HINDSIGHT_API_ANSWER_LLM_MODEL", os.getenv("HINDSIGHT_API_LLM_MODEL", "openai/gpt-oss-120b"))
|
|
1601
574
|
|
|
@@ -1605,11 +578,15 @@ class LLMProvider:
|
|
|
1605
578
|
def for_judge(cls) -> "LLMProvider":
|
|
1606
579
|
"""Create provider for judge/evaluator operations. Falls back to memory config if not set."""
|
|
1607
580
|
provider = os.getenv("HINDSIGHT_API_JUDGE_LLM_PROVIDER", os.getenv("HINDSIGHT_API_LLM_PROVIDER", "groq"))
|
|
1608
|
-
api_key = os.getenv("HINDSIGHT_API_JUDGE_LLM_API_KEY", os.getenv("HINDSIGHT_API_LLM_API_KEY"))
|
|
1609
|
-
|
|
581
|
+
api_key = os.getenv("HINDSIGHT_API_JUDGE_LLM_API_KEY", os.getenv("HINDSIGHT_API_LLM_API_KEY", ""))
|
|
582
|
+
|
|
583
|
+
# API key not needed for openai-codex (uses OAuth) or claude-code (uses Keychain OAuth)
|
|
584
|
+
if not api_key and provider not in ("openai-codex", "claude-code"):
|
|
1610
585
|
raise ValueError(
|
|
1611
|
-
"HINDSIGHT_API_LLM_API_KEY or HINDSIGHT_API_JUDGE_LLM_API_KEY environment variable is required"
|
|
586
|
+
"HINDSIGHT_API_LLM_API_KEY or HINDSIGHT_API_JUDGE_LLM_API_KEY environment variable is required "
|
|
587
|
+
"(unless using openai-codex or claude-code)"
|
|
1612
588
|
)
|
|
589
|
+
|
|
1613
590
|
base_url = os.getenv("HINDSIGHT_API_JUDGE_LLM_BASE_URL", os.getenv("HINDSIGHT_API_LLM_BASE_URL", ""))
|
|
1614
591
|
model = os.getenv("HINDSIGHT_API_JUDGE_LLM_MODEL", os.getenv("HINDSIGHT_API_LLM_MODEL", "openai/gpt-oss-120b"))
|
|
1615
592
|
|