vectara-agentic 0.4.1__py3-none-any.whl → 0.4.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. tests/__init__.py +1 -0
  2. tests/benchmark_models.py +1120 -0
  3. tests/conftest.py +18 -16
  4. tests/endpoint.py +9 -5
  5. tests/run_tests.py +3 -0
  6. tests/test_agent.py +52 -8
  7. tests/test_agent_type.py +2 -0
  8. tests/test_api_endpoint.py +13 -13
  9. tests/test_bedrock.py +9 -1
  10. tests/test_fallback.py +19 -8
  11. tests/test_gemini.py +14 -40
  12. tests/test_groq.py +9 -1
  13. tests/test_private_llm.py +20 -7
  14. tests/test_react_error_handling.py +293 -0
  15. tests/test_react_memory.py +257 -0
  16. tests/test_react_streaming.py +135 -0
  17. tests/test_react_workflow_events.py +395 -0
  18. tests/test_return_direct.py +1 -0
  19. tests/test_serialization.py +58 -20
  20. tests/test_together.py +9 -1
  21. tests/test_tools.py +3 -1
  22. tests/test_vectara_llms.py +2 -2
  23. tests/test_vhc.py +7 -2
  24. tests/test_workflow.py +17 -11
  25. vectara_agentic/_callback.py +79 -21
  26. vectara_agentic/_observability.py +19 -0
  27. vectara_agentic/_version.py +1 -1
  28. vectara_agentic/agent.py +89 -21
  29. vectara_agentic/agent_core/factory.py +5 -6
  30. vectara_agentic/agent_core/prompts.py +3 -4
  31. vectara_agentic/agent_core/serialization.py +12 -10
  32. vectara_agentic/agent_core/streaming.py +245 -68
  33. vectara_agentic/agent_core/utils/schemas.py +2 -2
  34. vectara_agentic/llm_utils.py +6 -2
  35. vectara_agentic/sub_query_workflow.py +3 -2
  36. vectara_agentic/tools.py +0 -19
  37. {vectara_agentic-0.4.1.dist-info → vectara_agentic-0.4.3.dist-info}/METADATA +156 -61
  38. vectara_agentic-0.4.3.dist-info/RECORD +58 -0
  39. vectara_agentic-0.4.1.dist-info/RECORD +0 -53
  40. {vectara_agentic-0.4.1.dist-info → vectara_agentic-0.4.3.dist-info}/WHEEL +0 -0
  41. {vectara_agentic-0.4.1.dist-info → vectara_agentic-0.4.3.dist-info}/licenses/LICENSE +0 -0
  42. {vectara_agentic-0.4.1.dist-info → vectara_agentic-0.4.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,395 @@
1
+ # Suppress external dependency warnings before any other imports
2
+ import warnings
3
+
4
+ warnings.simplefilter("ignore", DeprecationWarning)
5
+
6
+ import unittest
7
+ from typing import Dict, Any
8
+
9
+ from vectara_agentic.agent import Agent, AgentStatusType
10
+ from vectara_agentic.tools import ToolsFactory
11
+
12
+ import nest_asyncio
13
+
14
+ nest_asyncio.apply()
15
+
16
+ from conftest import (
17
+ AgentTestMixin,
18
+ react_config_anthropic,
19
+ react_config_gemini,
20
+ react_config_together,
21
+ mult,
22
+ add,
23
+ STANDARD_TEST_TOPIC,
24
+ STANDARD_TEST_INSTRUCTIONS,
25
+ )
26
+
27
+
28
+ class TestReActWorkflowEvents(unittest.IsolatedAsyncioTestCase, AgentTestMixin):
29
+ """Test workflow event handling and streaming for ReAct agents."""
30
+
31
+ def setUp(self):
32
+ self.tools = [ToolsFactory().create_tool(mult), ToolsFactory().create_tool(add)]
33
+ self.topic = STANDARD_TEST_TOPIC
34
+ self.instructions = STANDARD_TEST_INSTRUCTIONS
35
+ self.captured_events = []
36
+
37
+ def capture_progress_callback(
38
+ self, status_type: AgentStatusType, msg: Dict[str, Any], event_id: str = None
39
+ ):
40
+ """Capture agent progress events for testing."""
41
+ self.captured_events.append(
42
+ {
43
+ "status_type": status_type,
44
+ "msg": msg,
45
+ "event_id": event_id,
46
+ }
47
+ )
48
+
49
+ async def test_react_workflow_tool_call_events(self):
50
+ """Test that ReAct workflow generates proper tool call events."""
51
+ agent = Agent(
52
+ agent_config=react_config_anthropic,
53
+ tools=self.tools,
54
+ topic=self.topic,
55
+ custom_instructions=self.instructions,
56
+ agent_progress_callback=self.capture_progress_callback,
57
+ )
58
+
59
+ with self.with_provider_fallback("Anthropic"):
60
+ stream = await agent.astream_chat("Calculate 8 times 9.")
61
+
62
+ # Consume the stream
63
+ async for chunk in stream.async_response_gen():
64
+ pass
65
+
66
+ response = await stream.aget_response()
67
+ self.check_response_and_skip(response, "Anthropic")
68
+
69
+ if response.response and "72" in response.response:
70
+ # Verify we captured tool-related events
71
+ tool_call_events = [
72
+ event
73
+ for event in self.captured_events
74
+ if event["status_type"] == AgentStatusType.TOOL_CALL
75
+ ]
76
+
77
+ tool_output_events = [
78
+ event
79
+ for event in self.captured_events
80
+ if event["status_type"] == AgentStatusType.TOOL_OUTPUT
81
+ ]
82
+
83
+ # Should have at least one tool call and one tool output
84
+ self.assertGreater(
85
+ len(tool_call_events), 0, "Should capture tool call events"
86
+ )
87
+ self.assertGreater(
88
+ len(tool_output_events), 0, "Should capture tool output events"
89
+ )
90
+
91
+ # Verify tool call event structure
92
+ if tool_call_events:
93
+ tool_call = tool_call_events[0]
94
+ self.assertIn("tool_name", tool_call["msg"])
95
+ self.assertIn("arguments", tool_call["msg"])
96
+ self.assertIsNotNone(tool_call["event_id"])
97
+
98
+ # Verify tool output event structure
99
+ if tool_output_events:
100
+ tool_output = tool_output_events[0]
101
+ self.assertIn("tool_name", tool_output["msg"])
102
+ self.assertIn("content", tool_output["msg"])
103
+ self.assertIsNotNone(tool_output["event_id"])
104
+
105
+ async def test_react_workflow_multi_step_events(self):
106
+ """Test ReAct workflow events for multi-step reasoning tasks."""
107
+ self.captured_events.clear()
108
+
109
+ agent = Agent(
110
+ agent_config=react_config_anthropic,
111
+ tools=self.tools,
112
+ topic=self.topic,
113
+ custom_instructions=self.instructions,
114
+ agent_progress_callback=self.capture_progress_callback,
115
+ )
116
+
117
+ with self.with_provider_fallback("Anthropic"):
118
+ stream = await agent.astream_chat(
119
+ "First multiply 7 by 6, then add 15 to that result."
120
+ )
121
+
122
+ # Consume the stream
123
+ async for chunk in stream.async_response_gen():
124
+ pass
125
+
126
+ response = await stream.aget_response()
127
+ self.check_response_and_skip(response, "Anthropic")
128
+
129
+ # Should be (7*6)+15 = 42+15 = 57
130
+ if response.response and "57" in response.response:
131
+ # Should have multiple tool calls (multiplication and addition)
132
+ tool_call_events = [
133
+ event
134
+ for event in self.captured_events
135
+ if event["status_type"] == AgentStatusType.TOOL_CALL
136
+ ]
137
+
138
+ # Should have at least 2 tool calls (mult and add)
139
+ self.assertGreaterEqual(
140
+ len(tool_call_events),
141
+ 1,
142
+ "Should have tool call events for multi-step task",
143
+ )
144
+
145
+ # Verify event IDs are present for events that have them
146
+ # With simplified logic, events without proper IDs are skipped
147
+ events_with_ids = [
148
+ event
149
+ for event in self.captured_events
150
+ if event["event_id"]
151
+ ]
152
+
153
+ # At least some events should have IDs
154
+ self.assertGreater(
155
+ len(events_with_ids), 0, "Should have events with proper IDs"
156
+ )
157
+
158
+ async def test_react_workflow_agent_update_events(self):
159
+ """Test that ReAct workflow generates agent update events."""
160
+ self.captured_events.clear()
161
+
162
+ agent = Agent(
163
+ agent_config=react_config_anthropic,
164
+ tools=self.tools,
165
+ topic=self.topic,
166
+ custom_instructions=self.instructions,
167
+ agent_progress_callback=self.capture_progress_callback,
168
+ )
169
+
170
+ with self.with_provider_fallback("Anthropic"):
171
+ stream = await agent.astream_chat("Calculate 5 times 11.")
172
+
173
+ # Consume the stream
174
+ async for chunk in stream.async_response_gen():
175
+ pass
176
+
177
+ response = await stream.aget_response()
178
+ self.check_response_and_skip(response, "Anthropic")
179
+
180
+ if response.response and "55" in response.response:
181
+ # Look for agent update events
182
+ agent_update_events = [
183
+ event
184
+ for event in self.captured_events
185
+ if event["status_type"] == AgentStatusType.AGENT_UPDATE
186
+ ]
187
+
188
+ # ReAct agents should generate some agent update events during workflow
189
+ self.assertGreaterEqual(
190
+ len(agent_update_events),
191
+ 0,
192
+ "ReAct workflow should generate agent update events",
193
+ )
194
+
195
+ # Verify structure of agent update events
196
+ for event in agent_update_events:
197
+ self.assertIn("content", event["msg"])
198
+ self.assertIsInstance(event["msg"]["content"], str)
199
+
200
+ async def test_react_workflow_event_ordering(self):
201
+ """Test that ReAct workflow events are generated in correct order."""
202
+ self.captured_events.clear()
203
+
204
+ agent = Agent(
205
+ agent_config=react_config_anthropic,
206
+ tools=self.tools,
207
+ topic=self.topic,
208
+ custom_instructions=self.instructions,
209
+ agent_progress_callback=self.capture_progress_callback,
210
+ )
211
+
212
+ with self.with_provider_fallback("Anthropic"):
213
+ stream = await agent.astream_chat("Multiply 9 by 4.")
214
+
215
+ # Consume the stream
216
+ async for chunk in stream.async_response_gen():
217
+ pass
218
+
219
+ response = await stream.aget_response()
220
+ self.check_response_and_skip(response, "Anthropic")
221
+
222
+ if response.response and "36" in response.response:
223
+ # Find tool call and tool output events
224
+ tool_events = [
225
+ event
226
+ for event in self.captured_events
227
+ if event["status_type"]
228
+ in [AgentStatusType.TOOL_CALL, AgentStatusType.TOOL_OUTPUT]
229
+ ]
230
+
231
+ if len(tool_events) >= 2:
232
+ # Group events by event_id to match calls with outputs
233
+ event_groups = {}
234
+ for event in tool_events:
235
+ event_id = event["event_id"]
236
+ if event_id not in event_groups:
237
+ event_groups[event_id] = []
238
+ event_groups[event_id].append(event)
239
+
240
+ # For each event group, tool call should come before tool output
241
+ for event_id, events in event_groups.items():
242
+ if len(events) >= 2:
243
+ call_events = [
244
+ e
245
+ for e in events
246
+ if e["status_type"] == AgentStatusType.TOOL_CALL
247
+ ]
248
+ output_events = [
249
+ e
250
+ for e in events
251
+ if e["status_type"] == AgentStatusType.TOOL_OUTPUT
252
+ ]
253
+
254
+ if call_events and output_events:
255
+ # Find indices in original event list
256
+ call_index = self.captured_events.index(call_events[0])
257
+ output_index = self.captured_events.index(
258
+ output_events[0]
259
+ )
260
+
261
+ self.assertLess(
262
+ call_index,
263
+ output_index,
264
+ "Tool call should come before tool output",
265
+ )
266
+
267
+ async def test_react_workflow_event_error_handling(self):
268
+ """Test ReAct workflow event handling when tools fail."""
269
+ self.captured_events.clear()
270
+
271
+ def failing_tool(x: float) -> float:
272
+ """A tool that fails with certain inputs."""
273
+ if x == 0:
274
+ raise ValueError("Cannot process zero")
275
+ return x * 10
276
+
277
+ error_tools = [ToolsFactory().create_tool(failing_tool)]
278
+
279
+ agent = Agent(
280
+ agent_config=react_config_anthropic,
281
+ tools=error_tools,
282
+ topic=self.topic,
283
+ custom_instructions=self.instructions,
284
+ agent_progress_callback=self.capture_progress_callback,
285
+ )
286
+
287
+ with self.with_provider_fallback("Anthropic"):
288
+ stream = await agent.astream_chat("Use failing_tool with input 0.")
289
+
290
+ # Consume the stream
291
+ async for chunk in stream.async_response_gen():
292
+ pass
293
+
294
+ response = await stream.aget_response()
295
+ self.check_response_and_skip(response, "Anthropic")
296
+
297
+ # Even with tool errors, we should still capture events
298
+ self.assertGreater(
299
+ len(self.captured_events),
300
+ 0,
301
+ "Should capture events even when tools fail",
302
+ )
303
+
304
+ # Look for tool call events
305
+ tool_call_events = [
306
+ event
307
+ for event in self.captured_events
308
+ if event["status_type"] == AgentStatusType.TOOL_CALL
309
+ ]
310
+
311
+ self.assertGreater(
312
+ len(tool_call_events),
313
+ 0,
314
+ "Should capture tool call events even when tools fail",
315
+ )
316
+
317
+ async def test_react_workflow_event_callback_error_resilience(self):
318
+ """Test that ReAct workflow continues even if progress callback raises errors."""
319
+
320
+ def failing_callback(
321
+ status_type: AgentStatusType, msg: Dict[str, Any], event_id: str = None
322
+ ):
323
+ """A callback that always raises an error."""
324
+ raise RuntimeError("Callback error")
325
+
326
+ agent = Agent(
327
+ agent_config=react_config_anthropic,
328
+ tools=self.tools,
329
+ topic=self.topic,
330
+ custom_instructions=self.instructions,
331
+ agent_progress_callback=failing_callback,
332
+ )
333
+
334
+ with self.with_provider_fallback("Anthropic"):
335
+ # Even with failing callback, agent should still work
336
+ stream = await agent.astream_chat("Calculate 12 times 3.")
337
+
338
+ # Consume the stream
339
+ async for chunk in stream.async_response_gen():
340
+ pass
341
+
342
+ response = await stream.aget_response()
343
+ self.check_response_and_skip(response, "Anthropic")
344
+
345
+ if response.response and "36" in response.response:
346
+ # Test passed - agent worked despite callback failures
347
+ self.assertTrue(True)
348
+
349
+ async def test_react_workflow_event_consistency_across_providers(self):
350
+ """Test that ReAct workflow events are consistent across different providers."""
351
+ providers_to_test = [
352
+ ("Anthropic", react_config_anthropic),
353
+ ("Gemini", react_config_gemini),
354
+ ("Together AI", react_config_together),
355
+ ]
356
+
357
+ for provider_name, config in providers_to_test:
358
+ with self.subTest(provider=provider_name):
359
+ self.captured_events.clear()
360
+
361
+ agent = Agent(
362
+ agent_config=config,
363
+ tools=self.tools,
364
+ topic=self.topic,
365
+ custom_instructions=self.instructions,
366
+ agent_progress_callback=self.capture_progress_callback,
367
+ )
368
+
369
+ with self.with_provider_fallback(provider_name):
370
+ stream = await agent.astream_chat("Calculate 6 times 8.")
371
+
372
+ # Consume the stream
373
+ async for chunk in stream.async_response_gen():
374
+ pass
375
+
376
+ response = await stream.aget_response()
377
+ self.check_response_and_skip(response, provider_name)
378
+
379
+ if response.response and "48" in response.response:
380
+ # Should have some events
381
+ self.assertGreater(
382
+ len(self.captured_events),
383
+ 0,
384
+ f"{provider_name} should generate events",
385
+ )
386
+
387
+ # All events should have proper structure
388
+ for event in self.captured_events:
389
+ self.assertIn("status_type", event)
390
+ self.assertIn("msg", event)
391
+ self.assertIsInstance(event["msg"], dict)
392
+
393
+
394
+ if __name__ == "__main__":
395
+ unittest.main()
@@ -1,5 +1,6 @@
1
1
  # Suppress external dependency warnings before any other imports
2
2
  import warnings
3
+
3
4
  warnings.simplefilter("ignore", DeprecationWarning)
4
5
 
5
6
  import unittest
@@ -1,5 +1,6 @@
1
1
  # Suppress external dependency warnings before any other imports
2
2
  import warnings
3
+
3
4
  warnings.simplefilter("ignore", DeprecationWarning)
4
5
 
5
6
  import unittest
@@ -19,12 +20,13 @@ from conftest import mult, STANDARD_TEST_TOPIC, STANDARD_TEST_INSTRUCTIONS
19
20
 
20
21
  ARIZE_LOCK = threading.Lock()
21
22
 
23
+
22
24
  class TestAgentSerialization(unittest.TestCase):
23
25
 
24
26
  @classmethod
25
27
  def tearDown(cls):
26
28
  try:
27
- os.remove('ev_database.db')
29
+ os.remove("ev_database.db")
28
30
  except FileNotFoundError:
29
31
  pass
30
32
 
@@ -34,20 +36,24 @@ class TestAgentSerialization(unittest.TestCase):
34
36
  agent_type=AgentType.REACT,
35
37
  main_llm_provider=ModelProvider.ANTHROPIC,
36
38
  tool_llm_provider=ModelProvider.TOGETHER,
37
- observer=ObserverType.ARIZE_PHOENIX
39
+ observer=ObserverType.ARIZE_PHOENIX,
38
40
  )
