vectara-agentic 0.4.2__py3-none-any.whl → 0.4.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of vectara-agentic might be problematic. Click here for more details.
- tests/__init__.py +1 -0
- tests/benchmark_models.py +547 -372
- tests/conftest.py +14 -12
- tests/endpoint.py +9 -5
- tests/run_tests.py +1 -0
- tests/test_agent.py +22 -9
- tests/test_agent_fallback_memory.py +4 -4
- tests/test_agent_memory_consistency.py +4 -4
- tests/test_agent_type.py +2 -0
- tests/test_api_endpoint.py +13 -13
- tests/test_bedrock.py +9 -1
- tests/test_fallback.py +18 -7
- tests/test_gemini.py +14 -40
- tests/test_groq.py +43 -1
- tests/test_openai.py +160 -0
- tests/test_private_llm.py +19 -6
- tests/test_react_error_handling.py +293 -0
- tests/test_react_memory.py +257 -0
- tests/test_react_streaming.py +135 -0
- tests/test_react_workflow_events.py +395 -0
- tests/test_return_direct.py +1 -0
- tests/test_serialization.py +58 -20
- tests/test_session_memory.py +11 -11
- tests/test_streaming.py +0 -44
- tests/test_together.py +75 -1
- tests/test_tools.py +3 -1
- tests/test_vectara_llms.py +2 -2
- tests/test_vhc.py +7 -2
- tests/test_workflow.py +17 -11
- vectara_agentic/_callback.py +79 -21
- vectara_agentic/_version.py +1 -1
- vectara_agentic/agent.py +65 -27
- vectara_agentic/agent_core/serialization.py +5 -9
- vectara_agentic/agent_core/streaming.py +245 -64
- vectara_agentic/agent_core/utils/schemas.py +2 -2
- vectara_agentic/llm_utils.py +64 -15
- vectara_agentic/tools.py +88 -31
- {vectara_agentic-0.4.2.dist-info → vectara_agentic-0.4.4.dist-info}/METADATA +133 -36
- vectara_agentic-0.4.4.dist-info/RECORD +59 -0
- vectara_agentic-0.4.2.dist-info/RECORD +0 -54
- {vectara_agentic-0.4.2.dist-info → vectara_agentic-0.4.4.dist-info}/WHEEL +0 -0
- {vectara_agentic-0.4.2.dist-info → vectara_agentic-0.4.4.dist-info}/licenses/LICENSE +0 -0
- {vectara_agentic-0.4.2.dist-info → vectara_agentic-0.4.4.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,395 @@
|
|
|
1
|
+
# Suppress external dependency warnings before any other imports
|
|
2
|
+
import warnings
|
|
3
|
+
|
|
4
|
+
warnings.simplefilter("ignore", DeprecationWarning)
|
|
5
|
+
|
|
6
|
+
import unittest
|
|
7
|
+
from typing import Dict, Any
|
|
8
|
+
|
|
9
|
+
from vectara_agentic.agent import Agent, AgentStatusType
|
|
10
|
+
from vectara_agentic.tools import ToolsFactory
|
|
11
|
+
|
|
12
|
+
import nest_asyncio
|
|
13
|
+
|
|
14
|
+
nest_asyncio.apply()
|
|
15
|
+
|
|
16
|
+
from conftest import (
|
|
17
|
+
AgentTestMixin,
|
|
18
|
+
react_config_anthropic,
|
|
19
|
+
react_config_gemini,
|
|
20
|
+
react_config_together,
|
|
21
|
+
mult,
|
|
22
|
+
add,
|
|
23
|
+
STANDARD_TEST_TOPIC,
|
|
24
|
+
STANDARD_TEST_INSTRUCTIONS,
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class TestReActWorkflowEvents(unittest.IsolatedAsyncioTestCase, AgentTestMixin):
|
|
29
|
+
"""Test workflow event handling and streaming for ReAct agents."""
|
|
30
|
+
|
|
31
|
+
def setUp(self):
|
|
32
|
+
self.tools = [ToolsFactory().create_tool(mult), ToolsFactory().create_tool(add)]
|
|
33
|
+
self.topic = STANDARD_TEST_TOPIC
|
|
34
|
+
self.instructions = STANDARD_TEST_INSTRUCTIONS
|
|
35
|
+
self.captured_events = []
|
|
36
|
+
|
|
37
|
+
def capture_progress_callback(
|
|
38
|
+
self, status_type: AgentStatusType, msg: Dict[str, Any], event_id: str = None
|
|
39
|
+
):
|
|
40
|
+
"""Capture agent progress events for testing."""
|
|
41
|
+
self.captured_events.append(
|
|
42
|
+
{
|
|
43
|
+
"status_type": status_type,
|
|
44
|
+
"msg": msg,
|
|
45
|
+
"event_id": event_id,
|
|
46
|
+
}
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
async def test_react_workflow_tool_call_events(self):
|
|
50
|
+
"""Test that ReAct workflow generates proper tool call events."""
|
|
51
|
+
agent = Agent(
|
|
52
|
+
agent_config=react_config_anthropic,
|
|
53
|
+
tools=self.tools,
|
|
54
|
+
topic=self.topic,
|
|
55
|
+
custom_instructions=self.instructions,
|
|
56
|
+
agent_progress_callback=self.capture_progress_callback,
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
with self.with_provider_fallback("Anthropic"):
|
|
60
|
+
stream = await agent.astream_chat("Calculate 8 times 9.")
|
|
61
|
+
|
|
62
|
+
# Consume the stream
|
|
63
|
+
async for chunk in stream.async_response_gen():
|
|
64
|
+
pass
|
|
65
|
+
|
|
66
|
+
response = await stream.aget_response()
|
|
67
|
+
self.check_response_and_skip(response, "Anthropic")
|
|
68
|
+
|
|
69
|
+
if response.response and "72" in response.response:
|
|
70
|
+
# Verify we captured tool-related events
|
|
71
|
+
tool_call_events = [
|
|
72
|
+
event
|
|
73
|
+
for event in self.captured_events
|
|
74
|
+
if event["status_type"] == AgentStatusType.TOOL_CALL
|
|
75
|
+
]
|
|
76
|
+
|
|
77
|
+
tool_output_events = [
|
|
78
|
+
event
|
|
79
|
+
for event in self.captured_events
|
|
80
|
+
if event["status_type"] == AgentStatusType.TOOL_OUTPUT
|
|
81
|
+
]
|
|
82
|
+
|
|
83
|
+
# Should have at least one tool call and one tool output
|
|
84
|
+
self.assertGreater(
|
|
85
|
+
len(tool_call_events), 0, "Should capture tool call events"
|
|
86
|
+
)
|
|
87
|
+
self.assertGreater(
|
|
88
|
+
len(tool_output_events), 0, "Should capture tool output events"
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
# Verify tool call event structure
|
|
92
|
+
if tool_call_events:
|
|
93
|
+
tool_call = tool_call_events[0]
|
|
94
|
+
self.assertIn("tool_name", tool_call["msg"])
|
|
95
|
+
self.assertIn("arguments", tool_call["msg"])
|
|
96
|
+
self.assertIsNotNone(tool_call["event_id"])
|
|
97
|
+
|
|
98
|
+
# Verify tool output event structure
|
|
99
|
+
if tool_output_events:
|
|
100
|
+
tool_output = tool_output_events[0]
|
|
101
|
+
self.assertIn("tool_name", tool_output["msg"])
|
|
102
|
+
self.assertIn("content", tool_output["msg"])
|
|
103
|
+
self.assertIsNotNone(tool_output["event_id"])
|
|
104
|
+
|
|
105
|
+
async def test_react_workflow_multi_step_events(self):
|
|
106
|
+
"""Test ReAct workflow events for multi-step reasoning tasks."""
|
|
107
|
+
self.captured_events.clear()
|
|
108
|
+
|
|
109
|
+
agent = Agent(
|
|
110
|
+
agent_config=react_config_anthropic,
|
|
111
|
+
tools=self.tools,
|
|
112
|
+
topic=self.topic,
|
|
113
|
+
custom_instructions=self.instructions,
|
|
114
|
+
agent_progress_callback=self.capture_progress_callback,
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
with self.with_provider_fallback("Anthropic"):
|
|
118
|
+
stream = await agent.astream_chat(
|
|
119
|
+
"First multiply 7 by 6, then add 15 to that result."
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
# Consume the stream
|
|
123
|
+
async for chunk in stream.async_response_gen():
|
|
124
|
+
pass
|
|
125
|
+
|
|
126
|
+
response = await stream.aget_response()
|
|
127
|
+
self.check_response_and_skip(response, "Anthropic")
|
|
128
|
+
|
|
129
|
+
# Should be (7*6)+15 = 42+15 = 57
|
|
130
|
+
if response.response and "57" in response.response:
|
|
131
|
+
# Should have multiple tool calls (multiplication and addition)
|
|
132
|
+
tool_call_events = [
|
|
133
|
+
event
|
|
134
|
+
for event in self.captured_events
|
|
135
|
+
if event["status_type"] == AgentStatusType.TOOL_CALL
|
|
136
|
+
]
|
|
137
|
+
|
|
138
|
+
# Should have at least 2 tool calls (mult and add)
|
|
139
|
+
self.assertGreaterEqual(
|
|
140
|
+
len(tool_call_events),
|
|
141
|
+
1,
|
|
142
|
+
"Should have tool call events for multi-step task",
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
# Verify event IDs are present for events that have them
|
|
146
|
+
# With simplified logic, events without proper IDs are skipped
|
|
147
|
+
events_with_ids = [
|
|
148
|
+
event
|
|
149
|
+
for event in self.captured_events
|
|
150
|
+
if event["event_id"]
|
|
151
|
+
]
|
|
152
|
+
|
|
153
|
+
# At least some events should have IDs
|
|
154
|
+
self.assertGreater(
|
|
155
|
+
len(events_with_ids), 0, "Should have events with proper IDs"
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
async def test_react_workflow_agent_update_events(self):
|
|
159
|
+
"""Test that ReAct workflow generates agent update events."""
|
|
160
|
+
self.captured_events.clear()
|
|
161
|
+
|
|
162
|
+
agent = Agent(
|
|
163
|
+
agent_config=react_config_anthropic,
|
|
164
|
+
tools=self.tools,
|
|
165
|
+
topic=self.topic,
|
|
166
|
+
custom_instructions=self.instructions,
|
|
167
|
+
agent_progress_callback=self.capture_progress_callback,
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
with self.with_provider_fallback("Anthropic"):
|
|
171
|
+
stream = await agent.astream_chat("Calculate 5 times 11.")
|
|
172
|
+
|
|
173
|
+
# Consume the stream
|
|
174
|
+
async for chunk in stream.async_response_gen():
|
|
175
|
+
pass
|
|
176
|
+
|
|
177
|
+
response = await stream.aget_response()
|
|
178
|
+
self.check_response_and_skip(response, "Anthropic")
|
|
179
|
+
|
|
180
|
+
if response.response and "55" in response.response:
|
|
181
|
+
# Look for agent update events
|
|
182
|
+
agent_update_events = [
|
|
183
|
+
event
|
|
184
|
+
for event in self.captured_events
|
|
185
|
+
if event["status_type"] == AgentStatusType.AGENT_UPDATE
|
|
186
|
+
]
|
|
187
|
+
|
|
188
|
+
# ReAct agents should generate some agent update events during workflow
|
|
189
|
+
self.assertGreaterEqual(
|
|
190
|
+
len(agent_update_events),
|
|
191
|
+
0,
|
|
192
|
+
"ReAct workflow should generate agent update events",
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
# Verify structure of agent update events
|
|
196
|
+
for event in agent_update_events:
|
|
197
|
+
self.assertIn("content", event["msg"])
|
|
198
|
+
self.assertIsInstance(event["msg"]["content"], str)
|
|
199
|
+
|
|
200
|
+
async def test_react_workflow_event_ordering(self):
|
|
201
|
+
"""Test that ReAct workflow events are generated in correct order."""
|
|
202
|
+
self.captured_events.clear()
|
|
203
|
+
|
|
204
|
+
agent = Agent(
|
|
205
|
+
agent_config=react_config_anthropic,
|
|
206
|
+
tools=self.tools,
|
|
207
|
+
topic=self.topic,
|
|
208
|
+
custom_instructions=self.instructions,
|
|
209
|
+
agent_progress_callback=self.capture_progress_callback,
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
with self.with_provider_fallback("Anthropic"):
|
|
213
|
+
stream = await agent.astream_chat("Multiply 9 by 4.")
|
|
214
|
+
|
|
215
|
+
# Consume the stream
|
|
216
|
+
async for chunk in stream.async_response_gen():
|
|
217
|
+
pass
|
|
218
|
+
|
|
219
|
+
response = await stream.aget_response()
|
|
220
|
+
self.check_response_and_skip(response, "Anthropic")
|
|
221
|
+
|
|
222
|
+
if response.response and "36" in response.response:
|
|
223
|
+
# Find tool call and tool output events
|
|
224
|
+
tool_events = [
|
|
225
|
+
event
|
|
226
|
+
for event in self.captured_events
|
|
227
|
+
if event["status_type"]
|
|
228
|
+
in [AgentStatusType.TOOL_CALL, AgentStatusType.TOOL_OUTPUT]
|
|
229
|
+
]
|
|
230
|
+
|
|
231
|
+
if len(tool_events) >= 2:
|
|
232
|
+
# Group events by event_id to match calls with outputs
|
|
233
|
+
event_groups = {}
|
|
234
|
+
for event in tool_events:
|
|
235
|
+
event_id = event["event_id"]
|
|
236
|
+
if event_id not in event_groups:
|
|
237
|
+
event_groups[event_id] = []
|
|
238
|
+
event_groups[event_id].append(event)
|
|
239
|
+
|
|
240
|
+
# For each event group, tool call should come before tool output
|
|
241
|
+
for event_id, events in event_groups.items():
|
|
242
|
+
if len(events) >= 2:
|
|
243
|
+
call_events = [
|
|
244
|
+
e
|
|
245
|
+
for e in events
|
|
246
|
+
if e["status_type"] == AgentStatusType.TOOL_CALL
|
|
247
|
+
]
|
|
248
|
+
output_events = [
|
|
249
|
+
e
|
|
250
|
+
for e in events
|
|
251
|
+
if e["status_type"] == AgentStatusType.TOOL_OUTPUT
|
|
252
|
+
]
|
|
253
|
+
|
|
254
|
+
if call_events and output_events:
|
|
255
|
+
# Find indices in original event list
|
|
256
|
+
call_index = self.captured_events.index(call_events[0])
|
|
257
|
+
output_index = self.captured_events.index(
|
|
258
|
+
output_events[0]
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
self.assertLess(
|
|
262
|
+
call_index,
|
|
263
|
+
output_index,
|
|
264
|
+
"Tool call should come before tool output",
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
async def test_react_workflow_event_error_handling(self):
|
|
268
|
+
"""Test ReAct workflow event handling when tools fail."""
|
|
269
|
+
self.captured_events.clear()
|
|
270
|
+
|
|
271
|
+
def failing_tool(x: float) -> float:
|
|
272
|
+
"""A tool that fails with certain inputs."""
|
|
273
|
+
if x == 0:
|
|
274
|
+
raise ValueError("Cannot process zero")
|
|
275
|
+
return x * 10
|
|
276
|
+
|
|
277
|
+
error_tools = [ToolsFactory().create_tool(failing_tool)]
|
|
278
|
+
|
|
279
|
+
agent = Agent(
|
|
280
|
+
agent_config=react_config_anthropic,
|
|
281
|
+
tools=error_tools,
|
|
282
|
+
topic=self.topic,
|
|
283
|
+
custom_instructions=self.instructions,
|
|
284
|
+
agent_progress_callback=self.capture_progress_callback,
|
|
285
|
+
)
|
|
286
|
+
|
|
287
|
+
with self.with_provider_fallback("Anthropic"):
|
|
288
|
+
stream = await agent.astream_chat("Use failing_tool with input 0.")
|
|
289
|
+
|
|
290
|
+
# Consume the stream
|
|
291
|
+
async for chunk in stream.async_response_gen():
|
|
292
|
+
pass
|
|
293
|
+
|
|
294
|
+
response = await stream.aget_response()
|
|
295
|
+
self.check_response_and_skip(response, "Anthropic")
|
|
296
|
+
|
|
297
|
+
# Even with tool errors, we should still capture events
|
|
298
|
+
self.assertGreater(
|
|
299
|
+
len(self.captured_events),
|
|
300
|
+
0,
|
|
301
|
+
"Should capture events even when tools fail",
|
|
302
|
+
)
|
|
303
|
+
|
|
304
|
+
# Look for tool call events
|
|
305
|
+
tool_call_events = [
|
|
306
|
+
event
|
|
307
|
+
for event in self.captured_events
|
|
308
|
+
if event["status_type"] == AgentStatusType.TOOL_CALL
|
|
309
|
+
]
|
|
310
|
+
|
|
311
|
+
self.assertGreater(
|
|
312
|
+
len(tool_call_events),
|
|
313
|
+
0,
|
|
314
|
+
"Should capture tool call events even when tools fail",
|
|
315
|
+
)
|
|
316
|
+
|
|
317
|
+
async def test_react_workflow_event_callback_error_resilience(self):
|
|
318
|
+
"""Test that ReAct workflow continues even if progress callback raises errors."""
|
|
319
|
+
|
|
320
|
+
def failing_callback(
|
|
321
|
+
status_type: AgentStatusType, msg: Dict[str, Any], event_id: str = None
|
|
322
|
+
):
|
|
323
|
+
"""A callback that always raises an error."""
|
|
324
|
+
raise RuntimeError("Callback error")
|
|
325
|
+
|
|
326
|
+
agent = Agent(
|
|
327
|
+
agent_config=react_config_anthropic,
|
|
328
|
+
tools=self.tools,
|
|
329
|
+
topic=self.topic,
|
|
330
|
+
custom_instructions=self.instructions,
|
|
331
|
+
agent_progress_callback=failing_callback,
|
|
332
|
+
)
|
|
333
|
+
|
|
334
|
+
with self.with_provider_fallback("Anthropic"):
|
|
335
|
+
# Even with failing callback, agent should still work
|
|
336
|
+
stream = await agent.astream_chat("Calculate 12 times 3.")
|
|
337
|
+
|
|
338
|
+
# Consume the stream
|
|
339
|
+
async for chunk in stream.async_response_gen():
|
|
340
|
+
pass
|
|
341
|
+
|
|
342
|
+
response = await stream.aget_response()
|
|
343
|
+
self.check_response_and_skip(response, "Anthropic")
|
|
344
|
+
|
|
345
|
+
if response.response and "36" in response.response:
|
|
346
|
+
# Test passed - agent worked despite callback failures
|
|
347
|
+
self.assertTrue(True)
|
|
348
|
+
|
|
349
|
+
async def test_react_workflow_event_consistency_across_providers(self):
|
|
350
|
+
"""Test that ReAct workflow events are consistent across different providers."""
|
|
351
|
+
providers_to_test = [
|
|
352
|
+
("Anthropic", react_config_anthropic),
|
|
353
|
+
("Gemini", react_config_gemini),
|
|
354
|
+
("Together AI", react_config_together),
|
|
355
|
+
]
|
|
356
|
+
|
|
357
|
+
for provider_name, config in providers_to_test:
|
|
358
|
+
with self.subTest(provider=provider_name):
|
|
359
|
+
self.captured_events.clear()
|
|
360
|
+
|
|
361
|
+
agent = Agent(
|
|
362
|
+
agent_config=config,
|
|
363
|
+
tools=self.tools,
|
|
364
|
+
topic=self.topic,
|
|
365
|
+
custom_instructions=self.instructions,
|
|
366
|
+
agent_progress_callback=self.capture_progress_callback,
|
|
367
|
+
)
|
|
368
|
+
|
|
369
|
+
with self.with_provider_fallback(provider_name):
|
|
370
|
+
stream = await agent.astream_chat("Calculate 6 times 8.")
|
|
371
|
+
|
|
372
|
+
# Consume the stream
|
|
373
|
+
async for chunk in stream.async_response_gen():
|
|
374
|
+
pass
|
|
375
|
+
|
|
376
|
+
response = await stream.aget_response()
|
|
377
|
+
self.check_response_and_skip(response, provider_name)
|
|
378
|
+
|
|
379
|
+
if response.response and "48" in response.response:
|
|
380
|
+
# Should have some events
|
|
381
|
+
self.assertGreater(
|
|
382
|
+
len(self.captured_events),
|
|
383
|
+
0,
|
|
384
|
+
f"{provider_name} should generate events",
|
|
385
|
+
)
|
|
386
|
+
|
|
387
|
+
# All events should have proper structure
|
|
388
|
+
for event in self.captured_events:
|
|
389
|
+
self.assertIn("status_type", event)
|
|
390
|
+
self.assertIn("msg", event)
|
|
391
|
+
self.assertIsInstance(event["msg"], dict)
|
|
392
|
+
|
|
393
|
+
|
|
394
|
+
if __name__ == "__main__":
|
|
395
|
+
unittest.main()
|
tests/test_return_direct.py
CHANGED
tests/test_serialization.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
# Suppress external dependency warnings before any other imports
|
|
2
2
|
import warnings
|
|
3
|
+
|
|
3
4
|
warnings.simplefilter("ignore", DeprecationWarning)
|
|
4
5
|
|
|
5
6
|
import unittest
|
|
@@ -19,12 +20,13 @@ from conftest import mult, STANDARD_TEST_TOPIC, STANDARD_TEST_INSTRUCTIONS
|
|
|
19
20
|
|
|
20
21
|
ARIZE_LOCK = threading.Lock()
|
|
21
22
|
|
|
23
|
+
|
|
22
24
|
class TestAgentSerialization(unittest.TestCase):
|
|
23
25
|
|
|
24
26
|
@classmethod
|
|
25
27
|
def tearDown(cls):
|
|
26
28
|
try:
|
|
27
|
-
os.remove(
|
|
29
|
+
os.remove("ev_database.db")
|
|
28
30
|
except FileNotFoundError:
|
|
29
31
|
pass
|
|
30
32
|
|
|
@@ -34,20 +36,24 @@ class TestAgentSerialization(unittest.TestCase):
|
|
|
34
36
|
agent_type=AgentType.REACT,
|
|
35
37
|
main_llm_provider=ModelProvider.ANTHROPIC,
|
|
36
38
|
tool_llm_provider=ModelProvider.TOGETHER,
|
|
37
|
-
observer=ObserverType.ARIZE_PHOENIX
|
|
39
|
+
observer=ObserverType.ARIZE_PHOENIX,
|
|
38
40
|
)
|
|
39
41
|
db_tools = ToolsFactory().database_tools(
|
|
40
|
-
tool_name_prefix
|
|
41
|
-
content_description
|
|
42
|
-
sql_database
|
|
42
|
+
tool_name_prefix="ev",
|
|
43
|
+
content_description="Electric Vehicles in the state of Washington and other population information",
|
|
44
|
+
sql_database=SQLDatabase(create_engine("sqlite:///ev_database.db")),
|
|
43
45
|
)
|
|
44
46
|
|
|
45
|
-
tools =
|
|
47
|
+
tools = (
|
|
48
|
+
[ToolsFactory().create_tool(mult)]
|
|
49
|
+
+ ToolsFactory().standard_tools()
|
|
50
|
+
+ db_tools
|
|
51
|
+
)
|
|
46
52
|
agent = Agent(
|
|
47
53
|
tools=tools,
|
|
48
54
|
topic=STANDARD_TEST_TOPIC,
|
|
49
55
|
custom_instructions=STANDARD_TEST_INSTRUCTIONS,
|
|
50
|
-
agent_config=config
|
|
56
|
+
agent_config=config,
|
|
51
57
|
)
|
|
52
58
|
|
|
53
59
|
agent_reloaded = agent.loads(agent.dumps())
|
|
@@ -57,17 +63,33 @@ class TestAgentSerialization(unittest.TestCase):
|
|
|
57
63
|
self.assertEqual(agent, agent_reloaded)
|
|
58
64
|
self.assertEqual(agent.agent_type, agent_reloaded.agent_type)
|
|
59
65
|
|
|
60
|
-
self.assertEqual(
|
|
61
|
-
|
|
62
|
-
|
|
66
|
+
self.assertEqual(
|
|
67
|
+
agent.agent_config.observer, agent_reloaded.agent_config.observer
|
|
68
|
+
)
|
|
69
|
+
self.assertEqual(
|
|
70
|
+
agent.agent_config.main_llm_provider,
|
|
71
|
+
agent_reloaded.agent_config.main_llm_provider,
|
|
72
|
+
)
|
|
73
|
+
self.assertEqual(
|
|
74
|
+
agent.agent_config.tool_llm_provider,
|
|
75
|
+
agent_reloaded.agent_config.tool_llm_provider,
|
|
76
|
+
)
|
|
63
77
|
|
|
64
78
|
self.assertIsInstance(agent_reloaded, Agent)
|
|
65
79
|
self.assertEqual(agent, agent_reloaded_again)
|
|
66
80
|
self.assertEqual(agent.agent_type, agent_reloaded_again.agent_type)
|
|
67
81
|
|
|
68
|
-
self.assertEqual(
|
|
69
|
-
|
|
70
|
-
|
|
82
|
+
self.assertEqual(
|
|
83
|
+
agent.agent_config.observer, agent_reloaded_again.agent_config.observer
|
|
84
|
+
)
|
|
85
|
+
self.assertEqual(
|
|
86
|
+
agent.agent_config.main_llm_provider,
|
|
87
|
+
agent_reloaded_again.agent_config.main_llm_provider,
|
|
88
|
+
)
|
|
89
|
+
self.assertEqual(
|
|
90
|
+
agent.agent_config.tool_llm_provider,
|
|
91
|
+
agent_reloaded_again.agent_config.tool_llm_provider,
|
|
92
|
+
)
|
|
71
93
|
|
|
72
94
|
def test_serialization_from_corpus(self):
|
|
73
95
|
with ARIZE_LOCK:
|
|
@@ -75,7 +97,7 @@ class TestAgentSerialization(unittest.TestCase):
|
|
|
75
97
|
agent_type=AgentType.REACT,
|
|
76
98
|
main_llm_provider=ModelProvider.ANTHROPIC,
|
|
77
99
|
tool_llm_provider=ModelProvider.TOGETHER,
|
|
78
|
-
observer=ObserverType.ARIZE_PHOENIX
|
|
100
|
+
observer=ObserverType.ARIZE_PHOENIX,
|
|
79
101
|
)
|
|
80
102
|
|
|
81
103
|
agent = Agent.from_corpus(
|
|
@@ -94,17 +116,33 @@ class TestAgentSerialization(unittest.TestCase):
|
|
|
94
116
|
self.assertEqual(agent, agent_reloaded)
|
|
95
117
|
self.assertEqual(agent.agent_type, agent_reloaded.agent_type)
|
|
96
118
|
|
|
97
|
-
self.assertEqual(
|
|
98
|
-
|
|
99
|
-
|
|
119
|
+
self.assertEqual(
|
|
120
|
+
agent.agent_config.observer, agent_reloaded.agent_config.observer
|
|
121
|
+
)
|
|
122
|
+
self.assertEqual(
|
|
123
|
+
agent.agent_config.main_llm_provider,
|
|
124
|
+
agent_reloaded.agent_config.main_llm_provider,
|
|
125
|
+
)
|
|
126
|
+
self.assertEqual(
|
|
127
|
+
agent.agent_config.tool_llm_provider,
|
|
128
|
+
agent_reloaded.agent_config.tool_llm_provider,
|
|
129
|
+
)
|
|
100
130
|
|
|
101
131
|
self.assertIsInstance(agent_reloaded, Agent)
|
|
102
132
|
self.assertEqual(agent, agent_reloaded_again)
|
|
103
133
|
self.assertEqual(agent.agent_type, agent_reloaded_again.agent_type)
|
|
104
134
|
|
|
105
|
-
self.assertEqual(
|
|
106
|
-
|
|
107
|
-
|
|
135
|
+
self.assertEqual(
|
|
136
|
+
agent.agent_config.observer, agent_reloaded_again.agent_config.observer
|
|
137
|
+
)
|
|
138
|
+
self.assertEqual(
|
|
139
|
+
agent.agent_config.main_llm_provider,
|
|
140
|
+
agent_reloaded_again.agent_config.main_llm_provider,
|
|
141
|
+
)
|
|
142
|
+
self.assertEqual(
|
|
143
|
+
agent.agent_config.tool_llm_provider,
|
|
144
|
+
agent_reloaded_again.agent_config.tool_llm_provider,
|
|
145
|
+
)
|
|
108
146
|
|
|
109
147
|
|
|
110
148
|
if __name__ == "__main__":
|
tests/test_session_memory.py
CHANGED
|
@@ -43,8 +43,8 @@ class TestSessionMemoryManagement(unittest.TestCase):
|
|
|
43
43
|
# Verify the agent uses the provided session_id
|
|
44
44
|
self.assertEqual(agent.session_id, custom_session_id)
|
|
45
45
|
|
|
46
|
-
# Verify memory uses the same session_id
|
|
47
|
-
self.assertEqual(agent.memory.
|
|
46
|
+
# Verify memory uses the same session_id
|
|
47
|
+
self.assertEqual(agent.memory.session_id, custom_session_id)
|
|
48
48
|
|
|
49
49
|
def test_agent_init_without_session_id(self):
|
|
50
50
|
"""Test Agent initialization without session_id (auto-generation)"""
|
|
@@ -59,8 +59,8 @@ class TestSessionMemoryManagement(unittest.TestCase):
|
|
|
59
59
|
expected_pattern = f"{self.topic}:{date.today().isoformat()}"
|
|
60
60
|
self.assertEqual(agent.session_id, expected_pattern)
|
|
61
61
|
|
|
62
|
-
# Verify memory uses the same session_id
|
|
63
|
-
self.assertEqual(agent.memory.
|
|
62
|
+
# Verify memory uses the same session_id
|
|
63
|
+
self.assertEqual(agent.memory.session_id, expected_pattern)
|
|
64
64
|
|
|
65
65
|
def test_from_tools_with_session_id(self):
|
|
66
66
|
"""Test Agent.from_tools() with custom session_id"""
|
|
@@ -76,7 +76,7 @@ class TestSessionMemoryManagement(unittest.TestCase):
|
|
|
76
76
|
|
|
77
77
|
# Verify the agent uses the provided session_id
|
|
78
78
|
self.assertEqual(agent.session_id, custom_session_id)
|
|
79
|
-
self.assertEqual(agent.memory.
|
|
79
|
+
self.assertEqual(agent.memory.session_id, custom_session_id)
|
|
80
80
|
|
|
81
81
|
def test_from_tools_without_session_id(self):
|
|
82
82
|
"""Test Agent.from_tools() without session_id (auto-generation)"""
|
|
@@ -90,7 +90,7 @@ class TestSessionMemoryManagement(unittest.TestCase):
|
|
|
90
90
|
# Verify auto-generated session_id
|
|
91
91
|
expected_pattern = f"{self.topic}:{date.today().isoformat()}"
|
|
92
92
|
self.assertEqual(agent.session_id, expected_pattern)
|
|
93
|
-
self.assertEqual(agent.memory.
|
|
93
|
+
self.assertEqual(agent.memory.session_id, expected_pattern)
|
|
94
94
|
|
|
95
95
|
def test_session_id_consistency_across_agents(self):
|
|
96
96
|
"""Test that agents with same session_id have consistent session_id attributes"""
|
|
@@ -118,9 +118,9 @@ class TestSessionMemoryManagement(unittest.TestCase):
|
|
|
118
118
|
self.assertEqual(agent2.session_id, shared_session_id)
|
|
119
119
|
self.assertEqual(agent1.session_id, agent2.session_id)
|
|
120
120
|
|
|
121
|
-
# Verify their memory instances also have the correct session_id
|
|
122
|
-
self.assertEqual(agent1.memory.
|
|
123
|
-
self.assertEqual(agent2.memory.
|
|
121
|
+
# Verify their memory instances also have the correct session_id
|
|
122
|
+
self.assertEqual(agent1.memory.session_id, shared_session_id)
|
|
123
|
+
self.assertEqual(agent2.memory.session_id, shared_session_id)
|
|
124
124
|
|
|
125
125
|
# Note: Each agent gets its own Memory instance (this is expected behavior)
|
|
126
126
|
# In production, memory persistence happens through serialization/deserialization
|
|
@@ -204,7 +204,7 @@ class TestSessionMemoryManagement(unittest.TestCase):
|
|
|
204
204
|
|
|
205
205
|
# Verify session_id is preserved
|
|
206
206
|
self.assertEqual(restored_agent.session_id, custom_session_id)
|
|
207
|
-
self.assertEqual(restored_agent.memory.
|
|
207
|
+
self.assertEqual(restored_agent.memory.session_id, custom_session_id)
|
|
208
208
|
|
|
209
209
|
# Verify memory is preserved
|
|
210
210
|
restored_messages = restored_agent.memory.get()
|
|
@@ -231,7 +231,7 @@ class TestSessionMemoryManagement(unittest.TestCase):
|
|
|
231
231
|
|
|
232
232
|
# Verify session_id is correct
|
|
233
233
|
self.assertEqual(agent.session_id, custom_session_id)
|
|
234
|
-
self.assertEqual(agent.memory.
|
|
234
|
+
self.assertEqual(agent.memory.session_id, custom_session_id)
|
|
235
235
|
|
|
236
236
|
# Verify chat history was loaded into memory
|
|
237
237
|
messages = agent.memory.get()
|
tests/test_streaming.py
CHANGED
|
@@ -4,7 +4,6 @@ import warnings
|
|
|
4
4
|
warnings.simplefilter("ignore", DeprecationWarning)
|
|
5
5
|
|
|
6
6
|
import unittest
|
|
7
|
-
import asyncio
|
|
8
7
|
|
|
9
8
|
from vectara_agentic.agent import Agent
|
|
10
9
|
from vectara_agentic.tools import ToolsFactory
|
|
@@ -14,7 +13,6 @@ import nest_asyncio
|
|
|
14
13
|
nest_asyncio.apply()
|
|
15
14
|
|
|
16
15
|
from conftest import (
|
|
17
|
-
fc_config_openai,
|
|
18
16
|
fc_config_anthropic,
|
|
19
17
|
mult,
|
|
20
18
|
STANDARD_TEST_TOPIC,
|
|
@@ -62,48 +60,6 @@ class TestAgentStreaming(unittest.IsolatedAsyncioTestCase):
|
|
|
62
60
|
|
|
63
61
|
self.assertIn("1050", response3.response)
|
|
64
62
|
|
|
65
|
-
async def test_openai(self):
|
|
66
|
-
tools = [ToolsFactory().create_tool(mult)]
|
|
67
|
-
agent = Agent(
|
|
68
|
-
agent_config=fc_config_openai,
|
|
69
|
-
tools=tools,
|
|
70
|
-
topic=STANDARD_TEST_TOPIC,
|
|
71
|
-
custom_instructions=STANDARD_TEST_INSTRUCTIONS,
|
|
72
|
-
)
|
|
73
|
-
|
|
74
|
-
# First calculation: 5 * 10 = 50
|
|
75
|
-
stream1 = await agent.astream_chat(
|
|
76
|
-
"What is 5 times 10. Only give the answer, nothing else"
|
|
77
|
-
)
|
|
78
|
-
# Consume the stream
|
|
79
|
-
async for chunk in stream1.async_response_gen():
|
|
80
|
-
pass
|
|
81
|
-
_ = await stream1.aget_response()
|
|
82
|
-
|
|
83
|
-
# Second calculation: 3 * 7 = 21
|
|
84
|
-
stream2 = await agent.astream_chat(
|
|
85
|
-
"what is 3 times 7. Only give the answer, nothing else"
|
|
86
|
-
)
|
|
87
|
-
# Consume the stream
|
|
88
|
-
async for chunk in stream2.async_response_gen():
|
|
89
|
-
pass
|
|
90
|
-
_ = await stream2.aget_response()
|
|
91
|
-
|
|
92
|
-
# Final calculation: 50 * 21 = 1050
|
|
93
|
-
stream3 = await agent.astream_chat(
|
|
94
|
-
"multiply the results of the last two multiplications. Only give the answer, nothing else."
|
|
95
|
-
)
|
|
96
|
-
# Consume the stream
|
|
97
|
-
async for chunk in stream3.async_response_gen():
|
|
98
|
-
pass
|
|
99
|
-
response3 = await stream3.aget_response()
|
|
100
|
-
|
|
101
|
-
self.assertIn("1050", response3.response)
|
|
102
|
-
|
|
103
|
-
def test_openai_sync(self):
|
|
104
|
-
"""Synchronous wrapper for the async test"""
|
|
105
|
-
asyncio.run(self.test_openai())
|
|
106
|
-
|
|
107
63
|
|
|
108
64
|
if __name__ == "__main__":
|
|
109
65
|
unittest.main()
|