vectara-agentic 0.4.2__py3-none-any.whl → 0.4.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of vectara-agentic might be problematic. Click here for more details.

Files changed (43) hide show
  1. tests/__init__.py +1 -0
  2. tests/benchmark_models.py +547 -372
  3. tests/conftest.py +14 -12
  4. tests/endpoint.py +9 -5
  5. tests/run_tests.py +1 -0
  6. tests/test_agent.py +22 -9
  7. tests/test_agent_fallback_memory.py +4 -4
  8. tests/test_agent_memory_consistency.py +4 -4
  9. tests/test_agent_type.py +2 -0
  10. tests/test_api_endpoint.py +13 -13
  11. tests/test_bedrock.py +9 -1
  12. tests/test_fallback.py +18 -7
  13. tests/test_gemini.py +14 -40
  14. tests/test_groq.py +43 -1
  15. tests/test_openai.py +160 -0
  16. tests/test_private_llm.py +19 -6
  17. tests/test_react_error_handling.py +293 -0
  18. tests/test_react_memory.py +257 -0
  19. tests/test_react_streaming.py +135 -0
  20. tests/test_react_workflow_events.py +395 -0
  21. tests/test_return_direct.py +1 -0
  22. tests/test_serialization.py +58 -20
  23. tests/test_session_memory.py +11 -11
  24. tests/test_streaming.py +0 -44
  25. tests/test_together.py +75 -1
  26. tests/test_tools.py +3 -1
  27. tests/test_vectara_llms.py +2 -2
  28. tests/test_vhc.py +7 -2
  29. tests/test_workflow.py +17 -11
  30. vectara_agentic/_callback.py +79 -21
  31. vectara_agentic/_version.py +1 -1
  32. vectara_agentic/agent.py +65 -27
  33. vectara_agentic/agent_core/serialization.py +5 -9
  34. vectara_agentic/agent_core/streaming.py +245 -64
  35. vectara_agentic/agent_core/utils/schemas.py +2 -2
  36. vectara_agentic/llm_utils.py +64 -15
  37. vectara_agentic/tools.py +88 -31
  38. {vectara_agentic-0.4.2.dist-info → vectara_agentic-0.4.4.dist-info}/METADATA +133 -36
  39. vectara_agentic-0.4.4.dist-info/RECORD +59 -0
  40. vectara_agentic-0.4.2.dist-info/RECORD +0 -54
  41. {vectara_agentic-0.4.2.dist-info → vectara_agentic-0.4.4.dist-info}/WHEEL +0 -0
  42. {vectara_agentic-0.4.2.dist-info → vectara_agentic-0.4.4.dist-info}/licenses/LICENSE +0 -0
  43. {vectara_agentic-0.4.2.dist-info → vectara_agentic-0.4.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,395 @@
1
+ # Suppress external dependency warnings before any other imports
2
+ import warnings
3
+
4
+ warnings.simplefilter("ignore", DeprecationWarning)
5
+
6
+ import unittest
7
+ from typing import Dict, Any
8
+
9
+ from vectara_agentic.agent import Agent, AgentStatusType
10
+ from vectara_agentic.tools import ToolsFactory
11
+
12
+ import nest_asyncio
13
+
14
+ nest_asyncio.apply()
15
+
16
+ from conftest import (
17
+ AgentTestMixin,
18
+ react_config_anthropic,
19
+ react_config_gemini,
20
+ react_config_together,
21
+ mult,
22
+ add,
23
+ STANDARD_TEST_TOPIC,
24
+ STANDARD_TEST_INSTRUCTIONS,
25
+ )
26
+
27
+
28
+ class TestReActWorkflowEvents(unittest.IsolatedAsyncioTestCase, AgentTestMixin):
29
+ """Test workflow event handling and streaming for ReAct agents."""
30
+
31
+ def setUp(self):
32
+ self.tools = [ToolsFactory().create_tool(mult), ToolsFactory().create_tool(add)]
33
+ self.topic = STANDARD_TEST_TOPIC
34
+ self.instructions = STANDARD_TEST_INSTRUCTIONS
35
+ self.captured_events = []
36
+
37
+ def capture_progress_callback(
38
+ self, status_type: AgentStatusType, msg: Dict[str, Any], event_id: str = None
39
+ ):
40
+ """Capture agent progress events for testing."""
41
+ self.captured_events.append(
42
+ {
43
+ "status_type": status_type,
44
+ "msg": msg,
45
+ "event_id": event_id,
46
+ }
47
+ )
48
+
49
+ async def test_react_workflow_tool_call_events(self):
50
+ """Test that ReAct workflow generates proper tool call events."""
51
+ agent = Agent(
52
+ agent_config=react_config_anthropic,
53
+ tools=self.tools,
54
+ topic=self.topic,
55
+ custom_instructions=self.instructions,
56
+ agent_progress_callback=self.capture_progress_callback,
57
+ )
58
+
59
+ with self.with_provider_fallback("Anthropic"):
60
+ stream = await agent.astream_chat("Calculate 8 times 9.")
61
+
62
+ # Consume the stream
63
+ async for chunk in stream.async_response_gen():
64
+ pass
65
+
66
+ response = await stream.aget_response()
67
+ self.check_response_and_skip(response, "Anthropic")
68
+
69
+ if response.response and "72" in response.response:
70
+ # Verify we captured tool-related events
71
+ tool_call_events = [
72
+ event
73
+ for event in self.captured_events
74
+ if event["status_type"] == AgentStatusType.TOOL_CALL
75
+ ]
76
+
77
+ tool_output_events = [
78
+ event
79
+ for event in self.captured_events
80
+ if event["status_type"] == AgentStatusType.TOOL_OUTPUT
81
+ ]
82
+
83
+ # Should have at least one tool call and one tool output
84
+ self.assertGreater(
85
+ len(tool_call_events), 0, "Should capture tool call events"
86
+ )
87
+ self.assertGreater(
88
+ len(tool_output_events), 0, "Should capture tool output events"
89
+ )
90
+
91
+ # Verify tool call event structure
92
+ if tool_call_events:
93
+ tool_call = tool_call_events[0]
94
+ self.assertIn("tool_name", tool_call["msg"])
95
+ self.assertIn("arguments", tool_call["msg"])
96
+ self.assertIsNotNone(tool_call["event_id"])
97
+
98
+ # Verify tool output event structure
99
+ if tool_output_events:
100
+ tool_output = tool_output_events[0]
101
+ self.assertIn("tool_name", tool_output["msg"])
102
+ self.assertIn("content", tool_output["msg"])
103
+ self.assertIsNotNone(tool_output["event_id"])
104
+
105
+ async def test_react_workflow_multi_step_events(self):
106
+ """Test ReAct workflow events for multi-step reasoning tasks."""
107
+ self.captured_events.clear()
108
+
109
+ agent = Agent(
110
+ agent_config=react_config_anthropic,
111
+ tools=self.tools,
112
+ topic=self.topic,
113
+ custom_instructions=self.instructions,
114
+ agent_progress_callback=self.capture_progress_callback,
115
+ )
116
+
117
+ with self.with_provider_fallback("Anthropic"):
118
+ stream = await agent.astream_chat(
119
+ "First multiply 7 by 6, then add 15 to that result."
120
+ )
121
+
122
+ # Consume the stream
123
+ async for chunk in stream.async_response_gen():
124
+ pass
125
+
126
+ response = await stream.aget_response()
127
+ self.check_response_and_skip(response, "Anthropic")
128
+
129
+ # Should be (7*6)+15 = 42+15 = 57
130
+ if response.response and "57" in response.response:
131
+ # Should have multiple tool calls (multiplication and addition)
132
+ tool_call_events = [
133
+ event
134
+ for event in self.captured_events
135
+ if event["status_type"] == AgentStatusType.TOOL_CALL
136
+ ]
137
+
138
+ # Should have at least 2 tool calls (mult and add)
139
+ self.assertGreaterEqual(
140
+ len(tool_call_events),
141
+ 1,
142
+ "Should have tool call events for multi-step task",
143
+ )
144
+
145
+ # Verify event IDs are present for events that have them
146
+ # With simplified logic, events without proper IDs are skipped
147
+ events_with_ids = [
148
+ event
149
+ for event in self.captured_events
150
+ if event["event_id"]
151
+ ]
152
+
153
+ # At least some events should have IDs
154
+ self.assertGreater(
155
+ len(events_with_ids), 0, "Should have events with proper IDs"
156
+ )
157
+
158
+ async def test_react_workflow_agent_update_events(self):
159
+ """Test that ReAct workflow generates agent update events."""
160
+ self.captured_events.clear()
161
+
162
+ agent = Agent(
163
+ agent_config=react_config_anthropic,
164
+ tools=self.tools,
165
+ topic=self.topic,
166
+ custom_instructions=self.instructions,
167
+ agent_progress_callback=self.capture_progress_callback,
168
+ )
169
+
170
+ with self.with_provider_fallback("Anthropic"):
171
+ stream = await agent.astream_chat("Calculate 5 times 11.")
172
+
173
+ # Consume the stream
174
+ async for chunk in stream.async_response_gen():
175
+ pass
176
+
177
+ response = await stream.aget_response()
178
+ self.check_response_and_skip(response, "Anthropic")
179
+
180
+ if response.response and "55" in response.response:
181
+ # Look for agent update events
182
+ agent_update_events = [
183
+ event
184
+ for event in self.captured_events
185
+ if event["status_type"] == AgentStatusType.AGENT_UPDATE
186
+ ]
187
+
188
+ # ReAct agents should generate some agent update events during workflow
189
+ self.assertGreaterEqual(
190
+ len(agent_update_events),
191
+ 0,
192
+ "ReAct workflow should generate agent update events",
193
+ )
194
+
195
+ # Verify structure of agent update events
196
+ for event in agent_update_events:
197
+ self.assertIn("content", event["msg"])
198
+ self.assertIsInstance(event["msg"]["content"], str)
199
+
200
+ async def test_react_workflow_event_ordering(self):
201
+ """Test that ReAct workflow events are generated in correct order."""
202
+ self.captured_events.clear()
203
+
204
+ agent = Agent(
205
+ agent_config=react_config_anthropic,
206
+ tools=self.tools,
207
+ topic=self.topic,
208
+ custom_instructions=self.instructions,
209
+ agent_progress_callback=self.capture_progress_callback,
210
+ )
211
+
212
+ with self.with_provider_fallback("Anthropic"):
213
+ stream = await agent.astream_chat("Multiply 9 by 4.")
214
+
215
+ # Consume the stream
216
+ async for chunk in stream.async_response_gen():
217
+ pass
218
+
219
+ response = await stream.aget_response()
220
+ self.check_response_and_skip(response, "Anthropic")
221
+
222
+ if response.response and "36" in response.response:
223
+ # Find tool call and tool output events
224
+ tool_events = [
225
+ event
226
+ for event in self.captured_events
227
+ if event["status_type"]
228
+ in [AgentStatusType.TOOL_CALL, AgentStatusType.TOOL_OUTPUT]
229
+ ]
230
+
231
+ if len(tool_events) >= 2:
232
+ # Group events by event_id to match calls with outputs
233
+ event_groups = {}
234
+ for event in tool_events:
235
+ event_id = event["event_id"]
236
+ if event_id not in event_groups:
237
+ event_groups[event_id] = []
238
+ event_groups[event_id].append(event)
239
+
240
+ # For each event group, tool call should come before tool output
241
+ for event_id, events in event_groups.items():
242
+ if len(events) >= 2:
243
+ call_events = [
244
+ e
245
+ for e in events
246
+ if e["status_type"] == AgentStatusType.TOOL_CALL
247
+ ]
248
+ output_events = [
249
+ e
250
+ for e in events
251
+ if e["status_type"] == AgentStatusType.TOOL_OUTPUT
252
+ ]
253
+
254
+ if call_events and output_events:
255
+ # Find indices in original event list
256
+ call_index = self.captured_events.index(call_events[0])
257
+ output_index = self.captured_events.index(
258
+ output_events[0]
259
+ )
260
+
261
+ self.assertLess(
262
+ call_index,
263
+ output_index,
264
+ "Tool call should come before tool output",
265
+ )
266
+
267
+ async def test_react_workflow_event_error_handling(self):
268
+ """Test ReAct workflow event handling when tools fail."""
269
+ self.captured_events.clear()
270
+
271
+ def failing_tool(x: float) -> float:
272
+ """A tool that fails with certain inputs."""
273
+ if x == 0:
274
+ raise ValueError("Cannot process zero")
275
+ return x * 10
276
+
277
+ error_tools = [ToolsFactory().create_tool(failing_tool)]
278
+
279
+ agent = Agent(
280
+ agent_config=react_config_anthropic,
281
+ tools=error_tools,
282
+ topic=self.topic,
283
+ custom_instructions=self.instructions,
284
+ agent_progress_callback=self.capture_progress_callback,
285
+ )
286
+
287
+ with self.with_provider_fallback("Anthropic"):
288
+ stream = await agent.astream_chat("Use failing_tool with input 0.")
289
+
290
+ # Consume the stream
291
+ async for chunk in stream.async_response_gen():
292
+ pass
293
+
294
+ response = await stream.aget_response()
295
+ self.check_response_and_skip(response, "Anthropic")
296
+
297
+ # Even with tool errors, we should still capture events
298
+ self.assertGreater(
299
+ len(self.captured_events),
300
+ 0,
301
+ "Should capture events even when tools fail",
302
+ )
303
+
304
+ # Look for tool call events
305
+ tool_call_events = [
306
+ event
307
+ for event in self.captured_events
308
+ if event["status_type"] == AgentStatusType.TOOL_CALL
309
+ ]
310
+
311
+ self.assertGreater(
312
+ len(tool_call_events),
313
+ 0,
314
+ "Should capture tool call events even when tools fail",
315
+ )
316
+
317
+ async def test_react_workflow_event_callback_error_resilience(self):
318
+ """Test that ReAct workflow continues even if progress callback raises errors."""
319
+
320
+ def failing_callback(
321
+ status_type: AgentStatusType, msg: Dict[str, Any], event_id: str = None
322
+ ):
323
+ """A callback that always raises an error."""
324
+ raise RuntimeError("Callback error")
325
+
326
+ agent = Agent(
327
+ agent_config=react_config_anthropic,
328
+ tools=self.tools,
329
+ topic=self.topic,
330
+ custom_instructions=self.instructions,
331
+ agent_progress_callback=failing_callback,
332
+ )
333
+
334
+ with self.with_provider_fallback("Anthropic"):
335
+ # Even with failing callback, agent should still work
336
+ stream = await agent.astream_chat("Calculate 12 times 3.")
337
+
338
+ # Consume the stream
339
+ async for chunk in stream.async_response_gen():
340
+ pass
341
+
342
+ response = await stream.aget_response()
343
+ self.check_response_and_skip(response, "Anthropic")
344
+
345
+ if response.response and "36" in response.response:
346
+ # Test passed - agent worked despite callback failures
347
+ self.assertTrue(True)
348
+
349
+ async def test_react_workflow_event_consistency_across_providers(self):
350
+ """Test that ReAct workflow events are consistent across different providers."""
351
+ providers_to_test = [
352
+ ("Anthropic", react_config_anthropic),
353
+ ("Gemini", react_config_gemini),
354
+ ("Together AI", react_config_together),
355
+ ]
356
+
357
+ for provider_name, config in providers_to_test:
358
+ with self.subTest(provider=provider_name):
359
+ self.captured_events.clear()
360
+
361
+ agent = Agent(
362
+ agent_config=config,
363
+ tools=self.tools,
364
+ topic=self.topic,
365
+ custom_instructions=self.instructions,
366
+ agent_progress_callback=self.capture_progress_callback,
367
+ )
368
+
369
+ with self.with_provider_fallback(provider_name):
370
+ stream = await agent.astream_chat("Calculate 6 times 8.")
371
+
372
+ # Consume the stream
373
+ async for chunk in stream.async_response_gen():
374
+ pass
375
+
376
+ response = await stream.aget_response()
377
+ self.check_response_and_skip(response, provider_name)
378
+
379
+ if response.response and "48" in response.response:
380
+ # Should have some events
381
+ self.assertGreater(
382
+ len(self.captured_events),
383
+ 0,
384
+ f"{provider_name} should generate events",
385
+ )
386
+
387
+ # All events should have proper structure
388
+ for event in self.captured_events:
389
+ self.assertIn("status_type", event)
390
+ self.assertIn("msg", event)
391
+ self.assertIsInstance(event["msg"], dict)
392
+
393
+
394
+ if __name__ == "__main__":
395
+ unittest.main()
@@ -1,5 +1,6 @@
1
1
  # Suppress external dependency warnings before any other imports
2
2
  import warnings
3
+
3
4
  warnings.simplefilter("ignore", DeprecationWarning)
4
5
 
5
6
  import unittest
@@ -1,5 +1,6 @@
1
1
  # Suppress external dependency warnings before any other imports
2
2
  import warnings
3
+
3
4
  warnings.simplefilter("ignore", DeprecationWarning)
4
5
 
5
6
  import unittest
@@ -19,12 +20,13 @@ from conftest import mult, STANDARD_TEST_TOPIC, STANDARD_TEST_INSTRUCTIONS
19
20
 
20
21
  ARIZE_LOCK = threading.Lock()
21
22
 
23
+
22
24
  class TestAgentSerialization(unittest.TestCase):
23
25
 
24
26
  @classmethod
25
27
  def tearDown(cls):
26
28
  try:
27
- os.remove('ev_database.db')
29
+ os.remove("ev_database.db")
28
30
  except FileNotFoundError:
29
31
  pass
30
32
 
@@ -34,20 +36,24 @@ class TestAgentSerialization(unittest.TestCase):
34
36
  agent_type=AgentType.REACT,
35
37
  main_llm_provider=ModelProvider.ANTHROPIC,
36
38
  tool_llm_provider=ModelProvider.TOGETHER,
37
- observer=ObserverType.ARIZE_PHOENIX
39
+ observer=ObserverType.ARIZE_PHOENIX,
38
40
  )