39
41
  db_tools = ToolsFactory().database_tools(
40
- tool_name_prefix = "ev",
41
- content_description = 'Electric Vehicles in the state of Washington and other population information',
42
- sql_database = SQLDatabase(create_engine('sqlite:///ev_database.db')),
42
+ tool_name_prefix="ev",
43
+ content_description="Electric Vehicles in the state of Washington and other population information",
44
+ sql_database=SQLDatabase(create_engine("sqlite:///ev_database.db")),
43
45
  )
44
46
 
45
- tools = [ToolsFactory().create_tool(mult)] + ToolsFactory().standard_tools() + db_tools
47
+ tools = (
48
+ [ToolsFactory().create_tool(mult)]
49
+ + ToolsFactory().standard_tools()
50
+ + db_tools
51
+ )
46
52
  agent = Agent(
47
53
  tools=tools,
48
54
  topic=STANDARD_TEST_TOPIC,
49
55
  custom_instructions=STANDARD_TEST_INSTRUCTIONS,
50
- agent_config=config
56
+ agent_config=config,
51
57
  )
52
58
 
53
59
  agent_reloaded = agent.loads(agent.dumps())
@@ -57,17 +63,33 @@ class TestAgentSerialization(unittest.TestCase):
57
63
  self.assertEqual(agent, agent_reloaded)
