hindsight-api 0.2.1__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. hindsight_api/admin/__init__.py +1 -0
  2. hindsight_api/admin/cli.py +311 -0
  3. hindsight_api/alembic/versions/f1a2b3c4d5e6_add_memory_links_composite_index.py +44 -0
  4. hindsight_api/alembic/versions/g2a3b4c5d6e7_add_tags_column.py +48 -0
  5. hindsight_api/alembic/versions/h3c4d5e6f7g8_mental_models_v4.py +112 -0
  6. hindsight_api/alembic/versions/i4d5e6f7g8h9_delete_opinions.py +41 -0
  7. hindsight_api/alembic/versions/j5e6f7g8h9i0_mental_model_versions.py +95 -0
  8. hindsight_api/alembic/versions/k6f7g8h9i0j1_add_directive_subtype.py +58 -0
  9. hindsight_api/alembic/versions/l7g8h9i0j1k2_add_worker_columns.py +109 -0
  10. hindsight_api/alembic/versions/m8h9i0j1k2l3_mental_model_id_to_text.py +41 -0
  11. hindsight_api/alembic/versions/n9i0j1k2l3m4_learnings_and_pinned_reflections.py +134 -0
  12. hindsight_api/alembic/versions/o0j1k2l3m4n5_migrate_mental_models_data.py +113 -0
  13. hindsight_api/alembic/versions/p1k2l3m4n5o6_new_knowledge_architecture.py +194 -0
  14. hindsight_api/alembic/versions/q2l3m4n5o6p7_fix_mental_model_fact_type.py +50 -0
  15. hindsight_api/alembic/versions/r3m4n5o6p7q8_add_reflect_response_to_reflections.py +47 -0
  16. hindsight_api/alembic/versions/s4n5o6p7q8r9_add_consolidated_at_to_memory_units.py +53 -0
  17. hindsight_api/alembic/versions/t5o6p7q8r9s0_rename_mental_models_to_observations.py +134 -0
  18. hindsight_api/alembic/versions/u6p7q8r9s0t1_mental_models_text_id.py +41 -0
  19. hindsight_api/alembic/versions/v7q8r9s0t1u2_add_max_tokens_to_mental_models.py +50 -0
  20. hindsight_api/api/http.py +1406 -118
  21. hindsight_api/api/mcp.py +11 -196
  22. hindsight_api/config.py +359 -27
  23. hindsight_api/engine/consolidation/__init__.py +5 -0
  24. hindsight_api/engine/consolidation/consolidator.py +859 -0
  25. hindsight_api/engine/consolidation/prompts.py +69 -0
  26. hindsight_api/engine/cross_encoder.py +706 -88
  27. hindsight_api/engine/db_budget.py +284 -0
  28. hindsight_api/engine/db_utils.py +11 -0
  29. hindsight_api/engine/directives/__init__.py +5 -0
  30. hindsight_api/engine/directives/models.py +37 -0
  31. hindsight_api/engine/embeddings.py +553 -29
  32. hindsight_api/engine/entity_resolver.py +8 -5
  33. hindsight_api/engine/interface.py +40 -17
  34. hindsight_api/engine/llm_wrapper.py +744 -68
  35. hindsight_api/engine/memory_engine.py +2505 -1017
  36. hindsight_api/engine/mental_models/__init__.py +14 -0
  37. hindsight_api/engine/mental_models/models.py +53 -0
  38. hindsight_api/engine/query_analyzer.py +4 -3
  39. hindsight_api/engine/reflect/__init__.py +18 -0
  40. hindsight_api/engine/reflect/agent.py +933 -0
  41. hindsight_api/engine/reflect/models.py +109 -0
  42. hindsight_api/engine/reflect/observations.py +186 -0
  43. hindsight_api/engine/reflect/prompts.py +483 -0
  44. hindsight_api/engine/reflect/tools.py +437 -0
  45. hindsight_api/engine/reflect/tools_schema.py +250 -0
  46. hindsight_api/engine/response_models.py +168 -4
  47. hindsight_api/engine/retain/bank_utils.py +79 -201
  48. hindsight_api/engine/retain/fact_extraction.py +424 -195
  49. hindsight_api/engine/retain/fact_storage.py +35 -12
  50. hindsight_api/engine/retain/link_utils.py +29 -24
  51. hindsight_api/engine/retain/orchestrator.py +24 -43
  52. hindsight_api/engine/retain/types.py +11 -2
  53. hindsight_api/engine/search/graph_retrieval.py +43 -14
  54. hindsight_api/engine/search/link_expansion_retrieval.py +391 -0
  55. hindsight_api/engine/search/mpfp_retrieval.py +362 -117
  56. hindsight_api/engine/search/reranking.py +2 -2
  57. hindsight_api/engine/search/retrieval.py +848 -201
  58. hindsight_api/engine/search/tags.py +172 -0
  59. hindsight_api/engine/search/think_utils.py +42 -141
  60. hindsight_api/engine/search/trace.py +12 -1
  61. hindsight_api/engine/search/tracer.py +26 -6
  62. hindsight_api/engine/search/types.py +21 -3
  63. hindsight_api/engine/task_backend.py +113 -106
  64. hindsight_api/engine/utils.py +1 -152
  65. hindsight_api/extensions/__init__.py +10 -1
  66. hindsight_api/extensions/builtin/tenant.py +5 -1
  67. hindsight_api/extensions/context.py +10 -1
  68. hindsight_api/extensions/operation_validator.py +81 -4
  69. hindsight_api/extensions/tenant.py +26 -0
  70. hindsight_api/main.py +69 -6
  71. hindsight_api/mcp_local.py +12 -53
  72. hindsight_api/mcp_tools.py +494 -0
  73. hindsight_api/metrics.py +433 -48
  74. hindsight_api/migrations.py +141 -1
  75. hindsight_api/models.py +3 -3
  76. hindsight_api/pg0.py +53 -0
  77. hindsight_api/server.py +39 -2
  78. hindsight_api/worker/__init__.py +11 -0
  79. hindsight_api/worker/main.py +296 -0
  80. hindsight_api/worker/poller.py +486 -0
  81. {hindsight_api-0.2.1.dist-info → hindsight_api-0.4.0.dist-info}/METADATA +16 -6
  82. hindsight_api-0.4.0.dist-info/RECORD +112 -0
  83. {hindsight_api-0.2.1.dist-info → hindsight_api-0.4.0.dist-info}/entry_points.txt +2 -0
  84. hindsight_api/engine/retain/observation_regeneration.py +0 -254
  85. hindsight_api/engine/search/observation_utils.py +0 -125
  86. hindsight_api/engine/search/scoring.py +0 -159
  87. hindsight_api-0.2.1.dist-info/RECORD +0 -75
  88. {hindsight_api-0.2.1.dist-info → hindsight_api-0.4.0.dist-info}/WHEEL +0 -0
