hindsight-api 0.4.7__py3-none-any.whl → 0.4.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -8,15 +8,14 @@ import logging
8
8
  import os
9
9
  import re
10
10
  import time
11
+ import uuid
12
+ from pathlib import Path
11
13
  from typing import Any
12
14
 
13
15
  import httpx
14
- from google import genai
15
- from google.genai import errors as genai_errors
16
- from google.genai import types as genai_types
17
16
  from openai import APIConnectionError, APIStatusError, AsyncOpenAI, LengthFinishReasonError
18
17
 
19
- # Vertex AI imports (conditional)
18
+ # Vertex AI imports (conditional - for LLMProvider to pass credentials to GeminiLLM)
20
19
  try:
21
20
  import google.auth
22
21
  from google.oauth2 import service_account
@@ -61,6 +60,108 @@ class OutputTooLongError(Exception):
61
60
  pass
62
61
 
63
62
 
63
+ def create_llm_provider(
64
+ provider: str,
65
+ api_key: str,
66
+ base_url: str,
67
+ model: str,
68
+ reasoning_effort: str,
69
+ groq_service_tier: str | None = None,
70
+ vertexai_project_id: str | None = None,
71
+ vertexai_region: str | None = None,
72
+ vertexai_credentials: Any = None,
73
+ ) -> Any: # Returns LLMInterface
74
+ """
75
+ Factory function to create the appropriate LLM provider implementation.
76
+
77
+ Args:
78
+ provider: Provider name ("openai", "groq", "ollama", "gemini", "anthropic", etc.).
79
+ api_key: API key (may be None for local providers or OAuth providers).
80
+ base_url: Base URL for the API.
81
+ model: Model name.
82
+ reasoning_effort: Reasoning effort level for supported providers.
83
+ groq_service_tier: Groq service tier (for Groq provider).
84
+ vertexai_project_id: Vertex AI project ID (for VertexAI provider).
85
+ vertexai_region: Vertex AI region (for VertexAI provider).
86
+ vertexai_credentials: Vertex AI credentials object (for VertexAI provider).
87
+
88
+ Returns:
89
+ LLMInterface implementation for the specified provider.
90
+ """
91
+ from .llm_interface import LLMInterface
92
+ from .providers import (
93
+ AnthropicLLM,
94
+ ClaudeCodeLLM,
95
+ CodexLLM,
96
+ GeminiLLM,
97
+ MockLLM,
98
+ OpenAICompatibleLLM,
99
+ )
100
+
101
+ provider_lower = provider.lower()
102
+
103
+ if provider_lower == "openai-codex":
104
+ return CodexLLM(
105
+ provider=provider,
106
+ api_key=api_key,
107
+ base_url=base_url,
108
+ model=model,
109
+ reasoning_effort=reasoning_effort,
110
+ )
111
+
112
+ elif provider_lower == "claude-code":
113
+ return ClaudeCodeLLM(
114
+ provider=provider,
115
+ api_key=api_key,
116
+ base_url=base_url,
117
+ model=model,
118
+ reasoning_effort=reasoning_effort,
119
+ )
120
+
121
+ elif provider_lower == "mock":
122
+ return MockLLM(
123
+ provider=provider,
124
+ api_key=api_key,
125
+ base_url=base_url,
126
+ model=model,
127
+ reasoning_effort=reasoning_effort,
128
+ )
129
+
130
+ elif provider_lower in ("gemini", "vertexai"):
131
+ return GeminiLLM(
132
+ provider=provider,
133
+ api_key=api_key,
134
+ base_url=base_url,
135
+ model=model,
136
+ reasoning_effort=reasoning_effort,
137
+ vertexai_project_id=vertexai_project_id,
138
+ vertexai_region=vertexai_region,
139
+ vertexai_credentials=vertexai_credentials,
140
+ )
141
+
142
+ elif provider_lower == "anthropic":
143
+ return AnthropicLLM(
144
+ provider=provider,
145
+ api_key=api_key,
146
+ base_url=base_url,
147
+ model=model,
148
+ reasoning_effort=reasoning_effort,
149
+ )
150
+
151
+ elif provider_lower in ("openai", "groq", "ollama", "lmstudio"):
152
+ return OpenAICompatibleLLM(
153
+ provider=provider,
154
+ api_key=api_key,
155
+ base_url=base_url,
156
+ model=model,
157
+ reasoning_effort=reasoning_effort,
158
+ groq_service_tier=groq_service_tier,
159
+ )
160
+
161
+ else:
162
+ raise ValueError(f"Unknown provider: {provider}")
163
+
164
+
64
165
  class LLMProvider:
65
166
  """
66
167
  Unified LLM provider.
@@ -97,14 +198,21 @@ class LLMProvider:
97
198
  self.groq_service_tier = groq_service_tier or os.getenv(ENV_LLM_GROQ_SERVICE_TIER, "auto")
98
199
 
99
200
  # Validate provider
100
- valid_providers = ["openai", "groq", "ollama", "gemini", "anthropic", "lmstudio", "vertexai", "mock"]
201
+ valid_providers = [
202
+ "openai",
203
+ "groq",
204
+ "ollama",
205
+ "gemini",
206
+ "anthropic",
207
+ "lmstudio",
208
+ "vertexai",
209
+ "openai-codex",
210
+ "claude-code",
211
+ "mock",
212
+ ]
101
213
  if self.provider not in valid_providers:
102
214
  raise ValueError(f"Invalid LLM provider: {self.provider}. Must be one of: {', '.join(valid_providers)}")
103
215
 
104
- # Mock provider tracking (for testing)
105
- self._mock_calls: list[dict] = []
106
- self._mock_response: Any = None
107
-
108
216
  # Set default base URLs
109
217
  if not self.base_url:
110
218
  if self.provider == "groq":
@@ -114,24 +222,24 @@ class LLMProvider:
114
222
  elif self.provider == "lmstudio":
115
223
  self.base_url = "http://localhost:1234/v1"
116
224
 
117
- # Vertex AI config stored for client creation below
118
- self._vertexai_project_id: str | None = None
119
- self._vertexai_region: str | None = None
120
- self._vertexai_credentials: Any = None
225
+ # Prepare Vertex AI config (if applicable)
226
+ vertexai_project_id = None
227
+ vertexai_region = None
228
+ vertexai_credentials = None
121
229
 
122
230
  if self.provider == "vertexai":
123
231
  from ..config import get_config
124
232
 
125
233
  config = get_config()
126
234
 
127
- self._vertexai_project_id = config.llm_vertexai_project_id
128
- if not self._vertexai_project_id:
235
+ vertexai_project_id = config.llm_vertexai_project_id
236
+ if not vertexai_project_id:
129
237
  raise ValueError(
130
238
  "HINDSIGHT_API_LLM_VERTEXAI_PROJECT_ID is required for Vertex AI provider. "
131
239
  "Set it to your GCP project ID."
132
240
  )
133
241
 
134
- self._vertexai_region = config.llm_vertexai_region or "us-central1"
242
+ vertexai_region = config.llm_vertexai_region or "us-central1"
135
243
  service_account_key = config.llm_vertexai_service_account_key
136
244
 
137
245
  # Load explicit service account credentials if provided
@@ -141,75 +249,71 @@ class LLMProvider:
141
249
  "Vertex AI service account auth requires 'google-auth' package. "
142
250
  "Install with: pip install google-auth"
143
251
  )
144
- self._vertexai_credentials = service_account.Credentials.from_service_account_file(
252
+ vertexai_credentials = service_account.Credentials.from_service_account_file(
145
253
  service_account_key,
146
254
  scopes=["https://www.googleapis.com/auth/cloud-platform"],
147
255
  )
148
256
  logger.info(f"Vertex AI: Using service account key: {service_account_key}")
149
257
 
150
258
  # Strip google/ prefix from model name — native SDK uses bare names
151
- # e.g. "google/gemini-2.0-flash-lite-001" -> "gemini-2.0-flash-lite-001"
152
259
  if self.model.startswith("google/"):
153
260
  self.model = self.model[len("google/") :]
154
261
 
155
262
  logger.info(
156
- f"Vertex AI: project={self._vertexai_project_id}, region={self._vertexai_region}, "
263
+ f"Vertex AI: project={vertexai_project_id}, region={vertexai_region}, "
157
264
  f"model={self.model}, auth={'service_account' if service_account_key else 'ADC'}"
158
265
  )
159
266
 
160
- # Validate API key (not needed for ollama, lmstudio, vertexai, or mock)
161
- if self.provider not in ("ollama", "lmstudio", "vertexai", "mock") and not self.api_key:
162
- raise ValueError(f"API key not found for {self.provider}")
267
+ # Create provider implementation using factory
268
+ self._provider_impl = create_llm_provider(
269
+ provider=self.provider,
270
+ api_key=self.api_key,
271
+ base_url=self.base_url,
272
+ model=self.model,
273
+ reasoning_effort=self.reasoning_effort,
274
+ groq_service_tier=self.groq_service_tier,
275
+ vertexai_project_id=vertexai_project_id,
276
+ vertexai_region=vertexai_region,
277
+ vertexai_credentials=vertexai_credentials,
278
+ )
279
+
280
+ # Backward compatibility: Keep mock provider properties
281
+ self._mock_calls: list[dict] = []
282
+ self._mock_response: Any = None
163
283
 
164
- # Get timeout config (set HINDSIGHT_API_LLM_TIMEOUT for local LLMs that need longer timeouts)
165
- self.timeout = float(os.getenv(ENV_LLM_TIMEOUT, str(DEFAULT_LLM_TIMEOUT)))
284
+ @property
285
+ def _client(self) -> Any:
286
+ """
287
+ Get the OpenAI client for OpenAI-compatible providers.
166
288
 