58
64
  self.assertEqual(agent.agent_type, agent_reloaded.agent_type)
59
65
 
60
- self.assertEqual(agent.agent_config.observer, agent_reloaded.agent_config.observer)
61
- self.assertEqual(agent.agent_config.main_llm_provider, agent_reloaded.agent_config.main_llm_provider)
62
- self.assertEqual(agent.agent_config.tool_llm_provider, agent_reloaded.agent_config.tool_llm_provider)
66
+ self.assertEqual(
67
+ agent.agent_config.observer, agent_reloaded.agent_config.observer
68
+ )
69
+ self.assertEqual(
70
+ agent.agent_config.main_llm_provider,
71
+ agent_reloaded.agent_config.main_llm_provider,
72
+ )
73
+ self.assertEqual(
74
+ agent.agent_config.tool_llm_provider,
75
+ agent_reloaded.agent_config.tool_llm_provider,
76
+ )
63
77
 
64
78
  self.assertIsInstance(agent_reloaded, Agent)
65
79
  self.assertEqual(agent, agent_reloaded_again)
66
80
  self.assertEqual(agent.agent_type, agent_reloaded_again.agent_type)
67
81
 
68
- self.assertEqual(agent.agent_config.observer, agent_reloaded_again.agent_config.observer)
69
- self.assertEqual(agent.agent_config.main_llm_provider, agent_reloaded_again.agent_config.main_llm_provider)
70
- self.assertEqual(agent.agent_config.tool_llm_provider, agent_reloaded_again.agent_config.tool_llm_provider)
82
+ self.assertEqual(
83
+ agent.agent_config.observer, agent_reloaded_again.agent_config.observer
84
+ )
85
+ self.assertEqual(
86
+ agent.agent_config.main_llm_provider,
87
+ agent_reloaded_again.agent_config.main_llm_provider,
88
+ )
89
+ self.assertEqual(
90
+ agent.agent_config.tool_llm_provider,
91
+ agent_reloaded_again.agent_config.tool_llm_provider,
92
+ )
71
93
 
