hindsight-api 0.2.1__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hindsight_api/admin/__init__.py +1 -0
- hindsight_api/admin/cli.py +311 -0
- hindsight_api/alembic/versions/f1a2b3c4d5e6_add_memory_links_composite_index.py +44 -0
- hindsight_api/alembic/versions/g2a3b4c5d6e7_add_tags_column.py +48 -0
- hindsight_api/alembic/versions/h3c4d5e6f7g8_mental_models_v4.py +112 -0
- hindsight_api/alembic/versions/i4d5e6f7g8h9_delete_opinions.py +41 -0
- hindsight_api/alembic/versions/j5e6f7g8h9i0_mental_model_versions.py +95 -0
- hindsight_api/alembic/versions/k6f7g8h9i0j1_add_directive_subtype.py +58 -0
- hindsight_api/alembic/versions/l7g8h9i0j1k2_add_worker_columns.py +109 -0
- hindsight_api/alembic/versions/m8h9i0j1k2l3_mental_model_id_to_text.py +41 -0
- hindsight_api/alembic/versions/n9i0j1k2l3m4_learnings_and_pinned_reflections.py +134 -0
- hindsight_api/alembic/versions/o0j1k2l3m4n5_migrate_mental_models_data.py +113 -0
- hindsight_api/alembic/versions/p1k2l3m4n5o6_new_knowledge_architecture.py +194 -0
- hindsight_api/alembic/versions/q2l3m4n5o6p7_fix_mental_model_fact_type.py +50 -0
- hindsight_api/alembic/versions/r3m4n5o6p7q8_add_reflect_response_to_reflections.py +47 -0
- hindsight_api/alembic/versions/s4n5o6p7q8r9_add_consolidated_at_to_memory_units.py +53 -0
- hindsight_api/alembic/versions/t5o6p7q8r9s0_rename_mental_models_to_observations.py +134 -0
- hindsight_api/alembic/versions/u6p7q8r9s0t1_mental_models_text_id.py +41 -0
- hindsight_api/alembic/versions/v7q8r9s0t1u2_add_max_tokens_to_mental_models.py +50 -0
- hindsight_api/api/http.py +1406 -118
- hindsight_api/api/mcp.py +11 -196
- hindsight_api/config.py +359 -27
- hindsight_api/engine/consolidation/__init__.py +5 -0
- hindsight_api/engine/consolidation/consolidator.py +859 -0
- hindsight_api/engine/consolidation/prompts.py +69 -0
- hindsight_api/engine/cross_encoder.py +706 -88
- hindsight_api/engine/db_budget.py +284 -0
- hindsight_api/engine/db_utils.py +11 -0
- hindsight_api/engine/directives/__init__.py +5 -0
- hindsight_api/engine/directives/models.py +37 -0
- hindsight_api/engine/embeddings.py +553 -29
- hindsight_api/engine/entity_resolver.py +8 -5
- hindsight_api/engine/interface.py +40 -17
- hindsight_api/engine/llm_wrapper.py +744 -68
- hindsight_api/engine/memory_engine.py +2505 -1017
- hindsight_api/engine/mental_models/__init__.py +14 -0
- hindsight_api/engine/mental_models/models.py +53 -0
- hindsight_api/engine/query_analyzer.py +4 -3
- hindsight_api/engine/reflect/__init__.py +18 -0
- hindsight_api/engine/reflect/agent.py +933 -0
- hindsight_api/engine/reflect/models.py +109 -0
- hindsight_api/engine/reflect/observations.py +186 -0
- hindsight_api/engine/reflect/prompts.py +483 -0
- hindsight_api/engine/reflect/tools.py +437 -0
- hindsight_api/engine/reflect/tools_schema.py +250 -0
- hindsight_api/engine/response_models.py +168 -4
- hindsight_api/engine/retain/bank_utils.py +79 -201
- hindsight_api/engine/retain/fact_extraction.py +424 -195
- hindsight_api/engine/retain/fact_storage.py +35 -12
- hindsight_api/engine/retain/link_utils.py +29 -24
- hindsight_api/engine/retain/orchestrator.py +24 -43
- hindsight_api/engine/retain/types.py +11 -2
- hindsight_api/engine/search/graph_retrieval.py +43 -14
- hindsight_api/engine/search/link_expansion_retrieval.py +391 -0
- hindsight_api/engine/search/mpfp_retrieval.py +362 -117
- hindsight_api/engine/search/reranking.py +2 -2
- hindsight_api/engine/search/retrieval.py +848 -201
- hindsight_api/engine/search/tags.py +172 -0
- hindsight_api/engine/search/think_utils.py +42 -141
- hindsight_api/engine/search/trace.py +12 -1
- hindsight_api/engine/search/tracer.py +26 -6
- hindsight_api/engine/search/types.py +21 -3
- hindsight_api/engine/task_backend.py +113 -106
- hindsight_api/engine/utils.py +1 -152
- hindsight_api/extensions/__init__.py +10 -1
- hindsight_api/extensions/builtin/tenant.py +5 -1
- hindsight_api/extensions/context.py +10 -1
- hindsight_api/extensions/operation_validator.py +81 -4
- hindsight_api/extensions/tenant.py +26 -0
- hindsight_api/main.py +69 -6
- hindsight_api/mcp_local.py +12 -53
- hindsight_api/mcp_tools.py +494 -0
- hindsight_api/metrics.py +433 -48
- hindsight_api/migrations.py +141 -1
- hindsight_api/models.py +3 -3
- hindsight_api/pg0.py +53 -0
- hindsight_api/server.py +39 -2
- hindsight_api/worker/__init__.py +11 -0
- hindsight_api/worker/main.py +296 -0
- hindsight_api/worker/poller.py +486 -0
- {hindsight_api-0.2.1.dist-info → hindsight_api-0.4.0.dist-info}/METADATA +16 -6
- hindsight_api-0.4.0.dist-info/RECORD +112 -0
- {hindsight_api-0.2.1.dist-info → hindsight_api-0.4.0.dist-info}/entry_points.txt +2 -0
- hindsight_api/engine/retain/observation_regeneration.py +0 -254
- hindsight_api/engine/search/observation_utils.py +0 -125
- hindsight_api/engine/search/scoring.py +0 -159
- hindsight_api-0.2.1.dist-info/RECORD +0 -75
- {hindsight_api-0.2.1.dist-info → hindsight_api-0.4.0.dist-info}/WHEEL +0 -0
|
@@ -19,9 +19,12 @@ from openai import APIConnectionError, APIStatusError, AsyncOpenAI, LengthFinish
|
|
|
19
19
|
from ..config import (
|
|
20
20
|
DEFAULT_LLM_MAX_CONCURRENT,
|
|
21
21
|
DEFAULT_LLM_TIMEOUT,
|
|
22
|
+
ENV_LLM_GROQ_SERVICE_TIER,
|
|
22
23
|
ENV_LLM_MAX_CONCURRENT,
|
|
23
24
|
ENV_LLM_TIMEOUT,
|
|
24
25
|
)
|
|
26
|
+
from ..metrics import get_metrics_collector
|
|
27
|
+
from .response_models import TokenUsage
|
|
25
28
|
|
|
26
29
|
# Seed applied to every Groq request for deterministic behavior.
|
|
27
30
|
DEFAULT_LLM_SEED = 4242
|
|
@@ -63,6 +66,7 @@ class LLMProvider:
|
|
|
63
66
|
base_url: str,
|
|
64
67
|
model: str,
|
|
65
68
|
reasoning_effort: str = "low",
|
|
69
|
+
groq_service_tier: str | None = None,
|
|
66
70
|
):
|
|
67
71
|
"""
|
|
68
72
|
Initialize LLM provider.
|
|
@@ -73,18 +77,25 @@ class LLMProvider:
|
|
|
73
77
|
base_url: Base URL for the API.
|
|
74
78
|
model: Model name.
|
|
75
79
|
reasoning_effort: Reasoning effort level for supported providers.
|
|
80
|
+
groq_service_tier: Groq service tier ("on_demand", "flex", "auto"). Default: None (uses Groq's default).
|
|
76
81
|
"""
|
|
77
82
|
self.provider = provider.lower()
|
|
78
83
|
self.api_key = api_key
|
|
79
84
|
self.base_url = base_url
|
|
80
85
|
self.model = model
|
|
81
86
|
self.reasoning_effort = reasoning_effort
|
|
87
|
+
# Default to 'auto' for best performance, users can override to 'on_demand' for free tier
|
|
88
|
+
self.groq_service_tier = groq_service_tier or os.getenv(ENV_LLM_GROQ_SERVICE_TIER, "auto")
|
|
82
89
|
|
|
83
90
|
# Validate provider
|
|
84
|
-
valid_providers = ["openai", "groq", "ollama", "gemini", "anthropic", "lmstudio"]
|
|
91
|
+
valid_providers = ["openai", "groq", "ollama", "gemini", "anthropic", "lmstudio", "mock"]
|
|
85
92
|
if self.provider not in valid_providers:
|
|
86
93
|
raise ValueError(f"Invalid LLM provider: {self.provider}. Must be one of: {', '.join(valid_providers)}")
|
|
87
94
|
|
|
95
|
+
# Mock provider tracking (for testing)
|
|
96
|
+
self._mock_calls: list[dict] = []
|
|
97
|
+
self._mock_response: Any = None
|
|
98
|
+
|
|
88
99
|
# Set default base URLs
|
|
89
100
|
if not self.base_url:
|
|
90
101
|
if self.provider == "groq":
|
|
@@ -94,8 +105,8 @@ class LLMProvider:
|
|
|
94
105
|
elif self.provider == "lmstudio":
|
|
95
106
|
self.base_url = "http://localhost:1234/v1"
|
|
96
107
|
|
|
97
|
-
# Validate API key (not needed for ollama or
|
|
98
|
-
if self.provider not in ("ollama", "lmstudio") and not self.api_key:
|
|
108
|
+
# Validate API key (not needed for ollama, lmstudio, or mock)
|
|
109
|
+
if self.provider not in ("ollama", "lmstudio", "mock") and not self.api_key:
|
|
99
110
|
raise ValueError(f"API key not found for {self.provider}")
|
|
100
111
|
|
|
101
112
|
# Get timeout config (set HINDSIGHT_API_LLM_TIMEOUT for local LLMs that need longer timeouts)
|
|
@@ -106,7 +117,10 @@ class LLMProvider:
|
|
|
106
117
|
self._gemini_client = None
|
|
107
118
|
self._anthropic_client = None
|
|
108
119
|
|
|
109
|
-
if self.provider == "
|
|
120
|
+
if self.provider == "mock":
|
|
121
|
+
# Mock provider - no client needed
|
|
122
|
+
pass
|
|
123
|
+
elif self.provider == "gemini":
|
|
110
124
|
self._gemini_client = genai.Client(api_key=self.api_key)
|
|
111
125
|
elif self.provider == "anthropic":
|
|
112
126
|
from anthropic import AsyncAnthropic
|
|
@@ -169,6 +183,7 @@ class LLMProvider:
|
|
|
169
183
|
max_backoff: float = 60.0,
|
|
170
184
|
skip_validation: bool = False,
|
|
171
185
|
strict_schema: bool = False,
|
|
186
|
+
return_usage: bool = False,
|
|
172
187
|
) -> Any:
|
|
173
188
|
"""
|
|
174
189
|
Make an LLM API call with retry logic.
|
|
@@ -184,21 +199,43 @@ class LLMProvider:
|
|
|
184
199
|
max_backoff: Maximum backoff time in seconds.
|
|
185
200
|
skip_validation: Return raw JSON without Pydantic validation.
|
|
186
201
|
strict_schema: Use strict JSON schema enforcement (OpenAI only). Guarantees all required fields.
|
|
202
|
+
return_usage: If True, return tuple (result, TokenUsage) instead of just result.
|
|
187
203
|
|
|
188
204
|
Returns:
|
|
189
|
-
Parsed response if response_format is provided, otherwise text content.
|
|
205
|
+
If return_usage=False: Parsed response if response_format is provided, otherwise text content.
|
|
206
|
+
If return_usage=True: Tuple of (result, TokenUsage) with token counts from the LLM call.
|
|
190
207
|
|
|
191
208
|
Raises:
|
|
192
209
|
OutputTooLongError: If output exceeds token limits.
|
|
193
210
|
Exception: Re-raises API errors after retries exhausted.
|
|
194
211
|
"""
|
|
212
|
+
semaphore_start = time.time()
|
|
195
213
|
async with _global_llm_semaphore:
|
|
214
|
+
semaphore_wait_time = time.time() - semaphore_start
|
|
196
215
|
start_time = time.time()
|
|
197
216
|
|
|
217
|
+
# Handle Mock provider (for testing)
|
|
218
|
+
if self.provider == "mock":
|
|
219
|
+
return await self._call_mock(
|
|
220
|
+
messages,
|
|
221
|
+
response_format,
|
|
222
|
+
scope,
|
|
223
|
+
return_usage,
|
|
224
|
+
)
|
|
225
|
+
|
|
198
226
|
# Handle Gemini provider separately
|
|
199
227
|
if self.provider == "gemini":
|
|
200
228
|
return await self._call_gemini(
|
|
201
|
-
messages,
|
|
229
|
+
messages,
|
|
230
|
+
response_format,
|
|
231
|
+
max_retries,
|
|
232
|
+
initial_backoff,
|
|
233
|
+
max_backoff,
|
|
234
|
+
skip_validation,
|
|
235
|
+
start_time,
|
|
236
|
+
scope,
|
|
237
|
+
return_usage,
|
|
238
|
+
semaphore_wait_time,
|
|
202
239
|
)
|
|
203
240
|
|
|
204
241
|
# Handle Anthropic provider separately
|
|
@@ -212,6 +249,9 @@ class LLMProvider:
|
|
|
212
249
|
max_backoff,
|
|
213
250
|
skip_validation,
|
|
214
251
|
start_time,
|
|
252
|
+
scope,
|
|
253
|
+
return_usage,
|
|
254
|
+
semaphore_wait_time,
|
|
215
255
|
)
|
|
216
256
|
|
|
217
257
|
# Handle Ollama with native API for structured output (better schema enforcement)
|
|
@@ -226,6 +266,9 @@ class LLMProvider:
|
|
|
226
266
|
max_backoff,
|
|
227
267
|
skip_validation,
|
|
228
268
|
start_time,
|
|
269
|
+
scope,
|
|
270
|
+
return_usage,
|
|
271
|
+
semaphore_wait_time,
|
|
229
272
|
)
|
|
230
273
|
|
|
231
274
|
call_params = {
|
|
@@ -263,51 +306,56 @@ class LLMProvider:
|
|
|
263
306
|
# Provider-specific parameters
|
|
264
307
|
if self.provider == "groq":
|
|
265
308
|
call_params["seed"] = DEFAULT_LLM_SEED
|
|
266
|
-
extra_body = {
|
|
267
|
-
#
|
|
309
|
+
extra_body: dict[str, Any] = {}
|
|
310
|
+
# Add service_tier if configured (requires paid plan for flex/auto)
|
|
311
|
+
if self.groq_service_tier:
|
|
312
|
+
extra_body["service_tier"] = self.groq_service_tier
|
|
313
|
+
# Add reasoning parameters for reasoning models
|
|
268
314
|
if is_reasoning_model:
|
|
269
315
|
extra_body["include_reasoning"] = False
|
|
270
|
-
|
|
316
|
+
if extra_body:
|
|
317
|
+
call_params["extra_body"] = extra_body
|
|
271
318
|
|
|
272
319
|
last_exception = None
|
|
273
320
|
|
|
321
|
+
# Prepare response format ONCE before the retry loop
|
|
322
|
+
# (to avoid appending schema to messages on every retry)
|
|
323
|
+
if response_format is not None:
|
|
324
|
+
schema = None
|
|
325
|
+
if hasattr(response_format, "model_json_schema"):
|
|
326
|
+
schema = response_format.model_json_schema()
|
|
327
|
+
|
|
328
|
+
if strict_schema and schema is not None:
|
|
329
|
+
# Use OpenAI's strict JSON schema enforcement
|
|
330
|
+
# This guarantees all required fields are returned
|
|
331
|
+
call_params["response_format"] = {
|
|
332
|
+
"type": "json_schema",
|
|
333
|
+
"json_schema": {
|
|
334
|
+
"name": "response",
|
|
335
|
+
"strict": True,
|
|
336
|
+
"schema": schema,
|
|
337
|
+
},
|
|
338
|
+
}
|
|
339
|
+
else:
|
|
340
|
+
# Soft enforcement: add schema to prompt and use json_object mode
|
|
341
|
+
if schema is not None:
|
|
342
|
+
schema_msg = f"\n\nYou must respond with valid JSON matching this schema:\n{json.dumps(schema, indent=2)}"
|
|
343
|
+
|
|
344
|
+
if call_params["messages"] and call_params["messages"][0].get("role") == "system":
|
|
345
|
+
call_params["messages"][0]["content"] += schema_msg
|
|
346
|
+
elif call_params["messages"]:
|
|
347
|
+
call_params["messages"][0]["content"] = (
|
|
348
|
+
schema_msg + "\n\n" + call_params["messages"][0]["content"]
|
|
349
|
+
)
|
|
350
|
+
if self.provider not in ("lmstudio", "ollama"):
|
|
351
|
+
# LM Studio and Ollama don't support json_object response format reliably
|
|
352
|
+
# We rely on the schema in the system message instead
|
|
353
|
+
call_params["response_format"] = {"type": "json_object"}
|
|
354
|
+
|
|
274
355
|
for attempt in range(max_retries + 1):
|
|
275
356
|
try:
|
|
276
357
|
if response_format is not None:
|
|
277
|
-
schema = None
|
|
278
|
-
if hasattr(response_format, "model_json_schema"):
|
|
279
|
-
schema = response_format.model_json_schema()
|
|
280
|
-
|
|
281
|
-
if strict_schema and schema is not None:
|
|
282
|
-
# Use OpenAI's strict JSON schema enforcement
|
|
283
|
-
# This guarantees all required fields are returned
|
|
284
|
-
call_params["response_format"] = {
|
|
285
|
-
"type": "json_schema",
|
|
286
|
-
"json_schema": {
|
|
287
|
-
"name": "response",
|
|
288
|
-
"strict": True,
|
|
289
|
-
"schema": schema,
|
|
290
|
-
},
|
|
291
|
-
}
|
|
292
|
-
else:
|
|
293
|
-
# Soft enforcement: add schema to prompt and use json_object mode
|
|
294
|
-
if schema is not None:
|
|
295
|
-
schema_msg = f"\n\nYou must respond with valid JSON matching this schema:\n{json.dumps(schema, indent=2)}"
|
|
296
|
-
|
|
297
|
-
if call_params["messages"] and call_params["messages"][0].get("role") == "system":
|
|
298
|
-
call_params["messages"][0]["content"] += schema_msg
|
|
299
|
-
elif call_params["messages"]:
|
|
300
|
-
call_params["messages"][0]["content"] = (
|
|
301
|
-
schema_msg + "\n\n" + call_params["messages"][0]["content"]
|
|
302
|
-
)
|
|
303
|
-
if self.provider not in ("lmstudio", "ollama"):
|
|
304
|
-
# LM Studio and Ollama don't support json_object response format reliably
|
|
305
|
-
# We rely on the schema in the system message instead
|
|
306
|
-
call_params["response_format"] = {"type": "json_object"}
|
|
307
|
-
|
|
308
|
-
logger.debug(f"Sending request to {self.provider}/{self.model} (timeout={self.timeout})")
|
|
309
358
|
response = await self._client.chat.completions.create(**call_params)
|
|
310
|
-
logger.debug(f"Received response from {self.provider}/{self.model}")
|
|
311
359
|
|
|
312
360
|
content = response.choices[0].message.content
|
|
313
361
|
|
|
@@ -370,21 +418,46 @@ class LLMProvider:
|
|
|
370
418
|
response = await self._client.chat.completions.create(**call_params)
|
|
371
419
|
result = response.choices[0].message.content
|
|
372
420
|
|
|
373
|
-
#
|
|
421
|
+
# Record token usage metrics
|
|
374
422
|
duration = time.time() - start_time
|
|
375
423
|
usage = response.usage
|
|
376
|
-
if
|
|
377
|
-
|
|
424
|
+
input_tokens = usage.prompt_tokens or 0 if usage else 0
|
|
425
|
+
output_tokens = usage.completion_tokens or 0 if usage else 0
|
|
426
|
+
total_tokens = usage.total_tokens or 0 if usage else 0
|
|
427
|
+
|
|
428
|
+
# Record LLM metrics
|
|
429
|
+
metrics = get_metrics_collector()
|
|
430
|
+
metrics.record_llm_call(
|
|
431
|
+
provider=self.provider,
|
|
432
|
+
model=self.model,
|
|
433
|
+
scope=scope,
|
|
434
|
+
duration=duration,
|
|
435
|
+
input_tokens=input_tokens,
|
|
436
|
+
output_tokens=output_tokens,
|
|
437
|
+
success=True,
|
|
438
|
+
)
|
|
439
|
+
|
|
440
|
+
# Log slow calls
|
|
441
|
+
if duration > 10.0 and usage:
|
|
442
|
+
ratio = max(1, output_tokens) / max(1, input_tokens)
|
|
378
443
|
cached_tokens = 0
|
|
379
444
|
if hasattr(usage, "prompt_tokens_details") and usage.prompt_tokens_details:
|
|
380
445
|
cached_tokens = getattr(usage.prompt_tokens_details, "cached_tokens", 0) or 0
|
|
381
446
|
cache_info = f", cached_tokens={cached_tokens}" if cached_tokens > 0 else ""
|
|
447
|
+
wait_info = f", wait={semaphore_wait_time:.3f}s" if semaphore_wait_time > 0.1 else ""
|
|
382
448
|
logger.info(
|
|
383
|
-
f"slow llm call: model={self.provider}/{self.model}, "
|
|
384
|
-
f"input_tokens={
|
|
385
|
-
f"total_tokens={
|
|
449
|
+
f"slow llm call: scope={scope}, model={self.provider}/{self.model}, "
|
|
450
|
+
f"input_tokens={input_tokens}, output_tokens={output_tokens}, "
|
|
451
|
+
f"total_tokens={total_tokens}{cache_info}, time={duration:.3f}s{wait_info}, ratio out/in={ratio:.2f}"
|
|
386
452
|
)
|
|
387
453
|
|
|
454
|
+
if return_usage:
|
|
455
|
+
token_usage = TokenUsage(
|
|
456
|
+
input_tokens=input_tokens,
|
|
457
|
+
output_tokens=output_tokens,
|
|
458
|
+
total_tokens=total_tokens,
|
|
459
|
+
)
|
|
460
|
+
return result, token_usage
|
|
388
461
|
return result
|
|
389
462
|
|
|
390
463
|
except LengthFinishReasonError as e:
|
|
@@ -395,13 +468,11 @@ class LLMProvider:
|
|
|
395
468
|
|
|
396
469
|
except APIConnectionError as e:
|
|
397
470
|
last_exception = e
|
|
471
|
+
status_code = getattr(e, "status_code", None) or getattr(
|
|
472
|
+
getattr(e, "response", None), "status_code", None
|
|
473
|
+
)
|
|
474
|
+
logger.warning(f"APIConnectionError (HTTP {status_code}), attempt {attempt + 1}: {str(e)[:200]}")
|
|
398
475
|
if attempt < max_retries:
|
|
399
|
-
status_code = getattr(e, "status_code", None) or getattr(
|
|
400
|
-
getattr(e, "response", None), "status_code", None
|
|
401
|
-
)
|
|
402
|
-
logger.warning(
|
|
403
|
-
f"Connection error, retrying... (attempt {attempt + 1}/{max_retries + 1}) - status_code={status_code}, message={e}"
|
|
404
|
-
)
|
|
405
476
|
backoff = min(initial_backoff * (2**attempt), max_backoff)
|
|
406
477
|
await asyncio.sleep(backoff)
|
|
407
478
|
continue
|
|
@@ -415,6 +486,45 @@ class LLMProvider:
|
|
|
415
486
|
logger.error(f"Auth error (HTTP {e.status_code}), not retrying: {str(e)}")
|
|
416
487
|
raise
|
|
417
488
|
|
|
489
|
+
# Handle tool_use_failed error - model outputted in tool call format
|
|
490
|
+
# Convert to expected JSON format and continue
|
|
491
|
+
if e.status_code == 400 and response_format is not None:
|
|
492
|
+
try:
|
|
493
|
+
error_body = e.body if hasattr(e, "body") else {}
|
|
494
|
+
if isinstance(error_body, dict):
|
|
495
|
+
error_info: dict[str, Any] = error_body.get("error") or {}
|
|
496
|
+
if error_info.get("code") == "tool_use_failed":
|
|
497
|
+
failed_gen = error_info.get("failed_generation", "")
|
|
498
|
+
if failed_gen:
|
|
499
|
+
# Parse the tool call format and convert to actions format
|
|
500
|
+
tool_call = json.loads(failed_gen)
|
|
501
|
+
tool_name = tool_call.get("name", "")
|
|
502
|
+
tool_args = tool_call.get("arguments", {})
|
|
503
|
+
# Convert to actions format: {"actions": [{"tool": "name", ...args}]}
|
|
504
|
+
converted = {"actions": [{"tool": tool_name, **tool_args}]}
|
|
505
|
+
if skip_validation:
|
|
506
|
+
result = converted
|
|
507
|
+
else:
|
|
508
|
+
result = response_format.model_validate(converted)
|
|
509
|
+
|
|
510
|
+
# Record metrics for this successful recovery
|
|
511
|
+
duration = time.time() - start_time
|
|
512
|
+
metrics = get_metrics_collector()
|
|
513
|
+
metrics.record_llm_call(
|
|
514
|
+
provider=self.provider,
|
|
515
|
+
model=self.model,
|
|
516
|
+
scope=scope,
|
|
517
|
+
duration=duration,
|
|
518
|
+
input_tokens=0,
|
|
519
|
+
output_tokens=0,
|
|
520
|
+
success=True,
|
|
521
|
+
)
|
|
522
|
+
if return_usage:
|
|
523
|
+
return result, TokenUsage(input_tokens=0, output_tokens=0, total_tokens=0)
|
|
524
|
+
return result
|
|
525
|
+
except (json.JSONDecodeError, KeyError, TypeError):
|
|
526
|
+
pass # Failed to parse tool_use_failed, continue with normal retry
|
|
527
|
+
|
|
418
528
|
last_exception = e
|
|
419
529
|
if attempt < max_retries:
|
|
420
530
|
backoff = min(initial_backoff * (2**attempt), max_backoff)
|
|
@@ -425,14 +535,438 @@ class LLMProvider:
|
|
|
425
535
|
logger.error(f"API error after {max_retries + 1} attempts: {str(e)}")
|
|
426
536
|
raise
|
|
427
537
|
|
|
428
|
-
except Exception
|
|
429
|
-
logger.error(f"Unexpected error during LLM call: {type(e).__name__}: {str(e)}")
|
|
538
|
+
except Exception:
|
|
430
539
|
raise
|
|
431
540
|
|
|
432
541
|
if last_exception:
|
|
433
542
|
raise last_exception
|
|
434
543
|
raise RuntimeError("LLM call failed after all retries with no exception captured")
|
|
435
544
|
|
|
545
|
+
async def call_with_tools(
|
|
546
|
+
self,
|
|
547
|
+
messages: list[dict[str, Any]],
|
|
548
|
+
tools: list[dict[str, Any]],
|
|
549
|
+
max_completion_tokens: int | None = None,
|
|
550
|
+
temperature: float | None = None,
|
|
551
|
+
scope: str = "tools",
|
|
552
|
+
max_retries: int = 5,
|
|
553
|
+
initial_backoff: float = 1.0,
|
|
554
|
+
max_backoff: float = 30.0,
|
|
555
|
+
tool_choice: str | dict[str, Any] = "auto",
|
|
556
|
+
) -> "LLMToolCallResult":
|
|
557
|
+
"""
|
|
558
|
+
Make an LLM API call with tool/function calling support.
|
|
559
|
+
|
|
560
|
+
Args:
|
|
561
|
+
messages: List of message dicts. Can include tool results with role='tool'.
|
|
562
|
+
tools: List of tool definitions in OpenAI format.
|
|
563
|
+
max_completion_tokens: Maximum tokens in response.
|
|
564
|
+
temperature: Sampling temperature (0.0-2.0).
|
|
565
|
+
scope: Scope identifier for tracking.
|
|
566
|
+
max_retries: Maximum retry attempts.
|
|
567
|
+
initial_backoff: Initial backoff time in seconds.
|
|
568
|
+
max_backoff: Maximum backoff time in seconds.
|
|
569
|
+
tool_choice: How to choose tools - "auto", "none", "required", or {"type": "function", "function": {"name": "..."}}
|
|
570
|
+
|
|
571
|
+
Returns:
|
|
572
|
+
LLMToolCallResult with content and/or tool_calls.
|
|
573
|
+
"""
|
|
574
|
+
from .response_models import LLMToolCall, LLMToolCallResult
|
|
575
|
+
|
|
576
|
+
async with _global_llm_semaphore:
|
|
577
|
+
start_time = time.time()
|
|
578
|
+
|
|
579
|
+
# Handle Mock provider
|
|
580
|
+
if self.provider == "mock":
|
|
581
|
+
return await self._call_with_tools_mock(messages, tools, scope)
|
|
582
|
+
|
|
583
|
+
# Handle Anthropic separately (uses different tool format)
|
|
584
|
+
if self.provider == "anthropic":
|
|
585
|
+
return await self._call_with_tools_anthropic(
|
|
586
|
+
messages, tools, max_completion_tokens, max_retries, initial_backoff, max_backoff, start_time, scope
|
|
587
|
+
)
|
|
588
|
+
|
|
589
|
+
# Handle Gemini (convert to Gemini tool format)
|
|
590
|
+
if self.provider == "gemini":
|
|
591
|
+
return await self._call_with_tools_gemini(
|
|
592
|
+
messages, tools, max_retries, initial_backoff, max_backoff, start_time, scope
|
|
593
|
+
)
|
|
594
|
+
|
|
595
|
+
# OpenAI-compatible providers (OpenAI, Groq, Ollama, LMStudio)
|
|
596
|
+
call_params: dict[str, Any] = {
|
|
597
|
+
"model": self.model,
|
|
598
|
+
"messages": messages,
|
|
599
|
+
"tools": tools,
|
|
600
|
+
"tool_choice": tool_choice,
|
|
601
|
+
}
|
|
602
|
+
|
|
603
|
+
if max_completion_tokens is not None:
|
|
604
|
+
call_params["max_completion_tokens"] = max_completion_tokens
|
|
605
|
+
if temperature is not None:
|
|
606
|
+
call_params["temperature"] = temperature
|
|
607
|
+
|
|
608
|
+
# Provider-specific parameters
|
|
609
|
+
if self.provider == "groq":
|
|
610
|
+
call_params["seed"] = DEFAULT_LLM_SEED
|
|
611
|
+
|
|
612
|
+
last_exception = None
|
|
613
|
+
|
|
614
|
+
for attempt in range(max_retries + 1):
|
|
615
|
+
try:
|
|
616
|
+
response = await self._client.chat.completions.create(**call_params)
|
|
617
|
+
|
|
618
|
+
message = response.choices[0].message
|
|
619
|
+
finish_reason = response.choices[0].finish_reason
|
|
620
|
+
|
|
621
|
+
# Extract tool calls if present
|
|
622
|
+
tool_calls: list[LLMToolCall] = []
|
|
623
|
+
if message.tool_calls:
|
|
624
|
+
for tc in message.tool_calls:
|
|
625
|
+
try:
|
|
626
|
+
args = json.loads(tc.function.arguments) if tc.function.arguments else {}
|
|
627
|
+
except json.JSONDecodeError:
|
|
628
|
+
args = {"_raw": tc.function.arguments}
|
|
629
|
+
tool_calls.append(LLMToolCall(id=tc.id, name=tc.function.name, arguments=args))
|
|
630
|
+
|
|
631
|
+
content = message.content
|
|
632
|
+
|
|
633
|
+
# Record metrics
|
|
634
|
+
duration = time.time() - start_time
|
|
635
|
+
usage = response.usage
|
|
636
|
+
input_tokens = usage.prompt_tokens or 0 if usage else 0
|
|
637
|
+
output_tokens = usage.completion_tokens or 0 if usage else 0
|
|
638
|
+
|
|
639
|
+
metrics = get_metrics_collector()
|
|
640
|
+
metrics.record_llm_call(
|
|
641
|
+
provider=self.provider,
|
|
642
|
+
model=self.model,
|
|
643
|
+
scope=scope,
|
|
644
|
+
duration=duration,
|
|
645
|
+
input_tokens=input_tokens,
|
|
646
|
+
output_tokens=output_tokens,
|
|
647
|
+
success=True,
|
|
648
|
+
)
|
|
649
|
+
|
|
650
|
+
return LLMToolCallResult(
|
|
651
|
+
content=content,
|
|
652
|
+
tool_calls=tool_calls,
|
|
653
|
+
finish_reason=finish_reason,
|
|
654
|
+
input_tokens=input_tokens,
|
|
655
|
+
output_tokens=output_tokens,
|
|
656
|
+
)
|
|
657
|
+
|
|
658
|
+
except APIConnectionError as e:
|
|
659
|
+
last_exception = e
|
|
660
|
+
if attempt < max_retries:
|
|
661
|
+
await asyncio.sleep(min(initial_backoff * (2**attempt), max_backoff))
|
|
662
|
+
continue
|
|
663
|
+
raise
|
|
664
|
+
|
|
665
|
+
except APIStatusError as e:
|
|
666
|
+
if e.status_code in (401, 403):
|
|
667
|
+
raise
|
|
668
|
+
last_exception = e
|
|
669
|
+
if attempt < max_retries:
|
|
670
|
+
await asyncio.sleep(min(initial_backoff * (2**attempt), max_backoff))
|
|
671
|
+
continue
|
|
672
|
+
raise
|
|
673
|
+
|
|
674
|
+
except Exception:
|
|
675
|
+
raise
|
|
676
|
+
|
|
677
|
+
if last_exception:
|
|
678
|
+
raise last_exception
|
|
679
|
+
raise RuntimeError("Tool call failed after all retries")
|
|
680
|
+
|
|
681
|
+
async def _call_with_tools_mock(
|
|
682
|
+
self,
|
|
683
|
+
messages: list[dict[str, Any]],
|
|
684
|
+
tools: list[dict[str, Any]],
|
|
685
|
+
scope: str,
|
|
686
|
+
) -> "LLMToolCallResult":
|
|
687
|
+
"""Handle mock tool calls for testing."""
|
|
688
|
+
from .response_models import LLMToolCallResult
|
|
689
|
+
|
|
690
|
+
call_record = {
|
|
691
|
+
"provider": self.provider,
|
|
692
|
+
"model": self.model,
|
|
693
|
+
"messages": messages,
|
|
694
|
+
"tools": [t.get("function", {}).get("name") for t in tools],
|
|
695
|
+
"scope": scope,
|
|
696
|
+
}
|
|
697
|
+
self._mock_calls.append(call_record)
|
|
698
|
+
|
|
699
|
+
if self._mock_response is not None:
|
|
700
|
+
if isinstance(self._mock_response, LLMToolCallResult):
|
|
701
|
+
return self._mock_response
|
|
702
|
+
# Allow setting just tool calls as a list
|
|
703
|
+
if isinstance(self._mock_response, list):
|
|
704
|
+
from .response_models import LLMToolCall
|
|
705
|
+
|
|
706
|
+
return LLMToolCallResult(
|
|
707
|
+
tool_calls=[
|
|
708
|
+
LLMToolCall(id=f"mock_{i}", name=tc["name"], arguments=tc.get("arguments", {}))
|
|
709
|
+
for i, tc in enumerate(self._mock_response)
|
|
710
|
+
],
|
|
711
|
+
finish_reason="tool_calls",
|
|
712
|
+
)
|
|
713
|
+
|
|
714
|
+
return LLMToolCallResult(content="mock response", finish_reason="stop")
|
|
715
|
+
|
|
716
|
+
async def _call_with_tools_anthropic(
|
|
717
|
+
self,
|
|
718
|
+
messages: list[dict[str, Any]],
|
|
719
|
+
tools: list[dict[str, Any]],
|
|
720
|
+
max_completion_tokens: int | None,
|
|
721
|
+
max_retries: int,
|
|
722
|
+
initial_backoff: float,
|
|
723
|
+
max_backoff: float,
|
|
724
|
+
start_time: float,
|
|
725
|
+
scope: str,
|
|
726
|
+
) -> "LLMToolCallResult":
|
|
727
|
+
"""Handle Anthropic tool calling."""
|
|
728
|
+
from anthropic import APIConnectionError, APIStatusError
|
|
729
|
+
|
|
730
|
+
from .response_models import LLMToolCall, LLMToolCallResult
|
|
731
|
+
|
|
732
|
+
# Convert OpenAI tool format to Anthropic format
|
|
733
|
+
anthropic_tools = []
|
|
734
|
+
for tool in tools:
|
|
735
|
+
func = tool.get("function", {})
|
|
736
|
+
anthropic_tools.append(
|
|
737
|
+
{
|
|
738
|
+
"name": func.get("name", ""),
|
|
739
|
+
"description": func.get("description", ""),
|
|
740
|
+
"input_schema": func.get("parameters", {"type": "object", "properties": {}}),
|
|
741
|
+
}
|
|
742
|
+
)
|
|
743
|
+
|
|
744
|
+
# Convert messages - handle tool results
|
|
745
|
+
system_prompt = None
|
|
746
|
+
anthropic_messages = []
|
|
747
|
+
for msg in messages:
|
|
748
|
+
role = msg.get("role", "user")
|
|
749
|
+
content = msg.get("content", "")
|
|
750
|
+
|
|
751
|
+
if role == "system":
|
|
752
|
+
system_prompt = (system_prompt + "\n\n" + content) if system_prompt else content
|
|
753
|
+
elif role == "tool":
|
|
754
|
+
# Anthropic uses tool_result blocks
|
|
755
|
+
anthropic_messages.append(
|
|
756
|
+
{
|
|
757
|
+
"role": "user",
|
|
758
|
+
"content": [
|
|
759
|
+
{"type": "tool_result", "tool_use_id": msg.get("tool_call_id", ""), "content": content}
|
|
760
|
+
],
|
|
761
|
+
}
|
|
762
|
+
)
|
|
763
|
+
elif role == "assistant" and msg.get("tool_calls"):
|
|
764
|
+
# Convert assistant tool calls
|
|
765
|
+
tool_use_blocks = []
|
|
766
|
+
for tc in msg["tool_calls"]:
|
|
767
|
+
tool_use_blocks.append(
|
|
768
|
+
{
|
|
769
|
+
"type": "tool_use",
|
|
770
|
+
"id": tc.get("id", ""),
|
|
771
|
+
"name": tc.get("function", {}).get("name", ""),
|
|
772
|
+
"input": json.loads(tc.get("function", {}).get("arguments", "{}")),
|
|
773
|
+
}
|
|
774
|
+
)
|
|
775
|
+
anthropic_messages.append({"role": "assistant", "content": tool_use_blocks})
|
|
776
|
+
else:
|
|
777
|
+
anthropic_messages.append({"role": role, "content": content})
|
|
778
|
+
|
|
779
|
+
call_params: dict[str, Any] = {
|
|
780
|
+
"model": self.model,
|
|
781
|
+
"messages": anthropic_messages,
|
|
782
|
+
"tools": anthropic_tools,
|
|
783
|
+
"max_tokens": max_completion_tokens or 4096,
|
|
784
|
+
}
|
|
785
|
+
if system_prompt:
|
|
786
|
+
call_params["system"] = system_prompt
|
|
787
|
+
|
|
788
|
+
last_exception = None
|
|
789
|
+
for attempt in range(max_retries + 1):
|
|
790
|
+
try:
|
|
791
|
+
response = await self._anthropic_client.messages.create(**call_params)
|
|
792
|
+
|
|
793
|
+
# Extract content and tool calls
|
|
794
|
+
content_parts = []
|
|
795
|
+
tool_calls: list[LLMToolCall] = []
|
|
796
|
+
|
|
797
|
+
for block in response.content:
|
|
798
|
+
if block.type == "text":
|
|
799
|
+
content_parts.append(block.text)
|
|
800
|
+
elif block.type == "tool_use":
|
|
801
|
+
tool_calls.append(LLMToolCall(id=block.id, name=block.name, arguments=block.input or {}))
|
|
802
|
+
|
|
803
|
+
content = "".join(content_parts) if content_parts else None
|
|
804
|
+
finish_reason = "tool_calls" if tool_calls else "stop"
|
|
805
|
+
|
|
806
|
+
# Extract token usage
|
|
807
|
+
input_tokens = response.usage.input_tokens or 0
|
|
808
|
+
output_tokens = response.usage.output_tokens or 0
|
|
809
|
+
|
|
810
|
+
# Record metrics
|
|
811
|
+
metrics = get_metrics_collector()
|
|
812
|
+
metrics.record_llm_call(
|
|
813
|
+
provider=self.provider,
|
|
814
|
+
model=self.model,
|
|
815
|
+
scope=scope,
|
|
816
|
+
duration=time.time() - start_time,
|
|
817
|
+
input_tokens=input_tokens,
|
|
818
|
+
output_tokens=output_tokens,
|
|
819
|
+
success=True,
|
|
820
|
+
)
|
|
821
|
+
|
|
822
|
+
return LLMToolCallResult(
|
|
823
|
+
content=content,
|
|
824
|
+
tool_calls=tool_calls,
|
|
825
|
+
finish_reason=finish_reason,
|
|
826
|
+
input_tokens=input_tokens,
|
|
827
|
+
output_tokens=output_tokens,
|
|
828
|
+
)
|
|
829
|
+
|
|
830
|
+
except (APIConnectionError, APIStatusError) as e:
|
|
831
|
+
if isinstance(e, APIStatusError) and e.status_code in (401, 403):
|
|
832
|
+
raise
|
|
833
|
+
last_exception = e
|
|
834
|
+
if attempt < max_retries:
|
|
835
|
+
await asyncio.sleep(min(initial_backoff * (2**attempt), max_backoff))
|
|
836
|
+
continue
|
|
837
|
+
raise
|
|
838
|
+
|
|
839
|
+
if last_exception:
|
|
840
|
+
raise last_exception
|
|
841
|
+
raise RuntimeError("Anthropic tool call failed")
|
|
842
|
+
|
|
843
|
+
async def _call_with_tools_gemini(
|
|
844
|
+
self,
|
|
845
|
+
messages: list[dict[str, Any]],
|
|
846
|
+
tools: list[dict[str, Any]],
|
|
847
|
+
max_retries: int,
|
|
848
|
+
initial_backoff: float,
|
|
849
|
+
max_backoff: float,
|
|
850
|
+
start_time: float,
|
|
851
|
+
scope: str,
|
|
852
|
+
) -> "LLMToolCallResult":
|
|
853
|
+
"""Handle Gemini tool calling."""
|
|
854
|
+
from .response_models import LLMToolCall, LLMToolCallResult
|
|
855
|
+
|
|
856
|
+
# Convert tools to Gemini format
|
|
857
|
+
gemini_tools = []
|
|
858
|
+
for tool in tools:
|
|
859
|
+
func = tool.get("function", {})
|
|
860
|
+
gemini_tools.append(
|
|
861
|
+
genai_types.Tool(
|
|
862
|
+
function_declarations=[
|
|
863
|
+
genai_types.FunctionDeclaration(
|
|
864
|
+
name=func.get("name", ""),
|
|
865
|
+
description=func.get("description", ""),
|
|
866
|
+
parameters=func.get("parameters"),
|
|
867
|
+
)
|
|
868
|
+
]
|
|
869
|
+
)
|
|
870
|
+
)
|
|
871
|
+
|
|
872
|
+
# Convert messages
|
|
873
|
+
system_instruction = None
|
|
874
|
+
gemini_contents = []
|
|
875
|
+
for msg in messages:
|
|
876
|
+
role = msg.get("role", "user")
|
|
877
|
+
content = msg.get("content", "")
|
|
878
|
+
|
|
879
|
+
if role == "system":
|
|
880
|
+
system_instruction = (system_instruction + "\n\n" + content) if system_instruction else content
|
|
881
|
+
elif role == "tool":
|
|
882
|
+
# Gemini uses function_response
|
|
883
|
+
gemini_contents.append(
|
|
884
|
+
genai_types.Content(
|
|
885
|
+
role="user",
|
|
886
|
+
parts=[
|
|
887
|
+
genai_types.Part(
|
|
888
|
+
function_response=genai_types.FunctionResponse(
|
|
889
|
+
name=msg.get("name", ""),
|
|
890
|
+
response={"result": content},
|
|
891
|
+
)
|
|
892
|
+
)
|
|
893
|
+
],
|
|
894
|
+
)
|
|
895
|
+
)
|
|
896
|
+
elif role == "assistant":
|
|
897
|
+
gemini_contents.append(genai_types.Content(role="model", parts=[genai_types.Part(text=content)]))
|
|
898
|
+
else:
|
|
899
|
+
gemini_contents.append(genai_types.Content(role="user", parts=[genai_types.Part(text=content)]))
|
|
900
|
+
|
|
901
|
+
config = genai_types.GenerateContentConfig(
|
|
902
|
+
system_instruction=system_instruction,
|
|
903
|
+
tools=gemini_tools,
|
|
904
|
+
)
|
|
905
|
+
|
|
906
|
+
last_exception = None
|
|
907
|
+
for attempt in range(max_retries + 1):
|
|
908
|
+
try:
|
|
909
|
+
response = await self._gemini_client.aio.models.generate_content(
|
|
910
|
+
model=self.model,
|
|
911
|
+
contents=gemini_contents,
|
|
912
|
+
config=config,
|
|
913
|
+
)
|
|
914
|
+
|
|
915
|
+
# Extract content and tool calls
|
|
916
|
+
content = None
|
|
917
|
+
tool_calls: list[LLMToolCall] = []
|
|
918
|
+
|
|
919
|
+
if response.candidates and response.candidates[0].content:
|
|
920
|
+
for part in response.candidates[0].content.parts:
|
|
921
|
+
if hasattr(part, "text") and part.text:
|
|
922
|
+
content = part.text
|
|
923
|
+
if hasattr(part, "function_call") and part.function_call:
|
|
924
|
+
fc = part.function_call
|
|
925
|
+
tool_calls.append(
|
|
926
|
+
LLMToolCall(
|
|
927
|
+
id=f"gemini_{len(tool_calls)}",
|
|
928
|
+
name=fc.name,
|
|
929
|
+
arguments=dict(fc.args) if fc.args else {},
|
|
930
|
+
)
|
|
931
|
+
)
|
|
932
|
+
|
|
933
|
+
finish_reason = "tool_calls" if tool_calls else "stop"
|
|
934
|
+
|
|
935
|
+
# Record metrics
|
|
936
|
+
metrics = get_metrics_collector()
|
|
937
|
+
input_tokens = response.usage_metadata.prompt_token_count if response.usage_metadata else 0
|
|
938
|
+
output_tokens = response.usage_metadata.candidates_token_count if response.usage_metadata else 0
|
|
939
|
+
metrics.record_llm_call(
|
|
940
|
+
provider=self.provider,
|
|
941
|
+
model=self.model,
|
|
942
|
+
scope=scope,
|
|
943
|
+
duration=time.time() - start_time,
|
|
944
|
+
input_tokens=input_tokens,
|
|
945
|
+
output_tokens=output_tokens,
|
|
946
|
+
success=True,
|
|
947
|
+
)
|
|
948
|
+
|
|
949
|
+
return LLMToolCallResult(
|
|
950
|
+
content=content,
|
|
951
|
+
tool_calls=tool_calls,
|
|
952
|
+
finish_reason=finish_reason,
|
|
953
|
+
input_tokens=input_tokens,
|
|
954
|
+
output_tokens=output_tokens,
|
|
955
|
+
)
|
|
956
|
+
|
|
957
|
+
except genai_errors.APIError as e:
|
|
958
|
+
if e.code in (401, 403):
|
|
959
|
+
raise
|
|
960
|
+
last_exception = e
|
|
961
|
+
if attempt < max_retries:
|
|
962
|
+
await asyncio.sleep(min(initial_backoff * (2**attempt), max_backoff))
|
|
963
|
+
continue
|
|
964
|
+
raise
|
|
965
|
+
|
|
966
|
+
if last_exception:
|
|
967
|
+
raise last_exception
|
|
968
|
+
raise RuntimeError("Gemini tool call failed")
|
|
969
|
+
|
|
436
970
|
async def _call_anthropic(
|
|
437
971
|
self,
|
|
438
972
|
messages: list[dict[str, str]],
|
|
@@ -443,6 +977,9 @@ class LLMProvider:
|
|
|
443
977
|
max_backoff: float,
|
|
444
978
|
skip_validation: bool,
|
|
445
979
|
start_time: float,
|
|
980
|
+
scope: str = "memory",
|
|
981
|
+
return_usage: bool = False,
|
|
982
|
+
semaphore_wait_time: float = 0.0,
|
|
446
983
|
) -> Any:
|
|
447
984
|
"""Handle Anthropic-specific API calls."""
|
|
448
985
|
from anthropic import APIConnectionError, APIStatusError, RateLimitError
|
|
@@ -515,17 +1052,40 @@ class LLMProvider:
|
|
|
515
1052
|
else:
|
|
516
1053
|
result = content
|
|
517
1054
|
|
|
518
|
-
#
|
|
1055
|
+
# Record metrics and log slow calls
|
|
519
1056
|
duration = time.time() - start_time
|
|
1057
|
+
input_tokens = response.usage.input_tokens or 0 if response.usage else 0
|
|
1058
|
+
output_tokens = response.usage.output_tokens or 0 if response.usage else 0
|
|
1059
|
+
total_tokens = input_tokens + output_tokens
|
|
1060
|
+
|
|
1061
|
+
# Record LLM metrics
|
|
1062
|
+
metrics = get_metrics_collector()
|
|
1063
|
+
metrics.record_llm_call(
|
|
1064
|
+
provider=self.provider,
|
|
1065
|
+
model=self.model,
|
|
1066
|
+
scope=scope,
|
|
1067
|
+
duration=duration,
|
|
1068
|
+
input_tokens=input_tokens,
|
|
1069
|
+
output_tokens=output_tokens,
|
|
1070
|
+
success=True,
|
|
1071
|
+
)
|
|
1072
|
+
|
|
1073
|
+
# Log slow calls
|
|
520
1074
|
if duration > 10.0:
|
|
521
|
-
|
|
522
|
-
output_tokens = response.usage.output_tokens
|
|
1075
|
+
wait_info = f", wait={semaphore_wait_time:.3f}s" if semaphore_wait_time > 0.1 else ""
|
|
523
1076
|
logger.info(
|
|
524
|
-
f"slow llm call: model={self.provider}/{self.model}, "
|
|
1077
|
+
f"slow llm call: scope={scope}, model={self.provider}/{self.model}, "
|
|
525
1078
|
f"input_tokens={input_tokens}, output_tokens={output_tokens}, "
|
|
526
|
-
f"time={duration:.3f}s"
|
|
1079
|
+
f"time={duration:.3f}s{wait_info}"
|
|
527
1080
|
)
|
|
528
1081
|
|
|
1082
|
+
if return_usage:
|
|
1083
|
+
token_usage = TokenUsage(
|
|
1084
|
+
input_tokens=input_tokens,
|
|
1085
|
+
output_tokens=output_tokens,
|
|
1086
|
+
total_tokens=total_tokens,
|
|
1087
|
+
)
|
|
1088
|
+
return result, token_usage
|
|
529
1089
|
return result
|
|
530
1090
|
|
|
531
1091
|
except json.JSONDecodeError as e:
|
|
@@ -580,6 +1140,9 @@ class LLMProvider:
|
|
|
580
1140
|
max_backoff: float,
|
|
581
1141
|
skip_validation: bool,
|
|
582
1142
|
start_time: float,
|
|
1143
|
+
scope: str = "memory",
|
|
1144
|
+
return_usage: bool = False,
|
|
1145
|
+
semaphore_wait_time: float = 0.0,
|
|
583
1146
|
) -> Any:
|
|
584
1147
|
"""
|
|
585
1148
|
Call Ollama using native API with JSON schema enforcement.
|
|
@@ -654,11 +1217,39 @@ class LLMProvider:
|
|
|
654
1217
|
else:
|
|
655
1218
|
raise
|
|
656
1219
|
|
|
1220
|
+
# Extract token usage from Ollama response
|
|
1221
|
+
# Ollama returns prompt_eval_count (input) and eval_count (output)
|
|
1222
|
+
duration = time.time() - start_time
|
|
1223
|
+
input_tokens = result.get("prompt_eval_count", 0) or 0
|
|
1224
|
+
output_tokens = result.get("eval_count", 0) or 0
|
|
1225
|
+
total_tokens = input_tokens + output_tokens
|
|
1226
|
+
|
|
1227
|
+
# Record LLM metrics
|
|
1228
|
+
metrics = get_metrics_collector()
|
|
1229
|
+
metrics.record_llm_call(
|
|
1230
|
+
provider=self.provider,
|
|
1231
|
+
model=self.model,
|
|
1232
|
+
scope=scope,
|
|
1233
|
+
duration=duration,
|
|
1234
|
+
input_tokens=input_tokens,
|
|
1235
|
+
output_tokens=output_tokens,
|
|
1236
|
+
success=True,
|
|
1237
|
+
)
|
|
1238
|
+
|
|
657
1239
|
# Validate against Pydantic model or return raw JSON
|
|
658
1240
|
if skip_validation:
|
|
659
|
-
|
|
1241
|
+
validated_result = json_data
|
|
660
1242
|
else:
|
|
661
|
-
|
|
1243
|
+
validated_result = response_format.model_validate(json_data)
|
|
1244
|
+
|
|
1245
|
+
if return_usage:
|
|
1246
|
+
token_usage = TokenUsage(
|
|
1247
|
+
input_tokens=input_tokens,
|
|
1248
|
+
output_tokens=output_tokens,
|
|
1249
|
+
total_tokens=total_tokens,
|
|
1250
|
+
)
|
|
1251
|
+
return validated_result, token_usage
|
|
1252
|
+
return validated_result
|
|
662
1253
|
|
|
663
1254
|
except httpx.HTTPStatusError as e:
|
|
664
1255
|
last_exception = e
|
|
@@ -701,6 +1292,9 @@ class LLMProvider:
|
|
|
701
1292
|
max_backoff: float,
|
|
702
1293
|
skip_validation: bool,
|
|
703
1294
|
start_time: float,
|
|
1295
|
+
scope: str = "memory",
|
|
1296
|
+
return_usage: bool = False,
|
|
1297
|
+
semaphore_wait_time: float = 0.0,
|
|
704
1298
|
) -> Any:
|
|
705
1299
|
"""Handle Gemini-specific API calls."""
|
|
706
1300
|
# Convert OpenAI-style messages to Gemini format
|
|
@@ -777,16 +1371,43 @@ class LLMProvider:
|
|
|
777
1371
|
else:
|
|
778
1372
|
result = content
|
|
779
1373
|
|
|
780
|
-
#
|
|
1374
|
+
# Record metrics and log slow calls
|
|
781
1375
|
duration = time.time() - start_time
|
|
782
|
-
|
|
1376
|
+
input_tokens = 0
|
|
1377
|
+
output_tokens = 0
|
|
1378
|
+
if hasattr(response, "usage_metadata") and response.usage_metadata:
|
|
783
1379
|
usage = response.usage_metadata
|
|
1380
|
+
input_tokens = usage.prompt_token_count or 0
|
|
1381
|
+
output_tokens = usage.candidates_token_count or 0
|
|
1382
|
+
|
|
1383
|
+
# Record LLM metrics
|
|
1384
|
+
metrics = get_metrics_collector()
|
|
1385
|
+
metrics.record_llm_call(
|
|
1386
|
+
provider=self.provider,
|
|
1387
|
+
model=self.model,
|
|
1388
|
+
scope=scope,
|
|
1389
|
+
duration=duration,
|
|
1390
|
+
input_tokens=input_tokens,
|
|
1391
|
+
output_tokens=output_tokens,
|
|
1392
|
+
success=True,
|
|
1393
|
+
)
|
|
1394
|
+
|
|
1395
|
+
# Log slow calls
|
|
1396
|
+
if duration > 10.0 and input_tokens > 0:
|
|
1397
|
+
wait_info = f", wait={semaphore_wait_time:.3f}s" if semaphore_wait_time > 0.1 else ""
|
|
784
1398
|
logger.info(
|
|
785
|
-
f"slow llm call: model={self.provider}/{self.model}, "
|
|
786
|
-
f"input_tokens={
|
|
787
|
-
f"time={duration:.3f}s"
|
|
1399
|
+
f"slow llm call: scope={scope}, model={self.provider}/{self.model}, "
|
|
1400
|
+
f"input_tokens={input_tokens}, output_tokens={output_tokens}, "
|
|
1401
|
+
f"time={duration:.3f}s{wait_info}"
|
|
788
1402
|
)
|
|
789
1403
|
|
|
1404
|
+
if return_usage:
|
|
1405
|
+
token_usage = TokenUsage(
|
|
1406
|
+
input_tokens=input_tokens,
|
|
1407
|
+
output_tokens=output_tokens,
|
|
1408
|
+
total_tokens=input_tokens + output_tokens,
|
|
1409
|
+
)
|
|
1410
|
+
return result, token_usage
|
|
790
1411
|
return result
|
|
791
1412
|
|
|
792
1413
|
except json.JSONDecodeError as e:
|
|
@@ -828,6 +1449,61 @@ class LLMProvider:
|
|
|
828
1449
|
raise last_exception
|
|
829
1450
|
raise RuntimeError("Gemini call failed after all retries")
|
|
830
1451
|
|
|
1452
|
+
async def _call_mock(
|
|
1453
|
+
self,
|
|
1454
|
+
messages: list[dict[str, str]],
|
|
1455
|
+
response_format: Any | None,
|
|
1456
|
+
scope: str,
|
|
1457
|
+
return_usage: bool,
|
|
1458
|
+
) -> Any:
|
|
1459
|
+
"""
|
|
1460
|
+
Handle mock provider calls for testing.
|
|
1461
|
+
|
|
1462
|
+
Records the call and returns a configurable mock response.
|
|
1463
|
+
"""
|
|
1464
|
+
# Record the call for test verification
|
|
1465
|
+
call_record = {
|
|
1466
|
+
"provider": self.provider,
|
|
1467
|
+
"model": self.model,
|
|
1468
|
+
"messages": messages,
|
|
1469
|
+
"response_format": response_format.__name__
|
|
1470
|
+
if response_format and hasattr(response_format, "__name__")
|
|
1471
|
+
else str(response_format),
|
|
1472
|
+
"scope": scope,
|
|
1473
|
+
}
|
|
1474
|
+
self._mock_calls.append(call_record)
|
|
1475
|
+
logger.debug(f"Mock LLM call recorded: scope={scope}, model={self.model}")
|
|
1476
|
+
|
|
1477
|
+
# Return mock response
|
|
1478
|
+
if self._mock_response is not None:
|
|
1479
|
+
result = self._mock_response
|
|
1480
|
+
elif response_format is not None:
|
|
1481
|
+
# Try to create a minimal valid instance of the response format
|
|
1482
|
+
try:
|
|
1483
|
+
# For Pydantic models, try to create with minimal valid data
|
|
1484
|
+
result = {"mock": True}
|
|
1485
|
+
except Exception:
|
|
1486
|
+
result = {"mock": True}
|
|
1487
|
+
else:
|
|
1488
|
+
result = "mock response"
|
|
1489
|
+
|
|
1490
|
+
if return_usage:
|
|
1491
|
+
token_usage = TokenUsage(input_tokens=10, output_tokens=5, total_tokens=15)
|
|
1492
|
+
return result, token_usage
|
|
1493
|
+
return result
|
|
1494
|
+
|
|
1495
|
+
def set_mock_response(self, response: Any) -> None:
|
|
1496
|
+
"""Set the response to return from mock calls."""
|
|
1497
|
+
self._mock_response = response
|
|
1498
|
+
|
|
1499
|
+
def get_mock_calls(self) -> list[dict]:
|
|
1500
|
+
"""Get the list of recorded mock calls."""
|
|
1501
|
+
return self._mock_calls
|
|
1502
|
+
|
|
1503
|
+
def clear_mock_calls(self) -> None:
|
|
1504
|
+
"""Clear the recorded mock calls."""
|
|
1505
|
+
self._mock_calls = []
|
|
1506
|
+
|
|
831
1507
|
@classmethod
|
|
832
1508
|
def for_memory(cls) -> "LLMProvider":
|
|
833
1509
|
"""Create provider for memory operations from environment variables."""
|