@@ -19,9 +19,12 @@ from openai import APIConnectionError, APIStatusError, AsyncOpenAI, LengthFinish
19
19
  from ..config import (
20
20
  DEFAULT_LLM_MAX_CONCURRENT,
21
21
  DEFAULT_LLM_TIMEOUT,
22
+ ENV_LLM_GROQ_SERVICE_TIER,
22
23
  ENV_LLM_MAX_CONCURRENT,
23
24
  ENV_LLM_TIMEOUT,
24
25
  )
26
+ from ..metrics import get_metrics_collector
27
+ from .response_models import TokenUsage
25
28
 
26
29
  # Seed applied to every Groq request for deterministic behavior.
27
30
  DEFAULT_LLM_SEED = 4242
@@ -63,6 +66,7 @@ class LLMProvider:
63
66
  base_url: str,
64
67
  model: str,
65
68
  reasoning_effort: str = "low",
69
+ groq_service_tier: str | None = None,
66
70
  ):
67
71
  """
68
72
  Initialize LLM provider.
@@ -73,18 +77,25 @@ class LLMProvider:
73
77
  base_url: Base URL for the API.
74
78
  model: Model name.
75
79
  reasoning_effort: Reasoning effort level for supported providers.
80
+ groq_service_tier: Groq service tier ("on_demand", "flex", "auto"). Default: None (uses Groq's default).
76
81
  """
77
82
  self.provider = provider.lower()
78
83
  self.api_key = api_key
79
84
  self.base_url = base_url
80
85
  self.model = model
81
86
  self.reasoning_effort = reasoning_effort
87
+ # Default to 'auto' for best performance, users can override to 'on_demand' for free tier
88
+ self.groq_service_tier = groq_service_tier or os.getenv(ENV_LLM_GROQ_SERVICE_TIER, "auto")
82
89
 
83
90
  # Validate provider
84
- valid_providers = ["openai", "groq", "ollama", "gemini", "anthropic", "lmstudio"]
91
+ valid_providers = ["openai", "groq", "ollama", "gemini", "anthropic", "lmstudio", "mock"]
85
92
  if self.provider not in valid_providers:
86
93
  raise ValueError(f"Invalid LLM provider: {self.provider}. Must be one of: {', '.join(valid_providers)}")
87
94
 
95
+ # Mock provider tracking (for testing)
96
+ self._mock_calls: list[dict] = []
97
+ self._mock_response: Any = None
98
+
88
99
  # Set default base URLs
89
100
  if not self.base_url:
90
101
  if self.provider == "groq":
@@ -94,8 +105,8 @@ class LLMProvider:
94
105
  elif self.provider == "lmstudio":
95
106
  self.base_url = "http://localhost:1234/v1"
96
107
 
97
- # Validate API key (not needed for ollama or lmstudio)
98
- if self.provider not in ("ollama", "lmstudio") and not self.api_key:
108
+ # Validate API key (not needed for ollama, lmstudio, or mock)
109
+ if self.provider not in ("ollama", "lmstudio", "mock") and not self.api_key:
99
110
  raise ValueError(f"API key not found for {self.provider}")
100
111
 
101
112
  # Get timeout config (set HINDSIGHT_API_LLM_TIMEOUT for local LLMs that need longer timeouts)
@@ -106,7 +117,10 @@ class LLMProvider:
106
117
  self._gemini_client = None
107
118
  self._anthropic_client = None
108
119
 
109
- if self.provider == "gemini":
120
+ if self.provider == "mock":
121
+ # Mock provider - no client needed
122
+ pass
123
+ elif self.provider == "gemini":
110
124
  self._gemini_client = genai.Client(api_key=self.api_key)
111
125
  elif self.provider == "anthropic":
112
126
  from anthropic import AsyncAnthropic
@@ -169,6 +183,7 @@ class LLMProvider:
169
183
  max_backoff: float = 60.0,
170
184
  skip_validation: bool = False,
171
185
  strict_schema: bool = False,
186
+ return_usage: bool = False,
172
187
  ) -> Any:
173
188
  """
174
189
  Make an LLM API call with retry logic.
@@ -184,21 +199,43 @@ class LLMProvider:
184
199
  max_backoff: Maximum backoff time in seconds.
185
200
  skip_validation: Return raw JSON without Pydantic validation.
186
201
  strict_schema: Use strict JSON schema enforcement (OpenAI only). Guarantees all required fields.
202
+ return_usage: If True, return tuple (result, TokenUsage) instead of just result.
187
203
 
188
204
  Returns:
189
- Parsed response if response_format is provided, otherwise text content.
205
+ If return_usage=False: Parsed response if response_format is provided, otherwise text content.
206
+ If return_usage=True: Tuple of (result, TokenUsage) with token counts from the LLM call.
190
207
 
191
208
  Raises:
192
209
  OutputTooLongError: If output exceeds token limits.
193
210
  Exception: Re-raises API errors after retries exhausted.
194
211
  """
212
+ semaphore_start = time.time()
195
213
  async with _global_llm_semaphore:
214
+ semaphore_wait_time = time.time() - semaphore_start
196
215
  start_time = time.time()
197
216
 
217
+ # Handle Mock provider (for testing)
218
+ if self.provider == "mock":
219
+ return await self._call_mock(
220
+ messages,
221
+ response_format,
222
+ scope,
223
+ return_usage,
224
+ )
225
+
198
226
  # Handle Gemini provider separately
