vectara-agentic 0.2.13__py3-none-any.whl → 0.2.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of vectara-agentic might be problematic. Click here for more details.
- tests/test_groq.py +120 -0
- tests/test_return_direct.py +49 -0
- tests/test_tools.py +42 -6
- tests/test_vectara_llms.py +4 -12
- vectara_agentic/_observability.py +43 -21
- vectara_agentic/_prompts.py +1 -1
- vectara_agentic/_version.py +1 -1
- vectara_agentic/agent.py +83 -10
- vectara_agentic/llm_utils.py +174 -0
- vectara_agentic/tool_utils.py +536 -0
- vectara_agentic/tools.py +37 -475
- vectara_agentic/tools_catalog.py +3 -2
- vectara_agentic/utils.py +0 -153
- {vectara_agentic-0.2.13.dist-info → vectara_agentic-0.2.15.dist-info}/METADATA +25 -11
- vectara_agentic-0.2.15.dist-info/RECORD +34 -0
- {vectara_agentic-0.2.13.dist-info → vectara_agentic-0.2.15.dist-info}/WHEEL +1 -1
- vectara_agentic-0.2.13.dist-info/RECORD +0 -30
- {vectara_agentic-0.2.13.dist-info → vectara_agentic-0.2.15.dist-info}/licenses/LICENSE +0 -0
- {vectara_agentic-0.2.13.dist-info → vectara_agentic-0.2.15.dist-info}/top_level.txt +0 -0
tests/test_groq.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
import unittest
|
|
2
|
+
|
|
3
|
+
from pydantic import Field, BaseModel
|
|
4
|
+
|
|
5
|
+
from vectara_agentic.agent import Agent, AgentType
|
|
6
|
+
from vectara_agentic.agent_config import AgentConfig
|
|
7
|
+
from vectara_agentic.tools import VectaraToolFactory
|
|
8
|
+
from vectara_agentic.types import ModelProvider
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
import nest_asyncio
|
|
12
|
+
nest_asyncio.apply()
|
|
13
|
+
|
|
14
|
+
tickers = {
|
|
15
|
+
"C": "Citigroup",
|
|
16
|
+
"COF": "Capital One",
|
|
17
|
+
"JPM": "JPMorgan Chase",
|
|
18
|
+
"AAPL": "Apple Computer",
|
|
19
|
+
"GOOG": "Google",
|
|
20
|
+
"AMZN": "Amazon",
|
|
21
|
+
"SNOW": "Snowflake",
|
|
22
|
+
"TEAM": "Atlassian",
|
|
23
|
+
"TSLA": "Tesla",
|
|
24
|
+
"NVDA": "Nvidia",
|
|
25
|
+
"MSFT": "Microsoft",
|
|
26
|
+
"AMD": "Advanced Micro Devices",
|
|
27
|
+
"INTC": "Intel",
|
|
28
|
+
"NFLX": "Netflix",
|
|
29
|
+
"STT": "State Street",
|
|
30
|
+
"BK": "Bank of New York Mellon",
|
|
31
|
+
}
|
|
32
|
+
years = list(range(2015, 2025))
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def mult(x: float, y: float) -> float:
|
|
36
|
+
"Multiply two numbers"
|
|
37
|
+
return x * y
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def get_company_info() -> list[str]:
|
|
41
|
+
"""
|
|
42
|
+
Returns a dictionary of companies you can query about. Always check this before using any other tool.
|
|
43
|
+
The output is a dictionary of valid ticker symbols mapped to company names.
|
|
44
|
+
You can use this to identify the companies you can query about, and their ticker information.
|
|
45
|
+
"""
|
|
46
|
+
return tickers
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def get_valid_years() -> list[str]:
|
|
50
|
+
"""
|
|
51
|
+
Returns a list of the years for which financial reports are available.
|
|
52
|
+
Always check this before using any other tool.
|
|
53
|
+
"""
|
|
54
|
+
return years
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
config_gemini = AgentConfig(
|
|
58
|
+
agent_type=AgentType.FUNCTION_CALLING,
|
|
59
|
+
main_llm_provider=ModelProvider.GEMINI,
|
|
60
|
+
tool_llm_provider=ModelProvider.GEMINI,
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
fc_config_groq = AgentConfig(
|
|
65
|
+
agent_type=AgentType.FUNCTION_CALLING,
|
|
66
|
+
main_llm_provider=ModelProvider.GROQ,
|
|
67
|
+
tool_llm_provider=ModelProvider.GROQ,
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class TestGROQ(unittest.TestCase):
|
|
72
|
+
|
|
73
|
+
def test_tool_with_many_arguments(self):
|
|
74
|
+
|
|
75
|
+
vectara_corpus_key = "vectara-docs_1"
|
|
76
|
+
vectara_api_key = "zqt_UXrBcnI2UXINZkrv4g1tQPhzj02vfdtqYJIDiA"
|
|
77
|
+
vec_factory = VectaraToolFactory(vectara_corpus_key, vectara_api_key)
|
|
78
|
+
|
|
79
|
+
class QueryToolArgs(BaseModel):
|
|
80
|
+
arg1: str = Field(description="the first argument", examples=["val1"])
|
|
81
|
+
arg2: str = Field(description="the second argument", examples=["val2"])
|
|
82
|
+
arg3: str = Field(description="the third argument", examples=["val3"])
|
|
83
|
+
arg4: str = Field(description="the fourth argument", examples=["val4"])
|
|
84
|
+
arg5: str = Field(description="the fifth argument", examples=["val5"])
|
|
85
|
+
arg6: str = Field(description="the sixth argument", examples=["val6"])
|
|
86
|
+
arg7: str = Field(description="the seventh argument", examples=["val7"])
|
|
87
|
+
arg8: str = Field(description="the eighth argument", examples=["val8"])
|
|
88
|
+
arg9: str = Field(description="the ninth argument", examples=["val9"])
|
|
89
|
+
arg10: str = Field(description="the tenth argument", examples=["val10"])
|
|
90
|
+
arg11: str = Field(description="the eleventh argument", examples=["val11"])
|
|
91
|
+
arg12: str = Field(description="the twelfth argument", examples=["val12"])
|
|
92
|
+
arg13: str = Field(
|
|
93
|
+
description="the thirteenth argument", examples=["val13"]
|
|
94
|
+
)
|
|
95
|
+
arg14: str = Field(
|
|
96
|
+
description="the fourteenth argument", examples=["val14"]
|
|
97
|
+
)
|
|
98
|
+
arg15: str = Field(description="the fifteenth argument", examples=["val15"])
|
|
99
|
+
|
|
100
|
+
query_tool_1 = vec_factory.create_rag_tool(
|
|
101
|
+
tool_name="rag_tool",
|
|
102
|
+
tool_description="""
|
|
103
|
+
A dummy tool that takes 15 arguments and returns a response (str) to the user query based on the data in this corpus.
|
|
104
|
+
We are using this tool to test the tool factory works and does not crash with OpenAI.
|
|
105
|
+
""",
|
|
106
|
+
tool_args_schema=QueryToolArgs,
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
agent = Agent(
|
|
110
|
+
tools=[query_tool_1],
|
|
111
|
+
topic="Sample topic",
|
|
112
|
+
custom_instructions="Call the tool with 15 arguments",
|
|
113
|
+
agent_config=fc_config_groq,
|
|
114
|
+
)
|
|
115
|
+
res = agent.chat("What is the stock price?")
|
|
116
|
+
self.assertIn("I don't know", str(res))
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
if __name__ == "__main__":
|
|
120
|
+
unittest.main()
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
import unittest
|
|
2
|
+
|
|
3
|
+
from vectara_agentic.agent import Agent
|
|
4
|
+
from vectara_agentic.tools import VectaraToolFactory
|
|
5
|
+
|
|
6
|
+
vectara_corpus_key = "vectara-docs_1"
|
|
7
|
+
vectara_api_key = "zqt_UXrBcnI2UXINZkrv4g1tQPhzj02vfdtqYJIDiA"
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class TestAgentPackage(unittest.TestCase):
|
|
11
|
+
|
|
12
|
+
def test_return_direct1(self):
|
|
13
|
+
vec_factory = VectaraToolFactory(vectara_corpus_key, vectara_api_key)
|
|
14
|
+
|
|
15
|
+
query_tool = vec_factory.create_rag_tool(
|
|
16
|
+
tool_name="rag_tool",
|
|
17
|
+
tool_description="""
|
|
18
|
+
A dummy tool for testing return_direct.
|
|
19
|
+
""",
|
|
20
|
+
return_direct=True,
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
agent = Agent(
|
|
24
|
+
tools=[query_tool],
|
|
25
|
+
topic="Sample topic",
|
|
26
|
+
custom_instructions="You are a helpful assistant.",
|
|
27
|
+
)
|
|
28
|
+
res = agent.chat("What is Vectara?")
|
|
29
|
+
self.assertIn("Response:", str(res))
|
|
30
|
+
self.assertIn("fcs_score", str(res))
|
|
31
|
+
self.assertIn("References:", str(res))
|
|
32
|
+
|
|
33
|
+
def test_from_corpus(self):
|
|
34
|
+
agent = Agent.from_corpus(
|
|
35
|
+
tool_name="rag_tool",
|
|
36
|
+
vectara_corpus_key=vectara_corpus_key,
|
|
37
|
+
vectara_api_key=vectara_api_key,
|
|
38
|
+
data_description="stuff about Vectara",
|
|
39
|
+
assistant_specialty="question answering",
|
|
40
|
+
return_direct=True,
|
|
41
|
+
)
|
|
42
|
+
res = agent.chat("What is Vectara?")
|
|
43
|
+
self.assertIn("Response:", str(res))
|
|
44
|
+
self.assertIn("fcs_score", str(res))
|
|
45
|
+
self.assertIn("References:", str(res))
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
if __name__ == "__main__":
|
|
49
|
+
unittest.main()
|
tests/test_tools.py
CHANGED
|
@@ -9,6 +9,7 @@ from vectara_agentic.tools import (
|
|
|
9
9
|
)
|
|
10
10
|
from vectara_agentic.agent import Agent
|
|
11
11
|
from vectara_agentic.agent_config import AgentConfig
|
|
12
|
+
from vectara_agentic.types import AgentType, ModelProvider
|
|
12
13
|
|
|
13
14
|
from llama_index.core.tools import FunctionTool
|
|
14
15
|
|
|
@@ -179,22 +180,57 @@ class TestToolsPackage(unittest.TestCase):
|
|
|
179
180
|
query_tool_1 = vec_factory.create_rag_tool(
|
|
180
181
|
tool_name="rag_tool",
|
|
181
182
|
tool_description="""
|
|
182
|
-
A dummy tool that takes
|
|
183
|
+
A dummy tool that takes 15 arguments and returns a response (str) to the user query based on the data in this corpus.
|
|
183
184
|
We are using this tool to test the tool factory works and does not crash with OpenAI.
|
|
184
185
|
""",
|
|
185
186
|
tool_args_schema=QueryToolArgs,
|
|
186
187
|
)
|
|
187
188
|
|
|
188
|
-
|
|
189
|
+
# Test with 15 arguments which go over the 1024 limit.
|
|
190
|
+
config = AgentConfig(
|
|
191
|
+
agent_type=AgentType.OPENAI
|
|
192
|
+
)
|
|
189
193
|
agent = Agent(
|
|
190
194
|
tools=[query_tool_1],
|
|
191
195
|
topic="Sample topic",
|
|
192
|
-
custom_instructions="Call the tool with
|
|
196
|
+
custom_instructions="Call the tool with 15 arguments for OPENAI",
|
|
193
197
|
agent_config=config,
|
|
194
198
|
)
|
|
195
|
-
res = agent.chat("What is the stock price?")
|
|
199
|
+
res = agent.chat("What is the stock price for Yahoo on 12/31/22?")
|
|
196
200
|
self.assertIn("maximum length of 1024 characters", str(res))
|
|
197
201
|
|
|
202
|
+
# Same test but with GROQ
|
|
203
|
+
config = AgentConfig(
|
|
204
|
+
agent_type=AgentType.FUNCTION_CALLING,
|
|
205
|
+
main_llm_provider=ModelProvider.GROQ,
|
|
206
|
+
tool_llm_provider=ModelProvider.GROQ,
|
|
207
|
+
)
|
|
208
|
+
agent = Agent(
|
|
209
|
+
tools=[query_tool_1],
|
|
210
|
+
topic="Sample topic",
|
|
211
|
+
custom_instructions="Call the tool with 15 arguments for GROQ",
|
|
212
|
+
agent_config=config,
|
|
213
|
+
)
|
|
214
|
+
res = agent.chat("What is the stock price?")
|
|
215
|
+
self.assertNotIn("maximum length of 1024 characters", str(res))
|
|
216
|
+
|
|
217
|
+
# Same test but with ANTHROPIC
|
|
218
|
+
config = AgentConfig(
|
|
219
|
+
agent_type=AgentType.FUNCTION_CALLING,
|
|
220
|
+
main_llm_provider=ModelProvider.ANTHROPIC,
|
|
221
|
+
tool_llm_provider=ModelProvider.ANTHROPIC,
|
|
222
|
+
)
|
|
223
|
+
agent = Agent(
|
|
224
|
+
tools=[query_tool_1],
|
|
225
|
+
topic="Sample topic",
|
|
226
|
+
custom_instructions="Call the tool with 15 arguments for ANTHROPIC",
|
|
227
|
+
agent_config=config,
|
|
228
|
+
)
|
|
229
|
+
res = agent.chat("What is the stock price?")
|
|
230
|
+
# ANTHROPIC does not have that 1024 limit
|
|
231
|
+
self.assertIn("stock price", str(res))
|
|
232
|
+
|
|
233
|
+
# But using Compact_docstring=True, we can pass 15 arguments successfully.
|
|
198
234
|
vec_factory = VectaraToolFactory(
|
|
199
235
|
vectara_corpus_key, vectara_api_key, compact_docstring=True
|
|
200
236
|
)
|
|
@@ -211,7 +247,7 @@ class TestToolsPackage(unittest.TestCase):
|
|
|
211
247
|
agent = Agent(
|
|
212
248
|
tools=[query_tool_2],
|
|
213
249
|
topic="Sample topic",
|
|
214
|
-
custom_instructions="Call the tool with
|
|
250
|
+
custom_instructions="Call the tool with 15 arguments",
|
|
215
251
|
agent_config=config,
|
|
216
252
|
)
|
|
217
253
|
res = agent.chat("What is the stock price?")
|
|
@@ -227,7 +263,7 @@ class TestToolsPackage(unittest.TestCase):
|
|
|
227
263
|
tool_name="ask_vectara",
|
|
228
264
|
data_description="data from Vectara website",
|
|
229
265
|
assistant_specialty="RAG as a service",
|
|
230
|
-
vectara_summarizer="mockingbird-
|
|
266
|
+
vectara_summarizer="mockingbird-2.0",
|
|
231
267
|
)
|
|
232
268
|
|
|
233
269
|
self.assertIn(
|
tests/test_vectara_llms.py
CHANGED
|
@@ -15,7 +15,10 @@ vectara_api_key = "zqt_UXrBcnI2UXINZkrv4g1tQPhzj02vfdtqYJIDiA"
|
|
|
15
15
|
class TestLLMPackage(unittest.TestCase):
|
|
16
16
|
|
|
17
17
|
def test_vectara_openai(self):
|
|
18
|
-
vec_factory = VectaraToolFactory(
|
|
18
|
+
vec_factory = VectaraToolFactory(
|
|
19
|
+
vectara_corpus_key=vectara_corpus_key,
|
|
20
|
+
vectara_api_key=vectara_api_key
|
|
21
|
+
)
|
|
19
22
|
|
|
20
23
|
self.assertEqual(vectara_corpus_key, vec_factory.vectara_corpus_key)
|
|
21
24
|
self.assertEqual(vectara_api_key, vec_factory.vectara_api_key)
|
|
@@ -51,17 +54,6 @@ class TestLLMPackage(unittest.TestCase):
|
|
|
51
54
|
|
|
52
55
|
def test_vectara_mockingbird(self):
|
|
53
56
|
vec_factory = VectaraToolFactory(vectara_corpus_key, vectara_api_key)
|
|
54
|
-
|
|
55
|
-
query_tool = vec_factory.create_rag_tool(
|
|
56
|
-
tool_name="rag_tool",
|
|
57
|
-
tool_description="""
|
|
58
|
-
Returns a response (str) to the user query based on the data in this corpus.
|
|
59
|
-
""",
|
|
60
|
-
vectara_summarizer="mockingbird-1.0-2024-07-16",
|
|
61
|
-
)
|
|
62
|
-
res = query_tool(query="What is Vectara?")
|
|
63
|
-
self.assertIn("Vectara is an end-to-end platform", str(res))
|
|
64
|
-
|
|
65
57
|
query_tool = vec_factory.create_rag_tool(
|
|
66
58
|
tool_name="rag_tool",
|
|
67
59
|
tool_description="""
|
|
@@ -12,28 +12,50 @@ def setup_observer(config: AgentConfig, verbose: bool) -> bool:
|
|
|
12
12
|
'''
|
|
13
13
|
Setup the observer.
|
|
14
14
|
'''
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
15
|
+
if config.observer != ObserverType.ARIZE_PHOENIX:
|
|
16
|
+
if verbose:
|
|
17
|
+
print("No Phoenix observer set.")
|
|
18
|
+
return False
|
|
19
|
+
|
|
20
|
+
try:
|
|
21
|
+
import phoenix as px
|
|
22
|
+
from openinference.instrumentation.llama_index import LlamaIndexInstrumentor
|
|
23
|
+
from phoenix.otel import register
|
|
24
|
+
except ImportError:
|
|
25
|
+
print(
|
|
26
|
+
(
|
|
27
|
+
"Phoenix libraries not found. Please install with"
|
|
28
|
+
"'pip install arize-phoenix openinference-instrumentation-llama-index'"
|
|
29
|
+
)
|
|
30
|
+
)
|
|
31
|
+
return False
|
|
32
|
+
|
|
33
|
+
phoenix_endpoint = os.getenv("PHOENIX_ENDPOINT", None)
|
|
34
|
+
if not phoenix_endpoint:
|
|
35
|
+
print("Phoenix endpoint not set. Attempting to launch local Phoenix UI...")
|
|
36
|
+
px.launch_app()
|
|
37
|
+
print("Local Phoenix UI launched. You can view traces at the UI address (usually http://localhost:6006).")
|
|
38
|
+
|
|
39
|
+
if phoenix_endpoint and 'app.phoenix.arize.com' in phoenix_endpoint:
|
|
40
|
+
phoenix_api_key = os.getenv("PHOENIX_API_KEY")
|
|
41
|
+
if not phoenix_api_key:
|
|
42
|
+
raise ValueError(
|
|
43
|
+
"Arize Phoenix API key not set. Please set PHOENIX_API_KEY."
|
|
44
|
+
)
|
|
45
|
+
os.environ["PHOENIX_CLIENT_HEADERS"] = f"api_key={phoenix_api_key}"
|
|
46
|
+
os.environ["PHOENIX_COLLECTOR_ENDPOINT"] = "https://app.phoenix.arize.com"
|
|
47
|
+
|
|
48
|
+
reg_kwargs = {
|
|
49
|
+
"endpoint": phoenix_endpoint or 'http://localhost:6006/v1/traces',
|
|
50
|
+
"project_name": "vectara-agentic",
|
|
51
|
+
"batch": True,
|
|
52
|
+
"set_global_tracer_provider": False,
|
|
53
|
+
}
|
|
54
|
+
tracer_provider = register(**reg_kwargs)
|
|
55
|
+
LlamaIndexInstrumentor().instrument(tracer_provider=tracer_provider)
|
|
34
56
|
if verbose:
|
|
35
|
-
print("
|
|
36
|
-
return
|
|
57
|
+
print(f"Phoenix observer configured for project 'vectara-agentic' at endpoint: {reg_kwargs['endpoint']}")
|
|
58
|
+
return True
|
|
37
59
|
|
|
38
60
|
|
|
39
61
|
def _extract_fcs_value(output: Union[str, dict]) -> Optional[float]:
|
vectara_agentic/_prompts.py
CHANGED
|
@@ -5,7 +5,7 @@ This file contains the prompt templates for the different types of agents.
|
|
|
5
5
|
# General (shared) instructions
|
|
6
6
|
GENERAL_INSTRUCTIONS = """
|
|
7
7
|
- Use tools as your main source of information, do not respond without using a tool. Do not respond based on pre-trained knowledge.
|
|
8
|
-
- Use the 'get_bad_topics' tool to determine the topics you are not allowed to discuss or respond to.
|
|
8
|
+
- Use the 'get_bad_topics' (if it exists) tool to determine the topics you are not allowed to discuss or respond to.
|
|
9
9
|
- Before responding to a user query that requires knowledge of the current date, call the 'get_current_date' tool to get the current date.
|
|
10
10
|
Never rely on previous knowledge of the current date.
|
|
11
11
|
Example queries that require the current date: "What is the revenue of Apple last october?" or "What was the stock price 5 days ago?".
|
vectara_agentic/_version.py
CHANGED
vectara_agentic/agent.py
CHANGED
|
@@ -12,6 +12,8 @@ import logging
|
|
|
12
12
|
import asyncio
|
|
13
13
|
import importlib
|
|
14
14
|
from collections import Counter
|
|
15
|
+
import inspect
|
|
16
|
+
from inspect import Signature, Parameter, ismethod
|
|
15
17
|
|
|
16
18
|
import cloudpickle as pickle
|
|
17
19
|
|
|
@@ -19,6 +21,7 @@ from dotenv import load_dotenv
|
|
|
19
21
|
|
|
20
22
|
from pydantic import Field, create_model, ValidationError
|
|
21
23
|
|
|
24
|
+
|
|
22
25
|
from llama_index.core.memory import ChatMemoryBuffer
|
|
23
26
|
from llama_index.core.llms import ChatMessage, MessageRole
|
|
24
27
|
from llama_index.core.tools import FunctionTool
|
|
@@ -47,7 +50,7 @@ from .types import (
|
|
|
47
50
|
AgentStreamingResponse,
|
|
48
51
|
AgentConfigType,
|
|
49
52
|
)
|
|
50
|
-
from .
|
|
53
|
+
from .llm_utils import get_llm, get_tokenizer_for_model
|
|
51
54
|
from ._prompts import (
|
|
52
55
|
REACT_PROMPT_TEMPLATE,
|
|
53
56
|
GENERAL_PROMPT_TEMPLATE,
|
|
@@ -230,6 +233,10 @@ class Agent:
|
|
|
230
233
|
self.workflow_cls = workflow_cls
|
|
231
234
|
self.workflow_timeout = workflow_timeout
|
|
232
235
|
|
|
236
|
+
# Sanitize tools for Gemini if needed
|
|
237
|
+
if self.agent_config.main_llm_provider == ModelProvider.GEMINI:
|
|
238
|
+
self.tools = self._sanitize_tools_for_gemini(self.tools)
|
|
239
|
+
|
|
233
240
|
# Validate tools
|
|
234
241
|
# Check for:
|
|
235
242
|
# 1. multiple copies of the same tool
|
|
@@ -241,19 +248,25 @@ class Agent:
|
|
|
241
248
|
|
|
242
249
|
if validate_tools:
|
|
243
250
|
prompt = f"""
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
251
|
+
You are provided these tools:
|
|
252
|
+
<tools>{','.join(tool_names)}</tools>
|
|
253
|
+
And these instructions:
|
|
254
|
+
<instructions>
|
|
247
255
|
{self._custom_instructions}
|
|
248
|
-
|
|
249
|
-
Your
|
|
250
|
-
|
|
256
|
+
</instructions>
|
|
257
|
+
Your task is to identify invalid tools.
|
|
258
|
+
A tool is invalid if it is mentioned in the instructions but not in the tools list.
|
|
259
|
+
A tool's name must have at least two characters.
|
|
260
|
+
Your response should be a comma-separated list of the invalid tools.
|
|
261
|
+
If not invalid tools exist, respond with "<OKAY>".
|
|
251
262
|
"""
|
|
252
263
|
llm = get_llm(LLMRole.MAIN, config=self.agent_config)
|
|
253
|
-
|
|
254
|
-
if
|
|
264
|
+
bad_tools_str = llm.complete(prompt).text
|
|
265
|
+
if bad_tools_str and bad_tools_str != "<OKAY>":
|
|
266
|
+
bad_tools = [tool.strip() for tool in bad_tools_str.split(",")]
|
|
267
|
+
numbered = ", ".join(f"({i}) {tool}" for i, tool in enumerate(bad_tools, 1))
|
|
255
268
|
raise ValueError(
|
|
256
|
-
f"The Agent custom instructions mention these invalid tools: {
|
|
269
|
+
f"The Agent custom instructions mention these invalid tools: {numbered}"
|
|
257
270
|
)
|
|
258
271
|
|
|
259
272
|
# Create token counters for the main and tool LLMs
|
|
@@ -311,6 +324,63 @@ class Agent:
|
|
|
311
324
|
print(f"Failed to set up observer ({e}), ignoring")
|
|
312
325
|
self.observability_enabled = False
|
|
313
326
|
|
|
327
|
+
def _sanitize_tools_for_gemini(
|
|
328
|
+
self, tools: list[FunctionTool]
|
|
329
|
+
) -> list[FunctionTool]:
|
|
330
|
+
"""
|
|
331
|
+
Strip all default values from:
|
|
332
|
+
- tool.fn
|
|
333
|
+
- tool.async_fn
|
|
334
|
+
- tool.metadata.fn_schema
|
|
335
|
+
so Gemini sees *only* required parameters, no defaults.
|
|
336
|
+
"""
|
|
337
|
+
for tool in tools:
|
|
338
|
+
# 1) strip defaults off the actual callables
|
|
339
|
+
for func in (tool.fn, tool.async_fn):
|
|
340
|
+
if not func:
|
|
341
|
+
continue
|
|
342
|
+
orig_sig = inspect.signature(func)
|
|
343
|
+
new_params = [
|
|
344
|
+
p.replace(default=Parameter.empty)
|
|
345
|
+
for p in orig_sig.parameters.values()
|
|
346
|
+
]
|
|
347
|
+
new_sig = Signature(
|
|
348
|
+
new_params, return_annotation=orig_sig.return_annotation
|
|
349
|
+
)
|
|
350
|
+
if ismethod(func):
|
|
351
|
+
func.__func__.__signature__ = new_sig
|
|
352
|
+
else:
|
|
353
|
+
func.__signature__ = new_sig
|
|
354
|
+
|
|
355
|
+
# 2) rebuild the Pydantic schema so that *every* field is required
|
|
356
|
+
schema_cls = getattr(tool.metadata, "fn_schema", None)
|
|
357
|
+
if schema_cls and hasattr(schema_cls, "model_fields"):
|
|
358
|
+
# collect (name → (type, Field(...))) for all fields
|
|
359
|
+
new_fields: dict[str, tuple[type, Any]] = {}
|
|
360
|
+
for name, mf in schema_cls.model_fields.items():
|
|
361
|
+
typ = mf.annotation
|
|
362
|
+
desc = getattr(mf, "description", "")
|
|
363
|
+
# force required (no default) with Field(...)
|
|
364
|
+
new_fields[name] = (typ, Field(..., description=desc))
|
|
365
|
+
|
|
366
|
+
# make a brand-new schema class where every field is required
|
|
367
|
+
no_default_schema = create_model(
|
|
368
|
+
f"{schema_cls.__name__}", # new class name
|
|
369
|
+
**new_fields, # type: ignore
|
|
370
|
+
)
|
|
371
|
+
|
|
372
|
+
# give it a clean __signature__ so inspect.signature sees no defaults
|
|
373
|
+
params = [
|
|
374
|
+
Parameter(n, Parameter.POSITIONAL_OR_KEYWORD, annotation=typ)
|
|
375
|
+
for n, (typ, _) in new_fields.items()
|
|
376
|
+
]
|
|
377
|
+
no_default_schema.__signature__ = Signature(params)
|
|
378
|
+
|
|
379
|
+
# swap it back onto the tool
|
|
380
|
+
tool.metadata.fn_schema = no_default_schema
|
|
381
|
+
|
|
382
|
+
return tools
|
|
383
|
+
|
|
314
384
|
def _create_agent(
|
|
315
385
|
self, config: AgentConfig, llm_callback_manager: CallbackManager
|
|
316
386
|
) -> Union[BaseAgent, AgentRunner]:
|
|
@@ -625,6 +695,7 @@ class Agent:
|
|
|
625
695
|
vectara_frequency_penalty: Optional[float] = None,
|
|
626
696
|
vectara_presence_penalty: Optional[float] = None,
|
|
627
697
|
vectara_save_history: bool = True,
|
|
698
|
+
return_direct: bool = False,
|
|
628
699
|
) -> "Agent":
|
|
629
700
|
"""
|
|
630
701
|
Create an agent from a single Vectara corpus
|
|
@@ -674,6 +745,7 @@ class Agent:
|
|
|
674
745
|
vectara_presence_penalty (float, optional): How much to penalize repeating tokens in the response,
|
|
675
746
|
higher values increasing the diversity of topics.
|
|
676
747
|
vectara_save_history (bool, optional): Whether to save the query in history.
|
|
748
|
+
return_direct (bool, optional): Whether the agent should return the tool's response directly.
|
|
677
749
|
|
|
678
750
|
Returns:
|
|
679
751
|
Agent: An instance of the Agent class.
|
|
@@ -727,6 +799,7 @@ class Agent:
|
|
|
727
799
|
save_history=vectara_save_history,
|
|
728
800
|
include_citations=True,
|
|
729
801
|
verbose=verbose,
|
|
802
|
+
return_direct=return_direct,
|
|
730
803
|
)
|
|
731
804
|
|
|
732
805
|
assistant_instructions = f"""
|