167
- # Create client based on provider
168
- self._client = None
169
- self._gemini_client = None
170
- self._anthropic_client = None
289
+ This property provides backward compatibility for code that directly accesses
290
+ the _client attribute (e.g., benchmarks, memory_engine).
171
291
 
172
- if self.provider == "mock":
173
- # Mock provider - no client needed
174
- pass
175
- elif self.provider == "gemini":
176
- self._gemini_client = genai.Client(api_key=self.api_key)
177
- elif self.provider == "anthropic":
178
- from anthropic import AsyncAnthropic
179
-
180
- # Only pass base_url if it's set (Anthropic uses default URL otherwise)
181
- anthropic_kwargs = {"api_key": self.api_key}
182
- if self.base_url:
183
- anthropic_kwargs["base_url"] = self.base_url
184
- if self.timeout:
185
- anthropic_kwargs["timeout"] = self.timeout
186
- self._anthropic_client = AsyncAnthropic(**anthropic_kwargs)
187
- elif self.provider == "vertexai":
188
- # Native genai SDK with Vertex AI — handles ADC automatically,
189
- # or uses explicit service account credentials if provided
190
- client_kwargs = {
191
- "vertexai": True,
192
- "project": self._vertexai_project_id,
193
- "location": self._vertexai_region,
194
- }
195
- if self._vertexai_credentials is not None:
196
- client_kwargs["credentials"] = self._vertexai_credentials
197
- self._gemini_client = genai.Client(**client_kwargs)
198
- elif self.provider in ("ollama", "lmstudio"):
199
- # Use dummy key if not provided for local
200
- api_key = self.api_key or "local"
201
- client_kwargs = {"api_key": api_key, "base_url": self.base_url, "max_retries": 0}
202
- if self.timeout:
203
- client_kwargs["timeout"] = self.timeout
204
- self._client = AsyncOpenAI(**client_kwargs)
205
- else:
206
- # Only pass base_url if it's set (OpenAI uses default URL otherwise)
207
- client_kwargs = {"api_key": self.api_key, "max_retries": 0}
208
- if self.base_url:
209
- client_kwargs["base_url"] = self.base_url
210
- if self.timeout:
211
- client_kwargs["timeout"] = self.timeout
212
- self._client = AsyncOpenAI(**client_kwargs)
292
+ Returns:
293
+ AsyncOpenAI client instance for OpenAI-compatible providers, or None for other providers.
294
+ """
295
+ from .providers.openai_compatible_llm import OpenAICompatibleLLM
296
+
297
+ if isinstance(self._provider_impl, OpenAICompatibleLLM):
298
+ return self._provider_impl._client
299
+ return None
300
+
301
+ @property
302
+ def _gemini_client(self) -> Any:
303
+ """
304
+ Get the Gemini client for Gemini/VertexAI providers.
305
+
306
+ This property provides backward compatibility for code that directly accesses
307
+ the _gemini_client attribute.
308
+
309
+ Returns:
310
+ genai.Client instance for Gemini/VertexAI providers, or None for other providers.
311
+ """
312
+ from .providers.gemini_llm import GeminiLLM
313
+
314
+ if isinstance(self._provider_impl, GeminiLLM):
315
+ return self._provider_impl._client
316
+ return None
213
317
 
214
318
  async def verify_connection(self) -> None:
215
319
  """
@@ -218,21 +322,7 @@ class LLMProvider:
218
322
  Raises:
219
323
  RuntimeError: If the connection test fails.
220
324
  """
221
- try:
222
- logger.info(
223
- f"Verifying LLM: provider={self.provider}, model={self.model}, base_url={self.base_url or 'default'}..."
224
- )
225
- await self.call(
226
- messages=[{"role": "user", "content": "Say 'ok'"}],
227
- max_completion_tokens=100,
228
- max_retries=2,
229
- initial_backoff=0.5,
230
- max_backoff=2.0,
231
- )
232
- # If we get here without exception, the connection is working
233
- logger.info(f"LLM verified: {self.provider}/{self.model}")
234
- except Exception as e:
235
- raise RuntimeError(f"LLM connection verification failed for {self.provider}/{self.model}: {e}") from e
325
+ await self._provider_impl.verify_connection()
236
326
 
237
327
  async def call(
238
328
  self,
@@ -272,340 +362,32 @@ class LLMProvider:
272
362
  OutputTooLongError: If output exceeds token limits.
273
363
  Exception: Re-raises API errors after retries exhausted.
274
364
  """
275
- semaphore_start = time.time()
276
365
  async with _global_llm_semaphore:
277
- semaphore_wait_time = time.time() - semaphore_start
278
- start_time = time.time()
366
+ # Delegate to provider implementation
367
+ result = await self._provider_impl.call(
368
+ messages=messages,
369
+ response_format=response_format,
370
+ max_completion_tokens=max_completion_tokens,
371
+ temperature=temperature,
372
+ scope=scope,
373
+ max_retries=max_retries,
374
+ initial_backoff=initial_backoff,
375
+ max_backoff=max_backoff,
376
+ skip_validation=skip_validation,
377
+ strict_schema=strict_schema,
378
+ return_usage=return_usage,
379
+ )
279
380
 
280
- # Handle Mock provider (for testing)
381
+ # Backward compatibility: Update mock call tracking for mock provider
382
+ # This allows existing tests using LLMProvider._mock_calls to continue working
281
383
  if self.provider == "mock":
282
- return await self._call_mock(
283
- messages,
284
- response_format,
285
- scope,
286
- return_usage,
287
- )
288
-
289
- # Handle Gemini and Vertex AI providers (both use native genai SDK)
290
- if self.provider in ("gemini", "vertexai"):
291
- return await self._call_gemini(
292
- messages,
293
- response_format,
294
- max_retries,
295
- initial_backoff,
296
- max_backoff,
297
- skip_validation,
298
- start_time,
299
- scope,
300
- return_usage,
301
- semaphore_wait_time,
302
- )
303
-
304
- # Handle Anthropic provider separately
305
- if self.provider == "anthropic":
306
- return await self._call_anthropic(
307
- messages,
308
- response_format,
309
- max_completion_tokens,
310
- max_retries,
311
- initial_backoff,
312
- max_backoff,
313
- skip_validation,
314
- start_time,
315
- scope,
316
- return_usage,
317
- semaphore_wait_time,
318
- )
384
+ from .providers.mock_llm import MockLLM
319
385
 
320
- # Handle Ollama with native API for structured output (better schema enforcement)
321
- if self.provider == "ollama" and response_format is not None:
322
- return await self._call_ollama_native(
323
- messages,
324
- response_format,
325
- max_completion_tokens,
326
- temperature,
327
- max_retries,
328
- initial_backoff,
329
- max_backoff,
330
- skip_validation,
331
- start_time,
332
- scope,
333
- return_usage,
334
- semaphore_wait_time,
335
- )
386
+ if isinstance(self._provider_impl, MockLLM):
387
+ # Sync the mock calls from provider implementation to wrapper
388
+ self._mock_calls = self._provider_impl.get_mock_calls()
336
389
 