39
41
  db_tools = ToolsFactory().database_tools(
40
- tool_name_prefix = "ev",
41
- content_description = 'Electric Vehicles in the state of Washington and other population information',
42
- sql_database = SQLDatabase(create_engine('sqlite:///ev_database.db')),
42
+ tool_name_prefix="ev",
43
+ content_description="Electric Vehicles in the state of Washington and other population information",
44
+ sql_database=SQLDatabase(create_engine("sqlite:///ev_database.db")),
43
45
  )
44
46
 
45
- tools = [ToolsFactory().create_tool(mult)] + ToolsFactory().standard_tools() + db_tools
47
+ tools = (
48
+ [ToolsFactory().create_tool(mult)]
49
+ + ToolsFactory().standard_tools()
50
+ + db_tools
51
+ )
46
52
  agent = Agent(
47
53
  tools=tools,
48
54
  topic=STANDARD_TEST_TOPIC,
49
55
  custom_instructions=STANDARD_TEST_INSTRUCTIONS,
50
- agent_config=config
56
+ agent_config=config,
51
57
  )
52
58
 
53
59
  agent_reloaded = agent.loads(agent.dumps())
@@ -57,17 +63,33 @@ class TestAgentSerialization(unittest.TestCase):
57
63
  self.assertEqual(agent, agent_reloaded)
