vectara-agentic 0.3.3__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of vectara-agentic might be problematic. Click here for more details.
- tests/__init__.py +7 -0
- tests/conftest.py +312 -0
- tests/endpoint.py +54 -17
- tests/run_tests.py +111 -0
- tests/test_agent.py +10 -5
- tests/test_agent_type.py +82 -143
- tests/test_api_endpoint.py +4 -0
- tests/test_bedrock.py +4 -0
- tests/test_fallback.py +4 -0
- tests/test_gemini.py +28 -45
- tests/test_groq.py +4 -0
- tests/test_private_llm.py +11 -2
- tests/test_return_direct.py +6 -2
- tests/test_serialization.py +4 -0
- tests/test_streaming.py +88 -0
- tests/test_tools.py +10 -82
- tests/test_vectara_llms.py +4 -0
- tests/test_vhc.py +66 -0
- tests/test_workflow.py +4 -0
- vectara_agentic/__init__.py +27 -4
- vectara_agentic/_callback.py +65 -67
- vectara_agentic/_observability.py +30 -30
- vectara_agentic/_version.py +1 -1
- vectara_agentic/agent.py +375 -848
- vectara_agentic/agent_config.py +15 -14
- vectara_agentic/agent_core/__init__.py +22 -0
- vectara_agentic/agent_core/factory.py +501 -0
- vectara_agentic/{_prompts.py → agent_core/prompts.py} +3 -35
- vectara_agentic/agent_core/serialization.py +345 -0
- vectara_agentic/agent_core/streaming.py +495 -0
- vectara_agentic/agent_core/utils/__init__.py +34 -0
- vectara_agentic/agent_core/utils/hallucination.py +202 -0
- vectara_agentic/agent_core/utils/logging.py +52 -0
- vectara_agentic/agent_core/utils/prompt_formatting.py +56 -0
- vectara_agentic/agent_core/utils/schemas.py +87 -0
- vectara_agentic/agent_core/utils/tools.py +125 -0
- vectara_agentic/agent_endpoint.py +4 -6
- vectara_agentic/db_tools.py +37 -12
- vectara_agentic/llm_utils.py +41 -42
- vectara_agentic/sub_query_workflow.py +9 -14
- vectara_agentic/tool_utils.py +138 -83
- vectara_agentic/tools.py +36 -21
- vectara_agentic/tools_catalog.py +16 -16
- vectara_agentic/types.py +98 -6
- {vectara_agentic-0.3.3.dist-info → vectara_agentic-0.4.0.dist-info}/METADATA +69 -30
- vectara_agentic-0.4.0.dist-info/RECORD +50 -0
- tests/test_agent_planning.py +0 -64
- tests/test_hhem.py +0 -100
- vectara_agentic/hhem.py +0 -82
- vectara_agentic-0.3.3.dist-info/RECORD +0 -39
- {vectara_agentic-0.3.3.dist-info → vectara_agentic-0.4.0.dist-info}/WHEEL +0 -0
- {vectara_agentic-0.3.3.dist-info → vectara_agentic-0.4.0.dist-info}/licenses/LICENSE +0 -0
- {vectara_agentic-0.3.3.dist-info → vectara_agentic-0.4.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,202 @@
|
|
|
1
|
+
"""Vectara Hallucination Detection and Correction client."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from typing import List, Dict, Optional, Tuple
|
|
5
|
+
import requests
|
|
6
|
+
|
|
7
|
+
from llama_index.core.llms import MessageRole
|
|
8
|
+
|
|
9
|
+
class Hallucination:
|
|
10
|
+
"""Vectara Hallucination Correction."""
|
|
11
|
+
|
|
12
|
+
def __init__(self, vectara_api_key: str):
|
|
13
|
+
self._vectara_api_key = vectara_api_key
|
|
14
|
+
|
|
15
|
+
def compute(
|
|
16
|
+
self, query: str, context: list[str], hypothesis: str
|
|
17
|
+
) -> Tuple[str, list[str]]:
|
|
18
|
+
"""
|
|
19
|
+
Calls the Vectara VHC (Vectara Hallucination Correction)
|
|
20
|
+
|
|
21
|
+
Returns:
|
|
22
|
+
str: The corrected hypothesis text.
|
|
23
|
+
list[str]: the list of corrections from VHC
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
payload = {
|
|
27
|
+
"generated_text": hypothesis,
|
|
28
|
+
"query": query,
|
|
29
|
+
"documents": [{"text": c} for c in context],
|
|
30
|
+
"model_name": "vhc-large-1.0",
|
|
31
|
+
}
|
|
32
|
+
headers = {
|
|
33
|
+
"Content-Type": "application/json",
|
|
34
|
+
"Accept": "application/json",
|
|
35
|
+
"x-api-key": self._vectara_api_key,
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
response = requests.post(
|
|
39
|
+
"https://api.vectara.io/v2/hallucination_correctors/correct_hallucinations",
|
|
40
|
+
json=payload,
|
|
41
|
+
headers=headers,
|
|
42
|
+
timeout=30,
|
|
43
|
+
)
|
|
44
|
+
response.raise_for_status()
|
|
45
|
+
data = response.json()
|
|
46
|
+
corrected_text = data.get("corrected_text", "")
|
|
47
|
+
corrections = data.get("corrections", [])
|
|
48
|
+
|
|
49
|
+
logging.debug(
|
|
50
|
+
f"VHC: query={query}\n"
|
|
51
|
+
)
|
|
52
|
+
logging.debug(
|
|
53
|
+
f"VHC: response={hypothesis}\n"
|
|
54
|
+
)
|
|
55
|
+
logging.debug("VHC: Context:")
|
|
56
|
+
for i, ctx in enumerate(context):
|
|
57
|
+
logging.info(f"VHC: context {i}: {ctx}\n\n")
|
|
58
|
+
|
|
59
|
+
logging.debug(
|
|
60
|
+
f"VHC: outputs: {len(corrections)} corrections"
|
|
61
|
+
)
|
|
62
|
+
logging.debug(
|
|
63
|
+
f"VHC: corrected_text: {corrected_text}\n"
|
|
64
|
+
)
|
|
65
|
+
for correction in corrections:
|
|
66
|
+
logging.debug(f"VHC: correction: {correction}\n")
|
|
67
|
+
|
|
68
|
+
return corrected_text, corrections
|
|
69
|
+
|
|
70
|
+
def extract_tool_call_mapping(chat_history) -> Dict[str, str]:
|
|
71
|
+
"""Extract tool_call_id to tool_name mapping from chat history."""
|
|
72
|
+
tool_call_id_to_name = {}
|
|
73
|
+
for msg in chat_history:
|
|
74
|
+
if (
|
|
75
|
+
msg.role == MessageRole.ASSISTANT
|
|
76
|
+
and hasattr(msg, "additional_kwargs")
|
|
77
|
+
and msg.additional_kwargs
|
|
78
|
+
):
|
|
79
|
+
tool_calls = msg.additional_kwargs.get("tool_calls", [])
|
|
80
|
+
for tool_call in tool_calls:
|
|
81
|
+
if (
|
|
82
|
+
isinstance(tool_call, dict)
|
|
83
|
+
and "id" in tool_call
|
|
84
|
+
and "function" in tool_call
|
|
85
|
+
):
|
|
86
|
+
tool_call_id = tool_call["id"]
|
|
87
|
+
tool_name = tool_call["function"].get("name")
|
|
88
|
+
if tool_call_id and tool_name:
|
|
89
|
+
tool_call_id_to_name[tool_call_id] = tool_name
|
|
90
|
+
|
|
91
|
+
return tool_call_id_to_name
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def identify_tool_name(msg, tool_call_id_to_name: Dict[str, str]) -> Optional[str]:
|
|
95
|
+
"""Identify tool name from message using multiple strategies."""
|
|
96
|
+
tool_name = None
|
|
97
|
+
|
|
98
|
+
# First try: standard tool_name attribute (for backwards compatibility)
|
|
99
|
+
tool_name = getattr(msg, "tool_name", None)
|
|
100
|
+
|
|
101
|
+
# Second try: additional_kwargs (LlamaIndex standard location)
|
|
102
|
+
if (
|
|
103
|
+
tool_name is None
|
|
104
|
+
and hasattr(msg, "additional_kwargs")
|
|
105
|
+
and msg.additional_kwargs
|
|
106
|
+
):
|
|
107
|
+
tool_name = msg.additional_kwargs.get("name") or msg.additional_kwargs.get(
|
|
108
|
+
"tool_name"
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
# If no direct tool name, try to map from tool_call_id
|
|
112
|
+
if tool_name is None:
|
|
113
|
+
tool_call_id = msg.additional_kwargs.get("tool_call_id")
|
|
114
|
+
if tool_call_id and tool_call_id in tool_call_id_to_name:
|
|
115
|
+
tool_name = tool_call_id_to_name[tool_call_id]
|
|
116
|
+
|
|
117
|
+
# Third try: extract from content if it's a ToolOutput object
|
|
118
|
+
if tool_name is None and hasattr(msg.content, "tool_name"):
|
|
119
|
+
tool_name = msg.content.tool_name
|
|
120
|
+
|
|
121
|
+
return tool_name
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def check_tool_eligibility(tool_name: Optional[str], tools: List) -> bool:
|
|
125
|
+
"""Check if a tool output is eligible to be included in VHC, by looking up in tools list."""
|
|
126
|
+
if not tool_name or not tools:
|
|
127
|
+
return False
|
|
128
|
+
|
|
129
|
+
# Try to find the tool and check its VHC eligibility
|
|
130
|
+
for tool in tools:
|
|
131
|
+
if (
|
|
132
|
+
hasattr(tool, "metadata")
|
|
133
|
+
and hasattr(tool.metadata, "name")
|
|
134
|
+
and tool.metadata.name == tool_name
|
|
135
|
+
):
|
|
136
|
+
if hasattr(tool.metadata, "vhc_eligible"):
|
|
137
|
+
is_vhc_eligible = tool.metadata.vhc_eligible
|
|
138
|
+
return is_vhc_eligible
|
|
139
|
+
break
|
|
140
|
+
|
|
141
|
+
return True
|
|
142
|
+
|
|
143
|
+
def analyze_hallucinations(
|
|
144
|
+
query: str, chat_history: List,
|
|
145
|
+
agent_response: str, tools: List, vectara_api_key: str
|
|
146
|
+
) -> Tuple[Optional[str], List[str]]:
|
|
147
|
+
"""Use VHC to compute corrected_text and corrections."""
|
|
148
|
+
if not vectara_api_key:
|
|
149
|
+
logging.debug("No Vectara API key - returning None")
|
|
150
|
+
return None, []
|
|
151
|
+
|
|
152
|
+
# Build a mapping from tool_call_id to tool_name for better tool identification
|
|
153
|
+
tool_call_id_to_name = extract_tool_call_mapping(chat_history)
|
|
154
|
+
|
|
155
|
+
context = []
|
|
156
|
+
last_assistant_index = -1
|
|
157
|
+
for i, msg in enumerate(chat_history):
|
|
158
|
+
if msg.role == MessageRole.ASSISTANT and msg.content:
|
|
159
|
+
last_assistant_index = i
|
|
160
|
+
|
|
161
|
+
for i, msg in enumerate(chat_history):
|
|
162
|
+
if msg.role == MessageRole.TOOL:
|
|
163
|
+
tool_name = identify_tool_name(msg, tool_call_id_to_name)
|
|
164
|
+
is_vhc_eligible = check_tool_eligibility(tool_name, tools)
|
|
165
|
+
|
|
166
|
+
# Only count tool calls from VHC-eligible tools
|
|
167
|
+
if is_vhc_eligible:
|
|
168
|
+
content = msg.content
|
|
169
|
+
|
|
170
|
+
# Since tools with human-readable output now convert to formatted strings immediately
|
|
171
|
+
# in VectaraTool._format_tool_output(), we just use the content directly
|
|
172
|
+
content = str(content) if content is not None else ""
|
|
173
|
+
|
|
174
|
+
# Only add non-empty content to context
|
|
175
|
+
if content and content.strip():
|
|
176
|
+
context.append(content)
|
|
177
|
+
|
|
178
|
+
elif msg.role == MessageRole.USER and msg.content:
|
|
179
|
+
context.append(msg.content)
|
|
180
|
+
|
|
181
|
+
elif msg.role == MessageRole.ASSISTANT and msg.content:
|
|
182
|
+
if i == last_assistant_index: # do not include the last assistant message
|
|
183
|
+
continue
|
|
184
|
+
context.append(msg.content)
|
|
185
|
+
|
|
186
|
+
# If no context or no tool calls, we cannot compute VHC
|
|
187
|
+
if len(context) == 0:
|
|
188
|
+
return None, []
|
|
189
|
+
|
|
190
|
+
try:
|
|
191
|
+
h = Hallucination(vectara_api_key)
|
|
192
|
+
corrected_text, corrections = h.compute(
|
|
193
|
+
query=query, context=context, hypothesis=agent_response
|
|
194
|
+
)
|
|
195
|
+
return corrected_text, corrections
|
|
196
|
+
|
|
197
|
+
except Exception as e:
|
|
198
|
+
logging.error(
|
|
199
|
+
f"VHC call failed: {e}. "
|
|
200
|
+
"Ensure you have a valid Vectara API key and the Hallucination Correction service is available."
|
|
201
|
+
)
|
|
202
|
+
return None, []
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Logging configuration and utilities for agent functionality.
|
|
3
|
+
|
|
4
|
+
This module provides logging filters, configuration, and setup utilities
|
|
5
|
+
specifically tailored for agent operations and debugging.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import logging
|
|
9
|
+
from dotenv import load_dotenv
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class IgnoreUnpickleableAttributeFilter(logging.Filter):
|
|
13
|
+
"""
|
|
14
|
+
Filter to ignore log messages that contain certain strings.
|
|
15
|
+
|
|
16
|
+
This filter is used to suppress common unpickleable attribute warnings
|
|
17
|
+
that occur during agent serialization/deserialization operations.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
def filter(self, record):
|
|
21
|
+
"""
|
|
22
|
+
Filter log records based on message content.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
record: LogRecord to evaluate
|
|
26
|
+
|
|
27
|
+
Returns:
|
|
28
|
+
bool: True if record should be logged, False if it should be ignored
|
|
29
|
+
"""
|
|
30
|
+
msgs_to_ignore = [
|
|
31
|
+
"Removing unpickleable private attribute _split_fns",
|
|
32
|
+
"Removing unpickleable private attribute _sub_sentence_split_fns",
|
|
33
|
+
]
|
|
34
|
+
return all(msg not in record.getMessage() for msg in msgs_to_ignore)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def setup_agent_logging():
|
|
38
|
+
"""
|
|
39
|
+
Set up logging configuration for agent operations.
|
|
40
|
+
|
|
41
|
+
This configures logging filters and levels to reduce noise from
|
|
42
|
+
agent-related operations while maintaining useful debug information.
|
|
43
|
+
"""
|
|
44
|
+
# Add filter to suppress unpickleable attribute warnings
|
|
45
|
+
logging.getLogger().addFilter(IgnoreUnpickleableAttributeFilter())
|
|
46
|
+
|
|
47
|
+
# Set critical level for OTLP trace exporter to reduce noise
|
|
48
|
+
logger = logging.getLogger("opentelemetry.exporter.otlp.proto.http.trace_exporter")
|
|
49
|
+
logger.setLevel(logging.CRITICAL)
|
|
50
|
+
|
|
51
|
+
# Load environment variables with override
|
|
52
|
+
load_dotenv(override=True)
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Prompt formatting and templating utilities.
|
|
3
|
+
|
|
4
|
+
This module handles prompt template processing, placeholder replacement,
|
|
5
|
+
and LLM-specific prompt formatting for different agent types.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from datetime import date
|
|
9
|
+
|
|
10
|
+
def format_prompt(
|
|
11
|
+
prompt_template: str,
|
|
12
|
+
general_instructions: str,
|
|
13
|
+
topic: str,
|
|
14
|
+
custom_instructions: str,
|
|
15
|
+
) -> str:
|
|
16
|
+
"""
|
|
17
|
+
Generate a prompt by replacing placeholders with topic and date.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
prompt_template: The template for the prompt
|
|
21
|
+
general_instructions: General instructions to be included in the prompt
|
|
22
|
+
topic: The topic to be included in the prompt
|
|
23
|
+
custom_instructions: The custom instructions to be included in the prompt
|
|
24
|
+
|
|
25
|
+
Returns:
|
|
26
|
+
str: The formatted prompt
|
|
27
|
+
"""
|
|
28
|
+
return (
|
|
29
|
+
prompt_template.replace("{chat_topic}", topic)
|
|
30
|
+
.replace("{today}", date.today().strftime("%A, %B %d, %Y"))
|
|
31
|
+
.replace("{custom_instructions}", custom_instructions)
|
|
32
|
+
.replace("{INSTRUCTIONS}", general_instructions)
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def format_llm_compiler_prompt(
|
|
37
|
+
prompt: str, general_instructions: str, topic: str, custom_instructions: str
|
|
38
|
+
) -> str:
|
|
39
|
+
"""
|
|
40
|
+
Add custom instructions to the prompt for LLM compiler agents.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
prompt: The base prompt to which custom instructions should be added
|
|
44
|
+
general_instructions: General instructions for the agent
|
|
45
|
+
topic: Topic expertise for the agent
|
|
46
|
+
custom_instructions: Custom user instructions
|
|
47
|
+
|
|
48
|
+
Returns:
|
|
49
|
+
str: The prompt with custom instructions added
|
|
50
|
+
"""
|
|
51
|
+
prompt += "\nAdditional Instructions:\n"
|
|
52
|
+
prompt += f"You have expertise in {topic}.\n"
|
|
53
|
+
prompt += general_instructions
|
|
54
|
+
prompt += custom_instructions
|
|
55
|
+
prompt += f"Today is {date.today().strftime('%A, %B %d, %Y')}"
|
|
56
|
+
return prompt
|
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Schema and type conversion utilities for agent functionality.
|
|
3
|
+
|
|
4
|
+
This module handles JSON schema to Python type conversion,
|
|
5
|
+
Pydantic model reconstruction, and type mapping operations.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from typing import Any, Union, List
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
# Type mapping constants
|
|
12
|
+
JSON_TYPE_TO_PYTHON = {
|
|
13
|
+
"string": str,
|
|
14
|
+
"integer": int,
|
|
15
|
+
"boolean": bool,
|
|
16
|
+
"array": list,
|
|
17
|
+
"object": dict,
|
|
18
|
+
"number": float,
|
|
19
|
+
"null": type(None),
|
|
20
|
+
}
|
|
21
|
+
|
|
22
|
+
PY_TYPES = {
|
|
23
|
+
"str": str,
|
|
24
|
+
"int": int,
|
|
25
|
+
"float": float,
|
|
26
|
+
"bool": bool,
|
|
27
|
+
"dict": dict,
|
|
28
|
+
"list": list,
|
|
29
|
+
}
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def get_field_type(field_schema: dict) -> Any:
|
|
33
|
+
"""
|
|
34
|
+
Convert a JSON schema field definition to a Python type.
|
|
35
|
+
Handles 'type' and 'anyOf' cases.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
field_schema: JSON schema field definition
|
|
39
|
+
|
|
40
|
+
Returns:
|
|
41
|
+
Any: Corresponding Python type
|
|
42
|
+
"""
|
|
43
|
+
if not field_schema: # Handles empty schema {}
|
|
44
|
+
return Any
|
|
45
|
+
|
|
46
|
+
if "anyOf" in field_schema:
|
|
47
|
+
types = []
|
|
48
|
+
for option_schema in field_schema["anyOf"]:
|
|
49
|
+
types.append(get_field_type(option_schema)) # Recursive call
|
|
50
|
+
if not types:
|
|
51
|
+
return Any
|
|
52
|
+
return Union[tuple(types)]
|
|
53
|
+
|
|
54
|
+
if "type" in field_schema and isinstance(field_schema["type"], list):
|
|
55
|
+
types = []
|
|
56
|
+
for type_name in field_schema["type"]:
|
|
57
|
+
if type_name == "array":
|
|
58
|
+
item_schema = field_schema.get("items", {})
|
|
59
|
+
types.append(List[get_field_type(item_schema)])
|
|
60
|
+
elif type_name in JSON_TYPE_TO_PYTHON:
|
|
61
|
+
types.append(JSON_TYPE_TO_PYTHON[type_name])
|
|
62
|
+
else:
|
|
63
|
+
types.append(Any) # Fallback for unknown types in the list
|
|
64
|
+
if not types:
|
|
65
|
+
return Any
|
|
66
|
+
return Union[tuple(types)] # type: ignore
|
|
67
|
+
|
|
68
|
+
if "type" in field_schema:
|
|
69
|
+
schema_type_name = field_schema["type"]
|
|
70
|
+
if schema_type_name == "array":
|
|
71
|
+
item_schema = field_schema.get(
|
|
72
|
+
"items", {}
|
|
73
|
+
) # Default to Any if "items" is missing
|
|
74
|
+
return List[get_field_type(item_schema)]
|
|
75
|
+
|
|
76
|
+
return JSON_TYPE_TO_PYTHON.get(schema_type_name, Any)
|
|
77
|
+
|
|
78
|
+
# If only "items" is present (implies array by some conventions, but less standard)
|
|
79
|
+
# Or if it's a schema with other keywords like 'properties' (implying object)
|
|
80
|
+
# For simplicity, if no "type" or "anyOf" at this point, default to Any or add more specific handling.
|
|
81
|
+
# If 'properties' in field_schema or 'additionalProperties' in field_schema, it's likely an object.
|
|
82
|
+
if "properties" in field_schema or "additionalProperties" in field_schema:
|
|
83
|
+
# This path might need to reconstruct a nested Pydantic model if you encounter such schemas.
|
|
84
|
+
# For now, treating as 'dict' or 'Any' might be a simpler placeholder.
|
|
85
|
+
return dict # Or Any, or more sophisticated object reconstruction.
|
|
86
|
+
|
|
87
|
+
return Any
|
|
@@ -0,0 +1,125 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Tool processing and validation utilities for agent functionality.
|
|
3
|
+
|
|
4
|
+
This module provides utilities for tool validation, processing, and
|
|
5
|
+
compatibility adjustments for different LLM providers.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import inspect
|
|
9
|
+
from typing import Any, List
|
|
10
|
+
from inspect import Signature, Parameter, ismethod
|
|
11
|
+
from collections import Counter
|
|
12
|
+
|
|
13
|
+
from pydantic import Field, create_model
|
|
14
|
+
from llama_index.core.tools import FunctionTool
|
|
15
|
+
from ...llm_utils import get_llm
|
|
16
|
+
from ...types import LLMRole
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def sanitize_tools_for_gemini(tools: List[FunctionTool]) -> List[FunctionTool]:
|
|
20
|
+
"""
|
|
21
|
+
Strip all default values from tools for Gemini LLM compatibility.
|
|
22
|
+
|
|
23
|
+
Gemini requires that tools only show required parameters without defaults.
|
|
24
|
+
This function modifies:
|
|
25
|
+
- tool.fn signature
|
|
26
|
+
- tool.async_fn signature
|
|
27
|
+
- tool.metadata.fn_schema
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
tools: List of FunctionTool objects to sanitize
|
|
31
|
+
|
|
32
|
+
Returns:
|
|
33
|
+
List[FunctionTool]: Sanitized tools with no default values
|
|
34
|
+
"""
|
|
35
|
+
for tool in tools:
|
|
36
|
+
# 1) Strip defaults off the actual callables
|
|
37
|
+
for func in (tool.fn, tool.async_fn):
|
|
38
|
+
if not func:
|
|
39
|
+
continue
|
|
40
|
+
orig_sig = inspect.signature(func)
|
|
41
|
+
new_params = [
|
|
42
|
+
p.replace(default=Parameter.empty) for p in orig_sig.parameters.values()
|
|
43
|
+
]
|
|
44
|
+
new_sig = Signature(
|
|
45
|
+
new_params, return_annotation=orig_sig.return_annotation
|
|
46
|
+
)
|
|
47
|
+
if ismethod(func):
|
|
48
|
+
func.__func__.__signature__ = new_sig
|
|
49
|
+
else:
|
|
50
|
+
func.__signature__ = new_sig
|
|
51
|
+
|
|
52
|
+
# 2) Rebuild the Pydantic schema so that *every* field is required
|
|
53
|
+
schema_cls = getattr(tool.metadata, "fn_schema", None)
|
|
54
|
+
if schema_cls and hasattr(schema_cls, "model_fields"):
|
|
55
|
+
# Collect (name → (type, Field(...))) for all fields
|
|
56
|
+
new_fields: dict[str, tuple[type, Any]] = {}
|
|
57
|
+
for name, mf in schema_cls.model_fields.items():
|
|
58
|
+
typ = mf.annotation
|
|
59
|
+
desc = getattr(mf, "description", "")
|
|
60
|
+
# Force required (no default) with Field(...)
|
|
61
|
+
new_fields[name] = (typ, Field(..., description=desc))
|
|
62
|
+
|
|
63
|
+
# Make a brand-new schema class where every field is required
|
|
64
|
+
no_default_schema = create_model(
|
|
65
|
+
f"{schema_cls.__name__}", # new class name
|
|
66
|
+
**new_fields, # type: ignore
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
# Give it a clean __signature__ so inspect.signature sees no defaults
|
|
70
|
+
params = [
|
|
71
|
+
Parameter(n, Parameter.POSITIONAL_OR_KEYWORD, annotation=typ)
|
|
72
|
+
for n, (typ, _) in new_fields.items()
|
|
73
|
+
]
|
|
74
|
+
no_default_schema.__signature__ = Signature(params)
|
|
75
|
+
|
|
76
|
+
# Swap it back onto the tool
|
|
77
|
+
tool.metadata.fn_schema = no_default_schema
|
|
78
|
+
|
|
79
|
+
return tools
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def validate_tool_consistency(
|
|
83
|
+
tools: List[FunctionTool], custom_instructions: str, agent_config
|
|
84
|
+
) -> None:
|
|
85
|
+
"""
|
|
86
|
+
Validate that tools mentioned in instructions actually exist.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
tools: List of available tools
|
|
90
|
+
custom_instructions: Custom instructions that may reference tools
|
|
91
|
+
agent_config: Agent configuration for LLM access
|
|
92
|
+
|
|
93
|
+
Raises:
|
|
94
|
+
ValueError: If invalid tools are referenced in instructions
|
|
95
|
+
"""
|
|
96
|
+
tool_names = [tool.metadata.name for tool in tools]
|
|
97
|
+
|
|
98
|
+
# Check for duplicate tools
|
|
99
|
+
duplicates = [tool for tool, count in Counter(tool_names).items() if count > 1]
|
|
100
|
+
if duplicates:
|
|
101
|
+
raise ValueError(f"Duplicate tools detected: {', '.join(duplicates)}")
|
|
102
|
+
|
|
103
|
+
# Validate tools mentioned in instructions exist
|
|
104
|
+
if custom_instructions:
|
|
105
|
+
prompt = f"""
|
|
106
|
+
You are provided these tools:
|
|
107
|
+
<tools>{','.join(tool_names)}</tools>
|
|
108
|
+
And these instructions:
|
|
109
|
+
<instructions>
|
|
110
|
+
{custom_instructions}
|
|
111
|
+
</instructions>
|
|
112
|
+
Your task is to identify invalid tools.
|
|
113
|
+
A tool is invalid if it is mentioned in the instructions but not in the tools list.
|
|
114
|
+
A tool's name must have at least two characters.
|
|
115
|
+
Your response should be a comma-separated list of the invalid tools.
|
|
116
|
+
If no invalid tools exist, respond with "<OKAY>" (and nothing else).
|
|
117
|
+
"""
|
|
118
|
+
llm = get_llm(LLMRole.MAIN, config=agent_config)
|
|
119
|
+
bad_tools_str = llm.complete(prompt).text.strip("\n")
|
|
120
|
+
if bad_tools_str and bad_tools_str != "<OKAY>":
|
|
121
|
+
bad_tools = [tool.strip() for tool in bad_tools_str.split(",")]
|
|
122
|
+
numbered = ", ".join(f"({i}) {tool}" for i, tool in enumerate(bad_tools, 1))
|
|
123
|
+
raise ValueError(
|
|
124
|
+
f"The Agent custom instructions mention these invalid tools: {numbered}"
|
|
125
|
+
)
|
|
@@ -16,12 +16,6 @@ from .agent import Agent
|
|
|
16
16
|
from .agent_config import AgentConfig
|
|
17
17
|
|
|
18
18
|
|
|
19
|
-
class ChatRequest(BaseModel):
|
|
20
|
-
"""Request schema for the /chat endpoint."""
|
|
21
|
-
|
|
22
|
-
message: str
|
|
23
|
-
|
|
24
|
-
|
|
25
19
|
class CompletionRequest(BaseModel):
|
|
26
20
|
"""Request schema for the /v1/completions endpoint."""
|
|
27
21
|
|
|
@@ -64,12 +58,14 @@ class CompletionResponse(BaseModel):
|
|
|
64
58
|
|
|
65
59
|
class ChatMessage(BaseModel):
|
|
66
60
|
"""Schema for individual chat messages in ChatCompletionRequest."""
|
|
61
|
+
|
|
67
62
|
role: Literal["system", "user", "assistant"]
|
|
68
63
|
content: str
|
|
69
64
|
|
|
70
65
|
|
|
71
66
|
class ChatCompletionRequest(BaseModel):
|
|
72
67
|
"""Request schema for the /v1/chat endpoint."""
|
|
68
|
+
|
|
73
69
|
model: str
|
|
74
70
|
messages: List[ChatMessage]
|
|
75
71
|
temperature: Optional[float] = Field(1.0, ge=0.0, le=2.0)
|
|
@@ -79,6 +75,7 @@ class ChatCompletionRequest(BaseModel):
|
|
|
79
75
|
|
|
80
76
|
class ChatCompletionChoice(BaseModel):
|
|
81
77
|
"""Choice schema returned in ChatCompletionResponse."""
|
|
78
|
+
|
|
82
79
|
index: int
|
|
83
80
|
message: ChatMessage
|
|
84
81
|
finish_reason: Literal["stop", "length", "error", None]
|
|
@@ -86,6 +83,7 @@ class ChatCompletionChoice(BaseModel):
|
|
|
86
83
|
|
|
87
84
|
class ChatCompletionResponse(BaseModel):
|
|
88
85
|
"""Response schema for the /v1/chat endpoint."""
|
|
86
|
+
|
|
89
87
|
id: str
|
|
90
88
|
object: Literal["chat.completion"]
|
|
91
89
|
created: int
|