vectara-agentic 0.3.3__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of vectara-agentic might be problematic. Click here for more details.
- tests/__init__.py +7 -0
- tests/conftest.py +316 -0
- tests/endpoint.py +54 -17
- tests/run_tests.py +112 -0
- tests/test_agent.py +35 -33
- tests/test_agent_fallback_memory.py +270 -0
- tests/test_agent_memory_consistency.py +229 -0
- tests/test_agent_type.py +86 -143
- tests/test_api_endpoint.py +4 -0
- tests/test_bedrock.py +50 -31
- tests/test_fallback.py +4 -0
- tests/test_gemini.py +27 -59
- tests/test_groq.py +50 -31
- tests/test_private_llm.py +11 -2
- tests/test_return_direct.py +6 -2
- tests/test_serialization.py +7 -6
- tests/test_session_memory.py +252 -0
- tests/test_streaming.py +109 -0
- tests/test_together.py +62 -0
- tests/test_tools.py +10 -82
- tests/test_vectara_llms.py +4 -0
- tests/test_vhc.py +67 -0
- tests/test_workflow.py +13 -28
- vectara_agentic/__init__.py +27 -4
- vectara_agentic/_callback.py +65 -67
- vectara_agentic/_observability.py +30 -30
- vectara_agentic/_version.py +1 -1
- vectara_agentic/agent.py +565 -859
- vectara_agentic/agent_config.py +15 -14
- vectara_agentic/agent_core/__init__.py +22 -0
- vectara_agentic/agent_core/factory.py +383 -0
- vectara_agentic/{_prompts.py → agent_core/prompts.py} +21 -46
- vectara_agentic/agent_core/serialization.py +348 -0
- vectara_agentic/agent_core/streaming.py +483 -0
- vectara_agentic/agent_core/utils/__init__.py +29 -0
- vectara_agentic/agent_core/utils/hallucination.py +157 -0
- vectara_agentic/agent_core/utils/logging.py +52 -0
- vectara_agentic/agent_core/utils/schemas.py +87 -0
- vectara_agentic/agent_core/utils/tools.py +125 -0
- vectara_agentic/agent_endpoint.py +4 -6
- vectara_agentic/db_tools.py +37 -12
- vectara_agentic/llm_utils.py +42 -43
- vectara_agentic/sub_query_workflow.py +9 -14
- vectara_agentic/tool_utils.py +138 -83
- vectara_agentic/tools.py +36 -21
- vectara_agentic/tools_catalog.py +16 -16
- vectara_agentic/types.py +106 -8
- {vectara_agentic-0.3.3.dist-info → vectara_agentic-0.4.1.dist-info}/METADATA +111 -31
- vectara_agentic-0.4.1.dist-info/RECORD +53 -0
- tests/test_agent_planning.py +0 -64
- tests/test_hhem.py +0 -100
- vectara_agentic/hhem.py +0 -82
- vectara_agentic-0.3.3.dist-info/RECORD +0 -39
- {vectara_agentic-0.3.3.dist-info → vectara_agentic-0.4.1.dist-info}/WHEEL +0 -0
- {vectara_agentic-0.3.3.dist-info → vectara_agentic-0.4.1.dist-info}/licenses/LICENSE +0 -0
- {vectara_agentic-0.3.3.dist-info → vectara_agentic-0.4.1.dist-info}/top_level.txt +0 -0
tests/test_serialization.py
CHANGED
|
@@ -1,3 +1,7 @@
|
|
|
1
|
+
# Suppress external dependency warnings before any other imports
|
|
2
|
+
import warnings
|
|
3
|
+
warnings.simplefilter("ignore", DeprecationWarning)
|
|
4
|
+
|
|
1
5
|
import unittest
|
|
2
6
|
import threading
|
|
3
7
|
import os
|
|
@@ -10,8 +14,7 @@ from vectara_agentic.tools import ToolsFactory
|
|
|
10
14
|
from llama_index.core.utilities.sql_wrapper import SQLDatabase
|
|
11
15
|
from sqlalchemy import create_engine
|
|
12
16
|
|
|
13
|
-
|
|
14
|
-
return x * y
|
|
17
|
+
from conftest import mult, STANDARD_TEST_TOPIC, STANDARD_TEST_INSTRUCTIONS
|
|
15
18
|
|
|
16
19
|
|
|
17
20
|
ARIZE_LOCK = threading.Lock()
|
|
@@ -40,12 +43,10 @@ class TestAgentSerialization(unittest.TestCase):
|
|
|
40
43
|
)
|
|
41
44
|
|
|
42
45
|
tools = [ToolsFactory().create_tool(mult)] + ToolsFactory().standard_tools() + db_tools
|
|
43
|
-
topic = "AI topic"
|
|
44
|
-
instructions = "Always do as your father tells you, if your mother agrees!"
|
|
45
46
|
agent = Agent(
|
|
46
47
|
tools=tools,
|
|
47
|
-
topic=
|
|
48
|
-
custom_instructions=
|
|
48
|
+
topic=STANDARD_TEST_TOPIC,
|
|
49
|
+
custom_instructions=STANDARD_TEST_INSTRUCTIONS,
|
|
49
50
|
agent_config=config
|
|
50
51
|
)
|
|
51
52
|
|
|
@@ -0,0 +1,252 @@
|
|
|
1
|
+
# Suppress external dependency warnings before any other imports
|
|
2
|
+
import warnings
|
|
3
|
+
|
|
4
|
+
warnings.simplefilter("ignore", DeprecationWarning)
|
|
5
|
+
|
|
6
|
+
import unittest
|
|
7
|
+
import threading
|
|
8
|
+
from datetime import date
|
|
9
|
+
|
|
10
|
+
from vectara_agentic.agent import Agent
|
|
11
|
+
from vectara_agentic.agent_config import AgentConfig
|
|
12
|
+
from vectara_agentic.types import ModelProvider
|
|
13
|
+
from vectara_agentic.tools import ToolsFactory
|
|
14
|
+
from llama_index.core.llms import ChatMessage, MessageRole
|
|
15
|
+
from conftest import mult, add
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
ARIZE_LOCK = threading.Lock()
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class TestSessionMemoryManagement(unittest.TestCase):
|
|
22
|
+
"""Test session_id parameter and memory management functionality"""
|
|
23
|
+
|
|
24
|
+
def setUp(self):
|
|
25
|
+
"""Set up test fixtures"""
|
|
26
|
+
self.tools = [ToolsFactory().create_tool(mult), ToolsFactory().create_tool(add)]
|
|
27
|
+
self.topic = "Mathematics"
|
|
28
|
+
self.custom_instructions = "You are a helpful math assistant."
|
|
29
|
+
self.config = AgentConfig(main_llm_provider=ModelProvider.ANTHROPIC)
|
|
30
|
+
|
|
31
|
+
def test_agent_init_with_session_id(self):
|
|
32
|
+
"""Test Agent initialization with custom session_id"""
|
|
33
|
+
custom_session_id = "test-session-123"
|
|
34
|
+
|
|
35
|
+
agent = Agent(
|
|
36
|
+
tools=self.tools,
|
|
37
|
+
topic=self.topic,
|
|
38
|
+
custom_instructions=self.custom_instructions,
|
|
39
|
+
agent_config=self.config,
|
|
40
|
+
session_id=custom_session_id,
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
# Verify the agent uses the provided session_id
|
|
44
|
+
self.assertEqual(agent.session_id, custom_session_id)
|
|
45
|
+
|
|
46
|
+
# Verify memory uses the same session_id
|
|
47
|
+
self.assertEqual(agent.memory.session_id, custom_session_id)
|
|
48
|
+
|
|
49
|
+
def test_agent_init_without_session_id(self):
|
|
50
|
+
"""Test Agent initialization without session_id (auto-generation)"""
|
|
51
|
+
agent = Agent(
|
|
52
|
+
tools=self.tools,
|
|
53
|
+
topic=self.topic,
|
|
54
|
+
custom_instructions=self.custom_instructions,
|
|
55
|
+
agent_config=self.config,
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
# Verify auto-generated session_id follows expected pattern
|
|
59
|
+
expected_pattern = f"{self.topic}:{date.today().isoformat()}"
|
|
60
|
+
self.assertEqual(agent.session_id, expected_pattern)
|
|
61
|
+
|
|
62
|
+
# Verify memory uses the same session_id
|
|
63
|
+
self.assertEqual(agent.memory.session_id, expected_pattern)
|
|
64
|
+
|
|
65
|
+
def test_from_tools_with_session_id(self):
|
|
66
|
+
"""Test Agent.from_tools() with custom session_id"""
|
|
67
|
+
custom_session_id = "from-tools-session-456"
|
|
68
|
+
|
|
69
|
+
agent = Agent.from_tools(
|
|
70
|
+
tools=self.tools,
|
|
71
|
+
topic=self.topic,
|
|
72
|
+
custom_instructions=self.custom_instructions,
|
|
73
|
+
agent_config=self.config,
|
|
74
|
+
session_id=custom_session_id,
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
# Verify the agent uses the provided session_id
|
|
78
|
+
self.assertEqual(agent.session_id, custom_session_id)
|
|
79
|
+
self.assertEqual(agent.memory.session_id, custom_session_id)
|
|
80
|
+
|
|
81
|
+
def test_from_tools_without_session_id(self):
|
|
82
|
+
"""Test Agent.from_tools() without session_id (auto-generation)"""
|
|
83
|
+
agent = Agent.from_tools(
|
|
84
|
+
tools=self.tools,
|
|
85
|
+
topic=self.topic,
|
|
86
|
+
custom_instructions=self.custom_instructions,
|
|
87
|
+
agent_config=self.config,
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
# Verify auto-generated session_id
|
|
91
|
+
expected_pattern = f"{self.topic}:{date.today().isoformat()}"
|
|
92
|
+
self.assertEqual(agent.session_id, expected_pattern)
|
|
93
|
+
self.assertEqual(agent.memory.session_id, expected_pattern)
|
|
94
|
+
|
|
95
|
+
def test_session_id_consistency_across_agents(self):
|
|
96
|
+
"""Test that agents with same session_id have consistent session_id attributes"""
|
|
97
|
+
shared_session_id = "shared-session-id-test"
|
|
98
|
+
|
|
99
|
+
# Create two agents with the same session_id
|
|
100
|
+
agent1 = Agent(
|
|
101
|
+
tools=self.tools,
|
|
102
|
+
topic=self.topic,
|
|
103
|
+
custom_instructions=self.custom_instructions,
|
|
104
|
+
agent_config=self.config,
|
|
105
|
+
session_id=shared_session_id,
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
agent2 = Agent(
|
|
109
|
+
tools=self.tools,
|
|
110
|
+
topic=self.topic,
|
|
111
|
+
custom_instructions=self.custom_instructions,
|
|
112
|
+
agent_config=self.config,
|
|
113
|
+
session_id=shared_session_id,
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
# Verify both agents have the same session_id
|
|
117
|
+
self.assertEqual(agent1.session_id, shared_session_id)
|
|
118
|
+
self.assertEqual(agent2.session_id, shared_session_id)
|
|
119
|
+
self.assertEqual(agent1.session_id, agent2.session_id)
|
|
120
|
+
|
|
121
|
+
# Verify their memory instances also have the correct session_id
|
|
122
|
+
self.assertEqual(agent1.memory.session_id, shared_session_id)
|
|
123
|
+
self.assertEqual(agent2.memory.session_id, shared_session_id)
|
|
124
|
+
|
|
125
|
+
# Note: Each agent gets its own Memory instance (this is expected behavior)
|
|
126
|
+
# In production, memory persistence happens through serialization/deserialization
|
|
127
|
+
|
|
128
|
+
def test_memory_isolation_different_sessions(self):
|
|
129
|
+
"""Test that agents with different session_ids have isolated memory"""
|
|
130
|
+
session_id_1 = "isolated-session-1"
|
|
131
|
+
session_id_2 = "isolated-session-2"
|
|
132
|
+
|
|
133
|
+
# Create two agents with different session_ids
|
|
134
|
+
agent1 = Agent(
|
|
135
|
+
tools=self.tools,
|
|
136
|
+
topic=self.topic,
|
|
137
|
+
custom_instructions=self.custom_instructions,
|
|
138
|
+
agent_config=self.config,
|
|
139
|
+
session_id=session_id_1,
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
agent2 = Agent(
|
|
143
|
+
tools=self.tools,
|
|
144
|
+
topic=self.topic,
|
|
145
|
+
custom_instructions=self.custom_instructions,
|
|
146
|
+
agent_config=self.config,
|
|
147
|
+
session_id=session_id_2,
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
# Add messages to agent1's memory
|
|
151
|
+
agent1_messages = [
|
|
152
|
+
ChatMessage(role=MessageRole.USER, content="Agent 1 question"),
|
|
153
|
+
ChatMessage(role=MessageRole.ASSISTANT, content="Agent 1 response"),
|
|
154
|
+
]
|
|
155
|
+
agent1.memory.put_messages(agent1_messages)
|
|
156
|
+
|
|
157
|
+
# Add different messages to agent2's memory
|
|
158
|
+
agent2_messages = [
|
|
159
|
+
ChatMessage(role=MessageRole.USER, content="Agent 2 question"),
|
|
160
|
+
ChatMessage(role=MessageRole.ASSISTANT, content="Agent 2 response"),
|
|
161
|
+
]
|
|
162
|
+
agent2.memory.put_messages(agent2_messages)
|
|
163
|
+
|
|
164
|
+
# Verify memory isolation
|
|
165
|
+
retrieved_agent1_messages = agent1.memory.get()
|
|
166
|
+
retrieved_agent2_messages = agent2.memory.get()
|
|
167
|
+
|
|
168
|
+
self.assertEqual(len(retrieved_agent1_messages), 2)
|
|
169
|
+
self.assertEqual(len(retrieved_agent2_messages), 2)
|
|
170
|
+
|
|
171
|
+
# Verify agent1 only has its own messages
|
|
172
|
+
self.assertEqual(retrieved_agent1_messages[0].content, "Agent 1 question")
|
|
173
|
+
self.assertEqual(retrieved_agent1_messages[1].content, "Agent 1 response")
|
|
174
|
+
|
|
175
|
+
# Verify agent2 only has its own messages
|
|
176
|
+
self.assertEqual(retrieved_agent2_messages[0].content, "Agent 2 question")
|
|
177
|
+
self.assertEqual(retrieved_agent2_messages[1].content, "Agent 2 response")
|
|
178
|
+
|
|
179
|
+
def test_serialization_preserves_session_id(self):
|
|
180
|
+
"""Test that agent serialization preserves custom session_id"""
|
|
181
|
+
custom_session_id = "serialization-test-session"
|
|
182
|
+
|
|
183
|
+
# Create agent with custom session_id
|
|
184
|
+
original_agent = Agent(
|
|
185
|
+
tools=self.tools,
|
|
186
|
+
topic=self.topic,
|
|
187
|
+
custom_instructions=self.custom_instructions,
|
|
188
|
+
agent_config=self.config,
|
|
189
|
+
session_id=custom_session_id,
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
# Add some memory
|
|
193
|
+
test_messages = [
|
|
194
|
+
ChatMessage(role=MessageRole.USER, content="Test question"),
|
|
195
|
+
ChatMessage(role=MessageRole.ASSISTANT, content="Test answer"),
|
|
196
|
+
]
|
|
197
|
+
original_agent.memory.put_messages(test_messages)
|
|
198
|
+
|
|
199
|
+
# Serialize the agent
|
|
200
|
+
serialized_data = original_agent.dumps()
|
|
201
|
+
|
|
202
|
+
# Deserialize the agent
|
|
203
|
+
restored_agent = Agent.loads(serialized_data)
|
|
204
|
+
|
|
205
|
+
# Verify session_id is preserved
|
|
206
|
+
self.assertEqual(restored_agent.session_id, custom_session_id)
|
|
207
|
+
self.assertEqual(restored_agent.memory.session_id, custom_session_id)
|
|
208
|
+
|
|
209
|
+
# Verify memory is preserved
|
|
210
|
+
restored_messages = restored_agent.memory.get()
|
|
211
|
+
self.assertEqual(len(restored_messages), 2)
|
|
212
|
+
self.assertEqual(restored_messages[0].content, "Test question")
|
|
213
|
+
self.assertEqual(restored_messages[1].content, "Test answer")
|
|
214
|
+
|
|
215
|
+
def test_chat_history_initialization_with_session_id(self):
|
|
216
|
+
"""Test Agent initialization with chat_history and custom session_id"""
|
|
217
|
+
custom_session_id = "chat-history-session"
|
|
218
|
+
chat_history = [
|
|
219
|
+
("Hello", "Hi there!"),
|
|
220
|
+
("How are you?", "I'm doing well, thank you!"),
|
|
221
|
+
]
|
|
222
|
+
|
|
223
|
+
agent = Agent(
|
|
224
|
+
tools=self.tools,
|
|
225
|
+
topic=self.topic,
|
|
226
|
+
custom_instructions=self.custom_instructions,
|
|
227
|
+
agent_config=self.config,
|
|
228
|
+
session_id=custom_session_id,
|
|
229
|
+
chat_history=chat_history,
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
# Verify session_id is correct
|
|
233
|
+
self.assertEqual(agent.session_id, custom_session_id)
|
|
234
|
+
self.assertEqual(agent.memory.session_id, custom_session_id)
|
|
235
|
+
|
|
236
|
+
# Verify chat history was loaded into memory
|
|
237
|
+
messages = agent.memory.get()
|
|
238
|
+
self.assertEqual(len(messages), 4) # 2 user + 2 assistant messages
|
|
239
|
+
|
|
240
|
+
# Verify message content and roles
|
|
241
|
+
self.assertEqual(messages[0].role, MessageRole.USER)
|
|
242
|
+
self.assertEqual(messages[0].content, "Hello")
|
|
243
|
+
self.assertEqual(messages[1].role, MessageRole.ASSISTANT)
|
|
244
|
+
self.assertEqual(messages[1].content, "Hi there!")
|
|
245
|
+
self.assertEqual(messages[2].role, MessageRole.USER)
|
|
246
|
+
self.assertEqual(messages[2].content, "How are you?")
|
|
247
|
+
self.assertEqual(messages[3].role, MessageRole.ASSISTANT)
|
|
248
|
+
self.assertEqual(messages[3].content, "I'm doing well, thank you!")
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
if __name__ == "__main__":
|
|
252
|
+
unittest.main()
|
tests/test_streaming.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
1
|
+
# Suppress external dependency warnings before any other imports
|
|
2
|
+
import warnings
|
|
3
|
+
|
|
4
|
+
warnings.simplefilter("ignore", DeprecationWarning)
|
|
5
|
+
|
|
6
|
+
import unittest
|
|
7
|
+
import asyncio
|
|
8
|
+
|
|
9
|
+
from vectara_agentic.agent import Agent
|
|
10
|
+
from vectara_agentic.tools import ToolsFactory
|
|
11
|
+
|
|
12
|
+
import nest_asyncio
|
|
13
|
+
|
|
14
|
+
nest_asyncio.apply()
|
|
15
|
+
|
|
16
|
+
from conftest import (
|
|
17
|
+
fc_config_openai,
|
|
18
|
+
fc_config_anthropic,
|
|
19
|
+
mult,
|
|
20
|
+
STANDARD_TEST_TOPIC,
|
|
21
|
+
STANDARD_TEST_INSTRUCTIONS,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class TestAgentStreaming(unittest.IsolatedAsyncioTestCase):
|
|
26
|
+
|
|
27
|
+
async def test_anthropic(self):
|
|
28
|
+
tools = [ToolsFactory().create_tool(mult)]
|
|
29
|
+
agent = Agent(
|
|
30
|
+
agent_config=fc_config_anthropic,
|
|
31
|
+
tools=tools,
|
|
32
|
+
topic=STANDARD_TEST_TOPIC,
|
|
33
|
+
custom_instructions=STANDARD_TEST_INSTRUCTIONS,
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
# First calculation: 5 * 10 = 50
|
|
37
|
+
stream1 = await agent.astream_chat(
|
|
38
|
+
"What is 5 times 10. Only give the answer, nothing else"
|
|
39
|
+
)
|
|
40
|
+
# Consume the stream
|
|
41
|
+
async for chunk in stream1.async_response_gen():
|
|
42
|
+
pass
|
|
43
|
+
_ = await stream1.aget_response()
|
|
44
|
+
|
|
45
|
+
# Second calculation: 3 * 7 = 21
|
|
46
|
+
stream2 = await agent.astream_chat(
|
|
47
|
+
"what is 3 times 7. Only give the answer, nothing else"
|
|
48
|
+
)
|
|
49
|
+
# Consume the stream
|
|
50
|
+
async for chunk in stream2.async_response_gen():
|
|
51
|
+
pass
|
|
52
|
+
_ = await stream2.aget_response()
|
|
53
|
+
|
|
54
|
+
# Final calculation: 50 * 21 = 1050
|
|
55
|
+
stream3 = await agent.astream_chat(
|
|
56
|
+
"multiply the results of the last two multiplications. Only give the answer, nothing else."
|
|
57
|
+
)
|
|
58
|
+
# Consume the stream
|
|
59
|
+
async for chunk in stream3.async_response_gen():
|
|
60
|
+
pass
|
|
61
|
+
response3 = await stream3.aget_response()
|
|
62
|
+
|
|
63
|
+
self.assertIn("1050", response3.response)
|
|
64
|
+
|
|
65
|
+
async def test_openai(self):
|
|
66
|
+
tools = [ToolsFactory().create_tool(mult)]
|
|
67
|
+
agent = Agent(
|
|
68
|
+
agent_config=fc_config_openai,
|
|
69
|
+
tools=tools,
|
|
70
|
+
topic=STANDARD_TEST_TOPIC,
|
|
71
|
+
custom_instructions=STANDARD_TEST_INSTRUCTIONS,
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
# First calculation: 5 * 10 = 50
|
|
75
|
+
stream1 = await agent.astream_chat(
|
|
76
|
+
"What is 5 times 10. Only give the answer, nothing else"
|
|
77
|
+
)
|
|
78
|
+
# Consume the stream
|
|
79
|
+
async for chunk in stream1.async_response_gen():
|
|
80
|
+
pass
|
|
81
|
+
_ = await stream1.aget_response()
|
|
82
|
+
|
|
83
|
+
# Second calculation: 3 * 7 = 21
|
|
84
|
+
stream2 = await agent.astream_chat(
|
|
85
|
+
"what is 3 times 7. Only give the answer, nothing else"
|
|
86
|
+
)
|
|
87
|
+
# Consume the stream
|
|
88
|
+
async for chunk in stream2.async_response_gen():
|
|
89
|
+
pass
|
|
90
|
+
_ = await stream2.aget_response()
|
|
91
|
+
|
|
92
|
+
# Final calculation: 50 * 21 = 1050
|
|
93
|
+
stream3 = await agent.astream_chat(
|
|
94
|
+
"multiply the results of the last two multiplications. Only give the answer, nothing else."
|
|
95
|
+
)
|
|
96
|
+
# Consume the stream
|
|
97
|
+
async for chunk in stream3.async_response_gen():
|
|
98
|
+
pass
|
|
99
|
+
response3 = await stream3.aget_response()
|
|
100
|
+
|
|
101
|
+
self.assertIn("1050", response3.response)
|
|
102
|
+
|
|
103
|
+
def test_openai_sync(self):
|
|
104
|
+
"""Synchronous wrapper for the async test"""
|
|
105
|
+
asyncio.run(self.test_openai())
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
if __name__ == "__main__":
|
|
109
|
+
unittest.main()
|
tests/test_together.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
# Suppress external dependency warnings before any other imports
|
|
2
|
+
import warnings
|
|
3
|
+
warnings.simplefilter("ignore", DeprecationWarning)
|
|
4
|
+
|
|
5
|
+
import unittest
|
|
6
|
+
import threading
|
|
7
|
+
|
|
8
|
+
from vectara_agentic.agent import Agent
|
|
9
|
+
from vectara_agentic.tools import ToolsFactory
|
|
10
|
+
|
|
11
|
+
import nest_asyncio
|
|
12
|
+
nest_asyncio.apply()
|
|
13
|
+
|
|
14
|
+
from conftest import fc_config_together, mult, STANDARD_TEST_TOPIC, STANDARD_TEST_INSTRUCTIONS
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
ARIZE_LOCK = threading.Lock()
|
|
18
|
+
|
|
19
|
+
class TestTogether(unittest.IsolatedAsyncioTestCase):
|
|
20
|
+
|
|
21
|
+
async def test_multiturn(self):
|
|
22
|
+
with ARIZE_LOCK:
|
|
23
|
+
tools = [ToolsFactory().create_tool(mult)]
|
|
24
|
+
agent = Agent(
|
|
25
|
+
agent_config=fc_config_together,
|
|
26
|
+
tools=tools,
|
|
27
|
+
topic=STANDARD_TEST_TOPIC,
|
|
28
|
+
custom_instructions=STANDARD_TEST_INSTRUCTIONS,
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
# First calculation: 5 * 10 = 50
|
|
32
|
+
stream1 = await agent.astream_chat(
|
|
33
|
+
"What is 5 times 10. Only give the answer, nothing else"
|
|
34
|
+
)
|
|
35
|
+
# Consume the stream
|
|
36
|
+
async for chunk in stream1.async_response_gen():
|
|
37
|
+
pass
|
|
38
|
+
_ = await stream1.aget_response()
|
|
39
|
+
|
|
40
|
+
# Second calculation: 3 * 7 = 21
|
|
41
|
+
stream2 = await agent.astream_chat(
|
|
42
|
+
"what is 3 times 7. Only give the answer, nothing else"
|
|
43
|
+
)
|
|
44
|
+
# Consume the stream
|
|
45
|
+
async for chunk in stream2.async_response_gen():
|
|
46
|
+
pass
|
|
47
|
+
_ = await stream2.aget_response()
|
|
48
|
+
|
|
49
|
+
# Final calculation: 50 * 21 = 1050
|
|
50
|
+
stream3 = await agent.astream_chat(
|
|
51
|
+
"multiply the results of the last two questions. Output only the answer."
|
|
52
|
+
)
|
|
53
|
+
# Consume the stream
|
|
54
|
+
async for chunk in stream3.async_response_gen():
|
|
55
|
+
pass
|
|
56
|
+
response3 = await stream3.aget_response()
|
|
57
|
+
|
|
58
|
+
self.assertEqual(response3.response, "1050")
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
if __name__ == "__main__":
|
|
62
|
+
unittest.main()
|
tests/test_tools.py
CHANGED
|
@@ -1,3 +1,8 @@
|
|
|
1
|
+
# Suppress external dependency warnings before any other imports
|
|
2
|
+
import warnings
|
|
3
|
+
|
|
4
|
+
warnings.simplefilter("ignore", DeprecationWarning)
|
|
5
|
+
|
|
1
6
|
import unittest
|
|
2
7
|
from pydantic import Field, BaseModel
|
|
3
8
|
from unittest.mock import patch, MagicMock
|
|
@@ -13,7 +18,6 @@ from vectara_agentic.tools import (
|
|
|
13
18
|
)
|
|
14
19
|
from vectara_agentic.agent import Agent
|
|
15
20
|
from vectara_agentic.agent_config import AgentConfig
|
|
16
|
-
from vectara_agentic.types import AgentType, ModelProvider
|
|
17
21
|
|
|
18
22
|
from llama_index.core.tools import FunctionTool
|
|
19
23
|
|
|
@@ -96,7 +100,6 @@ class TestToolsPackage(unittest.TestCase):
|
|
|
96
100
|
examples=["AAPL", "GOOG"],
|
|
97
101
|
)
|
|
98
102
|
year: Optional[int | str] = Field(
|
|
99
|
-
default=None,
|
|
100
103
|
description="The year this query relates to. An integer between 2015 and 2024 or a string specifying a condition on the year",
|
|
101
104
|
examples=[
|
|
102
105
|
2020,
|
|
@@ -154,8 +157,7 @@ class TestToolsPackage(unittest.TestCase):
|
|
|
154
157
|
description="The ticker symbol for the company",
|
|
155
158
|
examples=["AAPL", "GOOG"],
|
|
156
159
|
)
|
|
157
|
-
year: int | str = Field(
|
|
158
|
-
default=None,
|
|
160
|
+
year: Optional[int | str] = Field(
|
|
159
161
|
description="The year this query relates to. An integer between 2015 and 2024 or a string specifying a condition on the year",
|
|
160
162
|
examples=[
|
|
161
163
|
2020,
|
|
@@ -235,80 +237,6 @@ class TestToolsPackage(unittest.TestCase):
|
|
|
235
237
|
self.assertIsInstance(tool, FunctionTool)
|
|
236
238
|
self.assertEqual(tool.metadata.tool_type, ToolType.QUERY)
|
|
237
239
|
|
|
238
|
-
def test_tool_with_many_arguments(self):
|
|
239
|
-
vec_factory = VectaraToolFactory(vectara_corpus_key, vectara_api_key)
|
|
240
|
-
|
|
241
|
-
class QueryToolArgs(BaseModel):
|
|
242
|
-
arg1: str = Field(description="the first argument", examples=["val1"])
|
|
243
|
-
arg2: str = Field(description="the second argument", examples=["val2"])
|
|
244
|
-
arg3: str = Field(description="the third argument", examples=["val3"])
|
|
245
|
-
arg4: str = Field(description="the fourth argument", examples=["val4"])
|
|
246
|
-
arg5: str = Field(description="the fifth argument", examples=["val5"])
|
|
247
|
-
arg6: str = Field(description="the sixth argument", examples=["val6"])
|
|
248
|
-
arg7: str = Field(description="the seventh argument", examples=["val7"])
|
|
249
|
-
arg8: str = Field(description="the eighth argument", examples=["val8"])
|
|
250
|
-
arg9: str = Field(description="the ninth argument", examples=["val9"])
|
|
251
|
-
arg10: str = Field(description="the tenth argument", examples=["val10"])
|
|
252
|
-
arg11: str = Field(description="the eleventh argument", examples=["val11"])
|
|
253
|
-
arg12: str = Field(description="the twelfth argument", examples=["val12"])
|
|
254
|
-
arg13: str = Field(
|
|
255
|
-
description="the thirteenth argument", examples=["val13"]
|
|
256
|
-
)
|
|
257
|
-
arg14: str = Field(
|
|
258
|
-
description="the fourteenth argument", examples=["val14"]
|
|
259
|
-
)
|
|
260
|
-
arg15: str = Field(description="the fifteenth argument", examples=["val15"])
|
|
261
|
-
|
|
262
|
-
query_tool_1 = vec_factory.create_rag_tool(
|
|
263
|
-
tool_name="rag_tool",
|
|
264
|
-
tool_description="""
|
|
265
|
-
A dummy tool that takes 15 arguments and returns a response (str) to the user query based on the data in this corpus.
|
|
266
|
-
We are using this tool to test the tool factory works and does not crash with OpenAI.
|
|
267
|
-
""",
|
|
268
|
-
tool_args_schema=QueryToolArgs,
|
|
269
|
-
)
|
|
270
|
-
|
|
271
|
-
# Test with 15 arguments to make sure no issues occur
|
|
272
|
-
config = AgentConfig(agent_type=AgentType.OPENAI)
|
|
273
|
-
agent = Agent(
|
|
274
|
-
tools=[query_tool_1],
|
|
275
|
-
topic="Sample topic",
|
|
276
|
-
custom_instructions="Call the tool with 15 arguments for OPENAI",
|
|
277
|
-
agent_config=config,
|
|
278
|
-
)
|
|
279
|
-
res = agent.chat("What is the stock price for Yahoo on 12/31/22?")
|
|
280
|
-
self.assertNotIn("maximum length of 1024 characters", str(res))
|
|
281
|
-
|
|
282
|
-
# Same test but with GROQ, should not have this limit
|
|
283
|
-
config = AgentConfig(
|
|
284
|
-
agent_type=AgentType.FUNCTION_CALLING,
|
|
285
|
-
main_llm_provider=ModelProvider.GROQ,
|
|
286
|
-
tool_llm_provider=ModelProvider.GROQ,
|
|
287
|
-
)
|
|
288
|
-
agent = Agent(
|
|
289
|
-
tools=[query_tool_1],
|
|
290
|
-
topic="Sample topic",
|
|
291
|
-
custom_instructions="Call the tool with 15 arguments for GROQ",
|
|
292
|
-
agent_config=config,
|
|
293
|
-
)
|
|
294
|
-
res = agent.chat("What is the stock price?")
|
|
295
|
-
self.assertNotIn("maximum length of 1024 characters", str(res))
|
|
296
|
-
|
|
297
|
-
# Same test but with ANTHROPIC, should not have this limit
|
|
298
|
-
config = AgentConfig(
|
|
299
|
-
agent_type=AgentType.FUNCTION_CALLING,
|
|
300
|
-
main_llm_provider=ModelProvider.ANTHROPIC,
|
|
301
|
-
tool_llm_provider=ModelProvider.ANTHROPIC,
|
|
302
|
-
)
|
|
303
|
-
agent = Agent(
|
|
304
|
-
tools=[query_tool_1],
|
|
305
|
-
topic="Sample topic",
|
|
306
|
-
custom_instructions="Call the tool with 15 arguments for ANTHROPIC",
|
|
307
|
-
agent_config=config,
|
|
308
|
-
)
|
|
309
|
-
res = agent.chat("What is the stock price?")
|
|
310
|
-
self.assertIn("stock price", str(res))
|
|
311
|
-
|
|
312
240
|
@patch.object(VectaraIndex, "as_query_engine")
|
|
313
241
|
def test_vectara_tool_args_type(
|
|
314
242
|
self,
|
|
@@ -384,7 +312,7 @@ class TestToolsPackage(unittest.TestCase):
|
|
|
384
312
|
def __init__(self):
|
|
385
313
|
pass
|
|
386
314
|
|
|
387
|
-
def mult(self, x, y):
|
|
315
|
+
def mult(self, x: float, y: float) -> float:
|
|
388
316
|
return x * y
|
|
389
317
|
|
|
390
318
|
test_class = TestClass()
|
|
@@ -410,7 +338,7 @@ class TestToolsPackage(unittest.TestCase):
|
|
|
410
338
|
class DummyArgs(BaseModel):
|
|
411
339
|
foo: int = Field(..., description="how many foos", examples=[1, 2, 3])
|
|
412
340
|
bar: str = Field(
|
|
413
|
-
"baz",
|
|
341
|
+
default="baz",
|
|
414
342
|
description="what bar to use",
|
|
415
343
|
examples=["x", "y"],
|
|
416
344
|
)
|
|
@@ -425,7 +353,7 @@ class TestToolsPackage(unittest.TestCase):
|
|
|
425
353
|
doc = dummy_tool.metadata.description
|
|
426
354
|
self.assertTrue(
|
|
427
355
|
doc.startswith(
|
|
428
|
-
"dummy_tool(query: str, foo: int, bar: str) -> dict[str, Any]"
|
|
356
|
+
"dummy_tool(query: str, foo: int, bar: str | None) -> dict[str, Any]"
|
|
429
357
|
)
|
|
430
358
|
)
|
|
431
359
|
self.assertIn("Args:", doc)
|
|
@@ -433,7 +361,7 @@ class TestToolsPackage(unittest.TestCase):
|
|
|
433
361
|
"query (str): The search query to perform, in the form of a question", doc
|
|
434
362
|
)
|
|
435
363
|
self.assertIn("foo (int): how many foos (e.g., 1, 2, 3)", doc)
|
|
436
|
-
self.assertIn("bar (str, default='baz'): what bar to use (e.g., 'x', 'y')", doc)
|
|
364
|
+
self.assertIn("bar (str | None, default='baz'): what bar to use (e.g., 'x', 'y')", doc)
|
|
437
365
|
self.assertIn("Returns:", doc)
|
|
438
366
|
self.assertIn("dict[str, Any]: A dictionary containing the result data.", doc)
|
|
439
367
|
|
tests/test_vectara_llms.py
CHANGED
tests/test_vhc.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
# Suppress external dependency warnings before any other imports
|
|
2
|
+
import warnings
|
|
3
|
+
warnings.simplefilter("ignore", DeprecationWarning)
|
|
4
|
+
|
|
5
|
+
import unittest
|
|
6
|
+
|
|
7
|
+
from vectara_agentic.agent import Agent, AgentType
|
|
8
|
+
from vectara_agentic.agent_config import AgentConfig
|
|
9
|
+
from vectara_agentic.tools import ToolsFactory
|
|
10
|
+
from vectara_agentic.types import ModelProvider
|
|
11
|
+
|
|
12
|
+
import nest_asyncio
|
|
13
|
+
nest_asyncio.apply()
|
|
14
|
+
|
|
15
|
+
statements = [
|
|
16
|
+
"The sky is blue.",
|
|
17
|
+
"Cats are better than dogs.",
|
|
18
|
+
"Python is a great programming language.",
|
|
19
|
+
"The Earth revolves around the Sun.",
|
|
20
|
+
"Chocolate is the best ice cream flavor.",
|
|
21
|
+
]
|
|
22
|
+
st_inx = 0
|
|
23
|
+
def get_statement() -> str:
|
|
24
|
+
"Generate next statement"
|
|
25
|
+
global st_inx
|
|
26
|
+
st = statements[st_inx]
|
|
27
|
+
st_inx += 1
|
|
28
|
+
return st
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
fc_config = AgentConfig(
|
|
32
|
+
agent_type=AgentType.FUNCTION_CALLING,
|
|
33
|
+
main_llm_provider=ModelProvider.OPENAI,
|
|
34
|
+
tool_llm_provider=ModelProvider.OPENAI,
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
vectara_api_key = 'zqt_UXrBcnI2UXINZkrv4g1tQPhzj02vfdtqYJIDiA'
|
|
38
|
+
|
|
39
|
+
class TestVHC(unittest.TestCase):
|
|
40
|
+
|
|
41
|
+
def test_vhc(self):
|
|
42
|
+
tools = [ToolsFactory().create_tool(get_statement)]
|
|
43
|
+
topic = "statements"
|
|
44
|
+
instructions = (
|
|
45
|
+
f"Call the get_statement tool multiple times to get all {len(statements)} statements."
|
|
46
|
+
f"Respond to the user question based exclusively on the statements you receive - do not use any other knowledge or information."
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
agent = Agent(
|
|
50
|
+
tools=tools,
|
|
51
|
+
topic=topic,
|
|
52
|
+
agent_config=fc_config,
|
|
53
|
+
custom_instructions=instructions,
|
|
54
|
+
vectara_api_key=vectara_api_key,
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
_ = agent.chat("Are large cats better than small dogs?")
|
|
58
|
+
vhc_res = agent.compute_vhc()
|
|
59
|
+
vhc_corrections = vhc_res.get("corrections", [])
|
|
60
|
+
self.assertTrue(
|
|
61
|
+
len(vhc_corrections) >= 0 and len(vhc_corrections) <= 2,
|
|
62
|
+
"Corrections should be between 0 and 2"
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
if __name__ == "__main__":
|
|
67
|
+
unittest.main()
|