vectara-agentic 0.4.0__py3-none-any.whl → 0.4.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of vectara-agentic might be problematic. Click here for more details.
- tests/benchmark_models.py +945 -0
- tests/conftest.py +9 -5
- tests/run_tests.py +3 -0
- tests/test_agent.py +57 -29
- tests/test_agent_fallback_memory.py +270 -0
- tests/test_agent_memory_consistency.py +229 -0
- tests/test_agent_type.py +4 -0
- tests/test_bedrock.py +46 -31
- tests/test_fallback.py +1 -1
- tests/test_gemini.py +7 -22
- tests/test_groq.py +46 -31
- tests/test_private_llm.py +1 -1
- tests/test_serialization.py +3 -6
- tests/test_session_memory.py +252 -0
- tests/test_streaming.py +58 -37
- tests/test_together.py +62 -0
- tests/test_vhc.py +3 -2
- tests/test_workflow.py +9 -28
- vectara_agentic/_observability.py +19 -0
- vectara_agentic/_version.py +1 -1
- vectara_agentic/agent.py +246 -37
- vectara_agentic/agent_core/factory.py +34 -153
- vectara_agentic/agent_core/prompts.py +19 -13
- vectara_agentic/agent_core/serialization.py +17 -8
- vectara_agentic/agent_core/streaming.py +27 -43
- vectara_agentic/agent_core/utils/__init__.py +0 -5
- vectara_agentic/agent_core/utils/hallucination.py +54 -99
- vectara_agentic/llm_utils.py +4 -2
- vectara_agentic/sub_query_workflow.py +3 -2
- vectara_agentic/tools.py +0 -19
- vectara_agentic/types.py +9 -3
- {vectara_agentic-0.4.0.dist-info → vectara_agentic-0.4.2.dist-info}/METADATA +79 -39
- vectara_agentic-0.4.2.dist-info/RECORD +54 -0
- vectara_agentic/agent_core/utils/prompt_formatting.py +0 -56
- vectara_agentic-0.4.0.dist-info/RECORD +0 -50
- {vectara_agentic-0.4.0.dist-info → vectara_agentic-0.4.2.dist-info}/WHEEL +0 -0
- {vectara_agentic-0.4.0.dist-info → vectara_agentic-0.4.2.dist-info}/licenses/LICENSE +0 -0
- {vectara_agentic-0.4.0.dist-info → vectara_agentic-0.4.2.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,229 @@
|
|
|
1
|
+
# Suppress external dependency warnings before any other imports
|
|
2
|
+
import warnings
|
|
3
|
+
|
|
4
|
+
warnings.simplefilter("ignore", DeprecationWarning)
|
|
5
|
+
|
|
6
|
+
import unittest
|
|
7
|
+
import threading
|
|
8
|
+
|
|
9
|
+
from vectara_agentic.agent import Agent, AgentType
|
|
10
|
+
from vectara_agentic.agent_config import AgentConfig
|
|
11
|
+
from vectara_agentic.types import ModelProvider, AgentConfigType
|
|
12
|
+
from vectara_agentic.tools import ToolsFactory
|
|
13
|
+
from llama_index.core.llms import ChatMessage, MessageRole
|
|
14
|
+
from conftest import mult, STANDARD_TEST_TOPIC, STANDARD_TEST_INSTRUCTIONS
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
ARIZE_LOCK = threading.Lock()
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class TestAgentMemoryConsistency(unittest.TestCase):
|
|
21
|
+
"""Test memory consistency behavior for main/fallback agent switching"""
|
|
22
|
+
|
|
23
|
+
def setUp(self):
|
|
24
|
+
"""Set up test fixtures"""
|
|
25
|
+
self.tools = [ToolsFactory().create_tool(mult)]
|
|
26
|
+
self.topic = STANDARD_TEST_TOPIC
|
|
27
|
+
self.custom_instructions = STANDARD_TEST_INSTRUCTIONS
|
|
28
|
+
|
|
29
|
+
# Main agent config
|
|
30
|
+
self.main_config = AgentConfig(
|
|
31
|
+
agent_type=AgentType.FUNCTION_CALLING,
|
|
32
|
+
main_llm_provider=ModelProvider.ANTHROPIC,
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
# Fallback agent config
|
|
36
|
+
self.fallback_config = AgentConfig(
|
|
37
|
+
agent_type=AgentType.REACT, main_llm_provider=ModelProvider.ANTHROPIC
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
self.session_id = "test-memory-consistency-123"
|
|
41
|
+
|
|
42
|
+
def test_agent_recreation_on_config_switch(self):
|
|
43
|
+
"""Test that agent instances are properly recreated when switching configurations"""
|
|
44
|
+
agent = Agent(
|
|
45
|
+
tools=self.tools,
|
|
46
|
+
topic=self.topic,
|
|
47
|
+
custom_instructions=self.custom_instructions,
|
|
48
|
+
agent_config=self.main_config,
|
|
49
|
+
fallback_agent_config=self.fallback_config,
|
|
50
|
+
session_id=self.session_id,
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
# Load main agent first
|
|
54
|
+
original_main_agent = agent.agent
|
|
55
|
+
self.assertIsNotNone(original_main_agent)
|
|
56
|
+
self.assertEqual(agent.agent_config_type, AgentConfigType.DEFAULT)
|
|
57
|
+
|
|
58
|
+
# Switch to fallback - should clear fallback agent instance for recreation
|
|
59
|
+
agent._switch_agent_config()
|
|
60
|
+
self.assertEqual(agent.agent_config_type, AgentConfigType.FALLBACK)
|
|
61
|
+
self.assertIsNone(agent._fallback_agent) # Should be cleared for recreation
|
|
62
|
+
|
|
63
|
+
# Load fallback agent - should be new instance
|
|
64
|
+
new_fallback_agent = agent.fallback_agent
|
|
65
|
+
self.assertIsNotNone(new_fallback_agent)
|
|
66
|
+
|
|
67
|
+
# Switch back to main - should clear main agent instance for recreation
|
|
68
|
+
agent._switch_agent_config()
|
|
69
|
+
self.assertEqual(agent.agent_config_type, AgentConfigType.DEFAULT)
|
|
70
|
+
self.assertIsNone(agent._agent) # Should be cleared for recreation
|
|
71
|
+
|
|
72
|
+
# Load main agent again - should be new instance
|
|
73
|
+
recreated_main_agent = agent.agent
|
|
74
|
+
self.assertIsNotNone(recreated_main_agent)
|
|
75
|
+
self.assertIsNot(recreated_main_agent, original_main_agent)
|
|
76
|
+
|
|
77
|
+
def test_memory_persistence_across_config_switches(self):
|
|
78
|
+
"""Test that Agent memory persists correctly when switching configurations"""
|
|
79
|
+
agent = Agent(
|
|
80
|
+
tools=self.tools,
|
|
81
|
+
topic=self.topic,
|
|
82
|
+
custom_instructions=self.custom_instructions,
|
|
83
|
+
agent_config=self.main_config,
|
|
84
|
+
fallback_agent_config=self.fallback_config,
|
|
85
|
+
session_id=self.session_id,
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
# Add initial memory
|
|
89
|
+
initial_messages = [
|
|
90
|
+
ChatMessage(role=MessageRole.USER, content="Initial question"),
|
|
91
|
+
ChatMessage(role=MessageRole.ASSISTANT, content="Initial response"),
|
|
92
|
+
]
|
|
93
|
+
agent.memory.put_messages(initial_messages)
|
|
94
|
+
|
|
95
|
+
# Verify initial memory
|
|
96
|
+
self.assertEqual(len(agent.memory.get()), 2)
|
|
97
|
+
self.assertEqual(agent.memory.get()[0].content, "Initial question")
|
|
98
|
+
|
|
99
|
+
# Switch to fallback configuration
|
|
100
|
+
agent._switch_agent_config()
|
|
101
|
+
self.assertEqual(agent.agent_config_type, AgentConfigType.FALLBACK)
|
|
102
|
+
|
|
103
|
+
# Memory should persist at the Agent level
|
|
104
|
+
self.assertEqual(len(agent.memory.get()), 2)
|
|
105
|
+
self.assertEqual(agent.memory.get()[0].content, "Initial question")
|
|
106
|
+
|
|
107
|
+
# Add more memory while in fallback mode
|
|
108
|
+
fallback_messages = [
|
|
109
|
+
ChatMessage(role=MessageRole.USER, content="Fallback question"),
|
|
110
|
+
ChatMessage(role=MessageRole.ASSISTANT, content="Fallback response"),
|
|
111
|
+
]
|
|
112
|
+
agent.memory.put_messages(fallback_messages)
|
|
113
|
+
|
|
114
|
+
# Verify combined memory
|
|
115
|
+
self.assertEqual(len(agent.memory.get()), 4)
|
|
116
|
+
self.assertEqual(agent.memory.get()[2].content, "Fallback question")
|
|
117
|
+
|
|
118
|
+
# Switch back to main configuration
|
|
119
|
+
agent._switch_agent_config()
|
|
120
|
+
self.assertEqual(agent.agent_config_type, AgentConfigType.DEFAULT)
|
|
121
|
+
|
|
122
|
+
# All memory should still be present
|
|
123
|
+
self.assertEqual(len(agent.memory.get()), 4)
|
|
124
|
+
self.assertEqual(agent.memory.get()[0].content, "Initial question")
|
|
125
|
+
self.assertEqual(agent.memory.get()[2].content, "Fallback question")
|
|
126
|
+
|
|
127
|
+
def test_clear_memory_resets_agent_instances(self):
|
|
128
|
+
"""Test that clearing memory properly resets agent instances"""
|
|
129
|
+
agent = Agent(
|
|
130
|
+
tools=self.tools,
|
|
131
|
+
topic=self.topic,
|
|
132
|
+
custom_instructions=self.custom_instructions,
|
|
133
|
+
agent_config=self.main_config,
|
|
134
|
+
fallback_agent_config=self.fallback_config,
|
|
135
|
+
session_id=self.session_id,
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
# Add memory and load both agents
|
|
139
|
+
test_messages = [
|
|
140
|
+
ChatMessage(role=MessageRole.USER, content="Test question"),
|
|
141
|
+
ChatMessage(role=MessageRole.ASSISTANT, content="Test response"),
|
|
142
|
+
]
|
|
143
|
+
agent.memory.put_messages(test_messages)
|
|
144
|
+
|
|
145
|
+
# Load both agents
|
|
146
|
+
_ = agent.agent
|
|
147
|
+
_ = agent.fallback_agent
|
|
148
|
+
|
|
149
|
+
# Verify memory exists
|
|
150
|
+
self.assertEqual(len(agent.memory.get()), 2)
|
|
151
|
+
|
|
152
|
+
# Clear memory
|
|
153
|
+
agent.clear_memory()
|
|
154
|
+
|
|
155
|
+
# Verify memory is cleared
|
|
156
|
+
self.assertEqual(len(agent.memory.get()), 0)
|
|
157
|
+
|
|
158
|
+
# Verify agent instances were reset
|
|
159
|
+
self.assertIsNone(agent._agent)
|
|
160
|
+
self.assertIsNone(agent._fallback_agent)
|
|
161
|
+
|
|
162
|
+
def test_session_id_consistency(self):
|
|
163
|
+
"""Test that session_id remains consistent throughout agent lifecycle"""
|
|
164
|
+
agent = Agent(
|
|
165
|
+
tools=self.tools,
|
|
166
|
+
topic=self.topic,
|
|
167
|
+
custom_instructions=self.custom_instructions,
|
|
168
|
+
agent_config=self.main_config,
|
|
169
|
+
fallback_agent_config=self.fallback_config,
|
|
170
|
+
session_id=self.session_id,
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
# Verify initial session_id
|
|
174
|
+
self.assertEqual(agent.session_id, self.session_id)
|
|
175
|
+
self.assertEqual(agent.memory.chat_store_key, self.session_id)
|
|
176
|
+
|
|
177
|
+
# Switch configurations multiple times
|
|
178
|
+
agent._switch_agent_config()
|
|
179
|
+
self.assertEqual(agent.session_id, self.session_id)
|
|
180
|
+
self.assertEqual(agent.memory.chat_store_key, self.session_id)
|
|
181
|
+
|
|
182
|
+
agent._switch_agent_config()
|
|
183
|
+
self.assertEqual(agent.session_id, self.session_id)
|
|
184
|
+
self.assertEqual(agent.memory.chat_store_key, self.session_id)
|
|
185
|
+
|
|
186
|
+
# Clear memory
|
|
187
|
+
agent.clear_memory()
|
|
188
|
+
self.assertEqual(agent.session_id, self.session_id)
|
|
189
|
+
self.assertEqual(agent.memory.chat_store_key, self.session_id)
|
|
190
|
+
|
|
191
|
+
def test_serialization_preserves_consistency(self):
|
|
192
|
+
"""Test that serialization/deserialization preserves memory consistency behavior"""
|
|
193
|
+
agent = Agent(
|
|
194
|
+
tools=self.tools,
|
|
195
|
+
topic=self.topic,
|
|
196
|
+
custom_instructions=self.custom_instructions,
|
|
197
|
+
agent_config=self.main_config,
|
|
198
|
+
fallback_agent_config=self.fallback_config,
|
|
199
|
+
session_id=self.session_id,
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
# Add memory and switch configurations
|
|
203
|
+
test_messages = [
|
|
204
|
+
ChatMessage(role=MessageRole.USER, content="Serialization test"),
|
|
205
|
+
ChatMessage(role=MessageRole.ASSISTANT, content="Serialization response"),
|
|
206
|
+
]
|
|
207
|
+
agent.memory.put_messages(test_messages)
|
|
208
|
+
agent._switch_agent_config() # Switch to fallback
|
|
209
|
+
|
|
210
|
+
# Serialize and deserialize
|
|
211
|
+
serialized_data = agent.dumps()
|
|
212
|
+
restored_agent = Agent.loads(serialized_data)
|
|
213
|
+
|
|
214
|
+
# Verify restored agent has same memory (config type resets to DEFAULT on deserialization)
|
|
215
|
+
self.assertEqual(restored_agent.session_id, self.session_id)
|
|
216
|
+
self.assertEqual(len(restored_agent.memory.get()), 2)
|
|
217
|
+
self.assertEqual(restored_agent.memory.get()[0].content, "Serialization test")
|
|
218
|
+
self.assertEqual(
|
|
219
|
+
restored_agent.agent_config_type, AgentConfigType.DEFAULT
|
|
220
|
+
) # Resets to default
|
|
221
|
+
|
|
222
|
+
# Verify memory consistency behavior is preserved
|
|
223
|
+
restored_agent._switch_agent_config() # Switch to fallback
|
|
224
|
+
self.assertEqual(restored_agent.agent_config_type, AgentConfigType.FALLBACK)
|
|
225
|
+
self.assertEqual(len(restored_agent.memory.get()), 2) # Memory should persist
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
if __name__ == "__main__":
|
|
229
|
+
unittest.main()
|
tests/test_agent_type.py
CHANGED
tests/test_bedrock.py
CHANGED
|
@@ -3,43 +3,58 @@ import warnings
|
|
|
3
3
|
warnings.simplefilter("ignore", DeprecationWarning)
|
|
4
4
|
|
|
5
5
|
import unittest
|
|
6
|
+
import threading
|
|
6
7
|
|
|
7
|
-
from vectara_agentic.agent import Agent
|
|
8
|
-
from vectara_agentic.agent_config import AgentConfig
|
|
8
|
+
from vectara_agentic.agent import Agent
|
|
9
9
|
from vectara_agentic.tools import ToolsFactory
|
|
10
|
-
from vectara_agentic.types import ModelProvider
|
|
11
10
|
|
|
12
11
|
import nest_asyncio
|
|
13
12
|
nest_asyncio.apply()
|
|
14
13
|
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
14
|
+
from conftest import mult, fc_config_bedrock, STANDARD_TEST_TOPIC, STANDARD_TEST_INSTRUCTIONS
|
|
15
|
+
|
|
16
|
+
ARIZE_LOCK = threading.Lock()
|
|
17
|
+
|
|
18
|
+
class TestBedrock(unittest.IsolatedAsyncioTestCase):
|
|
19
|
+
|
|
20
|
+
async def test_multiturn(self):
|
|
21
|
+
with ARIZE_LOCK:
|
|
22
|
+
tools = [ToolsFactory().create_tool(mult)]
|
|
23
|
+
agent = Agent(
|
|
24
|
+
tools=tools,
|
|
25
|
+
topic=STANDARD_TEST_TOPIC,
|
|
26
|
+
custom_instructions=STANDARD_TEST_INSTRUCTIONS,
|
|
27
|
+
agent_config=fc_config_bedrock,
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
# First calculation: 5 * 10 = 50
|
|
31
|
+
stream1 = await agent.astream_chat(
|
|
32
|
+
"What is 5 times 10. Only give the answer, nothing else"
|
|
33
|
+
)
|
|
34
|
+
# Consume the stream
|
|
35
|
+
async for chunk in stream1.async_response_gen():
|
|
36
|
+
pass
|
|
37
|
+
_ = await stream1.aget_response()
|
|
38
|
+
|
|
39
|
+
# Second calculation: 3 * 7 = 21
|
|
40
|
+
stream2 = await agent.astream_chat(
|
|
41
|
+
"what is 3 times 7. Only give the answer, nothing else"
|
|
42
|
+
)
|
|
43
|
+
# Consume the stream
|
|
44
|
+
async for chunk in stream2.async_response_gen():
|
|
45
|
+
pass
|
|
46
|
+
_ = await stream2.aget_response()
|
|
47
|
+
|
|
48
|
+
# Final calculation: 50 * 21 = 1050
|
|
49
|
+
stream3 = await agent.astream_chat(
|
|
50
|
+
"multiply the results of the last two questions. Output only the answer."
|
|
51
|
+
)
|
|
52
|
+
# Consume the stream
|
|
53
|
+
async for chunk in stream3.async_response_gen():
|
|
54
|
+
pass
|
|
55
|
+
response3 = await stream3.aget_response()
|
|
56
|
+
|
|
57
|
+
self.assertEqual(response3.response, "1050")
|
|
43
58
|
|
|
44
59
|
|
|
45
60
|
if __name__ == "__main__":
|
tests/test_fallback.py
CHANGED
|
@@ -54,7 +54,7 @@ class TestFallback(unittest.TestCase):
|
|
|
54
54
|
config = AgentConfig(
|
|
55
55
|
agent_type=AgentType.REACT,
|
|
56
56
|
main_llm_provider=ModelProvider.PRIVATE,
|
|
57
|
-
main_llm_model_name="gpt-
|
|
57
|
+
main_llm_model_name="gpt-4.1-mini",
|
|
58
58
|
private_llm_api_base=f"http://127.0.0.1:{FLASK_PORT}/v1",
|
|
59
59
|
private_llm_api_key="TEST_API_KEY",
|
|
60
60
|
)
|
tests/test_gemini.py
CHANGED
|
@@ -4,15 +4,15 @@ warnings.simplefilter("ignore", DeprecationWarning)
|
|
|
4
4
|
|
|
5
5
|
import unittest
|
|
6
6
|
|
|
7
|
-
from vectara_agentic.agent import Agent
|
|
8
|
-
from vectara_agentic.agent_config import AgentConfig
|
|
9
|
-
from vectara_agentic.types import ModelProvider
|
|
7
|
+
from vectara_agentic.agent import Agent
|
|
10
8
|
from vectara_agentic.tools import ToolsFactory
|
|
11
9
|
|
|
12
10
|
|
|
13
11
|
import nest_asyncio
|
|
14
12
|
nest_asyncio.apply()
|
|
15
13
|
|
|
14
|
+
from conftest import mult, fc_config_gemini, STANDARD_TEST_TOPIC, STANDARD_TEST_INSTRUCTIONS
|
|
15
|
+
|
|
16
16
|
tickers = {
|
|
17
17
|
"C": "Citigroup",
|
|
18
18
|
"COF": "Capital One",
|
|
@@ -34,11 +34,6 @@ tickers = {
|
|
|
34
34
|
years = list(range(2015, 2025))
|
|
35
35
|
|
|
36
36
|
|
|
37
|
-
def mult(x: float, y: float) -> float:
|
|
38
|
-
"Multiply two numbers"
|
|
39
|
-
return x * y
|
|
40
|
-
|
|
41
|
-
|
|
42
37
|
def get_company_info() -> list[str]:
|
|
43
38
|
"""
|
|
44
39
|
Returns a dictionary of companies you can query about. Always check this before using any other tool.
|
|
@@ -56,23 +51,15 @@ def get_valid_years() -> list[str]:
|
|
|
56
51
|
return years
|
|
57
52
|
|
|
58
53
|
|
|
59
|
-
fc_config_gemini = AgentConfig(
|
|
60
|
-
agent_type=AgentType.FUNCTION_CALLING,
|
|
61
|
-
main_llm_provider=ModelProvider.GEMINI,
|
|
62
|
-
tool_llm_provider=ModelProvider.GEMINI,
|
|
63
|
-
)
|
|
64
|
-
|
|
65
54
|
class TestGEMINI(unittest.TestCase):
|
|
66
55
|
def test_gemini(self):
|
|
67
56
|
tools = [ToolsFactory().create_tool(mult)]
|
|
68
|
-
topic = "AI topic"
|
|
69
|
-
instructions = "Always do as your father tells you, if your mother agrees!"
|
|
70
57
|
|
|
71
58
|
agent = Agent(
|
|
72
59
|
agent_config=fc_config_gemini,
|
|
73
60
|
tools=tools,
|
|
74
|
-
topic=
|
|
75
|
-
custom_instructions=
|
|
61
|
+
topic=STANDARD_TEST_TOPIC,
|
|
62
|
+
custom_instructions=STANDARD_TEST_INSTRUCTIONS,
|
|
76
63
|
)
|
|
77
64
|
_ = agent.chat("What is 5 times 10. Only give the answer, nothing else")
|
|
78
65
|
_ = agent.chat("what is 3 times 7. Only give the answer, nothing else")
|
|
@@ -81,14 +68,12 @@ class TestGEMINI(unittest.TestCase):
|
|
|
81
68
|
|
|
82
69
|
def test_gemini_single_prompt(self):
|
|
83
70
|
tools = [ToolsFactory().create_tool(mult)]
|
|
84
|
-
topic = "AI topic"
|
|
85
|
-
instructions = "Always do as your father tells you, if your mother agrees!"
|
|
86
71
|
|
|
87
72
|
agent = Agent(
|
|
88
73
|
agent_config=fc_config_gemini,
|
|
89
74
|
tools=tools,
|
|
90
|
-
topic=
|
|
91
|
-
custom_instructions=
|
|
75
|
+
topic=STANDARD_TEST_TOPIC,
|
|
76
|
+
custom_instructions=STANDARD_TEST_INSTRUCTIONS,
|
|
92
77
|
)
|
|
93
78
|
res = agent.chat("First, multiply 5 by 10. Then, multiply 3 by 7. Finally, multiply the results of the first two calculations.")
|
|
94
79
|
self.assertIn("1050", res.response)
|
tests/test_groq.py
CHANGED
|
@@ -3,43 +3,58 @@ import warnings
|
|
|
3
3
|
warnings.simplefilter("ignore", DeprecationWarning)
|
|
4
4
|
|
|
5
5
|
import unittest
|
|
6
|
+
import threading
|
|
6
7
|
|
|
7
|
-
from vectara_agentic.agent import Agent
|
|
8
|
-
from vectara_agentic.agent_config import AgentConfig
|
|
8
|
+
from vectara_agentic.agent import Agent
|
|
9
9
|
from vectara_agentic.tools import ToolsFactory
|
|
10
|
-
from vectara_agentic.types import ModelProvider
|
|
11
10
|
|
|
12
11
|
import nest_asyncio
|
|
13
12
|
nest_asyncio.apply()
|
|
14
13
|
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
14
|
+
from conftest import mult, fc_config_groq, STANDARD_TEST_TOPIC, STANDARD_TEST_INSTRUCTIONS
|
|
15
|
+
|
|
16
|
+
ARIZE_LOCK = threading.Lock()
|
|
17
|
+
|
|
18
|
+
class TestGROQ(unittest.IsolatedAsyncioTestCase):
|
|
19
|
+
|
|
20
|
+
async def test_multiturn(self):
|
|
21
|
+
with ARIZE_LOCK:
|
|
22
|
+
tools = [ToolsFactory().create_tool(mult)]
|
|
23
|
+
agent = Agent(
|
|
24
|
+
tools=tools,
|
|
25
|
+
topic=STANDARD_TEST_TOPIC,
|
|
26
|
+
custom_instructions=STANDARD_TEST_INSTRUCTIONS,
|
|
27
|
+
agent_config=fc_config_groq,
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
# First calculation: 5 * 10 = 50
|
|
31
|
+
stream1 = await agent.astream_chat(
|
|
32
|
+
"What is 5 times 10. Only give the answer, nothing else"
|
|
33
|
+
)
|
|
34
|
+
# Consume the stream
|
|
35
|
+
async for chunk in stream1.async_response_gen():
|
|
36
|
+
pass
|
|
37
|
+
_ = await stream1.aget_response()
|
|
38
|
+
|
|
39
|
+
# Second calculation: 3 * 7 = 21
|
|
40
|
+
stream2 = await agent.astream_chat(
|
|
41
|
+
"what is 3 times 7. Only give the answer, nothing else"
|
|
42
|
+
)
|
|
43
|
+
# Consume the stream
|
|
44
|
+
async for chunk in stream2.async_response_gen():
|
|
45
|
+
pass
|
|
46
|
+
_ = await stream2.aget_response()
|
|
47
|
+
|
|
48
|
+
# Final calculation: 50 * 21 = 1050
|
|
49
|
+
stream3 = await agent.astream_chat(
|
|
50
|
+
"multiply the results of the last two questions. Output only the answer."
|
|
51
|
+
)
|
|
52
|
+
# Consume the stream
|
|
53
|
+
async for chunk in stream3.async_response_gen():
|
|
54
|
+
pass
|
|
55
|
+
response3 = await stream3.aget_response()
|
|
56
|
+
|
|
57
|
+
self.assertEqual(response3.response, "1050")
|
|
43
58
|
|
|
44
59
|
|
|
45
60
|
if __name__ == "__main__":
|
tests/test_private_llm.py
CHANGED
|
@@ -54,7 +54,7 @@ class TestPrivateLLM(unittest.TestCase):
|
|
|
54
54
|
config = AgentConfig(
|
|
55
55
|
agent_type=AgentType.FUNCTION_CALLING,
|
|
56
56
|
main_llm_provider=ModelProvider.PRIVATE,
|
|
57
|
-
main_llm_model_name="gpt-4.1",
|
|
57
|
+
main_llm_model_name="gpt-4.1-mini",
|
|
58
58
|
private_llm_api_base=f"http://127.0.0.1:{FLASK_PORT}/v1",
|
|
59
59
|
private_llm_api_key="TEST_API_KEY",
|
|
60
60
|
)
|
tests/test_serialization.py
CHANGED
|
@@ -14,8 +14,7 @@ from vectara_agentic.tools import ToolsFactory
|
|
|
14
14
|
from llama_index.core.utilities.sql_wrapper import SQLDatabase
|
|
15
15
|
from sqlalchemy import create_engine
|
|
16
16
|
|
|
17
|
-
|
|
18
|
-
return x * y
|
|
17
|
+
from conftest import mult, STANDARD_TEST_TOPIC, STANDARD_TEST_INSTRUCTIONS
|
|
19
18
|
|
|
20
19
|
|
|
21
20
|
ARIZE_LOCK = threading.Lock()
|
|
@@ -44,12 +43,10 @@ class TestAgentSerialization(unittest.TestCase):
|
|
|
44
43
|
)
|
|
45
44
|
|
|
46
45
|
tools = [ToolsFactory().create_tool(mult)] + ToolsFactory().standard_tools() + db_tools
|
|
47
|
-
topic = "AI topic"
|
|
48
|
-
instructions = "Always do as your father tells you, if your mother agrees!"
|
|
49
46
|
agent = Agent(
|
|
50
47
|
tools=tools,
|
|
51
|
-
topic=
|
|
52
|
-
custom_instructions=
|
|
48
|
+
topic=STANDARD_TEST_TOPIC,
|
|
49
|
+
custom_instructions=STANDARD_TEST_INSTRUCTIONS,
|
|
53
50
|
agent_config=config
|
|
54
51
|
)
|
|
55
52
|
|