199
227
  if self.provider == "gemini":
200
228
  return await self._call_gemini(
201
- messages, response_format, max_retries, initial_backoff, max_backoff, skip_validation, start_time
229
+ messages,
230
+ response_format,
231
+ max_retries,
232
+ initial_backoff,
233
+ max_backoff,
234
+ skip_validation,
235
+ start_time,
236
+ scope,
237
+ return_usage,
238
+ semaphore_wait_time,
202
239
  )
203
240
 
204
241
  # Handle Anthropic provider separately
@@ -212,6 +249,9 @@ class LLMProvider:
212
249
  max_backoff,
213
250
  skip_validation,
214
251
  start_time,
252
+ scope,
253
+ return_usage,
254
+ semaphore_wait_time,
215
255
  )
216
256
 
217
257
  # Handle Ollama with native API for structured output (better schema enforcement)
@@ -226,6 +266,9 @@ class LLMProvider:
226
266
  max_backoff,
227
267
  skip_validation,
228
268
  start_time,
269
+ scope,
270
+ return_usage,
271
+ semaphore_wait_time,
229
272
  )
230
273
 
231
274
  call_params = {
@@ -263,51 +306,56 @@ class LLMProvider:
263
306
  # Provider-specific parameters
264
307
  if self.provider == "groq":
265
308
  call_params["seed"] = DEFAULT_LLM_SEED
266
- extra_body = {"service_tier": "auto"}
267
- # Only add reasoning parameters for reasoning models
309
+ extra_body: dict[str, Any] = {}
310
+ # Add service_tier if configured (requires paid plan for flex/auto)
311
+ if self.groq_service_tier:
312
+ extra_body["service_tier"] = self.groq_service_tier
313
+ # Add reasoning parameters for reasoning models
268
314
  if is_reasoning_model:
269
315
  extra_body["include_reasoning"] = False
270
- call_params["extra_body"] = extra_body
316
+ if extra_body:
317
+ call_params["extra_body"] = extra_body
271
318
 
272
319
  last_exception = None
273
320
 
321
+ # Prepare response format ONCE before the retry loop
322
+ # (to avoid appending schema to messages on every retry)
323
+ if response_format is not None:
324
+ schema = None
325
+ if hasattr(response_format, "model_json_schema"):
326
+ schema = response_format.model_json_schema()
327
+
328
+ if strict_schema and schema is not None:
329
+ # Use OpenAI's strict JSON schema enforcement
330
+ # This guarantees all required fields are returned
331
+ call_params["response_format"] = {
332
+ "type": "json_schema",
333
+ "json_schema": {
334
+ "name": "response",
335
+ "strict": True,
336
+ "schema": schema,
337
+ },
338
+ }
339
+ else:
340
+ # Soft enforcement: add schema to prompt and use json_object mode
341
+ if schema is not None:
342
+ schema_msg = f"\n\nYou must respond with valid JSON matching this schema:\n{json.dumps(schema, indent=2)}"
343
+
344
+ if call_params["messages"] and call_params["messages"][0].get("role") == "system":
345
+ call_params["messages"][0]["content"] += schema_msg
346
+ elif call_params["messages"]:
347
+ call_params["messages"][0]["content"] = (
348
+ schema_msg + "\n\n" + call_params["messages"][0]["content"]
349
+ )
350
+ if self.provider not in ("lmstudio", "ollama"):
351
+ # LM Studio and Ollama don't support json_object response format reliably
352
+ # We rely on the schema in the system message instead
353
+ call_params["response_format"] = {"type": "json_object"}
354
+
274
355
  for attempt in range(max_retries + 1):
275
356
  try:
276
357
  if response_format is not None:
277
- schema = None
278
- if hasattr(response_format, "model_json_schema"):
279
- schema = response_format.model_json_schema()
280
-
281
- if strict_schema and schema is not None:
282
- # Use OpenAI's strict JSON schema enforcement
283
- # This guarantees all required fields are returned
284
- call_params["response_format"] = {
285
- "type": "json_schema",
286
- "json_schema": {
287
- "name": "response",
288
- "strict": True,
289
- "schema": schema,
290
- },
291
- }
292
- else:
293
- # Soft enforcement: add schema to prompt and use json_object mode
294
- if schema is not None:
295
- schema_msg = f"\n\nYou must respond with valid JSON matching this schema:\n{json.dumps(schema, indent=2)}"
296
-
297
- if call_params["messages"] and call_params["messages"][0].get("role") == "system":
298
- call_params["messages"][0]["content"] += schema_msg
299
- elif call_params["messages"]:
300
- call_params["messages"][0]["content"] = (
301
- schema_msg + "\n\n" + call_params["messages"][0]["content"]
302
- )
303
- if self.provider not in ("lmstudio", "ollama"):
304
- # LM Studio and Ollama don't support json_object response format reliably
305
- # We rely on the schema in the system message instead
306
- call_params["response_format"] = {"type": "json_object"}
307
-
308
- logger.debug(f"Sending request to {self.provider}/{self.model} (timeout={self.timeout})")
309
358
  response = await self._client.chat.completions.create(**call_params)
310
- logger.debug(f"Received response from {self.provider}/{self.model}")
311
359
 
312
360
  content = response.choices[0].message.content
313
361
 
@@ -370,21 +418,46 @@ class LLMProvider:
370
418
  response = await self._client.chat.completions.create(**call_params)
371
419
  result = response.choices[0].message.content
372
420
 
373
- # Log slow calls
421
+ # Record token usage metrics
374
422
  duration = time.time() - start_time
375
423
  usage = response.usage
376
- if duration > 10.0:
377
- ratio = max(1, usage.completion_tokens) / usage.prompt_tokens
424
+ input_tokens = usage.prompt_tokens or 0 if usage else 0
425
+ output_tokens = usage.completion_tokens or 0 if usage else 0
426
+ total_tokens = usage.total_tokens or 0 if usage else 0
427
+
428
+ # Record LLM metrics
429
+ metrics = get_metrics_collector()
430
+ metrics.record_llm_call(
431
+ provider=self.provider,
432
+ model=self.model,
433
+ scope=scope,
434
+ duration=duration,
435
+ input_tokens=input_tokens,
436
+ output_tokens=output_tokens,
437
+ success=True,
438
+ )
439
+
440
+ # Log slow calls
441
+ if duration > 10.0 and usage:
442
+ ratio = max(1, output_tokens) / max(1, input_tokens)
378
443
  cached_tokens = 0
379
444
  if hasattr(usage, "prompt_tokens_details") and usage.prompt_tokens_details:
380
445
  cached_tokens = getattr(usage.prompt_tokens_details, "cached_tokens", 0) or 0
381
446
  cache_info = f", cached_tokens={cached_tokens}" if cached_tokens > 0 else ""
447
+ wait_info = f", wait={semaphore_wait_time:.3f}s" if semaphore_wait_time > 0.1 else ""
382
448
  logger.info(
383
- f"slow llm call: model={self.provider}/{self.model}, "
384
- f"input_tokens={usage.prompt_tokens}, output_tokens={usage.completion_tokens}, "
385
- f"total_tokens={usage.total_tokens}{cache_info}, time={duration:.3f}s, ratio out/in={ratio:.2f}"
449
+ f"slow llm call: scope={scope}, model={self.provider}/{self.model}, "
450
+ f"input_tokens={input_tokens}, output_tokens={output_tokens}, "
451
+ f"total_tokens={total_tokens}{cache_info}, time={duration:.3f}s{wait_info}, ratio out/in={ratio:.2f}"
386
452
  )
387
453
 
454
+ if return_usage:
455
+ token_usage = TokenUsage(
456
+ input_tokens=input_tokens,
457
+ output_tokens=output_tokens,
458
+ total_tokens=total_tokens,
459
+ )
460
+ return result, token_usage
388
461
  return result
389
462
 
390
463
  except LengthFinishReasonError as e:
@@ -395,13 +468,11 @@ class LLMProvider:
395
468
 
396
469
  except APIConnectionError as e:
397
470
  last_exception = e
471
+ status_code = getattr(e, "status_code", None) or getattr(
472
+ getattr(e, "response", None), "status_code", None
473
+ )
474
+ logger.warning(f"APIConnectionError (HTTP {status_code}), attempt {attempt + 1}: {str(e)[:200]}")
398
475
  if attempt < max_retries:
399
- status_code = getattr(e, "status_code", None) or getattr(
400
- getattr(e, "response", None), "status_code", None
401
- )
402
- logger.warning(
403
- f"Connection error, retrying... (attempt {attempt + 1}/{max_retries + 1}) - status_code={status_code}, message={e}"
404
- )
405
476
  backoff = min(initial_backoff * (2**attempt), max_backoff)
406
477
  await asyncio.sleep(backoff)
407
478
  continue
@@ -415,6 +486,45 @@ class LLMProvider:
415
486
  logger.error(f"Auth error (HTTP {e.status_code}), not retrying: {str(e)}")
416
487
  raise
417
488
 
489
+ # Handle tool_use_failed error - model outputted in tool call format
490
+ # Convert to expected JSON format and continue
491
+ if e.status_code == 400 and response_format is not None:
492
+ try:
493
+ error_body = e.body if hasattr(e, "body") else {}
494
+ if isinstance(error_body, dict):
495
+ error_info: dict[str, Any] = error_body.get("error") or {}
496
+ if error_info.get("code") == "tool_use_failed":
497
+ failed_gen = error_info.get("failed_generation", "")
498
+ if failed_gen:
499
+ # Parse the tool call format and convert to actions format
500
+ tool_call = json.loads(failed_gen)
501
+ tool_name = tool_call.get("name", "")
502
+ tool_args = tool_call.get("arguments", {})
503
+ # Convert to actions format: {"actions": [{"tool": "name", ...args}]}
504
+ converted = {"actions": [{"tool": tool_name, **tool_args}]}
505
+ if skip_validation:
506
+ result = converted
507
+ else:
508
+ result = response_format.model_validate(converted)
509
+
510
+ # Record metrics for this successful recovery
511
+ duration = time.time() - start_time
512
+ metrics = get_metrics_collector()
513
+ metrics.record_llm_call(
514
+ provider=self.provider,
515
+ model=self.model,
516
+ scope=scope,
517
+ duration=duration,
518
+ input_tokens=0,
519
+ output_tokens=0,
520
+ success=True,
521
+ )
522
+ if return_usage:
523
+ return result, TokenUsage(input_tokens=0, output_tokens=0, total_tokens=0)
524
+ return result
525
+ except (json.JSONDecodeError, KeyError, TypeError):
526
+ pass # Failed to parse tool_use_failed, continue with normal retry
527
+
418
528
  last_exception = e
419
529
  if attempt < max_retries:
420
530
  backoff = min(initial_backoff * (2**attempt), max_backoff)
@@ -425,14 +535,438 @@ class LLMProvider:
425
535
  logger.error(f"API error after {max_retries + 1} attempts: {str(e)}")
426
536
  raise
427
537
 
428
- except Exception as e:
429
- logger.error(f"Unexpected error during LLM call: {type(e).__name__}: {str(e)}")
538
+ except Exception:
430
539
  raise
431
540
 
432
541
  if last_exception:
433
542
  raise last_exception
434
543
  raise RuntimeError("LLM call failed after all retries with no exception captured")
435
544
 
545
+ async def call_with_tools(
546
+ self,
547
+ messages: list[dict[str, Any]],
548
+ tools: list[dict[str, Any]],
549
+ max_completion_tokens: int | None = None,
550
+ temperature: float | None = None,
551
+ scope: str = "tools",
552
+ max_retries: int = 5,
553
+ initial_backoff: float = 1.0,
554
+ max_backoff: float = 30.0,
555
+ tool_choice: str | dict[str, Any] = "auto",
556
+ ) -> "LLMToolCallResult":
557
+ """
558
+ Make an LLM API call with tool/function calling support.
559
+
560
+ Args:
561
+ messages: List of message dicts. Can include tool results with role='tool'.
562
+ tools: List of tool definitions in OpenAI format.
563
+ max_completion_tokens: Maximum tokens in response.
564
+ temperature: Sampling temperature (0.0-2.0).
565
+ scope: Scope identifier for tracking.
566
+ max_retries: Maximum retry attempts.
567
+ initial_backoff: Initial backoff time in seconds.
568
+ max_backoff: Maximum backoff time in seconds.
569
+ tool_choice: How to choose tools - "auto", "none", "required", or {"type": "function", "function": {"name": "..."}}
570
+
571
+ Returns:
572
+ LLMToolCallResult with content and/or tool_calls.
573
+ """
574
+ from .response_models import LLMToolCall, LLMToolCallResult
575
+
576
+ async with _global_llm_semaphore:
577
+ start_time = time.time()
578
+
579
+ # Handle Mock provider
580
+ if self.provider == "mock":
581
+ return await self._call_with_tools_mock(messages, tools, scope)
582
+
583
+ # Handle Anthropic separately (uses different tool format)
584
+ if self.provider == "anthropic":
585
+ return await self._call_with_tools_anthropic(
586
+ messages, tools, max_completion_tokens, max_retries, initial_backoff, max_backoff, start_time, scope
587
+ )
588
+
589
+ # Handle Gemini (convert to Gemini tool format)
590
+ if self.provider == "gemini":
591
+ return await self._call_with_tools_gemini(
592
+ messages, tools, max_retries, initial_backoff, max_backoff, start_time, scope
593
+ )
594
+
595
+ # OpenAI-compatible providers (OpenAI, Groq, Ollama, LMStudio)
596
+ call_params: dict[str, Any] = {
597
+ "model": self.model,
598
+ "messages": messages,
599
+ "tools": tools,
600
+ "tool_choice": tool_choice,
601
+ }
602
+
603
+ if max_completion_tokens is not None:
604
+ call_params["max_completion_tokens"] = max_completion_tokens
605
+ if temperature is not None:
606
+ call_params["temperature"] = temperature
607
+
608
+ # Provider-specific parameters
609
+ if self.provider == "groq":
610
+ call_params["seed"] = DEFAULT_LLM_SEED
611
+
612
+ last_exception = None
613
+
614
+ for attempt in range(max_retries + 1):
615
+ try:
616
+ response = await self._client.chat.completions.create(**call_params)
617
+
618
+ message = response.choices[0].message
619
+ finish_reason = response.choices[0].finish_reason
620
+
621
+ # Extract tool calls if present
622
+ tool_calls: list[LLMToolCall] = []
623
+ if message.tool_calls:
624
+ for tc in message.tool_calls:
625
+ try:
626
+ args = json.loads(tc.function.arguments) if tc.function.arguments else {}
627
+ except json.JSONDecodeError:
628
+ args = {"_raw": tc.function.arguments}
629
+ tool_calls.append(LLMToolCall(id=tc.id, name=tc.function.name, arguments=args))
630
+
631
+ content = message.content
632
+
633
+ # Record metrics
634
+ duration = time.time() - start_time
635
+ usage = response.usage
636
+ input_tokens = usage.prompt_tokens or 0 if usage else 0
637
+ output_tokens = usage.completion_tokens or 0 if usage else 0
638
+
639
+ metrics = get_metrics_collector()
640
+ metrics.record_llm_call(
641
+ provider=self.provider,
642
+ model=self.model,
643
+ scope=scope,
644
+ duration=duration,
645
+ input_tokens=input_tokens,
646
+ output_tokens=output_tokens,
647
+ success=True,
648
+ )
649
+
650
+ return LLMToolCallResult(
651
+ content=content,
652
+ tool_calls=tool_calls,
653
+ finish_reason=finish_reason,
654
+ input_tokens=input_tokens,
655
+ output_tokens=output_tokens,
656
+ )
657
+
658
+ except APIConnectionError as e:
659
+ last_exception = e
660
+ if attempt < max_retries:
661
+ await asyncio.sleep(min(initial_backoff * (2**attempt), max_backoff))
662
+ continue
663
+ raise
664
+
665
+ except APIStatusError as e:
666
+ if e.status_code in (401, 403):
667
+ raise
668
+ last_exception = e
669
+ if attempt < max_retries:
670
+ await asyncio.sleep(min(initial_backoff * (2**attempt), max_backoff))
671
+ continue
672
+ raise
673
+
674
+ except Exception:
675
+ raise
676
+
677
+ if last_exception:
678
+ raise last_exception
679
+ raise RuntimeError("Tool call failed after all retries")
680
+
681
+ async def _call_with_tools_mock(
682
+ self,
683
+ messages: list[dict[str, Any]],
684
+ tools: list[dict[str, Any]],
685
+ scope: str,
686
+ ) -> "LLMToolCallResult":
687
+ """Handle mock tool calls for testing."""
688
+ from .response_models import LLMToolCallResult
689
+
690
+ call_record = {
691
+ "provider": self.provider,
692
+ "model": self.model,
693
+ "messages": messages,
694
+ "tools": [t.get("function", {}).get("name") for t in tools],
695
+ "scope": scope,
696
+ }
697
+ self._mock_calls.append(call_record)
698
+
699
+ if self._mock_response is not None:
700
+ if isinstance(self._mock_response, LLMToolCallResult):
701
+ return self._mock_response
702
+ # Allow setting just tool calls as a list
703
+ if isinstance(self._mock_response, list):
704
+ from .response_models import LLMToolCall
705
+
706
+ return LLMToolCallResult(
707
+ tool_calls=[
708
+ LLMToolCall(id=f"mock_{i}", name=tc["name"], arguments=tc.get("arguments", {}))
709
+ for i, tc in enumerate(self._mock_response)
710
+ ],
711
+ finish_reason="tool_calls",
712
+ )
713
+
714
+ return LLMToolCallResult(content="mock response", finish_reason="stop")
715
+
716
+ async def _call_with_tools_anthropic(
717
+ self,
718
+ messages: list[dict[str, Any]],
719
+ tools: list[dict[str, Any]],
720
+ max_completion_tokens: int | None,
721
+ max_retries: int,
722
+ initial_backoff: float,
723
+ max_backoff: float,
724
+ start_time: float,
725
+ scope: str,
726
+ ) -> "LLMToolCallResult":
727
+ """Handle Anthropic tool calling."""
728
+ from anthropic import APIConnectionError, APIStatusError
729
+
730
+ from .response_models import LLMToolCall, LLMToolCallResult
731
+
732
+ # Convert OpenAI tool format to Anthropic format
733
+ anthropic_tools = []
734
+ for tool in tools:
735
+ func = tool.get("function", {})
736
+ anthropic_tools.append(
737
+ {
738
+ "name": func.get("name", ""),
739
+ "description": func.get("description", ""),
740
+ "input_schema": func.get("parameters", {"type": "object", "properties": {}}),
741
+ }
742
+ )
743
+
744
+ # Convert messages - handle tool results
745
+ system_prompt = None
746
+ anthropic_messages = []
747
+ for msg in messages:
748
+ role = msg.get("role", "user")
749
+ content = msg.get("content", "")
750
+
751
+ if role == "system":
752
+ system_prompt = (system_prompt + "\n\n" + content) if system_prompt else content
753
+ elif role == "tool":
754
+ # Anthropic uses tool_result blocks
755
+ anthropic_messages.append(
756
+ {
757
+ "role": "user",
758
+ "content": [
759
+ {"type": "tool_result", "tool_use_id": msg.get("tool_call_id", ""), "content": content}
760
+ ],
761
+ }
762
+ )
763
+ elif role == "assistant" and msg.get("tool_calls"):
764
+ # Convert assistant tool calls
765
+ tool_use_blocks = []
766
+ for tc in msg["tool_calls"]:
767
+ tool_use_blocks.append(
768
+ {
769
+ "type": "tool_use",
770
+ "id": tc.get("id", ""),
771
+ "name": tc.get("function", {}).get("name", ""),
772
+ "input": json.loads(tc.get("function", {}).get("arguments", "{}")),
773
+ }
774
+ )
775
+ anthropic_messages.append({"role": "assistant", "content": tool_use_blocks})
776
+ else:
777
+ anthropic_messages.append({"role": role, "content": content})
778
+
779
+ call_params: dict[str, Any] = {
780
+ "model": self.model,
781
+ "messages": anthropic_messages,
782
+ "tools": anthropic_tools,
783
+ "max_tokens": max_completion_tokens or 4096,
784
+ }
785
+ if system_prompt:
786
+ call_params["system"] = system_prompt
787
+
788
+ last_exception = None
789
+ for attempt in range(max_retries + 1):
790
+ try:
791
+ response = await self._anthropic_client.messages.create(**call_params)
792
+
793
+ # Extract content and tool calls
794
+ content_parts = []
795
+ tool_calls: list[LLMToolCall] = []
796
+
797
+ for block in response.content:
798
+ if block.type == "text":
799
+ content_parts.append(block.text)
800
+ elif block.type == "tool_use":
801
+ tool_calls.append(LLMToolCall(id=block.id, name=block.name, arguments=block.input or {}))
802
+
803
+ content = "".join(content_parts) if content_parts else None
804
+ finish_reason = "tool_calls" if tool_calls else "stop"
805
+
806
+ # Extract token usage
807
+ input_tokens = response.usage.input_tokens or 0
808
+ output_tokens = response.usage.output_tokens or 0
809
+
810
+ # Record metrics
811
+ metrics = get_metrics_collector()
812
+ metrics.record_llm_call(
813
+ provider=self.provider,
814
+ model=self.model,
815
+ scope=scope,
816
+ duration=time.time() - start_time,
817
+ input_tokens=input_tokens,
818
+ output_tokens=output_tokens,
819
+ success=True,
820
+ )
821
+
822
+ return LLMToolCallResult(
823
+ content=content,
824
+ tool_calls=tool_calls,
825
+ finish_reason=finish_reason,
826
+ input_tokens=input_tokens,
827
+ output_tokens=output_tokens,
828
+ )
829
+
830
+ except (APIConnectionError, APIStatusError) as e:
831
+ if isinstance(e, APIStatusError) and e.status_code in (401, 403):
832
+ raise
833
+ last_exception = e
834
+ if attempt < max_retries:
835
+ await asyncio.sleep(min(initial_backoff * (2**attempt), max_backoff))
836
+ continue
837
+ raise
838
+
839
+ if last_exception:
840
+ raise last_exception
841
+ raise RuntimeError("Anthropic tool call failed")
842
+
843
+ async def _call_with_tools_gemini(
844
+ self,
845
+ messages: list[dict[str, Any]],
846
+ tools: list[dict[str, Any]],
847
+ max_retries: int,
848
+ initial_backoff: float,
849
+ max_backoff: float,
850
+ start_time: float,
851
+ scope: str,
852
+ ) -> "LLMToolCallResult":
853
+ """Handle Gemini tool calling."""
854
+ from .response_models import LLMToolCall, LLMToolCallResult
855
+
856
+ # Convert tools to Gemini format
857
+ gemini_tools = []
858
+ for tool in tools:
859
+ func = tool.get("function", {})
860
+ gemini_tools.append(
861
+ genai_types.Tool(
862
+ function_declarations=[
863
+ genai_types.FunctionDeclaration(
864
+ name=func.get("name", ""),
865
+ description=func.get("description", ""),
866
+ parameters=func.get("parameters"),
867
+ )
868
+ ]
869
+ )
870
+ )
871
+
872
+ # Convert messages
873
+ system_instruction = None
874
+ gemini_contents = []
875
+ for msg in messages:
876
+ role = msg.get("role", "user")
877
+ content = msg.get("content", "")
878
+
879
+ if role == "system":
880
+ system_instruction = (system_instruction + "\n\n" + content) if system_instruction else content
881
+ elif role == "tool":
882
+ # Gemini uses function_response
883
+ gemini_contents.append(
884
+ genai_types.Content(
885
+ role="user",
886
+ parts=[
887
+ genai_types.Part(
888
+ function_response=genai_types.FunctionResponse(
889
+ name=msg.get("name", ""),
890
+ response={"result": content},
891
+ )
892
+ )
893
+ ],
894
+ )
895
+ )
896
+ elif role == "assistant":
897
+ gemini_contents.append(genai_types.Content(role="model", parts=[genai_types.Part(text=content)]))
898
+ else:
899
+ gemini_contents.append(genai_types.Content(role="user", parts=[genai_types.Part(text=content)]))
900
+
901
+ config = genai_types.GenerateContentConfig(
902
+ system_instruction=system_instruction,
903
+ tools=gemini_tools,
904
+ )
905
+
906
+ last_exception = None
907
+ for attempt in range(max_retries + 1):
908
+ try:
909
+ response = await self._gemini_client.aio.models.generate_content(
910
+ model=self.model,
911
+ contents=gemini_contents,
912
+ config=config,
913
+ )
914
+
915
+ # Extract content and tool calls
916
+ content = None
917
+ tool_calls: list[LLMToolCall] = []
918
+
919
+ if response.candidates and response.candidates[0].content:
920
+ for part in response.candidates[0].content.parts:
921
+ if hasattr(part, "text") and part.text:
922
+ content = part.text
923
+ if hasattr(part, "function_call") and part.function_call:
924
+ fc = part.function_call
925
+ tool_calls.append(
926
+ LLMToolCall(
927
+ id=f"gemini_{len(tool_calls)}",
928
+ name=fc.name,
929
+ arguments=dict(fc.args) if fc.args else {},
930
+ )
931
+ )
932
+
933
+ finish_reason = "tool_calls" if tool_calls else "stop"
934
+
935
+ # Record metrics
936
+ metrics = get_metrics_collector()
937
+ input_tokens = response.usage_metadata.prompt_token_count if response.usage_metadata else 0
938
+ output_tokens = response.usage_metadata.candidates_token_count if response.usage_metadata else 0
939
+ metrics.record_llm_call(
940
+ provider=self.provider,
941
+ model=self.model,
942
+ scope=scope,
943
+ duration=time.time() - start_time,
944
+ input_tokens=input_tokens,
945
+ output_tokens=output_tokens,
946
+ success=True,
947
+ )
948
+
949
+ return LLMToolCallResult(
950
+ content=content,
951
+ tool_calls=tool_calls,
952
+ finish_reason=finish_reason,
953
+ input_tokens=input_tokens,
954
+ output_tokens=output_tokens,
955
+ )
956
+
957
+ except genai_errors.APIError as e:
958
+ if e.code in (401, 403):
959
+ raise
960
+ last_exception = e
961
+ if attempt < max_retries:
962
+ await asyncio.sleep(min(initial_backoff * (2**attempt), max_backoff))
963
+ continue
964
+ raise
965
+
966
+ if last_exception:
967
+ raise last_exception
968
+ raise RuntimeError("Gemini tool call failed")
969
+
436
970
  async def _call_anthropic(
437
971
  self,
438
972
  messages: list[dict[str, str]],
@@ -443,6 +977,9 @@ class LLMProvider:
443
977
  max_backoff: float,
444
978
  skip_validation: bool,
445
979
  start_time: float,
980
+ scope: str = "memory",
981
+ return_usage: bool = False,
982
+ semaphore_wait_time: float = 0.0,
446
983
  ) -> Any:
447
984
  """Handle Anthropic-specific API calls."""
448
985
  from anthropic import APIConnectionError, APIStatusError, RateLimitError
@@ -515,17 +1052,40 @@ class LLMProvider:
515
1052
  else:
516
1053
  result = content
517
1054
 
518
- # Log slow calls
1055
+ # Record metrics and log slow calls
519
1056
  duration = time.time() - start_time
1057
+ input_tokens = response.usage.input_tokens or 0 if response.usage else 0
1058
+ output_tokens = response.usage.output_tokens or 0 if response.usage else 0
1059
+ total_tokens = input_tokens + output_tokens
1060
+
1061
+ # Record LLM metrics
1062
+ metrics = get_metrics_collector()
1063
+ metrics.record_llm_call(
1064
+ provider=self.provider,
1065
+ model=self.model,
1066
+ scope=scope,
1067
+ duration=duration,
1068
+ input_tokens=input_tokens,
1069
+ output_tokens=output_tokens,
1070
+ success=True,
1071
+ )
1072
+
1073
+ # Log slow calls
520
1074
  if duration > 10.0:
521
- input_tokens = response.usage.input_tokens
522
- output_tokens = response.usage.output_tokens
1075
+ wait_info = f", wait={semaphore_wait_time:.3f}s" if semaphore_wait_time > 0.1 else ""
523
1076
  logger.info(
524
- f"slow llm call: model={self.provider}/{self.model}, "
1077
+ f"slow llm call: scope={scope}, model={self.provider}/{self.model}, "
525
1078
  f"input_tokens={input_tokens}, output_tokens={output_tokens}, "
526
- f"time={duration:.3f}s"
1079
+ f"time={duration:.3f}s{wait_info}"
527
1080
  )
528
1081
 
1082
+ if return_usage:
1083
+ token_usage = TokenUsage(
1084
+ input_tokens=input_tokens,
1085
+ output_tokens=output_tokens,
1086
+ total_tokens=total_tokens,
1087
+ )
1088
+ return result, token_usage
529
1089
  return result
530
1090
 
531
1091
  except json.JSONDecodeError as e:
@@ -580,6 +1140,9 @@ class LLMProvider:
580
1140
  max_backoff: float,
581
1141
  skip_validation: bool,
582
1142
  start_time: float,
1143
+ scope: str = "memory",
1144
+ return_usage: bool = False,
1145
+ semaphore_wait_time: float = 0.0,
583
1146
  ) -> Any:
