vectara-agentic 0.4.1__py3-none-any.whl → 0.4.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/__init__.py +1 -0
- tests/benchmark_models.py +1120 -0
- tests/conftest.py +18 -16
- tests/endpoint.py +9 -5
- tests/run_tests.py +3 -0
- tests/test_agent.py +52 -8
- tests/test_agent_type.py +2 -0
- tests/test_api_endpoint.py +13 -13
- tests/test_bedrock.py +9 -1
- tests/test_fallback.py +19 -8
- tests/test_gemini.py +14 -40
- tests/test_groq.py +9 -1
- tests/test_private_llm.py +20 -7
- tests/test_react_error_handling.py +293 -0
- tests/test_react_memory.py +257 -0
- tests/test_react_streaming.py +135 -0
- tests/test_react_workflow_events.py +395 -0
- tests/test_return_direct.py +1 -0
- tests/test_serialization.py +58 -20
- tests/test_together.py +9 -1
- tests/test_tools.py +3 -1
- tests/test_vectara_llms.py +2 -2
- tests/test_vhc.py +7 -2
- tests/test_workflow.py +17 -11
- vectara_agentic/_callback.py +79 -21
- vectara_agentic/_observability.py +19 -0
- vectara_agentic/_version.py +1 -1
- vectara_agentic/agent.py +89 -21
- vectara_agentic/agent_core/factory.py +5 -6
- vectara_agentic/agent_core/prompts.py +3 -4
- vectara_agentic/agent_core/serialization.py +12 -10
- vectara_agentic/agent_core/streaming.py +245 -68
- vectara_agentic/agent_core/utils/schemas.py +2 -2
- vectara_agentic/llm_utils.py +6 -2
- vectara_agentic/sub_query_workflow.py +3 -2
- vectara_agentic/tools.py +0 -19
- {vectara_agentic-0.4.1.dist-info → vectara_agentic-0.4.3.dist-info}/METADATA +156 -61
- vectara_agentic-0.4.3.dist-info/RECORD +58 -0
- vectara_agentic-0.4.1.dist-info/RECORD +0 -53
- {vectara_agentic-0.4.1.dist-info → vectara_agentic-0.4.3.dist-info}/WHEEL +0 -0
- {vectara_agentic-0.4.1.dist-info → vectara_agentic-0.4.3.dist-info}/licenses/LICENSE +0 -0
- {vectara_agentic-0.4.1.dist-info → vectara_agentic-0.4.3.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,395 @@
|
|
|
1
|
+
# Suppress external dependency warnings before any other imports
|
|
2
|
+
import warnings
|
|
3
|
+
|
|
4
|
+
warnings.simplefilter("ignore", DeprecationWarning)
|
|
5
|
+
|
|
6
|
+
import unittest
|
|
7
|
+
from typing import Dict, Any
|
|
8
|
+
|
|
9
|
+
from vectara_agentic.agent import Agent, AgentStatusType
|
|
10
|
+
from vectara_agentic.tools import ToolsFactory
|
|
11
|
+
|
|
12
|
+
import nest_asyncio
|
|
13
|
+
|
|
14
|
+
nest_asyncio.apply()
|
|
15
|
+
|
|
16
|
+
from conftest import (
|
|
17
|
+
AgentTestMixin,
|
|
18
|
+
react_config_anthropic,
|
|
19
|
+
react_config_gemini,
|
|
20
|
+
react_config_together,
|
|
21
|
+
mult,
|
|
22
|
+
add,
|
|
23
|
+
STANDARD_TEST_TOPIC,
|
|
24
|
+
STANDARD_TEST_INSTRUCTIONS,
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class TestReActWorkflowEvents(unittest.IsolatedAsyncioTestCase, AgentTestMixin):
|
|
29
|
+
"""Test workflow event handling and streaming for ReAct agents."""
|
|
30
|
+
|
|
31
|
+
def setUp(self):
|
|
32
|
+
self.tools = [ToolsFactory().create_tool(mult), ToolsFactory().create_tool(add)]
|
|
33
|
+
self.topic = STANDARD_TEST_TOPIC
|
|
34
|
+
self.instructions = STANDARD_TEST_INSTRUCTIONS
|
|
35
|
+
self.captured_events = []
|
|
36
|
+
|
|
37
|
+
def capture_progress_callback(
|
|
38
|
+
self, status_type: AgentStatusType, msg: Dict[str, Any], event_id: str = None
|
|
39
|
+
):
|
|
40
|
+
"""Capture agent progress events for testing."""
|
|
41
|
+
self.captured_events.append(
|
|
42
|
+
{
|
|
43
|
+
"status_type": status_type,
|
|
44
|
+
"msg": msg,
|
|
45
|
+
"event_id": event_id,
|
|
46
|
+
}
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
async def test_react_workflow_tool_call_events(self):
|
|
50
|
+
"""Test that ReAct workflow generates proper tool call events."""
|
|
51
|
+
agent = Agent(
|
|
52
|
+
agent_config=react_config_anthropic,
|
|
53
|
+
tools=self.tools,
|
|
54
|
+
topic=self.topic,
|
|
55
|
+
custom_instructions=self.instructions,
|
|
56
|
+
agent_progress_callback=self.capture_progress_callback,
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
with self.with_provider_fallback("Anthropic"):
|
|
60
|
+
stream = await agent.astream_chat("Calculate 8 times 9.")
|
|
61
|
+
|
|
62
|
+
# Consume the stream
|
|
63
|
+
async for chunk in stream.async_response_gen():
|
|
64
|
+
pass
|
|
65
|
+
|
|
66
|
+
response = await stream.aget_response()
|
|
67
|
+
self.check_response_and_skip(response, "Anthropic")
|
|
68
|
+
|
|
69
|
+
if response.response and "72" in response.response:
|
|
70
|
+
# Verify we captured tool-related events
|
|
71
|
+
tool_call_events = [
|
|
72
|
+
event
|
|
73
|
+
for event in self.captured_events
|
|
74
|
+
if event["status_type"] == AgentStatusType.TOOL_CALL
|
|
75
|
+
]
|
|
76
|
+
|
|
77
|
+
tool_output_events = [
|
|
78
|
+
event
|
|
79
|
+
for event in self.captured_events
|
|
80
|
+
if event["status_type"] == AgentStatusType.TOOL_OUTPUT
|
|
81
|
+
]
|
|
82
|
+
|
|
83
|
+
# Should have at least one tool call and one tool output
|
|
84
|
+
self.assertGreater(
|
|
85
|
+
len(tool_call_events), 0, "Should capture tool call events"
|
|
86
|
+
)
|
|
87
|
+
self.assertGreater(
|
|
88
|
+
len(tool_output_events), 0, "Should capture tool output events"
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
# Verify tool call event structure
|
|
92
|
+
if tool_call_events:
|
|
93
|
+
tool_call = tool_call_events[0]
|
|
94
|
+
self.assertIn("tool_name", tool_call["msg"])
|
|
95
|
+
self.assertIn("arguments", tool_call["msg"])
|
|
96
|
+
self.assertIsNotNone(tool_call["event_id"])
|
|
97
|
+
|
|
98
|
+
# Verify tool output event structure
|
|
99
|
+
if tool_output_events:
|
|
100
|
+
tool_output = tool_output_events[0]
|
|
101
|
+
self.assertIn("tool_name", tool_output["msg"])
|
|
102
|
+
self.assertIn("content", tool_output["msg"])
|
|
103
|
+
self.assertIsNotNone(tool_output["event_id"])
|
|
104
|
+
|
|
105
|
+
async def test_react_workflow_multi_step_events(self):
|
|
106
|
+
"""Test ReAct workflow events for multi-step reasoning tasks."""
|
|
107
|
+
self.captured_events.clear()
|
|
108
|
+
|
|
109
|
+
agent = Agent(
|
|
110
|
+
agent_config=react_config_anthropic,
|
|
111
|
+
tools=self.tools,
|
|
112
|
+
topic=self.topic,
|
|
113
|
+
custom_instructions=self.instructions,
|
|
114
|
+
agent_progress_callback=self.capture_progress_callback,
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
with self.with_provider_fallback("Anthropic"):
|
|
118
|
+
stream = await agent.astream_chat(
|
|
119
|
+
"First multiply 7 by 6, then add 15 to that result."
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
# Consume the stream
|
|
123
|
+
async for chunk in stream.async_response_gen():
|
|
124
|
+
pass
|
|
125
|
+
|
|
126
|
+
response = await stream.aget_response()
|
|
127
|
+
self.check_response_and_skip(response, "Anthropic")
|
|
128
|
+
|
|
129
|
+
# Should be (7*6)+15 = 42+15 = 57
|
|
130
|
+
if response.response and "57" in response.response:
|
|
131
|
+
# Should have multiple tool calls (multiplication and addition)
|
|
132
|
+
tool_call_events = [
|
|
133
|
+
event
|
|
134
|
+
for event in self.captured_events
|
|
135
|
+
if event["status_type"] == AgentStatusType.TOOL_CALL
|
|
136
|
+
]
|
|
137
|
+
|
|
138
|
+
# Should have at least 2 tool calls (mult and add)
|
|
139
|
+
self.assertGreaterEqual(
|
|
140
|
+
len(tool_call_events),
|
|
141
|
+
1,
|
|
142
|
+
"Should have tool call events for multi-step task",
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
# Verify event IDs are present for events that have them
|
|
146
|
+
# With simplified logic, events without proper IDs are skipped
|
|
147
|
+
events_with_ids = [
|
|
148
|
+
event
|
|
149
|
+
for event in self.captured_events
|
|
150
|
+
if event["event_id"]
|
|
151
|
+
]
|
|
152
|
+
|
|
153
|
+
# At least some events should have IDs
|
|
154
|
+
self.assertGreater(
|
|
155
|
+
len(events_with_ids), 0, "Should have events with proper IDs"
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
async def test_react_workflow_agent_update_events(self):
|
|
159
|
+
"""Test that ReAct workflow generates agent update events."""
|
|
160
|
+
self.captured_events.clear()
|
|
161
|
+
|
|
162
|
+
agent = Agent(
|
|
163
|
+
agent_config=react_config_anthropic,
|
|
164
|
+
tools=self.tools,
|
|
165
|
+
topic=self.topic,
|
|
166
|
+
custom_instructions=self.instructions,
|
|
167
|
+
agent_progress_callback=self.capture_progress_callback,
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
with self.with_provider_fallback("Anthropic"):
|
|
171
|
+
stream = await agent.astream_chat("Calculate 5 times 11.")
|
|
172
|
+
|
|
173
|
+
# Consume the stream
|
|
174
|
+
async for chunk in stream.async_response_gen():
|
|
175
|
+
pass
|
|
176
|
+
|
|
177
|
+
response = await stream.aget_response()
|
|
178
|
+
self.check_response_and_skip(response, "Anthropic")
|
|
179
|
+
|
|
180
|
+
if response.response and "55" in response.response:
|
|
181
|
+
# Look for agent update events
|
|
182
|
+
agent_update_events = [
|
|
183
|
+
event
|
|
184
|
+
for event in self.captured_events
|
|
185
|
+
if event["status_type"] == AgentStatusType.AGENT_UPDATE
|
|
186
|
+
]
|
|
187
|
+
|
|
188
|
+
# ReAct agents should generate some agent update events during workflow
|
|
189
|
+
self.assertGreaterEqual(
|
|
190
|
+
len(agent_update_events),
|
|
191
|
+
0,
|
|
192
|
+
"ReAct workflow should generate agent update events",
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
# Verify structure of agent update events
|
|
196
|
+
for event in agent_update_events:
|
|
197
|
+
self.assertIn("content", event["msg"])
|
|
198
|
+
self.assertIsInstance(event["msg"]["content"], str)
|
|
199
|
+
|
|
200
|
+
async def test_react_workflow_event_ordering(self):
|
|
201
|
+
"""Test that ReAct workflow events are generated in correct order."""
|
|
202
|
+
self.captured_events.clear()
|
|
203
|
+
|
|
204
|
+
agent = Agent(
|
|
205
|
+
agent_config=react_config_anthropic,
|
|
206
|
+
tools=self.tools,
|
|
207
|
+
topic=self.topic,
|
|
208
|
+
custom_instructions=self.instructions,
|
|
209
|
+
agent_progress_callback=self.capture_progress_callback,
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
with self.with_provider_fallback("Anthropic"):
|
|
213
|
+
stream = await agent.astream_chat("Multiply 9 by 4.")
|
|
214
|
+
|
|
215
|
+
# Consume the stream
|
|
216
|
+
async for chunk in stream.async_response_gen():
|
|
217
|
+
pass
|
|
218
|
+
|
|
219
|
+
response = await stream.aget_response()
|
|
220
|
+
self.check_response_and_skip(response, "Anthropic")
|
|
221
|
+
|
|
222
|
+
if response.response and "36" in response.response:
|
|
223
|
+
# Find tool call and tool output events
|
|
224
|
+
tool_events = [
|
|
225
|
+
event
|
|
226
|
+
for event in self.captured_events
|
|
227
|
+
if event["status_type"]
|
|
228
|
+
in [AgentStatusType.TOOL_CALL, AgentStatusType.TOOL_OUTPUT]
|
|
229
|
+
]
|
|
230
|
+
|
|
231
|
+
if len(tool_events) >= 2:
|
|
232
|
+
# Group events by event_id to match calls with outputs
|
|
233
|
+
event_groups = {}
|
|
234
|
+
for event in tool_events:
|
|
235
|
+
event_id = event["event_id"]
|
|
236
|
+
if event_id not in event_groups:
|
|
237
|
+
event_groups[event_id] = []
|
|
238
|
+
event_groups[event_id].append(event)
|
|
239
|
+
|
|
240
|
+
# For each event group, tool call should come before tool output
|
|
241
|
+
for event_id, events in event_groups.items():
|
|
242
|
+
if len(events) >= 2:
|
|
243
|
+
call_events = [
|
|
244
|
+
e
|
|
245
|
+
for e in events
|
|
246
|
+
if e["status_type"] == AgentStatusType.TOOL_CALL
|
|
247
|
+
]
|
|
248
|
+
output_events = [
|
|
249
|
+
e
|
|
250
|
+
for e in events
|
|
251
|
+
if e["status_type"] == AgentStatusType.TOOL_OUTPUT
|
|
252
|
+
]
|
|
253
|
+
|
|
254
|
+
if call_events and output_events:
|
|
255
|
+
# Find indices in original event list
|
|
256
|
+
call_index = self.captured_events.index(call_events[0])
|
|
257
|
+
output_index = self.captured_events.index(
|
|
258
|
+
output_events[0]
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
self.assertLess(
|
|
262
|
+
call_index,
|
|
263
|
+
output_index,
|
|
264
|
+
"Tool call should come before tool output",
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
async def test_react_workflow_event_error_handling(self):
|
|
268
|
+
"""Test ReAct workflow event handling when tools fail."""
|
|
269
|
+
self.captured_events.clear()
|
|
270
|
+
|
|
271
|
+
def failing_tool(x: float) -> float:
|
|
272
|
+
"""A tool that fails with certain inputs."""
|
|
273
|
+
if x == 0:
|
|
274
|
+
raise ValueError("Cannot process zero")
|
|
275
|
+
return x * 10
|
|
276
|
+
|
|
277
|
+
error_tools = [ToolsFactory().create_tool(failing_tool)]
|
|
278
|
+
|
|
279
|
+
agent = Agent(
|
|
280
|
+
agent_config=react_config_anthropic,
|
|
281
|
+
tools=error_tools,
|
|
282
|
+
topic=self.topic,
|
|
283
|
+
custom_instructions=self.instructions,
|
|
284
|
+
agent_progress_callback=self.capture_progress_callback,
|
|
285
|
+
)
|
|
286
|
+
|
|
287
|
+
with self.with_provider_fallback("Anthropic"):
|
|
288
|
+
stream = await agent.astream_chat("Use failing_tool with input 0.")
|
|
289
|
+
|
|
290
|
+
# Consume the stream
|
|
291
|
+
async for chunk in stream.async_response_gen():
|
|
292
|
+
pass
|
|
293
|
+
|
|
294
|
+
response = await stream.aget_response()
|
|
295
|
+
self.check_response_and_skip(response, "Anthropic")
|
|
296
|
+
|
|
297
|
+
# Even with tool errors, we should still capture events
|
|
298
|
+
self.assertGreater(
|
|
299
|
+
len(self.captured_events),
|
|
300
|
+
0,
|
|
301
|
+
"Should capture events even when tools fail",
|
|
302
|
+
)
|
|
303
|
+
|
|
304
|
+
# Look for tool call events
|
|
305
|
+
tool_call_events = [
|
|
306
|
+
event
|
|
307
|
+
for event in self.captured_events
|
|
308
|
+
if event["status_type"] == AgentStatusType.TOOL_CALL
|
|
309
|
+
]
|
|
310
|
+
|
|
311
|
+
self.assertGreater(
|
|
312
|
+
len(tool_call_events),
|
|
313
|
+
0,
|
|
314
|
+
"Should capture tool call events even when tools fail",
|
|
315
|
+
)
|
|
316
|
+
|
|
317
|
+
async def test_react_workflow_event_callback_error_resilience(self):
|
|
318
|
+
"""Test that ReAct workflow continues even if progress callback raises errors."""
|
|
319
|
+
|
|
320
|
+
def failing_callback(
|
|
321
|
+
status_type: AgentStatusType, msg: Dict[str, Any], event_id: str = None
|
|
322
|
+
):
|
|
323
|
+
"""A callback that always raises an error."""
|
|
324
|
+
raise RuntimeError("Callback error")
|
|
325
|
+
|
|
326
|
+
agent = Agent(
|
|
327
|
+
agent_config=react_config_anthropic,
|
|
328
|
+
tools=self.tools,
|
|
329
|
+
topic=self.topic,
|
|
330
|
+
custom_instructions=self.instructions,
|
|
331
|
+
agent_progress_callback=failing_callback,
|
|
332
|
+
)
|
|
333
|
+
|
|
334
|
+
with self.with_provider_fallback("Anthropic"):
|
|
335
|
+
# Even with failing callback, agent should still work
|
|
336
|
+
stream = await agent.astream_chat("Calculate 12 times 3.")
|
|
337
|
+
|
|
338
|
+
# Consume the stream
|
|
339
|
+
async for chunk in stream.async_response_gen():
|
|
340
|
+
pass
|
|
341
|
+
|
|
342
|
+
response = await stream.aget_response()
|
|
343
|
+
self.check_response_and_skip(response, "Anthropic")
|
|
344
|
+
|
|
345
|
+
if response.response and "36" in response.response:
|
|
346
|
+
# Test passed - agent worked despite callback failures
|
|
347
|
+
self.assertTrue(True)
|
|
348
|
+
|
|
349
|
+
async def test_react_workflow_event_consistency_across_providers(self):
|
|
350
|
+
"""Test that ReAct workflow events are consistent across different providers."""
|
|
351
|
+
providers_to_test = [
|
|
352
|
+
("Anthropic", react_config_anthropic),
|
|
353
|
+
("Gemini", react_config_gemini),
|
|
354
|
+
("Together AI", react_config_together),
|
|
355
|
+
]
|
|
356
|
+
|
|
357
|
+
for provider_name, config in providers_to_test:
|
|
358
|
+
with self.subTest(provider=provider_name):
|
|
359
|
+
self.captured_events.clear()
|
|
360
|
+
|
|
361
|
+
agent = Agent(
|
|
362
|
+
agent_config=config,
|
|
363
|
+
tools=self.tools,
|
|
364
|
+
topic=self.topic,
|
|
365
|
+
custom_instructions=self.instructions,
|
|
366
|
+
agent_progress_callback=self.capture_progress_callback,
|
|
367
|
+
)
|
|
368
|
+
|
|
369
|
+
with self.with_provider_fallback(provider_name):
|
|
370
|
+
stream = await agent.astream_chat("Calculate 6 times 8.")
|
|
371
|
+
|
|
372
|
+
# Consume the stream
|
|
373
|
+
async for chunk in stream.async_response_gen():
|
|
374
|
+
pass
|
|
375
|
+
|
|
376
|
+
response = await stream.aget_response()
|
|
377
|
+
self.check_response_and_skip(response, provider_name)
|
|
378
|
+
|
|
379
|
+
if response.response and "48" in response.response:
|
|
380
|
+
# Should have some events
|
|
381
|
+
self.assertGreater(
|
|
382
|
+
len(self.captured_events),
|
|
383
|
+
0,
|
|
384
|
+
f"{provider_name} should generate events",
|
|
385
|
+
)
|
|
386
|
+
|
|
387
|
+
# All events should have proper structure
|
|
388
|
+
for event in self.captured_events:
|
|
389
|
+
self.assertIn("status_type", event)
|
|
390
|
+
self.assertIn("msg", event)
|
|
391
|
+
self.assertIsInstance(event["msg"], dict)
|
|
392
|
+
|
|
393
|
+
|
|
394
|
+
if __name__ == "__main__":
|
|
395
|
+
unittest.main()
|
tests/test_return_direct.py
CHANGED
tests/test_serialization.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
# Suppress external dependency warnings before any other imports
|
|
2
2
|
import warnings
|
|
3
|
+
|
|
3
4
|
warnings.simplefilter("ignore", DeprecationWarning)
|
|
4
5
|
|
|
5
6
|
import unittest
|
|
@@ -19,12 +20,13 @@ from conftest import mult, STANDARD_TEST_TOPIC, STANDARD_TEST_INSTRUCTIONS
|
|
|
19
20
|
|
|
20
21
|
ARIZE_LOCK = threading.Lock()
|
|
21
22
|
|
|
23
|
+
|
|
22
24
|
class TestAgentSerialization(unittest.TestCase):
|
|
23
25
|
|
|
24
26
|
@classmethod
|
|
25
27
|
def tearDown(cls):
|
|
26
28
|
try:
|
|
27
|
-
os.remove(
|
|
29
|
+
os.remove("ev_database.db")
|
|
28
30
|
except FileNotFoundError:
|
|
29
31
|
pass
|
|
30
32
|
|
|
@@ -34,20 +36,24 @@ class TestAgentSerialization(unittest.TestCase):
|
|
|
34
36
|
agent_type=AgentType.REACT,
|
|
35
37
|
main_llm_provider=ModelProvider.ANTHROPIC,
|
|
36
38
|
tool_llm_provider=ModelProvider.TOGETHER,
|
|
37
|
-
observer=ObserverType.ARIZE_PHOENIX
|
|
39
|
+
observer=ObserverType.ARIZE_PHOENIX,
|
|
38
40
|
)
|
|
39
41
|
db_tools = ToolsFactory().database_tools(
|
|
40
|
-
tool_name_prefix
|
|
41
|
-
content_description
|
|
42
|
-
sql_database
|
|
42
|
+
tool_name_prefix="ev",
|
|
43
|
+
content_description="Electric Vehicles in the state of Washington and other population information",
|
|
44
|
+
sql_database=SQLDatabase(create_engine("sqlite:///ev_database.db")),
|
|
43
45
|
)
|
|
44
46
|
|
|
45
|
-
tools =
|
|
47
|
+
tools = (
|
|
48
|
+
[ToolsFactory().create_tool(mult)]
|
|
49
|
+
+ ToolsFactory().standard_tools()
|
|
50
|
+
+ db_tools
|
|
51
|
+
)
|
|
46
52
|
agent = Agent(
|
|
47
53
|
tools=tools,
|
|
48
54
|
topic=STANDARD_TEST_TOPIC,
|
|
49
55
|
custom_instructions=STANDARD_TEST_INSTRUCTIONS,
|
|
50
|
-
agent_config=config
|
|
56
|
+
agent_config=config,
|
|
51
57
|
)
|
|
52
58
|
|
|
53
59
|
agent_reloaded = agent.loads(agent.dumps())
|
|
@@ -57,17 +63,33 @@ class TestAgentSerialization(unittest.TestCase):
|
|
|
57
63
|
self.assertEqual(agent, agent_reloaded)
|
|
58
64
|
self.assertEqual(agent.agent_type, agent_reloaded.agent_type)
|
|
59
65
|
|
|
60
|
-
self.assertEqual(
|
|
61
|
-
|
|
62
|
-
|
|
66
|
+
self.assertEqual(
|
|
67
|
+
agent.agent_config.observer, agent_reloaded.agent_config.observer
|
|
68
|
+
)
|
|
69
|
+
self.assertEqual(
|
|
70
|
+
agent.agent_config.main_llm_provider,
|
|
71
|
+
agent_reloaded.agent_config.main_llm_provider,
|
|
72
|
+
)
|
|
73
|
+
self.assertEqual(
|
|
74
|
+
agent.agent_config.tool_llm_provider,
|
|
75
|
+
agent_reloaded.agent_config.tool_llm_provider,
|
|
76
|
+
)
|
|
63
77
|
|
|
64
78
|
self.assertIsInstance(agent_reloaded, Agent)
|
|
65
79
|
self.assertEqual(agent, agent_reloaded_again)
|
|
66
80
|
self.assertEqual(agent.agent_type, agent_reloaded_again.agent_type)
|
|
67
81
|
|
|
68
|
-
self.assertEqual(
|
|
69
|
-
|
|
70
|
-
|
|
82
|
+
self.assertEqual(
|
|
83
|
+
agent.agent_config.observer, agent_reloaded_again.agent_config.observer
|
|
84
|
+
)
|
|
85
|
+
self.assertEqual(
|
|
86
|
+
agent.agent_config.main_llm_provider,
|
|
87
|
+
agent_reloaded_again.agent_config.main_llm_provider,
|
|
88
|
+
)
|
|
89
|
+
self.assertEqual(
|
|
90
|
+
agent.agent_config.tool_llm_provider,
|
|
91
|
+
agent_reloaded_again.agent_config.tool_llm_provider,
|
|
92
|
+
)
|
|
71
93
|
|
|
72
94
|
def test_serialization_from_corpus(self):
|
|
73
95
|
with ARIZE_LOCK:
|
|
@@ -75,7 +97,7 @@ class TestAgentSerialization(unittest.TestCase):
|
|
|
75
97
|
agent_type=AgentType.REACT,
|
|
76
98
|
main_llm_provider=ModelProvider.ANTHROPIC,
|
|
77
99
|
tool_llm_provider=ModelProvider.TOGETHER,
|
|
78
|
-
observer=ObserverType.ARIZE_PHOENIX
|
|
100
|
+
observer=ObserverType.ARIZE_PHOENIX,
|
|
79
101
|
)
|
|
80
102
|
|
|
81
103
|
agent = Agent.from_corpus(
|
|
@@ -94,17 +116,33 @@ class TestAgentSerialization(unittest.TestCase):
|
|
|
94
116
|
self.assertEqual(agent, agent_reloaded)
|
|
95
117
|
self.assertEqual(agent.agent_type, agent_reloaded.agent_type)
|
|
96
118
|
|
|
97
|
-
self.assertEqual(
|
|
98
|
-
|
|
99
|
-
|
|
119
|
+
self.assertEqual(
|
|
120
|
+
agent.agent_config.observer, agent_reloaded.agent_config.observer
|
|
121
|
+
)
|
|
122
|
+
self.assertEqual(
|
|
123
|
+
agent.agent_config.main_llm_provider,
|
|
124
|
+
agent_reloaded.agent_config.main_llm_provider,
|
|
125
|
+
)
|
|
126
|
+
self.assertEqual(
|
|
127
|
+
agent.agent_config.tool_llm_provider,
|
|
128
|
+
agent_reloaded.agent_config.tool_llm_provider,
|
|
129
|
+
)
|
|
100
130
|
|
|
101
131
|
self.assertIsInstance(agent_reloaded, Agent)
|
|
102
132
|
self.assertEqual(agent, agent_reloaded_again)
|
|
103
133
|
self.assertEqual(agent.agent_type, agent_reloaded_again.agent_type)
|
|
104
134
|
|
|
105
|
-
self.assertEqual(
|
|
106
|
-
|
|
107
|
-
|
|
135
|
+
self.assertEqual(
|
|
136
|
+
agent.agent_config.observer, agent_reloaded_again.agent_config.observer
|
|
137
|
+
)
|
|
138
|
+
self.assertEqual(
|
|
139
|
+
agent.agent_config.main_llm_provider,
|
|
140
|
+
agent_reloaded_again.agent_config.main_llm_provider,
|
|
141
|
+
)
|
|
142
|
+
self.assertEqual(
|
|
143
|
+
agent.agent_config.tool_llm_provider,
|
|
144
|
+
agent_reloaded_again.agent_config.tool_llm_provider,
|
|
145
|
+
)
|
|
108
146
|
|
|
109
147
|
|
|
110
148
|
if __name__ == "__main__":
|
tests/test_together.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
# Suppress external dependency warnings before any other imports
|
|
2
2
|
import warnings
|
|
3
|
+
|
|
3
4
|
warnings.simplefilter("ignore", DeprecationWarning)
|
|
4
5
|
|
|
5
6
|
import unittest
|
|
@@ -9,13 +10,20 @@ from vectara_agentic.agent import Agent
|
|
|
9
10
|
from vectara_agentic.tools import ToolsFactory
|
|
10
11
|
|
|
11
12
|
import nest_asyncio
|
|
13
|
+
|
|
12
14
|
nest_asyncio.apply()
|
|
13
15
|
|
|
14
|
-
from conftest import
|
|
16
|
+
from conftest import (
|
|
17
|
+
fc_config_together,
|
|
18
|
+
mult,
|
|
19
|
+
STANDARD_TEST_TOPIC,
|
|
20
|
+
STANDARD_TEST_INSTRUCTIONS,
|
|
21
|
+
)
|
|
15
22
|
|
|
16
23
|
|
|
17
24
|
ARIZE_LOCK = threading.Lock()
|
|
18
25
|
|
|
26
|
+
|
|
19
27
|
class TestTogether(unittest.IsolatedAsyncioTestCase):
|
|
20
28
|
|
|
21
29
|
async def test_multiturn(self):
|
tests/test_tools.py
CHANGED
|
@@ -361,7 +361,9 @@ class TestToolsPackage(unittest.TestCase):
|
|
|
361
361
|
"query (str): The search query to perform, in the form of a question", doc
|
|
362
362
|
)
|
|
363
363
|
self.assertIn("foo (int): how many foos (e.g., 1, 2, 3)", doc)
|
|
364
|
-
self.assertIn(
|
|
364
|
+
self.assertIn(
|
|
365
|
+
"bar (str | None, default='baz'): what bar to use (e.g., 'x', 'y')", doc
|
|
366
|
+
)
|
|
365
367
|
self.assertIn("Returns:", doc)
|
|
366
368
|
self.assertIn("dict[str, Any]: A dictionary containing the result data.", doc)
|
|
367
369
|
|
tests/test_vectara_llms.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
# Suppress external dependency warnings before any other imports
|
|
2
2
|
import warnings
|
|
3
|
+
|
|
3
4
|
warnings.simplefilter("ignore", DeprecationWarning)
|
|
4
5
|
|
|
5
6
|
import unittest
|
|
@@ -20,8 +21,7 @@ class TestLLMPackage(unittest.TestCase):
|
|
|
20
21
|
|
|
21
22
|
def test_vectara_openai(self):
|
|
22
23
|
vec_factory = VectaraToolFactory(
|
|
23
|
-
vectara_corpus_key=vectara_corpus_key,
|
|
24
|
-
vectara_api_key=vectara_api_key
|
|
24
|
+
vectara_corpus_key=vectara_corpus_key, vectara_api_key=vectara_api_key
|
|
25
25
|
)
|
|
26
26
|
|
|
27
27
|
self.assertEqual(vectara_corpus_key, vec_factory.vectara_corpus_key)
|
tests/test_vhc.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
# Suppress external dependency warnings before any other imports
|
|
2
2
|
import warnings
|
|
3
|
+
|
|
3
4
|
warnings.simplefilter("ignore", DeprecationWarning)
|
|
4
5
|
|
|
5
6
|
import unittest
|
|
@@ -10,6 +11,7 @@ from vectara_agentic.tools import ToolsFactory
|
|
|
10
11
|
from vectara_agentic.types import ModelProvider
|
|
11
12
|
|
|
12
13
|
import nest_asyncio
|
|
14
|
+
|
|
13
15
|
nest_asyncio.apply()
|
|
14
16
|
|
|
15
17
|
statements = [
|
|
@@ -20,6 +22,8 @@ statements = [
|
|
|
20
22
|
"Chocolate is the best ice cream flavor.",
|
|
21
23
|
]
|
|
22
24
|
st_inx = 0
|
|
25
|
+
|
|
26
|
+
|
|
23
27
|
def get_statement() -> str:
|
|
24
28
|
"Generate next statement"
|
|
25
29
|
global st_inx
|
|
@@ -34,7 +38,8 @@ fc_config = AgentConfig(
|
|
|
34
38
|
tool_llm_provider=ModelProvider.OPENAI,
|
|
35
39
|
)
|
|
36
40
|
|
|
37
|
-
vectara_api_key =
|
|
41
|
+
vectara_api_key = "zqt_UXrBcnI2UXINZkrv4g1tQPhzj02vfdtqYJIDiA"
|
|
42
|
+
|
|
38
43
|
|
|
39
44
|
class TestVHC(unittest.TestCase):
|
|
40
45
|
|
|
@@ -59,7 +64,7 @@ class TestVHC(unittest.TestCase):
|
|
|
59
64
|
vhc_corrections = vhc_res.get("corrections", [])
|
|
60
65
|
self.assertTrue(
|
|
61
66
|
len(vhc_corrections) >= 0 and len(vhc_corrections) <= 2,
|
|
62
|
-
"Corrections should be between 0 and 2"
|
|
67
|
+
"Corrections should be between 0 and 2",
|
|
63
68
|
)
|
|
64
69
|
|
|
65
70
|
|