337
- call_params = {
338
- "model": self.model,
339
- "messages": messages,
340
- }
341
-
342
- # Check if model supports reasoning parameter (o1, o3, gpt-5 families)
343
- model_lower = self.model.lower()
344
- is_reasoning_model = any(x in model_lower for x in ["gpt-5", "o1", "o3", "deepseek"])
345
-
346
- # For GPT-4 and GPT-4.1 models, cap max_completion_tokens to 32000
347
- # For GPT-4o models, cap to 16384
348
- is_gpt4_model = any(x in model_lower for x in ["gpt-4.1", "gpt-4-"])
349
- is_gpt4o_model = "gpt-4o" in model_lower
350
- if max_completion_tokens is not None:
351
- if is_gpt4o_model and max_completion_tokens > 16384:
352
- max_completion_tokens = 16384
353
- elif is_gpt4_model and max_completion_tokens > 32000:
354
- max_completion_tokens = 32000
355
- # For reasoning models, max_completion_tokens includes reasoning + output tokens
356
- # Enforce minimum of 16000 to ensure enough space for both
357
- if is_reasoning_model and max_completion_tokens < 16000:
358
- max_completion_tokens = 16000
359
- call_params["max_completion_tokens"] = max_completion_tokens
360
-
361
- # GPT-5/o1/o3 family doesn't support custom temperature (only default 1)
362
- if temperature is not None and not is_reasoning_model:
363
- call_params["temperature"] = temperature
364
-
365
- # Set reasoning_effort for reasoning models (OpenAI gpt-5, o1, o3)
366
- if is_reasoning_model:
367
- call_params["reasoning_effort"] = self.reasoning_effort
368
-
369
- # Provider-specific parameters
370
- if self.provider == "groq":
371
- call_params["seed"] = DEFAULT_LLM_SEED
372
- extra_body: dict[str, Any] = {}
373
- # Add service_tier if configured (requires paid plan for flex/auto)
374
- if self.groq_service_tier:
375
- extra_body["service_tier"] = self.groq_service_tier
376
- # Add reasoning parameters for reasoning models
377
- if is_reasoning_model:
378
- extra_body["include_reasoning"] = False
379
- if extra_body:
380
- call_params["extra_body"] = extra_body
381
-
382
- last_exception = None
383
-
384
- # Prepare response format ONCE before the retry loop
385
- # (to avoid appending schema to messages on every retry)
386
- if response_format is not None:
387
- schema = None
388
- if hasattr(response_format, "model_json_schema"):
389
- schema = response_format.model_json_schema()
390
-
391
- if strict_schema and schema is not None:
392
- # Use OpenAI's strict JSON schema enforcement
393
- # This guarantees all required fields are returned
394
- call_params["response_format"] = {
395
- "type": "json_schema",
396
- "json_schema": {
397
- "name": "response",
398
- "strict": True,
399
- "schema": schema,
400
- },
401
- }
402
- else:
403
- # Soft enforcement: add schema to prompt and use json_object mode
404
- if schema is not None:
405
- schema_msg = f"\n\nYou must respond with valid JSON matching this schema:\n{json.dumps(schema, indent=2)}"
406
-
407
- if call_params["messages"] and call_params["messages"][0].get("role") == "system":
408
- first_msg = call_params["messages"][0]
409
- if isinstance(first_msg, dict) and isinstance(first_msg.get("content"), str):
410
- first_msg["content"] += schema_msg
411
- elif call_params["messages"]:
412
- first_msg = call_params["messages"][0]
413
- if isinstance(first_msg, dict) and isinstance(first_msg.get("content"), str):
414
- first_msg["content"] = schema_msg + "\n\n" + first_msg["content"]
415
- if self.provider not in ("lmstudio", "ollama"):
416
- # LM Studio and Ollama don't support json_object response format reliably
417
- # We rely on the schema in the system message instead
418
- call_params["response_format"] = {"type": "json_object"}
419
-
420
- for attempt in range(max_retries + 1):
421
- try:
422
- if response_format is not None:
423
- response = await self._client.chat.completions.create(**call_params)
424
-
425
- content = response.choices[0].message.content
426
-
427
- # Strip reasoning model thinking tags
428
- # Supports: <think>, <thinking>, <reasoning>, |startthink|/|endthink|
429
- # for reasoning models that embed thinking in their output (e.g., Qwen3, DeepSeek)
430
- if content:
431
- original_len = len(content)
432
- content = re.sub(r"<think>.*?</think>", "", content, flags=re.DOTALL)
433
- content = re.sub(r"<thinking>.*?</thinking>", "", content, flags=re.DOTALL)
434
- content = re.sub(r"<reasoning>.*?</reasoning>", "", content, flags=re.DOTALL)
435
- content = re.sub(r"\|startthink\|.*?\|endthink\|", "", content, flags=re.DOTALL)
436
- content = content.strip()
437
- if len(content) < original_len:
438
- logger.debug(f"Stripped {original_len - len(content)} chars of reasoning tokens")
439
-
440
- # For local models, they may wrap JSON in markdown code blocks
441
- if self.provider in ("lmstudio", "ollama"):
442
- clean_content = content
443
- if "```json" in content:
444
- clean_content = content.split("```json")[1].split("```")[0].strip()
445
- elif "```" in content:
446
- clean_content = content.split("```")[1].split("```")[0].strip()
447
- try:
448
- json_data = json.loads(clean_content)
449
- except json.JSONDecodeError:
450
- # Fallback to parsing raw content
451
- json_data = json.loads(content)
452
- else:
453
- # Log raw LLM response for debugging JSON parse issues
454
- try:
455
- json_data = json.loads(content)
456
- except json.JSONDecodeError as json_err:
457
- # Truncate content for logging (first 500 and last 200 chars)
458
- content_preview = content[:500] if content else "<empty>"
459
- if content and len(content) > 700:
460
- content_preview = f"{content[:500]}...TRUNCATED...{content[-200:]}"
461
- logger.warning(
462
- f"JSON parse error from LLM response (attempt {attempt + 1}/{max_retries + 1}): {json_err}\n"
463
- f" Model: {self.provider}/{self.model}\n"
464
- f" Content length: {len(content) if content else 0} chars\n"
465
- f" Content preview: {content_preview!r}\n"
466
- f" Finish reason: {response.choices[0].finish_reason if response.choices else 'unknown'}"
467
- )
468
- # Retry on JSON parse errors - LLM may return valid JSON on next attempt
469
- if attempt < max_retries:
470
- backoff = min(initial_backoff * (2**attempt), max_backoff)
471
- await asyncio.sleep(backoff)
472
- last_exception = json_err
473
- continue
474
- else:
475
- logger.error(f"JSON parse error after {max_retries + 1} attempts, giving up")
476
- raise
477
-
478
- if skip_validation:
479
- result = json_data
480
- else:
481
- result = response_format.model_validate(json_data)
482
- else:
483
- response = await self._client.chat.completions.create(**call_params)
484
- result = response.choices[0].message.content
485
-
486
- # Record token usage metrics
487
- duration = time.time() - start_time
488
- usage = response.usage
489
- input_tokens = usage.prompt_tokens or 0 if usage else 0
490
- output_tokens = usage.completion_tokens or 0 if usage else 0
491
- total_tokens = usage.total_tokens or 0 if usage else 0
492
-
493
- # Record LLM metrics
494
- metrics = get_metrics_collector()
495
- metrics.record_llm_call(
496
- provider=self.provider,
497
- model=self.model,
498
- scope=scope,
499
- duration=duration,
500
- input_tokens=input_tokens,
501
- output_tokens=output_tokens,
502
- success=True,
503
- )
504
-
505
- # Log slow calls
506
- if duration > 10.0 and usage:
507
- ratio = max(1, output_tokens) / max(1, input_tokens)
508
- cached_tokens = 0
509
- if hasattr(usage, "prompt_tokens_details") and usage.prompt_tokens_details:
510
- cached_tokens = getattr(usage.prompt_tokens_details, "cached_tokens", 0) or 0
511
- cache_info = f", cached_tokens={cached_tokens}" if cached_tokens > 0 else ""
512
- wait_info = f", wait={semaphore_wait_time:.3f}s" if semaphore_wait_time > 0.1 else ""
513
- logger.info(
514
- f"slow llm call: scope={scope}, model={self.provider}/{self.model}, "
515
- f"input_tokens={input_tokens}, output_tokens={output_tokens}, "
516
- f"total_tokens={total_tokens}{cache_info}, time={duration:.3f}s{wait_info}, ratio out/in={ratio:.2f}"
517
- )
518
-
519
- if return_usage:
520
- token_usage = TokenUsage(
521
- input_tokens=input_tokens,
522
- output_tokens=output_tokens,
523
- total_tokens=total_tokens,
524
- )
525
- return result, token_usage
526
- return result
527
-
528
- except LengthFinishReasonError as e:
529
- logger.warning(f"LLM output exceeded token limits: {str(e)}")
530
- raise OutputTooLongError(
531
- "LLM output exceeded token limits. Input may need to be split into smaller chunks."
532
- ) from e
533
-
534
- except APIConnectionError as e:
535
- last_exception = e
536
- status_code = getattr(e, "status_code", None) or getattr(
537
- getattr(e, "response", None), "status_code", None
538
- )
539
- logger.warning(f"APIConnectionError (HTTP {status_code}), attempt {attempt + 1}: {str(e)[:200]}")
540
- if attempt < max_retries:
541
- backoff = min(initial_backoff * (2**attempt), max_backoff)
542
- await asyncio.sleep(backoff)
543
- continue
544
- else:
545
- logger.error(f"Connection error after {max_retries + 1} attempts: {str(e)}")
546
- raise
547
-
548
- except APIStatusError as e:
549
- # Fast fail only on 401 (unauthorized) and 403 (forbidden) - these won't recover with retries
550
- if e.status_code in (401, 403):
551
- logger.error(f"Auth error (HTTP {e.status_code}), not retrying: {str(e)}")
552
- raise
553
-
554
- # Handle tool_use_failed error - model outputted in tool call format
555
- # Convert to expected JSON format and continue
556
- if e.status_code == 400 and response_format is not None:
557
- try:
558
- error_body = e.body if hasattr(e, "body") else {}
559
- if isinstance(error_body, dict):
560
- error_info: dict[str, Any] = error_body.get("error") or {}
561
- if error_info.get("code") == "tool_use_failed":
562
- failed_gen = error_info.get("failed_generation", "")
563
- if failed_gen:
564
- # Parse the tool call format and convert to actions format
565
- tool_call = json.loads(failed_gen)
566
- tool_name = tool_call.get("name", "")
567
- tool_args = tool_call.get("arguments", {})
568
- # Convert to actions format: {"actions": [{"tool": "name", ...args}]}
569
- converted = {"actions": [{"tool": tool_name, **tool_args}]}
570
- if skip_validation:
571
- result = converted
572
- else:
573
- result = response_format.model_validate(converted)
574
-
575
- # Record metrics for this successful recovery
576
- duration = time.time() - start_time
577
- metrics = get_metrics_collector()
578
- metrics.record_llm_call(
579
- provider=self.provider,
580
- model=self.model,
581
- scope=scope,
582
- duration=duration,
583
- input_tokens=0,
584
- output_tokens=0,
585
- success=True,
586
- )
587
- if return_usage:
588
- return result, TokenUsage(input_tokens=0, output_tokens=0, total_tokens=0)
589
- return result
590
- except (json.JSONDecodeError, KeyError, TypeError):
591
- pass # Failed to parse tool_use_failed, continue with normal retry
592
-
593
- last_exception = e
594
- if attempt < max_retries:
595
- backoff = min(initial_backoff * (2**attempt), max_backoff)
596
- jitter = backoff * 0.2 * (2 * (time.time() % 1) - 1)
597
- sleep_time = backoff + jitter
598
- await asyncio.sleep(sleep_time)
599
- else:
600
- logger.error(f"API error after {max_retries + 1} attempts: {str(e)}")
601
- raise
602
-
603
- except Exception:
604
- raise
605
-
606
- if last_exception:
607
- raise last_exception
608
- raise RuntimeError("LLM call failed after all retries with no exception captured")
390
+ return result
609
391
 
