hindsight-api 0.4.6__py3-none-any.whl → 0.4.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hindsight_api/__init__.py +1 -1
- hindsight_api/alembic/versions/5a366d414dce_initial_schema.py +16 -2
- hindsight_api/api/http.py +83 -1
- hindsight_api/banner.py +3 -0
- hindsight_api/config.py +44 -6
- hindsight_api/daemon.py +18 -112
- hindsight_api/engine/llm_interface.py +146 -0
- hindsight_api/engine/llm_wrapper.py +304 -1327
- hindsight_api/engine/memory_engine.py +125 -41
- hindsight_api/engine/providers/__init__.py +14 -0
- hindsight_api/engine/providers/anthropic_llm.py +434 -0
- hindsight_api/engine/providers/claude_code_llm.py +352 -0
- hindsight_api/engine/providers/codex_llm.py +527 -0
- hindsight_api/engine/providers/gemini_llm.py +502 -0
- hindsight_api/engine/providers/mock_llm.py +234 -0
- hindsight_api/engine/providers/openai_compatible_llm.py +745 -0
- hindsight_api/engine/retain/fact_extraction.py +13 -9
- hindsight_api/engine/retain/fact_storage.py +5 -3
- hindsight_api/extensions/__init__.py +10 -0
- hindsight_api/extensions/builtin/tenant.py +36 -0
- hindsight_api/extensions/operation_validator.py +129 -0
- hindsight_api/main.py +6 -21
- hindsight_api/migrations.py +75 -0
- hindsight_api/worker/main.py +41 -11
- hindsight_api/worker/poller.py +26 -14
- {hindsight_api-0.4.6.dist-info → hindsight_api-0.4.8.dist-info}/METADATA +2 -1
- {hindsight_api-0.4.6.dist-info → hindsight_api-0.4.8.dist-info}/RECORD +29 -21
- {hindsight_api-0.4.6.dist-info → hindsight_api-0.4.8.dist-info}/WHEEL +0 -0
- {hindsight_api-0.4.6.dist-info → hindsight_api-0.4.8.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,745 @@
|
|
|
1
|
+
"""
|
|
2
|
+
OpenAI-compatible LLM provider supporting OpenAI, Groq, Ollama, and LMStudio.
|
|
3
|
+
|
|
4
|
+
This provider handles all OpenAI API-compatible models including:
|
|
5
|
+
- OpenAI: GPT-4, GPT-4o, GPT-5, o1, o3 (reasoning models)
|
|
6
|
+
- Groq: Fast inference with seed control and service tiers
|
|
7
|
+
- Ollama: Local models with native streaming API support
|
|
8
|
+
- LMStudio: Local models with OpenAI-compatible API
|
|
9
|
+
|
|
10
|
+
Features:
|
|
11
|
+
- Reasoning models with extended thinking (o1, o3, GPT-5 families)
|
|
12
|
+
- Strict JSON schema enforcement (OpenAI)
|
|
13
|
+
- Provider-specific parameters (Groq seed, service tier)
|
|
14
|
+
- Native Ollama streaming for better structured output
|
|
15
|
+
- Automatic token limit handling per model family
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
import asyncio
|
|
19
|
+
import json
|
|
20
|
+
import logging
|
|
21
|
+
import os
|
|
22
|
+
import re
|
|
23
|
+
import time
|
|
24
|
+
from typing import Any
|
|
25
|
+
|
|
26
|
+
import httpx
|
|
27
|
+
from openai import APIConnectionError, APIStatusError, AsyncOpenAI, LengthFinishReasonError
|
|
28
|
+
|
|
29
|
+
from hindsight_api.config import DEFAULT_LLM_TIMEOUT, ENV_LLM_TIMEOUT
|
|
30
|
+
from hindsight_api.engine.llm_interface import LLMInterface, OutputTooLongError
|
|
31
|
+
from hindsight_api.engine.response_models import LLMToolCall, LLMToolCallResult, TokenUsage
|
|
32
|
+
from hindsight_api.metrics import get_metrics_collector
|
|
33
|
+
|
|
34
|
+
logger = logging.getLogger(__name__)
|
|
35
|
+
|
|
36
|
+
# Seed applied to every Groq request for deterministic behavior
|
|
37
|
+
DEFAULT_LLM_SEED = 4242
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class OpenAICompatibleLLM(LLMInterface):
|
|
41
|
+
"""
|
|
42
|
+
LLM provider for OpenAI-compatible APIs.
|
|
43
|
+
|
|
44
|
+
Supports:
|
|
45
|
+
- OpenAI: Standard models (GPT-4, GPT-4o) and reasoning models (o1, o3, GPT-5)
|
|
46
|
+
- Groq: Fast inference with seed control and service tiers
|
|
47
|
+
- Ollama: Local models with native streaming API for better structured output
|
|
48
|
+
- LMStudio: Local models with OpenAI-compatible API
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
def __init__(
|
|
52
|
+
self,
|
|
53
|
+
provider: str,
|
|
54
|
+
api_key: str,
|
|
55
|
+
base_url: str,
|
|
56
|
+
model: str,
|
|
57
|
+
reasoning_effort: str = "low",
|
|
58
|
+
timeout: float | None = None,
|
|
59
|
+
groq_service_tier: str | None = None,
|
|
60
|
+
**kwargs: Any,
|
|
61
|
+
):
|
|
62
|
+
"""
|
|
63
|
+
Initialize OpenAI-compatible LLM provider.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
provider: Provider name ("openai", "groq", "ollama", "lmstudio").
|
|
67
|
+
api_key: API key (optional for ollama/lmstudio).
|
|
68
|
+
base_url: Base URL for the API (uses defaults for groq/ollama/lmstudio if empty).
|
|
69
|
+
model: Model name.
|
|
70
|
+
reasoning_effort: Reasoning effort level for supported models ("low", "medium", "high").
|
|
71
|
+
timeout: Request timeout in seconds (uses env var or 300s default).
|
|
72
|
+
groq_service_tier: Groq service tier ("on_demand", "flex", "auto").
|
|
73
|
+
**kwargs: Additional provider-specific parameters.
|
|
74
|
+
"""
|
|
75
|
+
super().__init__(provider, api_key, base_url, model, reasoning_effort, **kwargs)
|
|
76
|
+
|
|
77
|
+
# Validate provider
|
|
78
|
+
valid_providers = ["openai", "groq", "ollama", "lmstudio"]
|
|
79
|
+
if self.provider not in valid_providers:
|
|
80
|
+
raise ValueError(f"OpenAICompatibleLLM only supports: {', '.join(valid_providers)}. Got: {self.provider}")
|
|
81
|
+
|
|
82
|
+
# Set default base URLs
|
|
83
|
+
if not self.base_url:
|
|
84
|
+
if self.provider == "groq":
|
|
85
|
+
self.base_url = "https://api.groq.com/openai/v1"
|
|
86
|
+
elif self.provider == "ollama":
|
|
87
|
+
self.base_url = "http://localhost:11434/v1"
|
|
88
|
+
elif self.provider == "lmstudio":
|
|
89
|
+
self.base_url = "http://localhost:1234/v1"
|
|
90
|
+
|
|
91
|
+
# For ollama/lmstudio, use dummy key if not provided
|
|
92
|
+
if self.provider in ("ollama", "lmstudio") and not self.api_key:
|
|
93
|
+
self.api_key = "local"
|
|
94
|
+
|
|
95
|
+
# Validate API key for cloud providers
|
|
96
|
+
if self.provider in ("openai", "groq") and not self.api_key:
|
|
97
|
+
raise ValueError(f"API key is required for {self.provider}")
|
|
98
|
+
|
|
99
|
+
# Groq service tier configuration
|
|
100
|
+
self.groq_service_tier = groq_service_tier or os.getenv("HINDSIGHT_API_LLM_GROQ_SERVICE_TIER", "auto")
|
|
101
|
+
|
|
102
|
+
# Get timeout config
|
|
103
|
+
self.timeout = timeout or float(os.getenv(ENV_LLM_TIMEOUT, str(DEFAULT_LLM_TIMEOUT)))
|
|
104
|
+
|
|
105
|
+
# Create OpenAI client
|
|
106
|
+
client_kwargs: dict[str, Any] = {"api_key": self.api_key, "max_retries": 0}
|
|
107
|
+
if self.base_url:
|
|
108
|
+
client_kwargs["base_url"] = self.base_url
|
|
109
|
+
if self.timeout:
|
|
110
|
+
client_kwargs["timeout"] = self.timeout
|
|
111
|
+
|
|
112
|
+
self._client = AsyncOpenAI(**client_kwargs)
|
|
113
|
+
logger.info(
|
|
114
|
+
f"OpenAI-compatible client initialized: provider={self.provider}, model={self.model}, "
|
|
115
|
+
f"base_url={self.base_url or 'default'}"
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
async def verify_connection(self) -> None:
|
|
119
|
+
"""
|
|
120
|
+
Verify that the provider is configured correctly by making a simple test call.
|
|
121
|
+
|
|
122
|
+
Raises:
|
|
123
|
+
RuntimeError: If the connection test fails.
|
|
124
|
+
"""
|
|
125
|
+
try:
|
|
126
|
+
logger.info(f"Verifying connection: {self.provider}/{self.model}")
|
|
127
|
+
await self.call(
|
|
128
|
+
messages=[{"role": "user", "content": "Say 'ok'"}],
|
|
129
|
+
max_completion_tokens=100,
|
|
130
|
+
max_retries=2,
|
|
131
|
+
initial_backoff=0.5,
|
|
132
|
+
max_backoff=2.0,
|
|
133
|
+
)
|
|
134
|
+
logger.info(f"Connection verified: {self.provider}/{self.model}")
|
|
135
|
+
except Exception as e:
|
|
136
|
+
raise RuntimeError(f"Connection verification failed for {self.provider}/{self.model}: {e}") from e
|
|
137
|
+
|
|
138
|
+
def _supports_reasoning_model(self) -> bool:
|
|
139
|
+
"""Check if the current model is a reasoning model (o1, o3, GPT-5, DeepSeek)."""
|
|
140
|
+
model_lower = self.model.lower()
|
|
141
|
+
return any(x in model_lower for x in ["gpt-5", "o1", "o3", "deepseek"])
|
|
142
|
+
|
|
143
|
+
def _get_max_reasoning_tokens(self) -> int | None:
|
|
144
|
+
"""Get max reasoning tokens for reasoning models."""
|
|
145
|
+
model_lower = self.model.lower()
|
|
146
|
+
|
|
147
|
+
# GPT-4 and GPT-4.1 models have different caps
|
|
148
|
+
if any(x in model_lower for x in ["gpt-4.1", "gpt-4-"]):
|
|
149
|
+
return 32000
|
|
150
|
+
elif "gpt-4o" in model_lower:
|
|
151
|
+
return 16384
|
|
152
|
+
|
|
153
|
+
return None
|
|
154
|
+
|
|
155
|
+
async def call(
|
|
156
|
+
self,
|
|
157
|
+
messages: list[dict[str, str]],
|
|
158
|
+
response_format: Any | None = None,
|
|
159
|
+
max_completion_tokens: int | None = None,
|
|
160
|
+
temperature: float | None = None,
|
|
161
|
+
scope: str = "memory",
|
|
162
|
+
max_retries: int = 10,
|
|
163
|
+
initial_backoff: float = 1.0,
|
|
164
|
+
max_backoff: float = 60.0,
|
|
165
|
+
skip_validation: bool = False,
|
|
166
|
+
strict_schema: bool = False,
|
|
167
|
+
return_usage: bool = False,
|
|
168
|
+
) -> Any:
|
|
169
|
+
"""
|
|
170
|
+
Make an LLM API call with retry logic.
|
|
171
|
+
|
|
172
|
+
Args:
|
|
173
|
+
messages: List of message dicts with 'role' and 'content'.
|
|
174
|
+
response_format: Optional Pydantic model for structured output.
|
|
175
|
+
max_completion_tokens: Maximum tokens in response.
|
|
176
|
+
temperature: Sampling temperature (0.0-2.0).
|
|
177
|
+
scope: Scope identifier for tracking.
|
|
178
|
+
max_retries: Maximum retry attempts.
|
|
179
|
+
initial_backoff: Initial backoff time in seconds.
|
|
180
|
+
max_backoff: Maximum backoff time in seconds.
|
|
181
|
+
skip_validation: Return raw JSON without Pydantic validation.
|
|
182
|
+
strict_schema: Use strict JSON schema enforcement (OpenAI only).
|
|
183
|
+
return_usage: If True, return tuple (result, TokenUsage) instead of just result.
|
|
184
|
+
|
|
185
|
+
Returns:
|
|
186
|
+
If return_usage=False: Parsed response if response_format is provided, otherwise text content.
|
|
187
|
+
If return_usage=True: Tuple of (result, TokenUsage) with token counts.
|
|
188
|
+
|
|
189
|
+
Raises:
|
|
190
|
+
OutputTooLongError: If output exceeds token limits.
|
|
191
|
+
Exception: Re-raises API errors after retries exhausted.
|
|
192
|
+
"""
|
|
193
|
+
# Handle Ollama with native API for structured output (better schema enforcement)
|
|
194
|
+
if self.provider == "ollama" and response_format is not None:
|
|
195
|
+
return await self._call_ollama_native(
|
|
196
|
+
messages=messages,
|
|
197
|
+
response_format=response_format,
|
|
198
|
+
max_completion_tokens=max_completion_tokens,
|
|
199
|
+
temperature=temperature,
|
|
200
|
+
max_retries=max_retries,
|
|
201
|
+
initial_backoff=initial_backoff,
|
|
202
|
+
max_backoff=max_backoff,
|
|
203
|
+
skip_validation=skip_validation,
|
|
204
|
+
scope=scope,
|
|
205
|
+
return_usage=return_usage,
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
start_time = time.time()
|
|
209
|
+
|
|
210
|
+
# Build call parameters
|
|
211
|
+
call_params: dict[str, Any] = {
|
|
212
|
+
"model": self.model,
|
|
213
|
+
"messages": messages,
|
|
214
|
+
}
|
|
215
|
+
|
|
216
|
+
# Check if model supports reasoning parameter
|
|
217
|
+
is_reasoning_model = self._supports_reasoning_model()
|
|
218
|
+
|
|
219
|
+
# Apply model-specific token limits
|
|
220
|
+
if max_completion_tokens is not None:
|
|
221
|
+
max_tokens_cap = self._get_max_reasoning_tokens()
|
|
222
|
+
if max_tokens_cap and max_completion_tokens > max_tokens_cap:
|
|
223
|
+
max_completion_tokens = max_tokens_cap
|
|
224
|
+
# For reasoning models, enforce minimum to ensure space for reasoning + output
|
|
225
|
+
if is_reasoning_model and max_completion_tokens < 16000:
|
|
226
|
+
max_completion_tokens = 16000
|
|
227
|
+
call_params["max_completion_tokens"] = max_completion_tokens
|
|
228
|
+
|
|
229
|
+
# Temperature - reasoning models don't support custom temperature
|
|
230
|
+
if temperature is not None and not is_reasoning_model:
|
|
231
|
+
call_params["temperature"] = temperature
|
|
232
|
+
|
|
233
|
+
# Set reasoning_effort for reasoning models
|
|
234
|
+
if is_reasoning_model:
|
|
235
|
+
call_params["reasoning_effort"] = self.reasoning_effort
|
|
236
|
+
|
|
237
|
+
# Provider-specific parameters
|
|
238
|
+
if self.provider == "groq":
|
|
239
|
+
call_params["seed"] = DEFAULT_LLM_SEED
|
|
240
|
+
extra_body: dict[str, Any] = {}
|
|
241
|
+
# Add service_tier if configured
|
|
242
|
+
if self.groq_service_tier:
|
|
243
|
+
extra_body["service_tier"] = self.groq_service_tier
|
|
244
|
+
# Add reasoning parameters for reasoning models
|
|
245
|
+
if is_reasoning_model:
|
|
246
|
+
extra_body["include_reasoning"] = False
|
|
247
|
+
if extra_body:
|
|
248
|
+
call_params["extra_body"] = extra_body
|
|
249
|
+
|
|
250
|
+
# Prepare response format ONCE before retry loop
|
|
251
|
+
if response_format is not None:
|
|
252
|
+
schema = None
|
|
253
|
+
if hasattr(response_format, "model_json_schema"):
|
|
254
|
+
schema = response_format.model_json_schema()
|
|
255
|
+
|
|
256
|
+
if strict_schema and schema is not None:
|
|
257
|
+
# Use OpenAI's strict JSON schema enforcement
|
|
258
|
+
call_params["response_format"] = {
|
|
259
|
+
"type": "json_schema",
|
|
260
|
+
"json_schema": {
|
|
261
|
+
"name": "response",
|
|
262
|
+
"strict": True,
|
|
263
|
+
"schema": schema,
|
|
264
|
+
},
|
|
265
|
+
}
|
|
266
|
+
else:
|
|
267
|
+
# Soft enforcement: add schema to prompt and use json_object mode
|
|
268
|
+
if schema is not None:
|
|
269
|
+
schema_msg = (
|
|
270
|
+
f"\n\nYou must respond with valid JSON matching this schema:\n{json.dumps(schema, indent=2)}"
|
|
271
|
+
)
|
|
272
|
+
|
|
273
|
+
if call_params["messages"] and call_params["messages"][0].get("role") == "system":
|
|
274
|
+
first_msg = call_params["messages"][0]
|
|
275
|
+
if isinstance(first_msg, dict) and isinstance(first_msg.get("content"), str):
|
|
276
|
+
first_msg["content"] += schema_msg
|
|
277
|
+
elif call_params["messages"]:
|
|
278
|
+
first_msg = call_params["messages"][0]
|
|
279
|
+
if isinstance(first_msg, dict) and isinstance(first_msg.get("content"), str):
|
|
280
|
+
first_msg["content"] = schema_msg + "\n\n" + first_msg["content"]
|
|
281
|
+
if self.provider not in ("lmstudio", "ollama"):
|
|
282
|
+
# LM Studio and Ollama don't support json_object response format reliably
|
|
283
|
+
call_params["response_format"] = {"type": "json_object"}
|
|
284
|
+
|
|
285
|
+
last_exception = None
|
|
286
|
+
|
|
287
|
+
for attempt in range(max_retries + 1):
|
|
288
|
+
try:
|
|
289
|
+
if response_format is not None:
|
|
290
|
+
response = await self._client.chat.completions.create(**call_params)
|
|
291
|
+
|
|
292
|
+
content = response.choices[0].message.content
|
|
293
|
+
|
|
294
|
+
# Strip reasoning model thinking tags
|
|
295
|
+
# Supports: <think>, <thinking>, <reasoning>, |startthink|/|endthink|
|
|
296
|
+
if content:
|
|
297
|
+
original_len = len(content)
|
|
298
|
+
content = re.sub(r"<think>.*?</think>", "", content, flags=re.DOTALL)
|
|
299
|
+
content = re.sub(r"<thinking>.*?</thinking>", "", content, flags=re.DOTALL)
|
|
300
|
+
content = re.sub(r"<reasoning>.*?</reasoning>", "", content, flags=re.DOTALL)
|
|
301
|
+
content = re.sub(r"\|startthink\|.*?\|endthink\|", "", content, flags=re.DOTALL)
|
|
302
|
+
content = content.strip()
|
|
303
|
+
if len(content) < original_len:
|
|
304
|
+
logger.debug(f"Stripped {original_len - len(content)} chars of reasoning tokens")
|
|
305
|
+
|
|
306
|
+
# For local models, they may wrap JSON in markdown code blocks
|
|
307
|
+
if self.provider in ("lmstudio", "ollama"):
|
|
308
|
+
clean_content = content
|
|
309
|
+
if "```json" in content:
|
|
310
|
+
clean_content = content.split("```json")[1].split("```")[0].strip()
|
|
311
|
+
elif "```" in content:
|
|
312
|
+
clean_content = content.split("```")[1].split("```")[0].strip()
|
|
313
|
+
try:
|
|
314
|
+
json_data = json.loads(clean_content)
|
|
315
|
+
except json.JSONDecodeError:
|
|
316
|
+
# Fallback to parsing raw content
|
|
317
|
+
json_data = json.loads(content)
|
|
318
|
+
else:
|
|
319
|
+
# Log raw LLM response for debugging JSON parse issues
|
|
320
|
+
try:
|
|
321
|
+
json_data = json.loads(content)
|
|
322
|
+
except json.JSONDecodeError as json_err:
|
|
323
|
+
# Truncate content for logging
|
|
324
|
+
content_preview = content[:500] if content else "<empty>"
|
|
325
|
+
if content and len(content) > 700:
|
|
326
|
+
content_preview = f"{content[:500]}...TRUNCATED...{content[-200:]}"
|
|
327
|
+
logger.warning(
|
|
328
|
+
f"JSON parse error from LLM response (attempt {attempt + 1}/{max_retries + 1}): {json_err}\n"
|
|
329
|
+
f" Model: {self.provider}/{self.model}\n"
|
|
330
|
+
f" Content length: {len(content) if content else 0} chars\n"
|
|
331
|
+
f" Content preview: {content_preview!r}\n"
|
|
332
|
+
f" Finish reason: {response.choices[0].finish_reason if response.choices else 'unknown'}"
|
|
333
|
+
)
|
|
334
|
+
# Retry on JSON parse errors
|
|
335
|
+
if attempt < max_retries:
|
|
336
|
+
backoff = min(initial_backoff * (2**attempt), max_backoff)
|
|
337
|
+
await asyncio.sleep(backoff)
|
|
338
|
+
last_exception = json_err
|
|
339
|
+
continue
|
|
340
|
+
else:
|
|
341
|
+
logger.error(f"JSON parse error after {max_retries + 1} attempts, giving up")
|
|
342
|
+
raise
|
|
343
|
+
|
|
344
|
+
if skip_validation:
|
|
345
|
+
result = json_data
|
|
346
|
+
else:
|
|
347
|
+
result = response_format.model_validate(json_data)
|
|
348
|
+
else:
|
|
349
|
+
response = await self._client.chat.completions.create(**call_params)
|
|
350
|
+
result = response.choices[0].message.content
|
|
351
|
+
|
|
352
|
+
# Record token usage metrics
|
|
353
|
+
duration = time.time() - start_time
|
|
354
|
+
usage = response.usage
|
|
355
|
+
input_tokens = usage.prompt_tokens or 0 if usage else 0
|
|
356
|
+
output_tokens = usage.completion_tokens or 0 if usage else 0
|
|
357
|
+
total_tokens = usage.total_tokens or 0 if usage else 0
|
|
358
|
+
|
|
359
|
+
# Record LLM metrics
|
|
360
|
+
metrics = get_metrics_collector()
|
|
361
|
+
metrics.record_llm_call(
|
|
362
|
+
provider=self.provider,
|
|
363
|
+
model=self.model,
|
|
364
|
+
scope=scope,
|
|
365
|
+
duration=duration,
|
|
366
|
+
input_tokens=input_tokens,
|
|
367
|
+
output_tokens=output_tokens,
|
|
368
|
+
success=True,
|
|
369
|
+
)
|
|
370
|
+
|
|
371
|
+
# Log slow calls
|
|
372
|
+
if duration > 10.0 and usage:
|
|
373
|
+
ratio = max(1, output_tokens) / max(1, input_tokens)
|
|
374
|
+
cached_tokens = 0
|
|
375
|
+
if hasattr(usage, "prompt_tokens_details") and usage.prompt_tokens_details:
|
|
376
|
+
cached_tokens = getattr(usage.prompt_tokens_details, "cached_tokens", 0) or 0
|
|
377
|
+
cache_info = f", cached_tokens={cached_tokens}" if cached_tokens > 0 else ""
|
|
378
|
+
logger.info(
|
|
379
|
+
f"slow llm call: scope={scope}, model={self.provider}/{self.model}, "
|
|
380
|
+
f"input_tokens={input_tokens}, output_tokens={output_tokens}, "
|
|
381
|
+
f"total_tokens={total_tokens}{cache_info}, time={duration:.3f}s, ratio out/in={ratio:.2f}"
|
|
382
|
+
)
|
|
383
|
+
|
|
384
|
+
if return_usage:
|
|
385
|
+
token_usage = TokenUsage(
|
|
386
|
+
input_tokens=input_tokens,
|
|
387
|
+
output_tokens=output_tokens,
|
|
388
|
+
total_tokens=total_tokens,
|
|
389
|
+
)
|
|
390
|
+
return result, token_usage
|
|
391
|
+
return result
|
|
392
|
+
|
|
393
|
+
except LengthFinishReasonError as e:
|
|
394
|
+
logger.warning(f"LLM output exceeded token limits: {str(e)}")
|
|
395
|
+
raise OutputTooLongError(
|
|
396
|
+
"LLM output exceeded token limits. Input may need to be split into smaller chunks."
|
|
397
|
+
) from e
|
|
398
|
+
|
|
399
|
+
except APIConnectionError as e:
|
|
400
|
+
last_exception = e
|
|
401
|
+
status_code = getattr(e, "status_code", None) or getattr(
|
|
402
|
+
getattr(e, "response", None), "status_code", None
|
|
403
|
+
)
|
|
404
|
+
logger.warning(f"APIConnectionError (HTTP {status_code}), attempt {attempt + 1}: {str(e)[:200]}")
|
|
405
|
+
if attempt < max_retries:
|
|
406
|
+
backoff = min(initial_backoff * (2**attempt), max_backoff)
|
|
407
|
+
await asyncio.sleep(backoff)
|
|
408
|
+
continue
|
|
409
|
+
else:
|
|
410
|
+
logger.error(f"Connection error after {max_retries + 1} attempts: {str(e)}")
|
|
411
|
+
raise
|
|
412
|
+
|
|
413
|
+
except APIStatusError as e:
|
|
414
|
+
# Fast fail only on 401 (unauthorized) and 403 (forbidden)
|
|
415
|
+
if e.status_code in (401, 403):
|
|
416
|
+
logger.error(f"Auth error (HTTP {e.status_code}), not retrying: {str(e)}")
|
|
417
|
+
raise
|
|
418
|
+
|
|
419
|
+
# Handle tool_use_failed error - model outputted in tool call format
|
|
420
|
+
if e.status_code == 400 and response_format is not None:
|
|
421
|
+
try:
|
|
422
|
+
error_body = e.body if hasattr(e, "body") else {}
|
|
423
|
+
if isinstance(error_body, dict):
|
|
424
|
+
error_info: dict[str, Any] = error_body.get("error") or {}
|
|
425
|
+
if error_info.get("code") == "tool_use_failed":
|
|
426
|
+
failed_gen = error_info.get("failed_generation", "")
|
|
427
|
+
if failed_gen:
|
|
428
|
+
# Parse tool call format and convert to expected format
|
|
429
|
+
tool_call = json.loads(failed_gen)
|
|
430
|
+
tool_name = tool_call.get("name", "")
|
|
431
|
+
tool_args = tool_call.get("arguments", {})
|
|
432
|
+
converted = {"actions": [{"tool": tool_name, **tool_args}]}
|
|
433
|
+
if skip_validation:
|
|
434
|
+
result = converted
|
|
435
|
+
else:
|
|
436
|
+
result = response_format.model_validate(converted)
|
|
437
|
+
|
|
438
|
+
# Record metrics
|
|
439
|
+
duration = time.time() - start_time
|
|
440
|
+
metrics = get_metrics_collector()
|
|
441
|
+
metrics.record_llm_call(
|
|
442
|
+
provider=self.provider,
|
|
443
|
+
model=self.model,
|
|
444
|
+
scope=scope,
|
|
445
|
+
duration=duration,
|
|
446
|
+
input_tokens=0,
|
|
447
|
+
output_tokens=0,
|
|
448
|
+
success=True,
|
|
449
|
+
)
|
|
450
|
+
if return_usage:
|
|
451
|
+
return result, TokenUsage(input_tokens=0, output_tokens=0, total_tokens=0)
|
|
452
|
+
return result
|
|
453
|
+
except (json.JSONDecodeError, KeyError, TypeError):
|
|
454
|
+
pass # Failed to parse tool_use_failed, continue with normal retry
|
|
455
|
+
|
|
456
|
+
last_exception = e
|
|
457
|
+
if attempt < max_retries:
|
|
458
|
+
backoff = min(initial_backoff * (2**attempt), max_backoff)
|
|
459
|
+
jitter = backoff * 0.2 * (2 * (time.time() % 1) - 1)
|
|
460
|
+
sleep_time = backoff + jitter
|
|
461
|
+
await asyncio.sleep(sleep_time)
|
|
462
|
+
else:
|
|
463
|
+
logger.error(f"API error after {max_retries + 1} attempts: {str(e)}")
|
|
464
|
+
raise
|
|
465
|
+
|
|
466
|
+
except Exception:
|
|
467
|
+
raise
|
|
468
|
+
|
|
469
|
+
if last_exception:
|
|
470
|
+
raise last_exception
|
|
471
|
+
raise RuntimeError("LLM call failed after all retries with no exception captured")
|
|
472
|
+
|
|
473
|
+
async def call_with_tools(
|
|
474
|
+
self,
|
|
475
|
+
messages: list[dict[str, Any]],
|
|
476
|
+
tools: list[dict[str, Any]],
|
|
477
|
+
max_completion_tokens: int | None = None,
|
|
478
|
+
temperature: float | None = None,
|
|
479
|
+
scope: str = "tools",
|
|
480
|
+
max_retries: int = 5,
|
|
481
|
+
initial_backoff: float = 1.0,
|
|
482
|
+
max_backoff: float = 30.0,
|
|
483
|
+
tool_choice: str | dict[str, Any] = "auto",
|
|
484
|
+
) -> LLMToolCallResult:
|
|
485
|
+
"""
|
|
486
|
+
Make an LLM API call with tool/function calling support.
|
|
487
|
+
|
|
488
|
+
Args:
|
|
489
|
+
messages: List of message dicts. Can include tool results with role='tool'.
|
|
490
|
+
tools: List of tool definitions in OpenAI format.
|
|
491
|
+
max_completion_tokens: Maximum tokens in response.
|
|
492
|
+
temperature: Sampling temperature (0.0-2.0).
|
|
493
|
+
scope: Scope identifier for tracking.
|
|
494
|
+
max_retries: Maximum retry attempts.
|
|
495
|
+
initial_backoff: Initial backoff time in seconds.
|
|
496
|
+
max_backoff: Maximum backoff time in seconds.
|
|
497
|
+
tool_choice: How to choose tools - "auto", "none", "required", or specific function.
|
|
498
|
+
|
|
499
|
+
Returns:
|
|
500
|
+
LLMToolCallResult with content and/or tool_calls.
|
|
501
|
+
"""
|
|
502
|
+
start_time = time.time()
|
|
503
|
+
|
|
504
|
+
# Build call parameters
|
|
505
|
+
call_params: dict[str, Any] = {
|
|
506
|
+
"model": self.model,
|
|
507
|
+
"messages": messages,
|
|
508
|
+
"tools": tools,
|
|
509
|
+
"tool_choice": tool_choice,
|
|
510
|
+
}
|
|
511
|
+
|
|
512
|
+
if max_completion_tokens is not None:
|
|
513
|
+
call_params["max_completion_tokens"] = max_completion_tokens
|
|
514
|
+
if temperature is not None:
|
|
515
|
+
call_params["temperature"] = temperature
|
|
516
|
+
|
|
517
|
+
# Provider-specific parameters
|
|
518
|
+
if self.provider == "groq":
|
|
519
|
+
call_params["seed"] = DEFAULT_LLM_SEED
|
|
520
|
+
|
|
521
|
+
last_exception = None
|
|
522
|
+
|
|
523
|
+
for attempt in range(max_retries + 1):
|
|
524
|
+
try:
|
|
525
|
+
response = await self._client.chat.completions.create(**call_params)
|
|
526
|
+
|
|
527
|
+
message = response.choices[0].message
|
|
528
|
+
finish_reason = response.choices[0].finish_reason
|
|
529
|
+
|
|
530
|
+
# Extract tool calls if present
|
|
531
|
+
tool_calls: list[LLMToolCall] = []
|
|
532
|
+
if message.tool_calls:
|
|
533
|
+
for tc in message.tool_calls:
|
|
534
|
+
try:
|
|
535
|
+
args = json.loads(tc.function.arguments) if tc.function.arguments else {}
|
|
536
|
+
except json.JSONDecodeError:
|
|
537
|
+
args = {"_raw": tc.function.arguments}
|
|
538
|
+
tool_calls.append(LLMToolCall(id=tc.id, name=tc.function.name, arguments=args))
|
|
539
|
+
|
|
540
|
+
content = message.content
|
|
541
|
+
|
|
542
|
+
# Record metrics
|
|
543
|
+
duration = time.time() - start_time
|
|
544
|
+
usage = response.usage
|
|
545
|
+
input_tokens = usage.prompt_tokens or 0 if usage else 0
|
|
546
|
+
output_tokens = usage.completion_tokens or 0 if usage else 0
|
|
547
|
+
|
|
548
|
+
metrics = get_metrics_collector()
|
|
549
|
+
metrics.record_llm_call(
|
|
550
|
+
provider=self.provider,
|
|
551
|
+
model=self.model,
|
|
552
|
+
scope=scope,
|
|
553
|
+
duration=duration,
|
|
554
|
+
input_tokens=input_tokens,
|
|
555
|
+
output_tokens=output_tokens,
|
|
556
|
+
success=True,
|
|
557
|
+
)
|
|
558
|
+
|
|
559
|
+
return LLMToolCallResult(
|
|
560
|
+
content=content,
|
|
561
|
+
tool_calls=tool_calls,
|
|
562
|
+
finish_reason=finish_reason,
|
|
563
|
+
input_tokens=input_tokens,
|
|
564
|
+
output_tokens=output_tokens,
|
|
565
|
+
)
|
|
566
|
+
|
|
567
|
+
except APIConnectionError as e:
|
|
568
|
+
last_exception = e
|
|
569
|
+
if attempt < max_retries:
|
|
570
|
+
await asyncio.sleep(min(initial_backoff * (2**attempt), max_backoff))
|
|
571
|
+
continue
|
|
572
|
+
raise
|
|
573
|
+
|
|
574
|
+
except APIStatusError as e:
|
|
575
|
+
if e.status_code in (401, 403):
|
|
576
|
+
raise
|
|
577
|
+
last_exception = e
|
|
578
|
+
if attempt < max_retries:
|
|
579
|
+
await asyncio.sleep(min(initial_backoff * (2**attempt), max_backoff))
|
|
580
|
+
continue
|
|
581
|
+
raise
|
|
582
|
+
|
|
583
|
+
except Exception:
|
|
584
|
+
raise
|
|
585
|
+
|
|
586
|
+
if last_exception:
|
|
587
|
+
raise last_exception
|
|
588
|
+
raise RuntimeError("Tool call failed after all retries")
|
|
589
|
+
|
|
590
|
+
async def _call_ollama_native(
|
|
591
|
+
self,
|
|
592
|
+
messages: list[dict[str, str]],
|
|
593
|
+
response_format: Any,
|
|
594
|
+
max_completion_tokens: int | None,
|
|
595
|
+
temperature: float | None,
|
|
596
|
+
max_retries: int,
|
|
597
|
+
initial_backoff: float,
|
|
598
|
+
max_backoff: float,
|
|
599
|
+
skip_validation: bool,
|
|
600
|
+
scope: str = "memory",
|
|
601
|
+
return_usage: bool = False,
|
|
602
|
+
) -> Any:
|
|
603
|
+
"""
|
|
604
|
+
Call Ollama using native API with JSON schema enforcement.
|
|
605
|
+
|
|
606
|
+
Ollama's native API supports passing a full JSON schema in the 'format' parameter,
|
|
607
|
+
which provides better structured output control than the OpenAI-compatible API.
|
|
608
|
+
"""
|
|
609
|
+
start_time = time.time()
|
|
610
|
+
|
|
611
|
+
# Get the JSON schema from the Pydantic model
|
|
612
|
+
schema = response_format.model_json_schema() if hasattr(response_format, "model_json_schema") else None
|
|
613
|
+
|
|
614
|
+
# Build the base URL for Ollama's native API
|
|
615
|
+
# Default OpenAI-compatible URL is http://localhost:11434/v1
|
|
616
|
+
# Native API is at http://localhost:11434/api/chat
|
|
617
|
+
base_url = self.base_url or "http://localhost:11434/v1"
|
|
618
|
+
if base_url.endswith("/v1"):
|
|
619
|
+
native_url = base_url[:-3] + "/api/chat"
|
|
620
|
+
else:
|
|
621
|
+
native_url = base_url.rstrip("/") + "/api/chat"
|
|
622
|
+
|
|
623
|
+
# Build request payload
|
|
624
|
+
payload: dict[str, Any] = {
|
|
625
|
+
"model": self.model,
|
|
626
|
+
"messages": messages,
|
|
627
|
+
"stream": False,
|
|
628
|
+
}
|
|
629
|
+
|
|
630
|
+
# Add schema as format parameter for structured output
|
|
631
|
+
if schema:
|
|
632
|
+
payload["format"] = schema
|
|
633
|
+
|
|
634
|
+
# Add optional parameters with optimized defaults for Ollama
|
|
635
|
+
options: dict[str, Any] = {
|
|
636
|
+
"num_ctx": 16384, # 16k context window for larger prompts
|
|
637
|
+
"num_batch": 512, # Optimal batch size for prompt processing
|
|
638
|
+
}
|
|
639
|
+
if max_completion_tokens:
|
|
640
|
+
options["num_predict"] = max_completion_tokens
|
|
641
|
+
if temperature is not None:
|
|
642
|
+
options["temperature"] = temperature
|
|
643
|
+
payload["options"] = options
|
|
644
|
+
|
|
645
|
+
last_exception = None
|
|
646
|
+
|
|
647
|
+
async with httpx.AsyncClient(timeout=300.0) as client:
|
|
648
|
+
for attempt in range(max_retries + 1):
|
|
649
|
+
try:
|
|
650
|
+
response = await client.post(native_url, json=payload)
|
|
651
|
+
response.raise_for_status()
|
|
652
|
+
|
|
653
|
+
result = response.json()
|
|
654
|
+
content = result.get("message", {}).get("content", "")
|
|
655
|
+
|
|
656
|
+
# Parse JSON response
|
|
657
|
+
try:
|
|
658
|
+
json_data = json.loads(content)
|
|
659
|
+
except json.JSONDecodeError as json_err:
|
|
660
|
+
content_preview = content[:500] if content else "<empty>"
|
|
661
|
+
if content and len(content) > 700:
|
|
662
|
+
content_preview = f"{content[:500]}...TRUNCATED...{content[-200:]}"
|
|
663
|
+
logger.warning(
|
|
664
|
+
f"Ollama JSON parse error (attempt {attempt + 1}/{max_retries + 1}): {json_err}\n"
|
|
665
|
+
f" Model: ollama/{self.model}\n"
|
|
666
|
+
f" Content length: {len(content) if content else 0} chars\n"
|
|
667
|
+
f" Content preview: {content_preview!r}"
|
|
668
|
+
)
|
|
669
|
+
if attempt < max_retries:
|
|
670
|
+
backoff = min(initial_backoff * (2**attempt), max_backoff)
|
|
671
|
+
await asyncio.sleep(backoff)
|
|
672
|
+
last_exception = json_err
|
|
673
|
+
continue
|
|
674
|
+
else:
|
|
675
|
+
raise
|
|
676
|
+
|
|
677
|
+
# Extract token usage from Ollama response
|
|
678
|
+
duration = time.time() - start_time
|
|
679
|
+
input_tokens = result.get("prompt_eval_count", 0) or 0
|
|
680
|
+
output_tokens = result.get("eval_count", 0) or 0
|
|
681
|
+
total_tokens = input_tokens + output_tokens
|
|
682
|
+
|
|
683
|
+
# Record LLM metrics
|
|
684
|
+
metrics = get_metrics_collector()
|
|
685
|
+
metrics.record_llm_call(
|
|
686
|
+
provider=self.provider,
|
|
687
|
+
model=self.model,
|
|
688
|
+
scope=scope,
|
|
689
|
+
duration=duration,
|
|
690
|
+
input_tokens=input_tokens,
|
|
691
|
+
output_tokens=output_tokens,
|
|
692
|
+
success=True,
|
|
693
|
+
)
|
|
694
|
+
|
|
695
|
+
# Validate against Pydantic model or return raw JSON
|
|
696
|
+
if skip_validation:
|
|
697
|
+
validated_result = json_data
|
|
698
|
+
else:
|
|
699
|
+
validated_result = response_format.model_validate(json_data)
|
|
700
|
+
|
|
701
|
+
if return_usage:
|
|
702
|
+
token_usage = TokenUsage(
|
|
703
|
+
input_tokens=input_tokens,
|
|
704
|
+
output_tokens=output_tokens,
|
|
705
|
+
total_tokens=total_tokens,
|
|
706
|
+
)
|
|
707
|
+
return validated_result, token_usage
|
|
708
|
+
return validated_result
|
|
709
|
+
|
|
710
|
+
except httpx.HTTPStatusError as e:
|
|
711
|
+
last_exception = e
|
|
712
|
+
if attempt < max_retries:
|
|
713
|
+
logger.warning(
|
|
714
|
+
f"Ollama HTTP error (attempt {attempt + 1}/{max_retries + 1}): {e.response.status_code}"
|
|
715
|
+
)
|
|
716
|
+
backoff = min(initial_backoff * (2**attempt), max_backoff)
|
|
717
|
+
await asyncio.sleep(backoff)
|
|
718
|
+
continue
|
|
719
|
+
else:
|
|
720
|
+
logger.error(f"Ollama HTTP error after {max_retries + 1} attempts: {e}")
|
|
721
|
+
raise
|
|
722
|
+
|
|
723
|
+
except httpx.RequestError as e:
|
|
724
|
+
last_exception = e
|
|
725
|
+
if attempt < max_retries:
|
|
726
|
+
logger.warning(f"Ollama connection error (attempt {attempt + 1}/{max_retries + 1}): {e}")
|
|
727
|
+
backoff = min(initial_backoff * (2**attempt), max_backoff)
|
|
728
|
+
await asyncio.sleep(backoff)
|
|
729
|
+
continue
|
|
730
|
+
else:
|
|
731
|
+
logger.error(f"Ollama connection error after {max_retries + 1} attempts: {e}")
|
|
732
|
+
raise
|
|
733
|
+
|
|
734
|
+
except Exception as e:
|
|
735
|
+
logger.error(f"Unexpected error during Ollama call: {type(e).__name__}: {e}")
|
|
736
|
+
raise
|
|
737
|
+
|
|
738
|
+
if last_exception:
|
|
739
|
+
raise last_exception
|
|
740
|
+
raise RuntimeError("Ollama call failed after all retries")
|
|
741
|
+
|
|
742
|
+
async def cleanup(self) -> None:
|
|
743
|
+
"""Clean up resources (close OpenAI client connections)."""
|
|
744
|
+
if hasattr(self, "_client") and self._client:
|
|
745
|
+
await self._client.close()
|