72
94
  def test_serialization_from_corpus(self):
73
95
  with ARIZE_LOCK:
@@ -75,7 +97,7 @@ class TestAgentSerialization(unittest.TestCase):
75
97
  agent_type=AgentType.REACT,
76
98
  main_llm_provider=ModelProvider.ANTHROPIC,
77
99
  tool_llm_provider=ModelProvider.TOGETHER,
78
- observer=ObserverType.ARIZE_PHOENIX
100
+ observer=ObserverType.ARIZE_PHOENIX,
79
101
  )
80
102
 
81
103
  agent = Agent.from_corpus(
@@ -94,17 +116,33 @@ class TestAgentSerialization(unittest.TestCase):
94
116
  self.assertEqual(agent, agent_reloaded)
95
117
  self.assertEqual(agent.agent_type, agent_reloaded.agent_type)
96
118
 
97
- self.assertEqual(agent.agent_config.observer, agent_reloaded.agent_config.observer)
98
- self.assertEqual(agent.agent_config.main_llm_provider, agent_reloaded.agent_config.main_llm_provider)
99
- self.assertEqual(agent.agent_config.tool_llm_provider, agent_reloaded.agent_config.tool_llm_provider)
119
+ self.assertEqual(
120
+ agent.agent_config.observer, agent_reloaded.agent_config.observer
121
+ )
122
+ self.assertEqual(
123
+ agent.agent_config.main_llm_provider,
124
+ agent_reloaded.agent_config.main_llm_provider,
125
+ )
126
+ self.assertEqual(
127
+ agent.agent_config.tool_llm_provider,
128
+ agent_reloaded.agent_config.tool_llm_provider,
129
+ )
100
130
 
101
131
  self.assertIsInstance(agent_reloaded, Agent)
102
132
  self.assertEqual(agent, agent_reloaded_again)
103
133
  self.assertEqual(agent.agent_type, agent_reloaded_again.agent_type)
104
134
 
105
- self.assertEqual(agent.agent_config.observer, agent_reloaded_again.agent_config.observer)
106
- self.assertEqual(agent.agent_config.main_llm_provider, agent_reloaded_again.agent_config.main_llm_provider)
107
- self.assertEqual(agent.agent_config.tool_llm_provider, agent_reloaded_again.agent_config.tool_llm_provider)
135
+ self.assertEqual(
136
+ agent.agent_config.observer, agent_reloaded_again.agent_config.observer
137
+ )
138
+ self.assertEqual(
139
+ agent.agent_config.main_llm_provider,
140
+ agent_reloaded_again.agent_config.main_llm_provider,
141
+ )
142
+ self.assertEqual(
143
+ agent.agent_config.tool_llm_provider,
144
+ agent_reloaded_again.agent_config.tool_llm_provider,
145
+ )
108
146
 
109
147
 
110
148
  if __name__ == "__main__":
tests/test_together.py CHANGED
@@ -1,5 +1,6 @@
1
1
  # Suppress external dependency warnings before any other imports
2
2
  import warnings
3
+
3
4
  warnings.simplefilter("ignore", DeprecationWarning)
4
5
 
5
6
  import unittest
@@ -9,13 +10,20 @@ from vectara_agentic.agent import Agent
9
10
  from vectara_agentic.tools import ToolsFactory
10
11
 
11
12
  import nest_asyncio
13
+
12
14
  nest_asyncio.apply()
13
15
 
14
- from conftest import fc_config_together, mult, STANDARD_TEST_TOPIC, STANDARD_TEST_INSTRUCTIONS
16
+ from conftest import (
17
+ fc_config_together,
18
+ mult,
19
+ STANDARD_TEST_TOPIC,
20
+ STANDARD_TEST_INSTRUCTIONS,
21
+ )
15
22
 
16
23
 
17
24
  ARIZE_LOCK = threading.Lock()
18
25
 
26
+
19
27
  class TestTogether(unittest.IsolatedAsyncioTestCase):
20
28
 
21
29
  async def test_multiturn(self):
tests/test_tools.py CHANGED
@@ -361,7 +361,9 @@ class TestToolsPackage(unittest.TestCase):
361
361
  "query (str): The search query to perform, in the form of a question", doc
362
362
  )