610
392
  async def call_with_tools(
611
393
  self,
@@ -636,940 +418,122 @@ class LLMProvider:
636
418
  Returns:
637
419
  LLMToolCallResult with content and/or tool_calls.
638
420
  """
639
- from .response_models import LLMToolCall, LLMToolCallResult
640
-
641
421
  async with _global_llm_semaphore:
642
- start_time = time.time()
422
+ # Delegate to provider implementation
423
+ result = await self._provider_impl.call_with_tools(
424
+ messages=messages,
425
+ tools=tools,
426
+ max_completion_tokens=max_completion_tokens,
427
+ temperature=temperature,
428
+ scope=scope,
429
+ max_retries=max_retries,
430
+ initial_backoff=initial_backoff,
431
+ max_backoff=max_backoff,
432
+ tool_choice=tool_choice,
433
+ )
643
434
 
644
- # Handle Mock provider
435
+ # Backward compatibility: Update mock call tracking for mock provider
436
+ # This allows existing tests using LLMProvider._mock_calls to continue working
645
437
  if self.provider == "mock":
646
- return await self._call_with_tools_mock(messages, tools, scope)
647
-
648
- # Handle Anthropic separately (uses different tool format)
649
- if self.provider == "anthropic":
650
- return await self._call_with_tools_anthropic(
651
- messages, tools, max_completion_tokens, max_retries, initial_backoff, max_backoff, start_time, scope
652
- )
653
-
654
- # Handle Gemini and Vertex AI (convert to Gemini tool format)
655
- if self.provider in ("gemini", "vertexai"):
656
- return await self._call_with_tools_gemini(
657
- messages, tools, max_retries, initial_backoff, max_backoff, start_time, scope
658
- )
438
+ from .providers.mock_llm import MockLLM
659
439
 
660
- # OpenAI-compatible providers (OpenAI, Groq, Ollama, LMStudio)
661
- call_params: dict[str, Any] = {
662
- "model": self.model,
663
- "messages": messages,
664
- "tools": tools,
665
- "tool_choice": tool_choice,
666
- }
440
+ if isinstance(self._provider_impl, MockLLM):
441
+ # Sync the mock calls from provider implementation to wrapper
442
+ self._mock_calls = self._provider_impl.get_mock_calls()
667
443
 
668
- if max_completion_tokens is not None:
669
- call_params["max_completion_tokens"] = max_completion_tokens
670
- if temperature is not None:
671
- call_params["temperature"] = temperature
444
+ return result
672
445
 
673
- # Provider-specific parameters
674
- if self.provider == "groq":
675
- call_params["seed"] = DEFAULT_LLM_SEED
676
-
677
- last_exception = None
678
-
679
- for attempt in range(max_retries + 1):
680
- try:
681
- response = await self._client.chat.completions.create(**call_params)
682
-
683
- message = response.choices[0].message
684
- finish_reason = response.choices[0].finish_reason
685
-
686
- # Extract tool calls if present
687
- tool_calls: list[LLMToolCall] = []
688
- if message.tool_calls:
689
- for tc in message.tool_calls:
690
- try:
691
- args = json.loads(tc.function.arguments) if tc.function.arguments else {}
692
- except json.JSONDecodeError:
693
- args = {"_raw": tc.function.arguments}
694
- tool_calls.append(LLMToolCall(id=tc.id, name=tc.function.name, arguments=args))
695
-
696
- content = message.content
697
-
698
- # Record metrics
699
- duration = time.time() - start_time
700
- usage = response.usage
701
- input_tokens = usage.prompt_tokens or 0 if usage else 0
702
- output_tokens = usage.completion_tokens or 0 if usage else 0
703
-
704
- metrics = get_metrics_collector()
705
- metrics.record_llm_call(
706
- provider=self.provider,
707
- model=self.model,
708
- scope=scope,
709
- duration=duration,
710
- input_tokens=input_tokens,
711
- output_tokens=output_tokens,
712
- success=True,
713
- )
446
+ def set_mock_response(self, response: Any) -> None:
447
+ """Set the response to return from mock calls."""
448
+ # Backward compatibility: Store in both wrapper and provider implementation
449
+ self._mock_response = response
450
+ if self.provider == "mock":
451
+ from .providers.mock_llm import MockLLM
714
452
 
715
- return LLMToolCallResult(
716
- content=content,
717
- tool_calls=tool_calls,
718
- finish_reason=finish_reason,
719
- input_tokens=input_tokens,
720
- output_tokens=output_tokens,
721
- )
453
+ if isinstance(self._provider_impl, MockLLM):
454
+ self._provider_impl.set_mock_response(response)
722
455
 
723
- except APIConnectionError as e:
724
- last_exception = e
725
- if attempt < max_retries:
726
- await asyncio.sleep(min(initial_backoff * (2**attempt), max_backoff))
727
- continue
728
- raise
729
-
730
- except APIStatusError as e:
731
- if e.status_code in (401, 403):
732
- raise
733
- last_exception = e
734
- if attempt < max_retries:
735
- await asyncio.sleep(min(initial_backoff * (2**attempt), max_backoff))
736
- continue
737
- raise
738
-
739
- except Exception:
740
- raise
741
-
742
- if last_exception:
743
- raise last_exception
744
- raise RuntimeError("Tool call failed after all retries")
745
-
746
- async def _call_with_tools_mock(
747
- self,
748
- messages: list[dict[str, Any]],
749
- tools: list[dict[str, Any]],
750
- scope: str,
751
- ) -> "LLMToolCallResult":
752
- """Handle mock tool calls for testing."""
753
- from .response_models import LLMToolCallResult
754
-
755
- call_record = {
756
- "provider": self.provider,
757
- "model": self.model,
758
- "messages": messages,
759
- "tools": [t.get("function", {}).get("name") for t in tools],
760
- "scope": scope,
761
- }
762
- self._mock_calls.append(call_record)
763
-
764
- if self._mock_response is not None:
765
- if isinstance(self._mock_response, LLMToolCallResult):
766
- return self._mock_response
767
- # Allow setting just tool calls as a list
768
- if isinstance(self._mock_response, list):
769
- from .response_models import LLMToolCall
770
-
771
- return LLMToolCallResult(
772
- tool_calls=[
773
- LLMToolCall(id=f"mock_{i}", name=tc["name"], arguments=tc.get("arguments", {}))
774
- for i, tc in enumerate(self._mock_response)
775
- ],
776
- finish_reason="tool_calls",
777
- )
456
+ def get_mock_calls(self) -> list[dict]:
457
+ """Get the list of recorded mock calls."""
458
+ # Backward compatibility: Read from provider implementation if mock provider
459
+ if self.provider == "mock":
460
+ from .providers.mock_llm import MockLLM
778
461
 
779
- return LLMToolCallResult(content="mock response", finish_reason="stop")
462
+ if isinstance(self._provider_impl, MockLLM):
463
+ return self._provider_impl.get_mock_calls()
464
+ return self._mock_calls
780
465
 
781
- async def _call_with_tools_anthropic(
782
- self,
783
- messages: list[dict[str, Any]],
784
- tools: list[dict[str, Any]],
785
- max_completion_tokens: int | None,
786
- max_retries: int,
787
- initial_backoff: float,
788
- max_backoff: float,
789
- start_time: float,
790
- scope: str,
791
- ) -> "LLMToolCallResult":
792
- """Handle Anthropic tool calling."""
793
- from anthropic import APIConnectionError, APIStatusError
794
-
795
- from .response_models import LLMToolCall, LLMToolCallResult
796
-
797
- # Convert OpenAI tool format to Anthropic format
798
- anthropic_tools = []
799
- for tool in tools:
800
- func = tool.get("function", {})
801
- anthropic_tools.append(
802
- {
803
- "name": func.get("name", ""),
804
- "description": func.get("description", ""),
805
- "input_schema": func.get("parameters", {"type": "object", "properties": {}}),
806
- }
807
- )
466
+ def clear_mock_calls(self) -> None:
467
+ """Clear the recorded mock calls."""
468
+ # Backward compatibility: Clear in both wrapper and provider implementation
469
+ self._mock_calls = []
470
+ if self.provider == "mock":
471
+ from .providers.mock_llm import MockLLM
808
472
 
809
- # Convert messages - handle tool results
810
- system_prompt = None
811
- anthropic_messages = []
812
- for msg in messages:
813
- role = msg.get("role", "user")
814
- content = msg.get("content", "")
815
-
816
- if role == "system":
817
- system_prompt = (system_prompt + "\n\n" + content) if system_prompt else content
818
- elif role == "tool":
819
- # Anthropic uses tool_result blocks
820
- anthropic_messages.append(
821
- {
822
- "role": "user",
823
- "content": [
824
- {"type": "tool_result", "tool_use_id": msg.get("tool_call_id", ""), "content": content}
825
- ],
826
- }
827
- )
828
- elif role == "assistant" and msg.get("tool_calls"):
829
- # Convert assistant tool calls
830
- tool_use_blocks = []
831
- for tc in msg["tool_calls"]:
832
- tool_use_blocks.append(
833
- {
834
- "type": "tool_use",
835
- "id": tc.get("id", ""),
836
- "name": tc.get("function", {}).get("name", ""),
837
- "input": json.loads(tc.get("function", {}).get("arguments", "{}")),
838
- }
839
- )
840
- anthropic_messages.append({"role": "assistant", "content": tool_use_blocks})
841
- else:
842
- anthropic_messages.append({"role": role, "content": content})
843
-
844
- call_params: dict[str, Any] = {
845
- "model": self.model,
846
- "messages": anthropic_messages,
847
- "tools": anthropic_tools,
848
- "max_tokens": max_completion_tokens or 4096,
849
- }
850
- if system_prompt:
851
- call_params["system"] = system_prompt
852
-
853
- last_exception = None
854
- for attempt in range(max_retries + 1):
855
- try:
856
- response = await self._anthropic_client.messages.create(**call_params)
857
-
858
- # Extract content and tool calls
859
- content_parts = []
860
- tool_calls: list[LLMToolCall] = []
861
-
862
- for block in response.content:
863
- if block.type == "text":
864
- content_parts.append(block.text)
865
- elif block.type == "tool_use":
866
- tool_calls.append(LLMToolCall(id=block.id, name=block.name, arguments=block.input or {}))
867
-
868
- content = "".join(content_parts) if content_parts else None
869
- finish_reason = "tool_calls" if tool_calls else "stop"
870
-
871
- # Extract token usage
872
- input_tokens = response.usage.input_tokens or 0
873
- output_tokens = response.usage.output_tokens or 0
874
-
875
- # Record metrics
876
- metrics = get_metrics_collector()
877
- metrics.record_llm_call(
878
- provider=self.provider,
879
- model=self.model,
880
- scope=scope,
881
- duration=time.time() - start_time,
882
- input_tokens=input_tokens,
883
- output_tokens=output_tokens,
884
- success=True,
885
- )
473
+ if isinstance(self._provider_impl, MockLLM):
474
+ self._provider_impl.clear_mock_calls()
886
475
 
887
- return LLMToolCallResult(
888
- content=content,
889
- tool_calls=tool_calls,
890
- finish_reason=finish_reason,
891
- input_tokens=input_tokens,
892
- output_tokens=output_tokens,
893
- )
476
+ def _load_codex_auth(self) -> tuple[str, str]:
477
+ """
478
+ Load OAuth credentials from ~/.codex/auth.json.
894
479
 
895
- except (APIConnectionError, APIStatusError) as e:
896
- if isinstance(e, APIStatusError) and e.status_code in (401, 403):
897
- raise
898
- last_exception = e
899
- if attempt < max_retries:
900
- await asyncio.sleep(min(initial_backoff * (2**attempt), max_backoff))
901
- continue
902
- raise
480
+ Returns:
481
+ Tuple of (access_token, account_id).
903
482
 
904
- if last_exception:
905
- raise last_exception
906
- raise RuntimeError("Anthropic tool call failed")
483
+ Raises:
484
+ FileNotFoundError: If auth file doesn't exist.
485
+ ValueError: If auth file is invalid.
486
+ """
487
+ auth_file = Path.home() / ".codex" / "auth.json"
907
488
 
908
- async def _call_with_tools_gemini(
909
- self,
910
- messages: list[dict[str, Any]],
911
- tools: list[dict[str, Any]],
912
- max_retries: int,
913
- initial_backoff: float,
914
- max_backoff: float,
915
- start_time: float,
916
- scope: str,
917
- ) -> "LLMToolCallResult":
918
- """Handle Gemini tool calling."""
919
- from .response_models import LLMToolCall, LLMToolCallResult
920
-
921
- # Convert tools to Gemini format
922
- gemini_tools = []
923
- for tool in tools:
924
- func = tool.get("function", {})
925
- gemini_tools.append(
926
- genai_types.Tool(
927
- function_declarations=[
928
- genai_types.FunctionDeclaration(
929
- name=func.get("name", ""),
930
- description=func.get("description", ""),
931
- parameters=func.get("parameters"),
932
- )
933
- ]
934
- )
489
+ if not auth_file.exists():
490
+ raise FileNotFoundError(
491
+ f"Codex auth file not found: {auth_file}\nRun 'codex auth login' to authenticate with ChatGPT Plus/Pro."
935
492
  )
936
493
 
937
- # Convert messages
938
- system_instruction = None
939
- gemini_contents = []
940
- for msg in messages:
941
- role = msg.get("role", "user")
942
- content = msg.get("content", "")
943
-
944
- if role == "system":
945
- system_instruction = (system_instruction + "\n\n" + content) if system_instruction else content
946
- elif role == "tool":
947
- # Gemini uses function_response
948
- gemini_contents.append(
949
- genai_types.Content(
950
- role="user",
951
- parts=[
952
- genai_types.Part(
953
- function_response=genai_types.FunctionResponse(
954
- name=msg.get("name", ""),
955
- response={"result": content},
956
- )
957
- )
958
- ],
959
- )
960
- )
961
- elif role == "assistant":
962
- gemini_contents.append(genai_types.Content(role="model", parts=[genai_types.Part(text=content)]))
963
- else:
964
- gemini_contents.append(genai_types.Content(role="user", parts=[genai_types.Part(text=content)]))
965
-
966
- config = genai_types.GenerateContentConfig(
967
- system_instruction=system_instruction,
968
- tools=gemini_tools,
969
- )
970
-
971
- last_exception = None
972
- for attempt in range(max_retries + 1):
973
- try:
974
- response = await self._gemini_client.aio.models.generate_content(
975
- model=self.model,
976
- contents=gemini_contents,
977
- config=config,
978
- )
979
-
980
- # Extract content and tool calls
981
- content = None
982
- tool_calls: list[LLMToolCall] = []
983
-
984
- if response.candidates and response.candidates[0].content:
985
- parts = response.candidates[0].content.parts
986
- if parts:
987
- for part in parts:
988
- if hasattr(part, "text") and part.text:
989
- content = part.text
990
- if hasattr(part, "function_call") and part.function_call:
991
- fc = part.function_call
992
- tool_calls.append(
993
- LLMToolCall(
994
- id=f"gemini_{len(tool_calls)}",
995
- name=fc.name,
996
- arguments=dict(fc.args) if fc.args else {},
997
- )
998
- )
999
-
1000
- finish_reason = "tool_calls" if tool_calls else "stop"
1001
-
1002
- # Record metrics
1003
- metrics = get_metrics_collector()
1004
- input_tokens = response.usage_metadata.prompt_token_count if response.usage_metadata else 0
1005
- output_tokens = response.usage_metadata.candidates_token_count if response.usage_metadata else 0
1006
- metrics.record_llm_call(
1007
- provider=self.provider,
1008
- model=self.model,
1009
- scope=scope,
1010
- duration=time.time() - start_time,
1011
- input_tokens=input_tokens,
1012
- output_tokens=output_tokens,
1013
- success=True,
1014
- )
1015
-
1016
- return LLMToolCallResult(
1017
- content=content,
1018
- tool_calls=tool_calls,
1019
- finish_reason=finish_reason,
1020
- input_tokens=input_tokens,
1021
- output_tokens=output_tokens,
1022
- )
1023
-
1024
- except genai_errors.APIError as e:
1025
- if e.code in (401, 403):
1026
- raise
1027
- last_exception = e
1028
- if attempt < max_retries:
1029
- await asyncio.sleep(min(initial_backoff * (2**attempt), max_backoff))
1030
- continue
1031
- raise
1032
-
1033
- if last_exception:
1034
- raise last_exception
1035
- raise RuntimeError("Gemini tool call failed")
1036
-
1037
- async def _call_anthropic(
1038
- self,
1039
- messages: list[dict[str, str]],
1040
- response_format: Any | None,
1041
- max_completion_tokens: int | None,
1042
- max_retries: int,
1043
- initial_backoff: float,
1044
- max_backoff: float,
1045
- skip_validation: bool,
1046
- start_time: float,
1047
- scope: str = "memory",
1048
- return_usage: bool = False,
1049
- semaphore_wait_time: float = 0.0,
1050
- ) -> Any:
1051
- """Handle Anthropic-specific API calls."""
1052
- from anthropic import APIConnectionError, APIStatusError, RateLimitError
1053
-
1054
- # Convert OpenAI-style messages to Anthropic format
1055
- system_prompt = None
1056
- anthropic_messages = []
1057
-
1058
- for msg in messages:
1059
- role = msg.get("role", "user")
1060
- content = msg.get("content", "")
1061
-
1062
- if role == "system":
1063
- if system_prompt:
1064
- system_prompt += "\n\n" + content
1065
- else:
1066
- system_prompt = content
1067
- else:
1068
- anthropic_messages.append({"role": role, "content": content})
1069
-
1070
- # Add JSON schema instruction if response_format is provided
1071
- if response_format is not None and hasattr(response_format, "model_json_schema"):
1072
- schema = response_format.model_json_schema()
1073
- schema_msg = f"\n\nYou must respond with valid JSON matching this schema:\n{json.dumps(schema, indent=2)}"
1074
- if system_prompt:
1075
- system_prompt += schema_msg
1076
- else:
1077
- system_prompt = schema_msg
1078
-
1079
- # Prepare parameters
1080
- call_params = {
1081
- "model": self.model,
1082
- "messages": anthropic_messages,
1083
- "max_tokens": max_completion_tokens if max_completion_tokens is not None else 4096,
1084
- }
1085
-
1086
- if system_prompt:
1087
- call_params["system"] = system_prompt
1088
-
1089
- last_exception = None
1090
-
1091
- for attempt in range(max_retries + 1):
1092
- try:
1093
- response = await self._anthropic_client.messages.create(**call_params)
1094
-
1095
- # Anthropic response content is a list of blocks
1096
- content = ""
1097
- for block in response.content:
1098
- if block.type == "text":
1099
- content += block.text
1100
-
1101
- if response_format is not None:
1102
- # Models may wrap JSON in markdown code blocks
1103
- clean_content = content
1104
- if "```json" in content:
1105
- clean_content = content.split("```json")[1].split("```")[0].strip()
1106
- elif "```" in content:
1107
- clean_content = content.split("```")[1].split("```")[0].strip()
1108
-
1109
- try:
1110
- json_data = json.loads(clean_content)
1111
- except json.JSONDecodeError:
1112
- # Fallback to parsing raw content if markdown stripping failed
1113
- json_data = json.loads(content)
1114
-
1115
- if skip_validation:
1116
- result = json_data
1117
- else:
1118
- result = response_format.model_validate(json_data)
1119
- else:
1120
- result = content
1121
-
1122
- # Record metrics and log slow calls
1123
- duration = time.time() - start_time
1124
- input_tokens = response.usage.input_tokens or 0 if response.usage else 0
1125
- output_tokens = response.usage.output_tokens or 0 if response.usage else 0
1126
- total_tokens = input_tokens + output_tokens
1127
-
1128
- # Record LLM metrics
1129
- metrics = get_metrics_collector()
1130
- metrics.record_llm_call(
1131
- provider=self.provider,
1132
- model=self.model,
1133
- scope=scope,
1134
- duration=duration,
1135
- input_tokens=input_tokens,
1136
- output_tokens=output_tokens,
1137
- success=True,
1138
- )
1139
-
1140
- # Log slow calls
1141
- if duration > 10.0:
1142
- wait_info = f", wait={semaphore_wait_time:.3f}s" if semaphore_wait_time > 0.1 else ""
1143
- logger.info(
1144
- f"slow llm call: scope={scope}, model={self.provider}/{self.model}, "
1145
- f"input_tokens={input_tokens}, output_tokens={output_tokens}, "
1146
- f"time={duration:.3f}s{wait_info}"
1147
- )
1148
-
1149
- if return_usage:
1150
- token_usage = TokenUsage(
1151
- input_tokens=input_tokens,
1152
- output_tokens=output_tokens,
1153
- total_tokens=total_tokens,
1154
- )
1155
- return result, token_usage
1156
- return result
1157
-
1158
- except json.JSONDecodeError as e:
1159
- last_exception = e
1160
- if attempt < max_retries:
1161
- logger.warning("Anthropic returned invalid JSON, retrying...")
1162
- backoff = min(initial_backoff * (2**attempt), max_backoff)
1163
- await asyncio.sleep(backoff)
1164
- continue
1165
- else:
1166
- logger.error(f"Anthropic returned invalid JSON after {max_retries + 1} attempts")
1167
- raise
1168
-
1169
- except (APIConnectionError, RateLimitError, APIStatusError) as e:
1170
- # Fast fail on 401/403
1171
- if isinstance(e, APIStatusError) and e.status_code in (401, 403):
1172
- logger.error(f"Anthropic auth error (HTTP {e.status_code}), not retrying: {str(e)}")
1173
- raise
1174
-
1175
- last_exception = e
1176
- if attempt < max_retries:
1177
- # Check if it's a rate limit or server error
1178
- should_retry = isinstance(e, (APIConnectionError, RateLimitError)) or (
1179
- isinstance(e, APIStatusError) and e.status_code >= 500
1180
- )
1181
-
1182
- if should_retry:
1183
- backoff = min(initial_backoff * (2**attempt), max_backoff)
1184
- jitter = backoff * 0.2 * (2 * (time.time() % 1) - 1)
1185
- await asyncio.sleep(backoff + jitter)
1186
- continue
494
+ with open(auth_file) as f:
495
+ data = json.load(f)
1187
496
 
1188
- logger.error(f"Anthropic API error after {max_retries + 1} attempts: {str(e)}")
1189
- raise
497
+ # Validate auth structure
498
+ auth_mode = data.get("auth_mode")
499
+ if auth_mode != "chatgpt":
500
+ raise ValueError(f"Expected auth_mode='chatgpt', got: {auth_mode}")
1190
501
 
1191
- except Exception as e:
1192
- logger.error(f"Unexpected error during Anthropic call: {type(e).__name__}: {str(e)}")
1193
- raise
502
+ tokens = data.get("tokens", {})
503
+ access_token = tokens.get("access_token")
504
+ account_id = tokens.get("account_id")
1194
505
 
1195
- if last_exception:
1196
- raise last_exception
1197
- raise RuntimeError("Anthropic call failed after all retries")
506
+ if not access_token:
507
+ raise ValueError("No access_token found in Codex auth file. Run 'codex auth login' again.")
1198
508
 
1199
- async def _call_ollama_native(
1200
- self,
1201
- messages: list[dict[str, str]],
1202
- response_format: Any,
1203
- max_completion_tokens: int | None,
1204
- temperature: float | None,
1205
- max_retries: int,
1206
- initial_backoff: float,
1207
- max_backoff: float,
1208
- skip_validation: bool,
1209
- start_time: float,
1210
- scope: str = "memory",
1211
- return_usage: bool = False,
1212
- semaphore_wait_time: float = 0.0,
1213
- ) -> Any:
1214
- """
1215
- Call Ollama using native API with JSON schema enforcement.
509
+ return access_token, account_id
1216
510
 
1217
- Ollama's native API supports passing a full JSON schema in the 'format' parameter,
1218
- which provides better structured output control than the OpenAI-compatible API.
511
+ def _verify_claude_code_available(self) -> None:
1219
512
  """
1220
- # Get the JSON schema from the Pydantic model
1221
- schema = response_format.model_json_schema() if hasattr(response_format, "model_json_schema") else None
1222
-
1223
- # Build the base URL for Ollama's native API
1224
- # Default OpenAI-compatible URL is http://localhost:11434/v1
1225
- # Native API is at http://localhost:11434/api/chat
1226
- base_url = self.base_url or "http://localhost:11434/v1"
1227
- if base_url.endswith("/v1"):
1228
- native_url = base_url[:-3] + "/api/chat"
1229
- else:
1230
- native_url = base_url.rstrip("/") + "/api/chat"
1231
-
1232
- # Build request payload
1233
- payload = {
1234
- "model": self.model,
1235
- "messages": messages,
1236
- "stream": False,
1237
- }
1238
-
1239
- # Add schema as format parameter for structured output
1240
- if schema:
1241
- payload["format"] = schema
1242
-
1243
- # Add optional parameters with optimized defaults for Ollama
1244
- # Benchmarking shows num_ctx=16384 + num_batch=512 is optimal
1245
- options = {
1246
- "num_ctx": 16384, # 16k context window for larger prompts
1247
- "num_batch": 512, # Optimal batch size for prompt processing
1248
- }
1249
- if max_completion_tokens:
1250
- options["num_predict"] = max_completion_tokens
1251
- if temperature is not None:
1252
- options["temperature"] = temperature
1253
- payload["options"] = options
1254
-
1255
- last_exception = None
1256
-
1257
- async with httpx.AsyncClient(timeout=300.0) as client:
1258
- for attempt in range(max_retries + 1):
1259
- try:
1260
- response = await client.post(native_url, json=payload)
1261
- response.raise_for_status()
1262
-
1263
- result = response.json()
1264
- content = result.get("message", {}).get("content", "")
1265
-
1266
- # Parse JSON response
1267
- try:
1268
- json_data = json.loads(content)
1269
- except json.JSONDecodeError as json_err:
1270
- content_preview = content[:500] if content else "<empty>"
1271
- if content and len(content) > 700:
1272
- content_preview = f"{content[:500]}...TRUNCATED...{content[-200:]}"
1273
- logger.warning(
1274
- f"Ollama JSON parse error (attempt {attempt + 1}/{max_retries + 1}): {json_err}\n"
1275
- f" Model: ollama/{self.model}\n"
1276
- f" Content length: {len(content) if content else 0} chars\n"
1277
- f" Content preview: {content_preview!r}"
1278
- )
1279
- if attempt < max_retries:
1280
- backoff = min(initial_backoff * (2**attempt), max_backoff)
1281
- await asyncio.sleep(backoff)
1282
- last_exception = json_err
1283
- continue
1284
- else:
1285
- raise
1286
-
1287
- # Extract token usage from Ollama response
1288
- # Ollama returns prompt_eval_count (input) and eval_count (output)
1289
- duration = time.time() - start_time
1290
- input_tokens = result.get("prompt_eval_count", 0) or 0
1291
- output_tokens = result.get("eval_count", 0) or 0
1292
- total_tokens = input_tokens + output_tokens
1293
-
1294
- # Record LLM metrics
1295
- metrics = get_metrics_collector()
1296
- metrics.record_llm_call(
1297
- provider=self.provider,
1298
- model=self.model,
1299
- scope=scope,
1300
- duration=duration,
1301
- input_tokens=input_tokens,
1302
- output_tokens=output_tokens,
1303
- success=True,
1304
- )
1305
-
1306
- # Validate against Pydantic model or return raw JSON
1307
- if skip_validation:
1308
- validated_result = json_data
1309
- else:
1310
- validated_result = response_format.model_validate(json_data)
1311
-
1312
- if return_usage:
1313
- token_usage = TokenUsage(
1314
- input_tokens=input_tokens,
1315
- output_tokens=output_tokens,
1316
- total_tokens=total_tokens,
1317
- )
1318
- return validated_result, token_usage
1319
- return validated_result
1320
-
1321
- except httpx.HTTPStatusError as e:
1322
- last_exception = e
1323
- if attempt < max_retries:
1324
- logger.warning(
1325
- f"Ollama HTTP error (attempt {attempt + 1}/{max_retries + 1}): {e.response.status_code}"
1326
- )
1327
- backoff = min(initial_backoff * (2**attempt), max_backoff)
1328
- await asyncio.sleep(backoff)
1329
- continue
1330
- else:
1331
- logger.error(f"Ollama HTTP error after {max_retries + 1} attempts: {e}")
1332
- raise
1333
-
1334
- except httpx.RequestError as e:
1335
- last_exception = e
1336
- if attempt < max_retries:
1337
- logger.warning(f"Ollama connection error (attempt {attempt + 1}/{max_retries + 1}): {e}")
1338
- backoff = min(initial_backoff * (2**attempt), max_backoff)
1339
- await asyncio.sleep(backoff)
1340
- continue
1341
- else:
1342
- logger.error(f"Ollama connection error after {max_retries + 1} attempts: {e}")
1343
- raise
1344
-
1345
- except Exception as e:
1346
- logger.error(f"Unexpected error during Ollama call: {type(e).__name__}: {e}")
1347
- raise
1348
-
1349
- if last_exception:
1350
- raise last_exception
1351
- raise RuntimeError("Ollama call failed after all retries")
1352
-
1353
- async def _call_gemini(
1354
- self,
1355
- messages: list[dict[str, str]],
1356
- response_format: Any | None,
1357
- max_retries: int,
1358
- initial_backoff: float,
1359
- max_backoff: float,
1360
- skip_validation: bool,
1361
- start_time: float,
1362
- scope: str = "memory",
1363
- return_usage: bool = False,
1364
- semaphore_wait_time: float = 0.0,
1365
- ) -> Any:
1366
- """Handle Gemini-specific API calls."""
1367
- # Convert OpenAI-style messages to Gemini format
1368
- system_instruction = None
1369
- gemini_contents = []
1370
-
1371
- for msg in messages:
1372
- role = msg.get("role", "user")
1373
- content = msg.get("content", "")
1374
-
1375
- if role == "system":
1376
- if system_instruction:
1377
- system_instruction += "\n\n" + content
1378
- else:
1379
- system_instruction = content
1380
- elif role == "assistant":
1381
- gemini_contents.append(genai_types.Content(role="model", parts=[genai_types.Part(text=content)]))
1382
- else:
1383
- gemini_contents.append(genai_types.Content(role="user", parts=[genai_types.Part(text=content)]))
1384
-
1385
- # Add JSON schema instruction if response_format is provided
1386
- if response_format is not None and hasattr(response_format, "model_json_schema"):
1387
- schema = response_format.model_json_schema()
1388
- schema_msg = f"\n\nYou must respond with valid JSON matching this schema:\n{json.dumps(schema, indent=2)}"
1389
- if system_instruction:
1390
- system_instruction += schema_msg
1391
- else:
1392
- system_instruction = schema_msg
1393
-
1394
- # Build generation config
1395
- config_kwargs = {}
1396
- if system_instruction:
1397
- config_kwargs["system_instruction"] = system_instruction
1398
- if response_format is not None:
1399
- config_kwargs["response_mime_type"] = "application/json"
1400
- config_kwargs["response_schema"] = response_format
1401
-
1402
- generation_config = genai_types.GenerateContentConfig(**config_kwargs) if config_kwargs else None
1403
-
1404
- last_exception = None
1405
-
1406
- for attempt in range(max_retries + 1):
1407
- try:
1408
- response = await self._gemini_client.aio.models.generate_content(
1409
- model=self.model,
1410
- contents=gemini_contents,
1411
- config=generation_config,
1412
- )
1413
-
1414
- content = response.text
1415
-
1416
- # Handle empty response
1417
- if content is None:
1418
- block_reason = None
1419
- if hasattr(response, "candidates") and response.candidates:
1420
- candidate = response.candidates[0]
1421
- if hasattr(candidate, "finish_reason"):
1422
- block_reason = candidate.finish_reason
1423
-
1424
- if attempt < max_retries:
1425
- logger.warning(f"Gemini returned empty response (reason: {block_reason}), retrying...")
1426
- backoff = min(initial_backoff * (2**attempt), max_backoff)
1427
- await asyncio.sleep(backoff)
1428
- continue
1429
- else:
1430
- raise RuntimeError(f"Gemini returned empty response after {max_retries + 1} attempts")
1431
-
1432
- if response_format is not None:
1433
- json_data = json.loads(content)
1434
- if skip_validation:
1435
- result = json_data
1436
- else:
1437
- result = response_format.model_validate(json_data)
1438
- else:
1439
- result = content
1440
-
1441
- # Record metrics and log slow calls
1442
- duration = time.time() - start_time
1443
- input_tokens = 0
1444
- output_tokens = 0
1445
- if hasattr(response, "usage_metadata") and response.usage_metadata:
1446
- usage = response.usage_metadata
1447
- input_tokens = usage.prompt_token_count or 0
1448
- output_tokens = usage.candidates_token_count or 0
1449
-
1450
- # Record LLM metrics
1451
- metrics = get_metrics_collector()
1452
- metrics.record_llm_call(
1453
- provider=self.provider,
1454
- model=self.model,
1455
- scope=scope,
1456
- duration=duration,
1457
- input_tokens=input_tokens,
1458
- output_tokens=output_tokens,
1459
- success=True,
1460
- )
1461
-
1462
- # Log slow calls
1463
- if duration > 10.0 and input_tokens > 0:
1464
- wait_info = f", wait={semaphore_wait_time:.3f}s" if semaphore_wait_time > 0.1 else ""
1465
- logger.info(
1466
- f"slow llm call: scope={scope}, model={self.provider}/{self.model}, "
1467
- f"input_tokens={input_tokens}, output_tokens={output_tokens}, "
1468
- f"time={duration:.3f}s{wait_info}"
1469
- )
513
+ Verify that Claude Agent SDK can be imported and is properly configured.
1470
514
 
1471
- if return_usage:
1472
- token_usage = TokenUsage(
1473
- input_tokens=input_tokens,
1474
- output_tokens=output_tokens,
1475
- total_tokens=input_tokens + output_tokens,
1476
- )
1477
- return result, token_usage
1478
- return result
1479
-
1480
- except json.JSONDecodeError as e:
1481
- last_exception = e
1482
- if attempt < max_retries:
1483
- logger.warning("Gemini returned invalid JSON, retrying...")
1484
- backoff = min(initial_backoff * (2**attempt), max_backoff)
1485
- await asyncio.sleep(backoff)
1486
- continue
1487
- else:
1488
- logger.error(f"Gemini returned invalid JSON after {max_retries + 1} attempts")
1489
- raise
1490
-
1491
- except genai_errors.APIError as e:
1492
- # Fast fail only on 401 (unauthorized) and 403 (forbidden) - these won't recover with retries
1493
- if e.code in (401, 403):
1494
- logger.error(f"Gemini auth error (HTTP {e.code}), not retrying: {str(e)}")
1495
- raise
1496
-
1497
- # Retry on retryable errors (rate limits, server errors, and other client errors like 400)
1498
- if e.code in (400, 429, 500, 502, 503, 504) or (e.code and e.code >= 500):
1499
- last_exception = e
1500
- if attempt < max_retries:
1501
- backoff = min(initial_backoff * (2**attempt), max_backoff)
1502
- jitter = backoff * 0.2 * (2 * (time.time() % 1) - 1)
1503
- await asyncio.sleep(backoff + jitter)
1504
- else:
1505
- logger.error(f"Gemini API error after {max_retries + 1} attempts: {str(e)}")
1506
- raise
1507
- else:
1508
- logger.error(f"Gemini API error: {type(e).__name__}: {str(e)}")
1509
- raise
1510
-
1511
- except Exception as e:
1512
- logger.error(f"Unexpected error during Gemini call: {type(e).__name__}: {str(e)}")
1513
- raise
1514
-
1515
- if last_exception:
1516
- raise last_exception
1517
- raise RuntimeError("Gemini call failed after all retries")
1518
-
1519
- async def _call_mock(
1520
- self,
1521
- messages: list[dict[str, str]],
1522
- response_format: Any | None,
1523
- scope: str,
1524
- return_usage: bool,
1525
- ) -> Any:
515
+ Raises:
516
+ ImportError: If Claude Agent SDK is not installed.
517
+ RuntimeError: If Claude Code is not authenticated.
1526
518
  """
1527
- Handle mock provider calls for testing.
519
+ try:
520
+ # Import Claude Agent SDK
521
+ # Reduce Claude Agent SDK logging verbosity
522
+ import logging as sdk_logging
1528
523
 
1529
- Records the call and returns a configurable mock response.
1530
- """
1531
- # Record the call for test verification
1532
- call_record = {
1533
- "provider": self.provider,
1534
- "model": self.model,
1535
- "messages": messages,
1536
- "response_format": response_format.__name__
1537
- if response_format and hasattr(response_format, "__name__")
1538
- else str(response_format),
1539
- "scope": scope,
1540
- }
1541
- self._mock_calls.append(call_record)
1542
- logger.debug(f"Mock LLM call recorded: scope={scope}, model={self.model}")
1543
-
1544
- # Return mock response
1545
- if self._mock_response is not None:
1546
- result = self._mock_response
1547
- elif response_format is not None:
1548
- # Try to create a minimal valid instance of the response format
1549
- try:
1550
- # For Pydantic models, try to create with minimal valid data
1551
- result = {"mock": True}
1552
- except Exception:
1553
- result = {"mock": True}
1554
- else:
1555
- result = "mock response"
1556
-
1557
- if return_usage:
1558
- token_usage = TokenUsage(input_tokens=10, output_tokens=5, total_tokens=15)
1559
- return result, token_usage
1560
- return result
524
+ from claude_agent_sdk import query # noqa: F401
1561
525
 
1562
- def set_mock_response(self, response: Any) -> None:
1563
- """Set the response to return from mock calls."""
1564
- self._mock_response = response
526
+ sdk_logging.getLogger("claude_agent_sdk").setLevel(sdk_logging.WARNING)
527
+ sdk_logging.getLogger("claude_agent_sdk._internal").setLevel(sdk_logging.WARNING)
1565
528
 
1566
- def get_mock_calls(self) -> list[dict]:
1567
- """Get the list of recorded mock calls."""
1568
- return self._mock_calls
529
+ logger.debug("Claude Agent SDK imported successfully")
530
+ except ImportError as e:
531
+ raise ImportError(
532
+ "Claude Agent SDK not installed. Run: uv add claude-agent-sdk or pip install claude-agent-sdk"
533
+ ) from e
1569
534
 
1570
- def clear_mock_calls(self) -> None:
1571
- """Clear the recorded mock calls."""
1572
- self._mock_calls = []
535
+ # SDK will automatically check for authentication when first used
536
+ # No need to verify here - let it fail gracefully on first call with helpful error
1573
537
 
1574
538
  async def cleanup(self) -> None:
1575
539
  """Clean up resources."""
@@ -1579,9 +543,14 @@ class LLMProvider:
1579
543
  def for_memory(cls) -> "LLMProvider":
1580
544
  """Create provider for memory operations from environment variables."""
1581
545
  provider = os.getenv("HINDSIGHT_API_LLM_PROVIDER", "groq")
1582
- api_key = os.getenv("HINDSIGHT_API_LLM_API_KEY")
1583
- if not api_key:
1584
- raise ValueError("HINDSIGHT_API_LLM_API_KEY environment variable is required")
546
+ api_key = os.getenv("HINDSIGHT_API_LLM_API_KEY", "")
547
+
548
+ # API key not needed for openai-codex (uses OAuth) or claude-code (uses Keychain OAuth)
549
+ if not api_key and provider not in ("openai-codex", "claude-code"):
550
+ raise ValueError(
551
+ "HINDSIGHT_API_LLM_API_KEY environment variable is required (unless using openai-codex or claude-code)"
552
+ )
553
+
1585
554
  base_url = os.getenv("HINDSIGHT_API_LLM_BASE_URL", "")
1586
555
  model = os.getenv("HINDSIGHT_API_LLM_MODEL", "openai/gpt-oss-120b")
1587
556
 
@@ -1591,11 +560,15 @@ class LLMProvider:
1591
560
  def for_answer_generation(cls) -> "LLMProvider":
1592
561
  """Create provider for answer generation. Falls back to memory config if not set."""
1593
562
  provider = os.getenv("HINDSIGHT_API_ANSWER_LLM_PROVIDER", os.getenv("HINDSIGHT_API_LLM_PROVIDER", "groq"))
1594
- api_key = os.getenv("HINDSIGHT_API_ANSWER_LLM_API_KEY", os.getenv("HINDSIGHT_API_LLM_API_KEY"))
1595
- if not api_key:
563
+ api_key = os.getenv("HINDSIGHT_API_ANSWER_LLM_API_KEY", os.getenv("HINDSIGHT_API_LLM_API_KEY", ""))
564
+
565
+ # API key not needed for openai-codex (uses OAuth) or claude-code (uses Keychain OAuth)
566
+ if not api_key and provider not in ("openai-codex", "claude-code"):
1596
567
  raise ValueError(
1597
- "HINDSIGHT_API_LLM_API_KEY or HINDSIGHT_API_ANSWER_LLM_API_KEY environment variable is required"
568
+ "HINDSIGHT_API_LLM_API_KEY or HINDSIGHT_API_ANSWER_LLM_API_KEY environment variable is required "
569
+ "(unless using openai-codex or claude-code)"
1598
570
  )
571
+
1599
572
  base_url = os.getenv("HINDSIGHT_API_ANSWER_LLM_BASE_URL", os.getenv("HINDSIGHT_API_LLM_BASE_URL", ""))
1600
573
  model = os.getenv("HINDSIGHT_API_ANSWER_LLM_MODEL", os.getenv("HINDSIGHT_API_LLM_MODEL", "openai/gpt-oss-120b"))
1601
574
 
@@ -1605,11 +578,15 @@ class LLMProvider:
1605
578
  def for_judge(cls) -> "LLMProvider":
1606
579
  """Create provider for judge/evaluator operations. Falls back to memory config if not set."""
1607
580
  provider = os.getenv("HINDSIGHT_API_JUDGE_LLM_PROVIDER", os.getenv("HINDSIGHT_API_LLM_PROVIDER", "groq"))
1608
- api_key = os.getenv("HINDSIGHT_API_JUDGE_LLM_API_KEY", os.getenv("HINDSIGHT_API_LLM_API_KEY"))
1609
- if not api_key:
581
+ api_key = os.getenv("HINDSIGHT_API_JUDGE_LLM_API_KEY", os.getenv("HINDSIGHT_API_LLM_API_KEY", ""))
582
+
583
+ # API key not needed for openai-codex (uses OAuth) or claude-code (uses Keychain OAuth)
584
+ if not api_key and provider not in ("openai-codex", "claude-code"):
1610
585
  raise ValueError(
1611
- "HINDSIGHT_API_LLM_API_KEY or HINDSIGHT_API_JUDGE_LLM_API_KEY environment variable is required"
586
+ "HINDSIGHT_API_LLM_API_KEY or HINDSIGHT_API_JUDGE_LLM_API_KEY environment variable is required "
587
+ "(unless using openai-codex or claude-code)"
1612
588
  )
589
+
1613
590
  base_url = os.getenv("HINDSIGHT_API_JUDGE_LLM_BASE_URL", os.getenv("HINDSIGHT_API_LLM_BASE_URL", ""))
1614
591
  model = os.getenv("HINDSIGHT_API_JUDGE_LLM_MODEL", os.getenv("HINDSIGHT_API_LLM_MODEL", "openai/gpt-oss-120b"))
1615
592