584
1147
  """
585
1148
  Call Ollama using native API with JSON schema enforcement.
@@ -654,11 +1217,39 @@ class LLMProvider:
654
1217
  else:
655
1218
  raise
656
1219
 
1220
+ # Extract token usage from Ollama response
1221
+ # Ollama returns prompt_eval_count (input) and eval_count (output)
1222
+ duration = time.time() - start_time
1223
+ input_tokens = result.get("prompt_eval_count", 0) or 0
1224
+ output_tokens = result.get("eval_count", 0) or 0
1225
+ total_tokens = input_tokens + output_tokens
1226
+
1227
+ # Record LLM metrics
1228
+ metrics = get_metrics_collector()
1229
+ metrics.record_llm_call(
1230
+ provider=self.provider,
1231
+ model=self.model,
1232
+ scope=scope,
1233
+ duration=duration,
1234
+ input_tokens=input_tokens,
1235
+ output_tokens=output_tokens,
1236
+ success=True,
1237
+ )
1238
+
657
1239
  # Validate against Pydantic model or return raw JSON
658
1240
  if skip_validation:
659
- return json_data
1241
+ validated_result = json_data
660
1242
  else:
661
- return response_format.model_validate(json_data)
1243
+ validated_result = response_format.model_validate(json_data)
1244
+
1245
+ if return_usage:
1246
+ token_usage = TokenUsage(
1247
+ input_tokens=input_tokens,
1248
+ output_tokens=output_tokens,
1249
+ total_tokens=total_tokens,
1250
+ )
1251
+ return validated_result, token_usage
1252
+ return validated_result
662
1253
 
663
1254
  except httpx.HTTPStatusError as e:
664
1255
  last_exception = e
@@ -701,6 +1292,9 @@ class LLMProvider:
701
1292
  max_backoff: float,
702
1293
  skip_validation: bool,
703
1294
  start_time: float,
1295
+ scope: str = "memory",
1296
+ return_usage: bool = False,
1297
+ semaphore_wait_time: float = 0.0,
704
1298
  ) -> Any:
705
1299
  """Handle Gemini-specific API calls."""
706
1300
  # Convert OpenAI-style messages to Gemini format
@@ -777,16 +1371,43 @@ class LLMProvider:
777
1371
  else:
778
1372
  result = content
779
1373
 
780
- # Log slow calls
1374
+ # Record metrics and log slow calls
781
1375
  duration = time.time() - start_time
782
- if duration > 10.0 and hasattr(response, "usage_metadata") and response.usage_metadata:
1376
+ input_tokens = 0
1377
+ output_tokens = 0
1378
+ if hasattr(response, "usage_metadata") and response.usage_metadata:
783
1379
  usage = response.usage_metadata
1380
+ input_tokens = usage.prompt_token_count or 0
1381
+ output_tokens = usage.candidates_token_count or 0
1382
+
1383
+ # Record LLM metrics
1384
+ metrics = get_metrics_collector()
1385
+ metrics.record_llm_call(
1386
+ provider=self.provider,
1387
+ model=self.model,
1388
+ scope=scope,
1389
+ duration=duration,
1390
+ input_tokens=input_tokens,
1391
+ output_tokens=output_tokens,
1392
+ success=True,
1393
+ )
1394
+
1395
+ # Log slow calls
1396
+ if duration > 10.0 and input_tokens > 0:
1397
+ wait_info = f", wait={semaphore_wait_time:.3f}s" if semaphore_wait_time > 0.1 else ""
784
1398
  logger.info(
785
- f"slow llm call: model={self.provider}/{self.model}, "
786
- f"input_tokens={usage.prompt_token_count}, output_tokens={usage.candidates_token_count}, "
787
- f"time={duration:.3f}s"
1399
+ f"slow llm call: scope={scope}, model={self.provider}/{self.model}, "
1400
+ f"input_tokens={input_tokens}, output_tokens={output_tokens}, "
1401
+ f"time={duration:.3f}s{wait_info}"
788
1402
  )
789
1403
 
1404
+ if return_usage:
1405
+ token_usage = TokenUsage(
1406
+ input_tokens=input_tokens,
1407
+ output_tokens=output_tokens,
1408
+ total_tokens=input_tokens + output_tokens,
1409
+ )
1410
+ return result, token_usage
790
1411
  return result
791
1412
 
792
1413
  except json.JSONDecodeError as e:
@@ -828,6 +1449,61 @@ class LLMProvider:
828
1449
  raise last_exception
829
1450
  raise RuntimeError("Gemini call failed after all retries")
830
1451
 
1452
+ async def _call_mock(
1453
+ self,
1454
+ messages: list[dict[str, str]],
1455
+ response_format: Any | None,
1456
+ scope: str,
1457
+ return_usage: bool,
1458
+ ) -> Any:
1459
+ """
1460
+ Handle mock provider calls for testing.
1461
+
1462
+ Records the call and returns a configurable mock response.
1463
+ """
1464
+ # Record the call for test verification
1465
+ call_record = {
1466
+ "provider": self.provider,
1467
+ "model": self.model,
1468
+ "messages": messages,
1469
+ "response_format": response_format.__name__
1470
+ if response_format and hasattr(response_format, "__name__")
1471
+ else str(response_format),
1472
+ "scope": scope,
1473
+ }
1474
+ self._mock_calls.append(call_record)
1475
+ logger.debug(f"Mock LLM call recorded: scope={scope}, model={self.model}")
1476
+
1477
+ # Return mock response
1478
+ if self._mock_response is not None:
1479
+ result = self._mock_response
1480
+ elif response_format is not None:
1481
+ # Try to create a minimal valid instance of the response format
1482
+ try:
1483
+ # For Pydantic models, try to create with minimal valid data
1484
+ result = {"mock": True}
1485
+ except Exception:
1486
+ result = {"mock": True}
1487
+ else:
1488
+ result = "mock response"
1489
+
1490
+ if return_usage:
1491
+ token_usage = TokenUsage(input_tokens=10, output_tokens=5, total_tokens=15)
1492
+ return result, token_usage
1493
+ return result
1494
+
1495
+ def set_mock_response(self, response: Any) -> None:
1496
+ """Set the response to return from mock calls."""
1497
+ self._mock_response = response
1498
+
1499
+ def get_mock_calls(self) -> list[dict]:
1500
+ """Get the list of recorded mock calls."""
1501
+ return self._mock_calls
1502
+
1503
+ def clear_mock_calls(self) -> None:
1504
+ """Clear the recorded mock calls."""
1505
+ self._mock_calls = []
1506
+
831
1507
  @classmethod
832
1508
  def for_memory(cls) -> "LLMProvider":
833
1509
  """Create provider for memory operations from environment variables."""