363
363
  self.assertIn("foo (int): how many foos (e.g., 1, 2, 3)", doc)
364
- self.assertIn("bar (str | None, default='baz'): what bar to use (e.g., 'x', 'y')", doc)
364
+ self.assertIn(
365
+ "bar (str | None, default='baz'): what bar to use (e.g., 'x', 'y')", doc
366
+ )
365
367
  self.assertIn("Returns:", doc)
366
368
  self.assertIn("dict[str, Any]: A dictionary containing the result data.", doc)
367
369
 
@@ -1,5 +1,6 @@
1
1
  # Suppress external dependency warnings before any other imports
2
2
  import warnings
3
+
3
4
  warnings.simplefilter("ignore", DeprecationWarning)
4
5
 
5
6
  import unittest
@@ -20,8 +21,7 @@ class TestLLMPackage(unittest.TestCase):
20
21
 
21
22
  def test_vectara_openai(self):
22
23
  vec_factory = VectaraToolFactory(
23
- vectara_corpus_key=vectara_corpus_key,
24
- vectara_api_key=vectara_api_key
24
+ vectara_corpus_key=vectara_corpus_key, vectara_api_key=vectara_api_key
25
25
  )
26
26
 
27
27
  self.assertEqual(vectara_corpus_key, vec_factory.vectara_corpus_key)
