hindsight-api 0.3.0__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hindsight_api/__init__.py +1 -1
- hindsight_api/admin/cli.py +59 -0
- hindsight_api/alembic/versions/h3c4d5e6f7g8_mental_models_v4.py +112 -0
- hindsight_api/alembic/versions/i4d5e6f7g8h9_delete_opinions.py +41 -0
- hindsight_api/alembic/versions/j5e6f7g8h9i0_mental_model_versions.py +95 -0
- hindsight_api/alembic/versions/k6f7g8h9i0j1_add_directive_subtype.py +58 -0
- hindsight_api/alembic/versions/l7g8h9i0j1k2_add_worker_columns.py +109 -0
- hindsight_api/alembic/versions/m8h9i0j1k2l3_mental_model_id_to_text.py +41 -0
- hindsight_api/alembic/versions/n9i0j1k2l3m4_learnings_and_pinned_reflections.py +134 -0
- hindsight_api/alembic/versions/o0j1k2l3m4n5_migrate_mental_models_data.py +113 -0
- hindsight_api/alembic/versions/p1k2l3m4n5o6_new_knowledge_architecture.py +194 -0
- hindsight_api/alembic/versions/q2l3m4n5o6p7_fix_mental_model_fact_type.py +50 -0
- hindsight_api/alembic/versions/r3m4n5o6p7q8_add_reflect_response_to_reflections.py +47 -0
- hindsight_api/alembic/versions/s4n5o6p7q8r9_add_consolidated_at_to_memory_units.py +53 -0
- hindsight_api/alembic/versions/t5o6p7q8r9s0_rename_mental_models_to_observations.py +134 -0
- hindsight_api/alembic/versions/u6p7q8r9s0t1_mental_models_text_id.py +41 -0
- hindsight_api/alembic/versions/v7q8r9s0t1u2_add_max_tokens_to_mental_models.py +50 -0
- hindsight_api/api/http.py +1120 -93
- hindsight_api/api/mcp.py +11 -191
- hindsight_api/config.py +174 -46
- hindsight_api/engine/consolidation/__init__.py +5 -0
- hindsight_api/engine/consolidation/consolidator.py +926 -0
- hindsight_api/engine/consolidation/prompts.py +77 -0
- hindsight_api/engine/cross_encoder.py +153 -22
- hindsight_api/engine/directives/__init__.py +5 -0
- hindsight_api/engine/directives/models.py +37 -0
- hindsight_api/engine/embeddings.py +136 -13
- hindsight_api/engine/interface.py +32 -13
- hindsight_api/engine/llm_wrapper.py +505 -43
- hindsight_api/engine/memory_engine.py +2101 -1094
- hindsight_api/engine/mental_models/__init__.py +14 -0
- hindsight_api/engine/mental_models/models.py +53 -0
- hindsight_api/engine/reflect/__init__.py +18 -0
- hindsight_api/engine/reflect/agent.py +933 -0
- hindsight_api/engine/reflect/models.py +109 -0
- hindsight_api/engine/reflect/observations.py +186 -0
- hindsight_api/engine/reflect/prompts.py +483 -0
- hindsight_api/engine/reflect/tools.py +437 -0
- hindsight_api/engine/reflect/tools_schema.py +250 -0
- hindsight_api/engine/response_models.py +130 -4
- hindsight_api/engine/retain/bank_utils.py +79 -201
- hindsight_api/engine/retain/fact_extraction.py +81 -48
- hindsight_api/engine/retain/fact_storage.py +5 -8
- hindsight_api/engine/retain/link_utils.py +5 -8
- hindsight_api/engine/retain/orchestrator.py +1 -55
- hindsight_api/engine/retain/types.py +2 -2
- hindsight_api/engine/search/graph_retrieval.py +2 -2
- hindsight_api/engine/search/link_expansion_retrieval.py +164 -29
- hindsight_api/engine/search/mpfp_retrieval.py +1 -1
- hindsight_api/engine/search/retrieval.py +14 -14
- hindsight_api/engine/search/think_utils.py +41 -140
- hindsight_api/engine/search/trace.py +0 -1
- hindsight_api/engine/search/tracer.py +2 -5
- hindsight_api/engine/search/types.py +0 -3
- hindsight_api/engine/task_backend.py +112 -196
- hindsight_api/engine/utils.py +0 -151
- hindsight_api/extensions/__init__.py +10 -1
- hindsight_api/extensions/builtin/tenant.py +11 -4
- hindsight_api/extensions/operation_validator.py +81 -4
- hindsight_api/extensions/tenant.py +26 -0
- hindsight_api/main.py +28 -5
- hindsight_api/mcp_local.py +12 -53
- hindsight_api/mcp_tools.py +494 -0
- hindsight_api/models.py +0 -2
- hindsight_api/worker/__init__.py +11 -0
- hindsight_api/worker/main.py +296 -0
- hindsight_api/worker/poller.py +486 -0
- {hindsight_api-0.3.0.dist-info → hindsight_api-0.4.1.dist-info}/METADATA +12 -6
- hindsight_api-0.4.1.dist-info/RECORD +112 -0
- {hindsight_api-0.3.0.dist-info → hindsight_api-0.4.1.dist-info}/entry_points.txt +1 -0
- hindsight_api/engine/retain/observation_regeneration.py +0 -254
- hindsight_api/engine/search/observation_utils.py +0 -125
- hindsight_api/engine/search/scoring.py +0 -159
- hindsight_api-0.3.0.dist-info/RECORD +0 -82
- {hindsight_api-0.3.0.dist-info → hindsight_api-0.4.1.dist-info}/WHEEL +0 -0
|
@@ -209,10 +209,10 @@ class LLMProvider:
|
|
|
209
209
|
OutputTooLongError: If output exceeds token limits.
|
|
210
210
|
Exception: Re-raises API errors after retries exhausted.
|
|
211
211
|
"""
|
|
212
|
-
|
|
212
|
+
semaphore_start = time.time()
|
|
213
213
|
async with _global_llm_semaphore:
|
|
214
|
+
semaphore_wait_time = time.time() - semaphore_start
|
|
214
215
|
start_time = time.time()
|
|
215
|
-
semaphore_wait_time = start_time - queue_start_time
|
|
216
216
|
|
|
217
217
|
# Handle Mock provider (for testing)
|
|
218
218
|
if self.provider == "mock":
|
|
@@ -318,43 +318,44 @@ class LLMProvider:
|
|
|
318
318
|
|
|
319
319
|
last_exception = None
|
|
320
320
|
|
|
321
|
+
# Prepare response format ONCE before the retry loop
|
|
322
|
+
# (to avoid appending schema to messages on every retry)
|
|
323
|
+
if response_format is not None:
|
|
324
|
+
schema = None
|
|
325
|
+
if hasattr(response_format, "model_json_schema"):
|
|
326
|
+
schema = response_format.model_json_schema()
|
|
327
|
+
|
|
328
|
+
if strict_schema and schema is not None:
|
|
329
|
+
# Use OpenAI's strict JSON schema enforcement
|
|
330
|
+
# This guarantees all required fields are returned
|
|
331
|
+
call_params["response_format"] = {
|
|
332
|
+
"type": "json_schema",
|
|
333
|
+
"json_schema": {
|
|
334
|
+
"name": "response",
|
|
335
|
+
"strict": True,
|
|
336
|
+
"schema": schema,
|
|
337
|
+
},
|
|
338
|
+
}
|
|
339
|
+
else:
|
|
340
|
+
# Soft enforcement: add schema to prompt and use json_object mode
|
|
341
|
+
if schema is not None:
|
|
342
|
+
schema_msg = f"\n\nYou must respond with valid JSON matching this schema:\n{json.dumps(schema, indent=2)}"
|
|
343
|
+
|
|
344
|
+
if call_params["messages"] and call_params["messages"][0].get("role") == "system":
|
|
345
|
+
call_params["messages"][0]["content"] += schema_msg
|
|
346
|
+
elif call_params["messages"]:
|
|
347
|
+
call_params["messages"][0]["content"] = (
|
|
348
|
+
schema_msg + "\n\n" + call_params["messages"][0]["content"]
|
|
349
|
+
)
|
|
350
|
+
if self.provider not in ("lmstudio", "ollama"):
|
|
351
|
+
# LM Studio and Ollama don't support json_object response format reliably
|
|
352
|
+
# We rely on the schema in the system message instead
|
|
353
|
+
call_params["response_format"] = {"type": "json_object"}
|
|
354
|
+
|
|
321
355
|
for attempt in range(max_retries + 1):
|
|
322
356
|
try:
|
|
323
357
|
if response_format is not None:
|
|
324
|
-
schema = None
|
|
325
|
-
if hasattr(response_format, "model_json_schema"):
|
|
326
|
-
schema = response_format.model_json_schema()
|
|
327
|
-
|
|
328
|
-
if strict_schema and schema is not None:
|
|
329
|
-
# Use OpenAI's strict JSON schema enforcement
|
|
330
|
-
# This guarantees all required fields are returned
|
|
331
|
-
call_params["response_format"] = {
|
|
332
|
-
"type": "json_schema",
|
|
333
|
-
"json_schema": {
|
|
334
|
-
"name": "response",
|
|
335
|
-
"strict": True,
|
|
336
|
-
"schema": schema,
|
|
337
|
-
},
|
|
338
|
-
}
|
|
339
|
-
else:
|
|
340
|
-
# Soft enforcement: add schema to prompt and use json_object mode
|
|
341
|
-
if schema is not None:
|
|
342
|
-
schema_msg = f"\n\nYou must respond with valid JSON matching this schema:\n{json.dumps(schema, indent=2)}"
|
|
343
|
-
|
|
344
|
-
if call_params["messages"] and call_params["messages"][0].get("role") == "system":
|
|
345
|
-
call_params["messages"][0]["content"] += schema_msg
|
|
346
|
-
elif call_params["messages"]:
|
|
347
|
-
call_params["messages"][0]["content"] = (
|
|
348
|
-
schema_msg + "\n\n" + call_params["messages"][0]["content"]
|
|
349
|
-
)
|
|
350
|
-
if self.provider not in ("lmstudio", "ollama"):
|
|
351
|
-
# LM Studio and Ollama don't support json_object response format reliably
|
|
352
|
-
# We rely on the schema in the system message instead
|
|
353
|
-
call_params["response_format"] = {"type": "json_object"}
|
|
354
|
-
|
|
355
|
-
logger.debug(f"Sending request to {self.provider}/{self.model} (timeout={self.timeout})")
|
|
356
358
|
response = await self._client.chat.completions.create(**call_params)
|
|
357
|
-
logger.debug(f"Received response from {self.provider}/{self.model}")
|
|
358
359
|
|
|
359
360
|
content = response.choices[0].message.content
|
|
360
361
|
|
|
@@ -467,13 +468,11 @@ class LLMProvider:
|
|
|
467
468
|
|
|
468
469
|
except APIConnectionError as e:
|
|
469
470
|
last_exception = e
|
|
471
|
+
status_code = getattr(e, "status_code", None) or getattr(
|
|
472
|
+
getattr(e, "response", None), "status_code", None
|
|
473
|
+
)
|
|
474
|
+
logger.warning(f"APIConnectionError (HTTP {status_code}), attempt {attempt + 1}: {str(e)[:200]}")
|
|
470
475
|
if attempt < max_retries:
|
|
471
|
-
status_code = getattr(e, "status_code", None) or getattr(
|
|
472
|
-
getattr(e, "response", None), "status_code", None
|
|
473
|
-
)
|
|
474
|
-
logger.warning(
|
|
475
|
-
f"Connection error, retrying... (attempt {attempt + 1}/{max_retries + 1}) - status_code={status_code}, message={e}"
|
|
476
|
-
)
|
|
477
476
|
backoff = min(initial_backoff * (2**attempt), max_backoff)
|
|
478
477
|
await asyncio.sleep(backoff)
|
|
479
478
|
continue
|
|
@@ -487,6 +486,45 @@ class LLMProvider:
|
|
|
487
486
|
logger.error(f"Auth error (HTTP {e.status_code}), not retrying: {str(e)}")
|
|
488
487
|
raise
|
|
489
488
|
|
|
489
|
+
# Handle tool_use_failed error - model outputted in tool call format
|
|
490
|
+
# Convert to expected JSON format and continue
|
|
491
|
+
if e.status_code == 400 and response_format is not None:
|
|
492
|
+
try:
|
|
493
|
+
error_body = e.body if hasattr(e, "body") else {}
|
|
494
|
+
if isinstance(error_body, dict):
|
|
495
|
+
error_info: dict[str, Any] = error_body.get("error") or {}
|
|
496
|
+
if error_info.get("code") == "tool_use_failed":
|
|
497
|
+
failed_gen = error_info.get("failed_generation", "")
|
|
498
|
+
if failed_gen:
|
|
499
|
+
# Parse the tool call format and convert to actions format
|
|
500
|
+
tool_call = json.loads(failed_gen)
|
|
501
|
+
tool_name = tool_call.get("name", "")
|
|
502
|
+
tool_args = tool_call.get("arguments", {})
|
|
503
|
+
# Convert to actions format: {"actions": [{"tool": "name", ...args}]}
|
|
504
|
+
converted = {"actions": [{"tool": tool_name, **tool_args}]}
|
|
505
|
+
if skip_validation:
|
|
506
|
+
result = converted
|
|
507
|
+
else:
|
|
508
|
+
result = response_format.model_validate(converted)
|
|
509
|
+
|
|
510
|
+
# Record metrics for this successful recovery
|
|
511
|
+
duration = time.time() - start_time
|
|
512
|
+
metrics = get_metrics_collector()
|
|
513
|
+
metrics.record_llm_call(
|
|
514
|
+
provider=self.provider,
|
|
515
|
+
model=self.model,
|
|
516
|
+
scope=scope,
|
|
517
|
+
duration=duration,
|
|
518
|
+
input_tokens=0,
|
|
519
|
+
output_tokens=0,
|
|
520
|
+
success=True,
|
|
521
|
+
)
|
|
522
|
+
if return_usage:
|
|
523
|
+
return result, TokenUsage(input_tokens=0, output_tokens=0, total_tokens=0)
|
|
524
|
+
return result
|
|
525
|
+
except (json.JSONDecodeError, KeyError, TypeError):
|
|
526
|
+
pass # Failed to parse tool_use_failed, continue with normal retry
|
|
527
|
+
|
|
490
528
|
last_exception = e
|
|
491
529
|
if attempt < max_retries:
|
|
492
530
|
backoff = min(initial_backoff * (2**attempt), max_backoff)
|
|
@@ -497,14 +535,438 @@ class LLMProvider:
|
|
|
497
535
|
logger.error(f"API error after {max_retries + 1} attempts: {str(e)}")
|
|
498
536
|
raise
|
|
499
537
|
|
|
500
|
-
except Exception
|
|
501
|
-
logger.error(f"Unexpected error during LLM call: {type(e).__name__}: {str(e)}")
|
|
538
|
+
except Exception:
|
|
502
539
|
raise
|
|
503
540
|
|
|
504
541
|
if last_exception:
|
|
505
542
|
raise last_exception
|
|
506
543
|
raise RuntimeError("LLM call failed after all retries with no exception captured")
|
|
507
544
|
|
|
545
|
+
async def call_with_tools(
|
|
546
|
+
self,
|
|
547
|
+
messages: list[dict[str, Any]],
|
|
548
|
+
tools: list[dict[str, Any]],
|
|
549
|
+
max_completion_tokens: int | None = None,
|
|
550
|
+
temperature: float | None = None,
|
|
551
|
+
scope: str = "tools",
|
|
552
|
+
max_retries: int = 5,
|
|
553
|
+
initial_backoff: float = 1.0,
|
|
554
|
+
max_backoff: float = 30.0,
|
|
555
|
+
tool_choice: str | dict[str, Any] = "auto",
|
|
556
|
+
) -> "LLMToolCallResult":
|
|
557
|
+
"""
|
|
558
|
+
Make an LLM API call with tool/function calling support.
|
|
559
|
+
|
|
560
|
+
Args:
|
|
561
|
+
messages: List of message dicts. Can include tool results with role='tool'.
|
|
562
|
+
tools: List of tool definitions in OpenAI format.
|
|
563
|
+
max_completion_tokens: Maximum tokens in response.
|
|
564
|
+
temperature: Sampling temperature (0.0-2.0).
|
|
565
|
+
scope: Scope identifier for tracking.
|
|
566
|
+
max_retries: Maximum retry attempts.
|
|
567
|
+
initial_backoff: Initial backoff time in seconds.
|
|
568
|
+
max_backoff: Maximum backoff time in seconds.
|
|
569
|
+
tool_choice: How to choose tools - "auto", "none", "required", or {"type": "function", "function": {"name": "..."}}
|
|
570
|
+
|
|
571
|
+
Returns:
|
|
572
|
+
LLMToolCallResult with content and/or tool_calls.
|
|
573
|
+
"""
|
|
574
|
+
from .response_models import LLMToolCall, LLMToolCallResult
|
|
575
|
+
|
|
576
|
+
async with _global_llm_semaphore:
|
|
577
|
+
start_time = time.time()
|
|
578
|
+
|
|
579
|
+
# Handle Mock provider
|
|
580
|
+
if self.provider == "mock":
|
|
581
|
+
return await self._call_with_tools_mock(messages, tools, scope)
|
|
582
|
+
|
|
583
|
+
# Handle Anthropic separately (uses different tool format)
|
|
584
|
+
if self.provider == "anthropic":
|
|
585
|
+
return await self._call_with_tools_anthropic(
|
|
586
|
+
messages, tools, max_completion_tokens, max_retries, initial_backoff, max_backoff, start_time, scope
|
|
587
|
+
)
|
|
588
|
+
|
|
589
|
+
# Handle Gemini (convert to Gemini tool format)
|
|
590
|
+
if self.provider == "gemini":
|
|
591
|
+
return await self._call_with_tools_gemini(
|
|
592
|
+
messages, tools, max_retries, initial_backoff, max_backoff, start_time, scope
|
|
593
|
+
)
|
|
594
|
+
|
|
595
|
+
# OpenAI-compatible providers (OpenAI, Groq, Ollama, LMStudio)
|
|
596
|
+
call_params: dict[str, Any] = {
|
|
597
|
+
"model": self.model,
|
|
598
|
+
"messages": messages,
|
|
599
|
+
"tools": tools,
|
|
600
|
+
"tool_choice": tool_choice,
|
|
601
|
+
}
|
|
602
|
+
|
|
603
|
+
if max_completion_tokens is not None:
|
|
604
|
+
call_params["max_completion_tokens"] = max_completion_tokens
|
|
605
|
+
if temperature is not None:
|
|
606
|
+
call_params["temperature"] = temperature
|
|
607
|
+
|
|
608
|
+
# Provider-specific parameters
|
|
609
|
+
if self.provider == "groq":
|
|
610
|
+
call_params["seed"] = DEFAULT_LLM_SEED
|
|
611
|
+
|
|
612
|
+
last_exception = None
|
|
613
|
+
|
|
614
|
+
for attempt in range(max_retries + 1):
|
|
615
|
+
try:
|
|
616
|
+
response = await self._client.chat.completions.create(**call_params)
|
|
617
|
+
|
|
618
|
+
message = response.choices[0].message
|
|
619
|
+
finish_reason = response.choices[0].finish_reason
|
|
620
|
+
|
|
621
|
+
# Extract tool calls if present
|
|
622
|
+
tool_calls: list[LLMToolCall] = []
|
|
623
|
+
if message.tool_calls:
|
|
624
|
+
for tc in message.tool_calls:
|
|
625
|
+
try:
|
|
626
|
+
args = json.loads(tc.function.arguments) if tc.function.arguments else {}
|
|
627
|
+
except json.JSONDecodeError:
|
|
628
|
+
args = {"_raw": tc.function.arguments}
|
|
629
|
+
tool_calls.append(LLMToolCall(id=tc.id, name=tc.function.name, arguments=args))
|
|
630
|
+
|
|
631
|
+
content = message.content
|
|
632
|
+
|
|
633
|
+
# Record metrics
|
|
634
|
+
duration = time.time() - start_time
|
|
635
|
+
usage = response.usage
|
|
636
|
+
input_tokens = usage.prompt_tokens or 0 if usage else 0
|
|
637
|
+
output_tokens = usage.completion_tokens or 0 if usage else 0
|
|
638
|
+
|
|
639
|
+
metrics = get_metrics_collector()
|
|
640
|
+
metrics.record_llm_call(
|
|
641
|
+
provider=self.provider,
|
|
642
|
+
model=self.model,
|
|
643
|
+
scope=scope,
|
|
644
|
+
duration=duration,
|
|
645
|
+
input_tokens=input_tokens,
|
|
646
|
+
output_tokens=output_tokens,
|
|
647
|
+
success=True,
|
|
648
|
+
)
|
|
649
|
+
|
|
650
|
+
return LLMToolCallResult(
|
|
651
|
+
content=content,
|
|
652
|
+
tool_calls=tool_calls,
|
|
653
|
+
finish_reason=finish_reason,
|
|
654
|
+
input_tokens=input_tokens,
|
|
655
|
+
output_tokens=output_tokens,
|
|
656
|
+
)
|
|
657
|
+
|
|
658
|
+
except APIConnectionError as e:
|
|
659
|
+
last_exception = e
|
|
660
|
+
if attempt < max_retries:
|
|
661
|
+
await asyncio.sleep(min(initial_backoff * (2**attempt), max_backoff))
|
|
662
|
+
continue
|
|
663
|
+
raise
|
|
664
|
+
|
|
665
|
+
except APIStatusError as e:
|
|
666
|
+
if e.status_code in (401, 403):
|
|
667
|
+
raise
|
|
668
|
+
last_exception = e
|
|
669
|
+
if attempt < max_retries:
|
|
670
|
+
await asyncio.sleep(min(initial_backoff * (2**attempt), max_backoff))
|
|
671
|
+
continue
|
|
672
|
+
raise
|
|
673
|
+
|
|
674
|
+
except Exception:
|
|
675
|
+
raise
|
|
676
|
+
|
|
677
|
+
if last_exception:
|
|
678
|
+
raise last_exception
|
|
679
|
+
raise RuntimeError("Tool call failed after all retries")
|
|
680
|
+
|
|
681
|
+
async def _call_with_tools_mock(
|
|
682
|
+
self,
|
|
683
|
+
messages: list[dict[str, Any]],
|
|
684
|
+
tools: list[dict[str, Any]],
|
|
685
|
+
scope: str,
|
|
686
|
+
) -> "LLMToolCallResult":
|
|
687
|
+
"""Handle mock tool calls for testing."""
|
|
688
|
+
from .response_models import LLMToolCallResult
|
|
689
|
+
|
|
690
|
+
call_record = {
|
|
691
|
+
"provider": self.provider,
|
|
692
|
+
"model": self.model,
|
|
693
|
+
"messages": messages,
|
|
694
|
+
"tools": [t.get("function", {}).get("name") for t in tools],
|
|
695
|
+
"scope": scope,
|
|
696
|
+
}
|
|
697
|
+
self._mock_calls.append(call_record)
|
|
698
|
+
|
|
699
|
+
if self._mock_response is not None:
|
|
700
|
+
if isinstance(self._mock_response, LLMToolCallResult):
|
|
701
|
+
return self._mock_response
|
|
702
|
+
# Allow setting just tool calls as a list
|
|
703
|
+
if isinstance(self._mock_response, list):
|
|
704
|
+
from .response_models import LLMToolCall
|
|
705
|
+
|
|
706
|
+
return LLMToolCallResult(
|
|
707
|
+
tool_calls=[
|
|
708
|
+
LLMToolCall(id=f"mock_{i}", name=tc["name"], arguments=tc.get("arguments", {}))
|
|
709
|
+
for i, tc in enumerate(self._mock_response)
|
|
710
|
+
],
|
|
711
|
+
finish_reason="tool_calls",
|
|
712
|
+
)
|
|
713
|
+
|
|
714
|
+
return LLMToolCallResult(content="mock response", finish_reason="stop")
|
|
715
|
+
|
|
716
|
+
async def _call_with_tools_anthropic(
|
|
717
|
+
self,
|
|
718
|
+
messages: list[dict[str, Any]],
|
|
719
|
+
tools: list[dict[str, Any]],
|
|
720
|
+
max_completion_tokens: int | None,
|
|
721
|
+
max_retries: int,
|
|
722
|
+
initial_backoff: float,
|
|
723
|
+
max_backoff: float,
|
|
724
|
+
start_time: float,
|
|
725
|
+
scope: str,
|
|
726
|
+
) -> "LLMToolCallResult":
|
|
727
|
+
"""Handle Anthropic tool calling."""
|
|
728
|
+
from anthropic import APIConnectionError, APIStatusError
|
|
729
|
+
|
|
730
|
+
from .response_models import LLMToolCall, LLMToolCallResult
|
|
731
|
+
|
|
732
|
+
# Convert OpenAI tool format to Anthropic format
|
|
733
|
+
anthropic_tools = []
|
|
734
|
+
for tool in tools:
|
|
735
|
+
func = tool.get("function", {})
|
|
736
|
+
anthropic_tools.append(
|
|
737
|
+
{
|
|
738
|
+
"name": func.get("name", ""),
|
|
739
|
+
"description": func.get("description", ""),
|
|
740
|
+
"input_schema": func.get("parameters", {"type": "object", "properties": {}}),
|
|
741
|
+
}
|
|
742
|
+
)
|
|
743
|
+
|
|
744
|
+
# Convert messages - handle tool results
|
|
745
|
+
system_prompt = None
|
|
746
|
+
anthropic_messages = []
|
|
747
|
+
for msg in messages:
|
|
748
|
+
role = msg.get("role", "user")
|
|
749
|
+
content = msg.get("content", "")
|
|
750
|
+
|
|
751
|
+
if role == "system":
|
|
752
|
+
system_prompt = (system_prompt + "\n\n" + content) if system_prompt else content
|
|
753
|
+
elif role == "tool":
|
|
754
|
+
# Anthropic uses tool_result blocks
|
|
755
|
+
anthropic_messages.append(
|
|
756
|
+
{
|
|
757
|
+
"role": "user",
|
|
758
|
+
"content": [
|
|
759
|
+
{"type": "tool_result", "tool_use_id": msg.get("tool_call_id", ""), "content": content}
|
|
760
|
+
],
|
|
761
|
+
}
|
|
762
|
+
)
|
|
763
|
+
elif role == "assistant" and msg.get("tool_calls"):
|
|
764
|
+
# Convert assistant tool calls
|
|
765
|
+
tool_use_blocks = []
|
|
766
|
+
for tc in msg["tool_calls"]:
|
|
767
|
+
tool_use_blocks.append(
|
|
768
|
+
{
|
|
769
|
+
"type": "tool_use",
|
|
770
|
+
"id": tc.get("id", ""),
|
|
771
|
+
"name": tc.get("function", {}).get("name", ""),
|
|
772
|
+
"input": json.loads(tc.get("function", {}).get("arguments", "{}")),
|
|
773
|
+
}
|
|
774
|
+
)
|
|
775
|
+
anthropic_messages.append({"role": "assistant", "content": tool_use_blocks})
|
|
776
|
+
else:
|
|
777
|
+
anthropic_messages.append({"role": role, "content": content})
|
|
778
|
+
|
|
779
|
+
call_params: dict[str, Any] = {
|
|
780
|
+
"model": self.model,
|
|
781
|
+
"messages": anthropic_messages,
|
|
782
|
+
"tools": anthropic_tools,
|
|
783
|
+
"max_tokens": max_completion_tokens or 4096,
|
|
784
|
+
}
|
|
785
|
+
if system_prompt:
|
|
786
|
+
call_params["system"] = system_prompt
|
|
787
|
+
|
|
788
|
+
last_exception = None
|
|
789
|
+
for attempt in range(max_retries + 1):
|
|
790
|
+
try:
|
|
791
|
+
response = await self._anthropic_client.messages.create(**call_params)
|
|
792
|
+
|
|
793
|
+
# Extract content and tool calls
|
|
794
|
+
content_parts = []
|
|
795
|
+
tool_calls: list[LLMToolCall] = []
|
|
796
|
+
|
|
797
|
+
for block in response.content:
|
|
798
|
+
if block.type == "text":
|
|
799
|
+
content_parts.append(block.text)
|
|
800
|
+
elif block.type == "tool_use":
|
|
801
|
+
tool_calls.append(LLMToolCall(id=block.id, name=block.name, arguments=block.input or {}))
|
|
802
|
+
|
|
803
|
+
content = "".join(content_parts) if content_parts else None
|
|
804
|
+
finish_reason = "tool_calls" if tool_calls else "stop"
|
|
805
|
+
|
|
806
|
+
# Extract token usage
|
|
807
|
+
input_tokens = response.usage.input_tokens or 0
|
|
808
|
+
output_tokens = response.usage.output_tokens or 0
|
|
809
|
+
|
|
810
|
+
# Record metrics
|
|
811
|
+
metrics = get_metrics_collector()
|
|
812
|
+
metrics.record_llm_call(
|
|
813
|
+
provider=self.provider,
|
|
814
|
+
model=self.model,
|
|
815
|
+
scope=scope,
|
|
816
|
+
duration=time.time() - start_time,
|
|
817
|
+
input_tokens=input_tokens,
|
|
818
|
+
output_tokens=output_tokens,
|
|
819
|
+
success=True,
|
|
820
|
+
)
|
|
821
|
+
|
|
822
|
+
return LLMToolCallResult(
|
|
823
|
+
content=content,
|
|
824
|
+
tool_calls=tool_calls,
|
|
825
|
+
finish_reason=finish_reason,
|
|
826
|
+
input_tokens=input_tokens,
|
|
827
|
+
output_tokens=output_tokens,
|
|
828
|
+
)
|
|
829
|
+
|
|
830
|
+
except (APIConnectionError, APIStatusError) as e:
|
|
831
|
+
if isinstance(e, APIStatusError) and e.status_code in (401, 403):
|
|
832
|
+
raise
|
|
833
|
+
last_exception = e
|
|
834
|
+
if attempt < max_retries:
|
|
835
|
+
await asyncio.sleep(min(initial_backoff * (2**attempt), max_backoff))
|
|
836
|
+
continue
|
|
837
|
+
raise
|
|
838
|
+
|
|
839
|
+
if last_exception:
|
|
840
|
+
raise last_exception
|
|
841
|
+
raise RuntimeError("Anthropic tool call failed")
|
|
842
|
+
|
|
843
|
+
async def _call_with_tools_gemini(
|
|
844
|
+
self,
|
|
845
|
+
messages: list[dict[str, Any]],
|
|
846
|
+
tools: list[dict[str, Any]],
|
|
847
|
+
max_retries: int,
|
|
848
|
+
initial_backoff: float,
|
|
849
|
+
max_backoff: float,
|
|
850
|
+
start_time: float,
|
|
851
|
+
scope: str,
|
|
852
|
+
) -> "LLMToolCallResult":
|
|
853
|
+
"""Handle Gemini tool calling."""
|
|
854
|
+
from .response_models import LLMToolCall, LLMToolCallResult
|
|
855
|
+
|
|
856
|
+
# Convert tools to Gemini format
|
|
857
|
+
gemini_tools = []
|
|
858
|
+
for tool in tools:
|
|
859
|
+
func = tool.get("function", {})
|
|
860
|
+
gemini_tools.append(
|
|
861
|
+
genai_types.Tool(
|
|
862
|
+
function_declarations=[
|
|
863
|
+
genai_types.FunctionDeclaration(
|
|
864
|
+
name=func.get("name", ""),
|
|
865
|
+
description=func.get("description", ""),
|
|
866
|
+
parameters=func.get("parameters"),
|
|
867
|
+
)
|
|
868
|
+
]
|
|
869
|
+
)
|
|
870
|
+
)
|
|
871
|
+
|
|
872
|
+
# Convert messages
|
|
873
|
+
system_instruction = None
|
|
874
|
+
gemini_contents = []
|
|
875
|
+
for msg in messages:
|
|
876
|
+
role = msg.get("role", "user")
|
|
877
|
+
content = msg.get("content", "")
|
|
878
|
+
|
|
879
|
+
if role == "system":
|
|
880
|
+
system_instruction = (system_instruction + "\n\n" + content) if system_instruction else content
|
|
881
|
+
elif role == "tool":
|
|
882
|
+
# Gemini uses function_response
|
|
883
|
+
gemini_contents.append(
|
|
884
|
+
genai_types.Content(
|
|
885
|
+
role="user",
|
|
886
|
+
parts=[
|
|
887
|
+
genai_types.Part(
|
|
888
|
+
function_response=genai_types.FunctionResponse(
|
|
889
|
+
name=msg.get("name", ""),
|
|
890
|
+
response={"result": content},
|
|
891
|
+
)
|
|
892
|
+
)
|
|
893
|
+
],
|
|
894
|
+
)
|
|
895
|
+
)
|
|
896
|
+
elif role == "assistant":
|
|
897
|
+
gemini_contents.append(genai_types.Content(role="model", parts=[genai_types.Part(text=content)]))
|
|
898
|
+
else:
|
|
899
|
+
gemini_contents.append(genai_types.Content(role="user", parts=[genai_types.Part(text=content)]))
|
|
900
|
+
|
|
901
|
+
config = genai_types.GenerateContentConfig(
|
|
902
|
+
system_instruction=system_instruction,
|
|
903
|
+
tools=gemini_tools,
|
|
904
|
+
)
|
|
905
|
+
|
|
906
|
+
last_exception = None
|
|
907
|
+
for attempt in range(max_retries + 1):
|
|
908
|
+
try:
|
|
909
|
+
response = await self._gemini_client.aio.models.generate_content(
|
|
910
|
+
model=self.model,
|
|
911
|
+
contents=gemini_contents,
|
|
912
|
+
config=config,
|
|
913
|
+
)
|
|
914
|
+
|
|
915
|
+
# Extract content and tool calls
|
|
916
|
+
content = None
|
|
917
|
+
tool_calls: list[LLMToolCall] = []
|
|
918
|
+
|
|
919
|
+
if response.candidates and response.candidates[0].content:
|
|
920
|
+
for part in response.candidates[0].content.parts:
|
|
921
|
+
if hasattr(part, "text") and part.text:
|
|
922
|
+
content = part.text
|
|
923
|
+
if hasattr(part, "function_call") and part.function_call:
|
|
924
|
+
fc = part.function_call
|
|
925
|
+
tool_calls.append(
|
|
926
|
+
LLMToolCall(
|
|
927
|
+
id=f"gemini_{len(tool_calls)}",
|
|
928
|
+
name=fc.name,
|
|
929
|
+
arguments=dict(fc.args) if fc.args else {},
|
|
930
|
+
)
|
|
931
|
+
)
|
|
932
|
+
|
|
933
|
+
finish_reason = "tool_calls" if tool_calls else "stop"
|
|
934
|
+
|
|
935
|
+
# Record metrics
|
|
936
|
+
metrics = get_metrics_collector()
|
|
937
|
+
input_tokens = response.usage_metadata.prompt_token_count if response.usage_metadata else 0
|
|
938
|
+
output_tokens = response.usage_metadata.candidates_token_count if response.usage_metadata else 0
|
|
939
|
+
metrics.record_llm_call(
|
|
940
|
+
provider=self.provider,
|
|
941
|
+
model=self.model,
|
|
942
|
+
scope=scope,
|
|
943
|
+
duration=time.time() - start_time,
|
|
944
|
+
input_tokens=input_tokens,
|
|
945
|
+
output_tokens=output_tokens,
|
|
946
|
+
success=True,
|
|
947
|
+
)
|
|
948
|
+
|
|
949
|
+
return LLMToolCallResult(
|
|
950
|
+
content=content,
|
|
951
|
+
tool_calls=tool_calls,
|
|
952
|
+
finish_reason=finish_reason,
|
|
953
|
+
input_tokens=input_tokens,
|
|
954
|
+
output_tokens=output_tokens,
|
|
955
|
+
)
|
|
956
|
+
|
|
957
|
+
except genai_errors.APIError as e:
|
|
958
|
+
if e.code in (401, 403):
|
|
959
|
+
raise
|
|
960
|
+
last_exception = e
|
|
961
|
+
if attempt < max_retries:
|
|
962
|
+
await asyncio.sleep(min(initial_backoff * (2**attempt), max_backoff))
|
|
963
|
+
continue
|
|
964
|
+
raise
|
|
965
|
+
|
|
966
|
+
if last_exception:
|
|
967
|
+
raise last_exception
|
|
968
|
+
raise RuntimeError("Gemini tool call failed")
|
|
969
|
+
|
|
508
970
|
async def _call_anthropic(
|
|
509
971
|
self,
|
|
510
972
|
messages: list[dict[str, str]],
|