58
64
  self.assertEqual(agent.agent_type, agent_reloaded.agent_type)
59
65
 
60
- self.assertEqual(agent.agent_config.observer, agent_reloaded.agent_config.observer)
61
- self.assertEqual(agent.agent_config.main_llm_provider, agent_reloaded.agent_config.main_llm_provider)
62
- self.assertEqual(agent.agent_config.tool_llm_provider, agent_reloaded.agent_config.tool_llm_provider)
66
+ self.assertEqual(
67
+ agent.agent_config.observer, agent_reloaded.agent_config.observer
68
+ )
69
+ self.assertEqual(
70
+ agent.agent_config.main_llm_provider,
71
+ agent_reloaded.agent_config.main_llm_provider,
72
+ )
73
+ self.assertEqual(
74
+ agent.agent_config.tool_llm_provider,
75
+ agent_reloaded.agent_config.tool_llm_provider,
76
+ )
63
77
 
64
78
  self.assertIsInstance(agent_reloaded, Agent)
65
79
  self.assertEqual(agent, agent_reloaded_again)
66
80
  self.assertEqual(agent.agent_type, agent_reloaded_again.agent_type)
67
81
 
68
- self.assertEqual(agent.agent_config.observer, agent_reloaded_again.agent_config.observer)
69
- self.assertEqual(agent.agent_config.main_llm_provider, agent_reloaded_again.agent_config.main_llm_provider)
70
- self.assertEqual(agent.agent_config.tool_llm_provider, agent_reloaded_again.agent_config.tool_llm_provider)
82
+ self.assertEqual(
83
+ agent.agent_config.observer, agent_reloaded_again.agent_config.observer
84
+ )
85
+ self.assertEqual(
86
+ agent.agent_config.main_llm_provider,
87
+ agent_reloaded_again.agent_config.main_llm_provider,
88
+ )
89
+ self.assertEqual(
90
+ agent.agent_config.tool_llm_provider,
91
+ agent_reloaded_again.agent_config.tool_llm_provider,
92
+ )
71
93
 