tests/test_vhc.py CHANGED
@@ -1,5 +1,6 @@
1
1
  # Suppress external dependency warnings before any other imports
2
2
  import warnings
3
+
3
4
  warnings.simplefilter("ignore", DeprecationWarning)
4
5
 
5
6
  import unittest
@@ -10,6 +11,7 @@ from vectara_agentic.tools import ToolsFactory
10
11
  from vectara_agentic.types import ModelProvider
11
12
 
12
13
  import nest_asyncio
14
+
13
15
  nest_asyncio.apply()
14
16
 
15
17
  statements = [
@@ -20,6 +22,8 @@ statements = [
20
22
  "Chocolate is the best ice cream flavor.",
21
23
  ]
22
24
  st_inx = 0
25
+
26
+
23
27
  def get_statement() -> str:
24
28
  "Generate next statement"
25
29
  global st_inx
@@ -34,7 +38,8 @@ fc_config = AgentConfig(
34
38
  tool_llm_provider=ModelProvider.OPENAI,
35
39
  )
36
40
 
37
- vectara_api_key = 'zqt_UXrBcnI2UXINZkrv4g1tQPhzj02vfdtqYJIDiA'
41
+ vectara_api_key = "zqt_UXrBcnI2UXINZkrv4g1tQPhzj02vfdtqYJIDiA"
42
+
38
43
 
39
44
  class TestVHC(unittest.TestCase):
40
45
 
@@ -59,7 +64,7 @@ class TestVHC(unittest.TestCase):
59
64
  vhc_corrections = vhc_res.get("corrections", [])
60
65
  self.assertTrue(
61
66
  len(vhc_corrections) >= 0 and len(vhc_corrections) <= 2,
62
- "Corrections should be between 0 and 2"
67
+ "Corrections should be between 0 and 2",
63
68
  )
64
69
 
65
70