72
94
  def test_serialization_from_corpus(self):
73
95
  with ARIZE_LOCK:
@@ -75,7 +97,7 @@ class TestAgentSerialization(unittest.TestCase):
75
97
  agent_type=AgentType.REACT,
76
98
  main_llm_provider=ModelProvider.ANTHROPIC,
77
99
  tool_llm_provider=ModelProvider.TOGETHER,
78
- observer=ObserverType.ARIZE_PHOENIX
100
+ observer=ObserverType.ARIZE_PHOENIX,
79
101
  )
80
102
 
81
103
  agent = Agent.from_corpus(
@@ -94,17 +116,33 @@ class TestAgentSerialization(unittest.TestCase):
94
116
  self.assertEqual(agent, agent_reloaded)
95
117
  self.assertEqual(agent.agent_type, agent_reloaded.agent_type)
96
118
 
97
- self.assertEqual(agent.agent_config.observer, agent_reloaded.agent_config.observer)
98
- self.assertEqual(agent.agent_config.main_llm_provider, agent_reloaded.agent_config.main_llm_provider)
99
- self.assertEqual(agent.agent_config.tool_llm_provider, agent_reloaded.agent_config.tool_llm_provider)
119
+ self.assertEqual(
120
+ agent.agent_config.observer, agent_reloaded.agent_config.observer
121
+ )
122
+ self.assertEqual(
123
+ agent.agent_config.main_llm_provider,
124
+ agent_reloaded.agent_config.main_llm_provider,
125
+ )
126
+ self.assertEqual(
127
+ agent.agent_config.tool_llm_provider,
128
+ agent_reloaded.agent_config.tool_llm_provider,
129
+ )
100
130
 
101
131
  self.assertIsInstance(agent_reloaded, Agent)
102
132
  self.assertEqual(agent, agent_reloaded_again)
103
133
  self.assertEqual(agent.agent_type, agent_reloaded_again.agent_type)
104
134
 
105
- self.assertEqual(agent.agent_config.observer, agent_reloaded_again.agent_config.observer)
106
- self.assertEqual(agent.agent_config.main_llm_provider, agent_reloaded_again.agent_config.main_llm_provider)
107
- self.assertEqual(agent.agent_config.tool_llm_provider, agent_reloaded_again.agent_config.tool_llm_provider)
135
+ self.assertEqual(
136
+ agent.agent_config.observer, agent_reloaded_again.agent_config.observer
137
+ )
138
+ self.assertEqual(
139
+ agent.agent_config.main_llm_provider,
140
+ agent_reloaded_again.agent_config.main_llm_provider,
141
+ )
142
+ self.assertEqual(
143
+ agent.agent_config.tool_llm_provider,
144
+ agent_reloaded_again.agent_config.tool_llm_provider,
145
+ )
108
146
 
109
147
 
110
148
  if __name__ == "__main__":
@@ -43,8 +43,8 @@ class TestSessionMemoryManagement(unittest.TestCase):
43
43
  # Verify the agent uses the provided session_id
44
44
  self.assertEqual(agent.session_id, custom_session_id)
45
45
 
46
- # Verify memory uses the same session_id (via chat_store_key)
47
- self.assertEqual(agent.memory.chat_store_key, custom_session_id)
46
+ # Verify memory uses the same session_id
47
+ self.assertEqual(agent.memory.session_id, custom_session_id)
48
48
 
49
49
  def test_agent_init_without_session_id(self):
50
50
  """Test Agent initialization without session_id (auto-generation)"""
@@ -59,8 +59,8 @@ class TestSessionMemoryManagement(unittest.TestCase):
59
59
  expected_pattern = f"{self.topic}:{date.today().isoformat()}"
60
60
  self.assertEqual(agent.session_id, expected_pattern)
61
61
 
62
- # Verify memory uses the same session_id (via chat_store_key)
63
- self.assertEqual(agent.memory.chat_store_key, expected_pattern)
62
+ # Verify memory uses the same session_id
63
+ self.assertEqual(agent.memory.session_id, expected_pattern)
64
64
 
65
65
  def test_from_tools_with_session_id(self):
66
66
  """Test Agent.from_tools() with custom session_id"""
@@ -76,7 +76,7 @@ class TestSessionMemoryManagement(unittest.TestCase):
76
76
 
77
77
  # Verify the agent uses the provided session_id
78
78
  self.assertEqual(agent.session_id, custom_session_id)
79
- self.assertEqual(agent.memory.chat_store_key, custom_session_id)
79
+ self.assertEqual(agent.memory.session_id, custom_session_id)
80
80
 
81
81
  def test_from_tools_without_session_id(self):
82
82
  """Test Agent.from_tools() without session_id (auto-generation)"""
@@ -90,7 +90,7 @@ class TestSessionMemoryManagement(unittest.TestCase):
90
90
  # Verify auto-generated session_id
91
91
  expected_pattern = f"{self.topic}:{date.today().isoformat()}"
92
92
  self.assertEqual(agent.session_id, expected_pattern)
93
- self.assertEqual(agent.memory.chat_store_key, expected_pattern)
93
+ self.assertEqual(agent.memory.session_id, expected_pattern)
94
94
 
95
95
  def test_session_id_consistency_across_agents(self):
96
96
  """Test that agents with same session_id have consistent session_id attributes"""
@@ -118,9 +118,9 @@ class TestSessionMemoryManagement(unittest.TestCase):
118
118
  self.assertEqual(agent2.session_id, shared_session_id)
119
119
  self.assertEqual(agent1.session_id, agent2.session_id)
120
120
 
121
- # Verify their memory instances also have the correct session_id (via chat_store_key)
122
- self.assertEqual(agent1.memory.chat_store_key, shared_session_id)
123
- self.assertEqual(agent2.memory.chat_store_key, shared_session_id)
121
+ # Verify their memory instances also have the correct session_id
122
+ self.assertEqual(agent1.memory.session_id, shared_session_id)
123
+ self.assertEqual(agent2.memory.session_id, shared_session_id)
124
124
 
125
125
  # Note: Each agent gets its own Memory instance (this is expected behavior)
126
126
  # In production, memory persistence happens through serialization/deserialization
@@ -204,7 +204,7 @@ class TestSessionMemoryManagement(unittest.TestCase):
204
204
 
205
205
  # Verify session_id is preserved
206
206
  self.assertEqual(restored_agent.session_id, custom_session_id)
207
- self.assertEqual(restored_agent.memory.chat_store_key, custom_session_id)
207
+ self.assertEqual(restored_agent.memory.session_id, custom_session_id)
208
208
 
209
209
  # Verify memory is preserved
210
210
  restored_messages = restored_agent.memory.get()
@@ -231,7 +231,7 @@ class TestSessionMemoryManagement(unittest.TestCase):
231
231
 
232
232
  # Verify session_id is correct
233
233
  self.assertEqual(agent.session_id, custom_session_id)
234
- self.assertEqual(agent.memory.chat_store_key, custom_session_id)
234
+ self.assertEqual(agent.memory.session_id, custom_session_id)
235
235
 
236
236
  # Verify chat history was loaded into memory
237
237
  messages = agent.memory.get()
tests/test_streaming.py CHANGED
@@ -4,7 +4,6 @@ import warnings
4
4
  warnings.simplefilter("ignore", DeprecationWarning)
5
5
 
6
6
  import unittest
7
- import asyncio
8
7
 
9
8
  from vectara_agentic.agent import Agent
10
9
  from vectara_agentic.tools import ToolsFactory
@@ -14,7 +13,6 @@ import nest_asyncio
14
13
  nest_asyncio.apply()
15
14
 
16
15
  from conftest import (
17
- fc_config_openai,
18
16
  fc_config_anthropic,
19
17
  mult,
20
18
  STANDARD_TEST_TOPIC,
@@ -62,48 +60,6 @@ class TestAgentStreaming(unittest.IsolatedAsyncioTestCase):
62
60
 
63
61
  self.assertIn("1050", response3.response)
64
62
 
65
- async def test_openai(self):
66
- tools = [ToolsFactory().create_tool(mult)]
67
- agent = Agent(
68
- agent_config=fc_config_openai,
69
- tools=tools,
70
- topic=STANDARD_TEST_TOPIC,
71
- custom_instructions=STANDARD_TEST_INSTRUCTIONS,
72
- )
73
-
74
- # First calculation: 5 * 10 = 50
75
- stream1 = await agent.astream_chat(
76
- "What is 5 times 10. Only give the answer, nothing else"
77
- )
78
- # Consume the stream
79
- async for chunk in stream1.async_response_gen():
80
- pass
81
- _ = await stream1.aget_response()
82
-
83
- # Second calculation: 3 * 7 = 21
84
- stream2 = await agent.astream_chat(
85
- "what is 3 times 7. Only give the answer, nothing else"
86
- )
87
- # Consume the stream
88
- async for chunk in stream2.async_response_gen():
89
- pass
90
- _ = await stream2.aget_response()
91
-
92
- # Final calculation: 50 * 21 = 1050
93
- stream3 = await agent.astream_chat(
94
- "multiply the results of the last two multiplications. Only give the answer, nothing else."
95
- )
96
- # Consume the stream
97
- async for chunk in stream3.async_response_gen():
98
- pass
99
- response3 = await stream3.aget_response()
100
-
101
- self.assertIn("1050", response3.response)
102
-
103
- def test_openai_sync(self):
104
- """Synchronous wrapper for the async test"""
105
- asyncio.run(self.test_openai())
106
-
107
63
 
108
64
  if __name__ == "__main__":
109
65
  unittest.main()