vectara-agentic 0.4.0__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of vectara-agentic might be problematic. Click here for more details.

tests/test_bedrock.py CHANGED
@@ -3,43 +3,58 @@ import warnings
3
3
  warnings.simplefilter("ignore", DeprecationWarning)
4
4
 
5
5
  import unittest
6
+ import threading
6
7
 
7
- from vectara_agentic.agent import Agent, AgentType
8
- from vectara_agentic.agent_config import AgentConfig
8
+ from vectara_agentic.agent import Agent
9
9
  from vectara_agentic.tools import ToolsFactory
10
- from vectara_agentic.types import ModelProvider
11
10
 
12
11
  import nest_asyncio
13
12
  nest_asyncio.apply()
14
13
 
15
-
16
- def mult(x: float, y: float) -> float:
17
- "Multiply two numbers"
18
- return x * y
19
-
20
-
21
- fc_config_bedrock = AgentConfig(
22
- agent_type=AgentType.FUNCTION_CALLING,
23
- main_llm_provider=ModelProvider.BEDROCK,
24
- tool_llm_provider=ModelProvider.BEDROCK,
25
- )
26
-
27
- class TestBedrock(unittest.TestCase):
28
-
29
- def test_multiturn(self):
30
- tools = [ToolsFactory().create_tool(mult)]
31
- topic = "AI topic"
32
- instructions = "Always do as your father tells you, if your mother agrees!"
33
- agent = Agent(
34
- tools=tools,
35
- topic=topic,
36
- custom_instructions=instructions,
37
- )
38
-
39
- agent.chat("What is 5 times 10. Only give the answer, nothing else")
40
- agent.chat("what is 3 times 7. Only give the answer, nothing else")
41
- res = agent.chat("multiply the results of the last two questions. Output only the answer.")
42
- self.assertEqual(res.response, "1050")
14
+ from conftest import mult, fc_config_bedrock, STANDARD_TEST_TOPIC, STANDARD_TEST_INSTRUCTIONS
15
+
16
+ ARIZE_LOCK = threading.Lock()
17
+
18
+ class TestBedrock(unittest.IsolatedAsyncioTestCase):
19
+
20
+ async def test_multiturn(self):
21
+ with ARIZE_LOCK:
22
+ tools = [ToolsFactory().create_tool(mult)]
23
+ agent = Agent(
24
+ tools=tools,
25
+ topic=STANDARD_TEST_TOPIC,
26
+ custom_instructions=STANDARD_TEST_INSTRUCTIONS,
27
+ agent_config=fc_config_bedrock,
28
+ )
29
+
30
+ # First calculation: 5 * 10 = 50
31
+ stream1 = await agent.astream_chat(
32
+ "What is 5 times 10. Only give the answer, nothing else"
33
+ )
34
+ # Consume the stream
35
+ async for chunk in stream1.async_response_gen():
36
+ pass
37
+ _ = await stream1.aget_response()
38
+
39
+ # Second calculation: 3 * 7 = 21
40
+ stream2 = await agent.astream_chat(
41
+ "what is 3 times 7. Only give the answer, nothing else"
42
+ )
43
+ # Consume the stream
44
+ async for chunk in stream2.async_response_gen():
45
+ pass
46
+ _ = await stream2.aget_response()
47
+
48
+ # Final calculation: 50 * 21 = 1050
49
+ stream3 = await agent.astream_chat(
50
+ "multiply the results of the last two questions. Output only the answer."
51
+ )
52
+ # Consume the stream
53
+ async for chunk in stream3.async_response_gen():
54
+ pass
55
+ response3 = await stream3.aget_response()
56
+
57
+ self.assertEqual(response3.response, "1050")
43
58
 
44
59
 
45
60
  if __name__ == "__main__":
tests/test_gemini.py CHANGED
@@ -4,15 +4,15 @@ warnings.simplefilter("ignore", DeprecationWarning)
4
4
 
5
5
  import unittest
6
6
 
7
- from vectara_agentic.agent import Agent, AgentType
8
- from vectara_agentic.agent_config import AgentConfig
9
- from vectara_agentic.types import ModelProvider
7
+ from vectara_agentic.agent import Agent
10
8
  from vectara_agentic.tools import ToolsFactory
11
9
 
12
10
 
13
11
  import nest_asyncio
14
12
  nest_asyncio.apply()
15
13
 
14
+ from conftest import mult, fc_config_gemini, STANDARD_TEST_TOPIC, STANDARD_TEST_INSTRUCTIONS
15
+
16
16
  tickers = {
17
17
  "C": "Citigroup",
18
18
  "COF": "Capital One",
@@ -34,11 +34,6 @@ tickers = {
34
34
  years = list(range(2015, 2025))
35
35
 
36
36
 
37
- def mult(x: float, y: float) -> float:
38
- "Multiply two numbers"
39
- return x * y
40
-
41
-
42
37
  def get_company_info() -> list[str]:
43
38
  """
44
39
  Returns a dictionary of companies you can query about. Always check this before using any other tool.
@@ -56,23 +51,15 @@ def get_valid_years() -> list[str]:
56
51
  return years
57
52
 
58
53
 
59
- fc_config_gemini = AgentConfig(
60
- agent_type=AgentType.FUNCTION_CALLING,
61
- main_llm_provider=ModelProvider.GEMINI,
62
- tool_llm_provider=ModelProvider.GEMINI,
63
- )
64
-
65
54
  class TestGEMINI(unittest.TestCase):
66
55
  def test_gemini(self):
67
56
  tools = [ToolsFactory().create_tool(mult)]
68
- topic = "AI topic"
69
- instructions = "Always do as your father tells you, if your mother agrees!"
70
57
 
71
58
  agent = Agent(
72
59
  agent_config=fc_config_gemini,
73
60
  tools=tools,
74
- topic=topic,
75
- custom_instructions=instructions,
61
+ topic=STANDARD_TEST_TOPIC,
62
+ custom_instructions=STANDARD_TEST_INSTRUCTIONS,
76
63
  )
77
64
  _ = agent.chat("What is 5 times 10. Only give the answer, nothing else")
78
65
  _ = agent.chat("what is 3 times 7. Only give the answer, nothing else")
@@ -81,14 +68,12 @@ class TestGEMINI(unittest.TestCase):
81
68
 
82
69
  def test_gemini_single_prompt(self):
83
70
  tools = [ToolsFactory().create_tool(mult)]
84
- topic = "AI topic"
85
- instructions = "Always do as your father tells you, if your mother agrees!"
86
71
 
87
72
  agent = Agent(
88
73
  agent_config=fc_config_gemini,
89
74
  tools=tools,
90
- topic=topic,
91
- custom_instructions=instructions,
75
+ topic=STANDARD_TEST_TOPIC,
76
+ custom_instructions=STANDARD_TEST_INSTRUCTIONS,
92
77
  )
93
78
  res = agent.chat("First, multiply 5 by 10. Then, multiply 3 by 7. Finally, multiply the results of the first two calculations.")
94
79
  self.assertIn("1050", res.response)
tests/test_groq.py CHANGED
@@ -3,43 +3,58 @@ import warnings
3
3
  warnings.simplefilter("ignore", DeprecationWarning)
4
4
 
5
5
  import unittest
6
+ import threading
6
7
 
7
- from vectara_agentic.agent import Agent, AgentType
8
- from vectara_agentic.agent_config import AgentConfig
8
+ from vectara_agentic.agent import Agent
9
9
  from vectara_agentic.tools import ToolsFactory
10
- from vectara_agentic.types import ModelProvider
11
10
 
12
11
  import nest_asyncio
13
12
  nest_asyncio.apply()
14
13
 
15
- def mult(x: float, y: float) -> float:
16
- "Multiply two numbers"
17
- return x * y
18
-
19
-
20
- fc_config_groq = AgentConfig(
21
- agent_type=AgentType.FUNCTION_CALLING,
22
- main_llm_provider=ModelProvider.GROQ,
23
- tool_llm_provider=ModelProvider.GROQ,
24
- )
25
-
26
-
27
- class TestGROQ(unittest.TestCase):
28
-
29
- def test_multiturn(self):
30
- tools = [ToolsFactory().create_tool(mult)]
31
- topic = "AI topic"
32
- instructions = "Always do as your father tells you, if your mother agrees!"
33
- agent = Agent(
34
- tools=tools,
35
- topic=topic,
36
- custom_instructions=instructions,
37
- )
38
-
39
- agent.chat("What is 5 times 10. Only give the answer, nothing else")
40
- agent.chat("what is 3 times 7. Only give the answer, nothing else")
41
- res = agent.chat("multiply the results of the last two questions. Output only the answer.")
42
- self.assertEqual(res.response, "1050")
14
+ from conftest import mult, fc_config_groq, STANDARD_TEST_TOPIC, STANDARD_TEST_INSTRUCTIONS
15
+
16
+ ARIZE_LOCK = threading.Lock()
17
+
18
+ class TestGROQ(unittest.IsolatedAsyncioTestCase):
19
+
20
+ async def test_multiturn(self):
21
+ with ARIZE_LOCK:
22
+ tools = [ToolsFactory().create_tool(mult)]
23
+ agent = Agent(
24
+ tools=tools,
25
+ topic=STANDARD_TEST_TOPIC,
26
+ custom_instructions=STANDARD_TEST_INSTRUCTIONS,
27
+ agent_config=fc_config_groq,
28
+ )
29
+
30
+ # First calculation: 5 * 10 = 50
31
+ stream1 = await agent.astream_chat(
32
+ "What is 5 times 10. Only give the answer, nothing else"
33
+ )
34
+ # Consume the stream
35
+ async for chunk in stream1.async_response_gen():
36
+ pass
37
+ _ = await stream1.aget_response()
38
+
39
+ # Second calculation: 3 * 7 = 21
40
+ stream2 = await agent.astream_chat(
41
+ "what is 3 times 7. Only give the answer, nothing else"
42
+ )
43
+ # Consume the stream
44
+ async for chunk in stream2.async_response_gen():
45
+ pass
46
+ _ = await stream2.aget_response()
47
+
48
+ # Final calculation: 50 * 21 = 1050
49
+ stream3 = await agent.astream_chat(
50
+ "multiply the results of the last two questions. Output only the answer."
51
+ )
52
+ # Consume the stream
53
+ async for chunk in stream3.async_response_gen():
54
+ pass
55
+ response3 = await stream3.aget_response()
56
+
57
+ self.assertEqual(response3.response, "1050")
43
58
 
44
59
 
45
60
  if __name__ == "__main__":
@@ -14,8 +14,7 @@ from vectara_agentic.tools import ToolsFactory
14
14
  from llama_index.core.utilities.sql_wrapper import SQLDatabase
15
15
  from sqlalchemy import create_engine
16
16
 
17
- def mult(x: float, y: float) -> float:
18
- return x * y
17
+ from conftest import mult, STANDARD_TEST_TOPIC, STANDARD_TEST_INSTRUCTIONS
19
18
 
20
19
 
21
20
  ARIZE_LOCK = threading.Lock()
@@ -44,12 +43,10 @@ class TestAgentSerialization(unittest.TestCase):
44
43
  )
45
44
 
46
45
  tools = [ToolsFactory().create_tool(mult)] + ToolsFactory().standard_tools() + db_tools
47
- topic = "AI topic"
48
- instructions = "Always do as your father tells you, if your mother agrees!"
49
46
  agent = Agent(
50
47
  tools=tools,
51
- topic=topic,
52
- custom_instructions=instructions,
48
+ topic=STANDARD_TEST_TOPIC,
49
+ custom_instructions=STANDARD_TEST_INSTRUCTIONS,
53
50
  agent_config=config
54
51
  )
55
52
 
@@ -0,0 +1,252 @@
1
+ # Suppress external dependency warnings before any other imports
2
+ import warnings
3
+
4
+ warnings.simplefilter("ignore", DeprecationWarning)
5
+
6
+ import unittest
7
+ import threading
8
+ from datetime import date
9
+
10
+ from vectara_agentic.agent import Agent
11
+ from vectara_agentic.agent_config import AgentConfig
12
+ from vectara_agentic.types import ModelProvider
13
+ from vectara_agentic.tools import ToolsFactory
14
+ from llama_index.core.llms import ChatMessage, MessageRole
15
+ from conftest import mult, add
16
+
17
+
18
+ ARIZE_LOCK = threading.Lock()
19
+
20
+
21
+ class TestSessionMemoryManagement(unittest.TestCase):
22
+ """Test session_id parameter and memory management functionality"""
23
+
24
+ def setUp(self):
25
+ """Set up test fixtures"""
26
+ self.tools = [ToolsFactory().create_tool(mult), ToolsFactory().create_tool(add)]
27
+ self.topic = "Mathematics"
28
+ self.custom_instructions = "You are a helpful math assistant."
29
+ self.config = AgentConfig(main_llm_provider=ModelProvider.ANTHROPIC)
30
+
31
+ def test_agent_init_with_session_id(self):
32
+ """Test Agent initialization with custom session_id"""
33
+ custom_session_id = "test-session-123"
34
+
35
+ agent = Agent(
36
+ tools=self.tools,
37
+ topic=self.topic,
38
+ custom_instructions=self.custom_instructions,
39
+ agent_config=self.config,
40
+ session_id=custom_session_id,
41
+ )
42
+
43
+ # Verify the agent uses the provided session_id
44
+ self.assertEqual(agent.session_id, custom_session_id)
45
+
46
+ # Verify memory uses the same session_id
47
+ self.assertEqual(agent.memory.session_id, custom_session_id)
48
+
49
+ def test_agent_init_without_session_id(self):
50
+ """Test Agent initialization without session_id (auto-generation)"""
51
+ agent = Agent(
52
+ tools=self.tools,
53
+ topic=self.topic,
54
+ custom_instructions=self.custom_instructions,
55
+ agent_config=self.config,
56
+ )
57
+
58
+ # Verify auto-generated session_id follows expected pattern
59
+ expected_pattern = f"{self.topic}:{date.today().isoformat()}"
60
+ self.assertEqual(agent.session_id, expected_pattern)
61
+
62
+ # Verify memory uses the same session_id
63
+ self.assertEqual(agent.memory.session_id, expected_pattern)
64
+
65
+ def test_from_tools_with_session_id(self):
66
+ """Test Agent.from_tools() with custom session_id"""
67
+ custom_session_id = "from-tools-session-456"
68
+
69
+ agent = Agent.from_tools(
70
+ tools=self.tools,
71
+ topic=self.topic,
72
+ custom_instructions=self.custom_instructions,
73
+ agent_config=self.config,
74
+ session_id=custom_session_id,
75
+ )
76
+
77
+ # Verify the agent uses the provided session_id
78
+ self.assertEqual(agent.session_id, custom_session_id)
79
+ self.assertEqual(agent.memory.session_id, custom_session_id)
80
+
81
+ def test_from_tools_without_session_id(self):
82
+ """Test Agent.from_tools() without session_id (auto-generation)"""
83
+ agent = Agent.from_tools(
84
+ tools=self.tools,
85
+ topic=self.topic,
86
+ custom_instructions=self.custom_instructions,
87
+ agent_config=self.config,
88
+ )
89
+
90
+ # Verify auto-generated session_id
91
+ expected_pattern = f"{self.topic}:{date.today().isoformat()}"
92
+ self.assertEqual(agent.session_id, expected_pattern)
93
+ self.assertEqual(agent.memory.session_id, expected_pattern)
94
+
95
+ def test_session_id_consistency_across_agents(self):
96
+ """Test that agents with same session_id have consistent session_id attributes"""
97
+ shared_session_id = "shared-session-id-test"
98
+
99
+ # Create two agents with the same session_id
100
+ agent1 = Agent(
101
+ tools=self.tools,
102
+ topic=self.topic,
103
+ custom_instructions=self.custom_instructions,
104
+ agent_config=self.config,
105
+ session_id=shared_session_id,
106
+ )
107
+
108
+ agent2 = Agent(
109
+ tools=self.tools,
110
+ topic=self.topic,
111
+ custom_instructions=self.custom_instructions,
112
+ agent_config=self.config,
113
+ session_id=shared_session_id,
114
+ )
115
+
116
+ # Verify both agents have the same session_id
117
+ self.assertEqual(agent1.session_id, shared_session_id)
118
+ self.assertEqual(agent2.session_id, shared_session_id)
119
+ self.assertEqual(agent1.session_id, agent2.session_id)
120
+
121
+ # Verify their memory instances also have the correct session_id
122
+ self.assertEqual(agent1.memory.session_id, shared_session_id)
123
+ self.assertEqual(agent2.memory.session_id, shared_session_id)
124
+
125
+ # Note: Each agent gets its own Memory instance (this is expected behavior)
126
+ # In production, memory persistence happens through serialization/deserialization
127
+
128
+ def test_memory_isolation_different_sessions(self):
129
+ """Test that agents with different session_ids have isolated memory"""
130
+ session_id_1 = "isolated-session-1"
131
+ session_id_2 = "isolated-session-2"
132
+
133
+ # Create two agents with different session_ids
134
+ agent1 = Agent(
135
+ tools=self.tools,
136
+ topic=self.topic,
137
+ custom_instructions=self.custom_instructions,
138
+ agent_config=self.config,
139
+ session_id=session_id_1,
140
+ )
141
+
142
+ agent2 = Agent(
143
+ tools=self.tools,
144
+ topic=self.topic,
145
+ custom_instructions=self.custom_instructions,
146
+ agent_config=self.config,
147
+ session_id=session_id_2,
148
+ )
149
+
150
+ # Add messages to agent1's memory
151
+ agent1_messages = [
152
+ ChatMessage(role=MessageRole.USER, content="Agent 1 question"),
153
+ ChatMessage(role=MessageRole.ASSISTANT, content="Agent 1 response"),
154
+ ]
155
+ agent1.memory.put_messages(agent1_messages)
156
+
157
+ # Add different messages to agent2's memory
158
+ agent2_messages = [
159
+ ChatMessage(role=MessageRole.USER, content="Agent 2 question"),
160
+ ChatMessage(role=MessageRole.ASSISTANT, content="Agent 2 response"),
161
+ ]
162
+ agent2.memory.put_messages(agent2_messages)
163
+
164
+ # Verify memory isolation
165
+ retrieved_agent1_messages = agent1.memory.get()
166
+ retrieved_agent2_messages = agent2.memory.get()
167
+
168
+ self.assertEqual(len(retrieved_agent1_messages), 2)
169
+ self.assertEqual(len(retrieved_agent2_messages), 2)
170
+
171
+ # Verify agent1 only has its own messages
172
+ self.assertEqual(retrieved_agent1_messages[0].content, "Agent 1 question")
173
+ self.assertEqual(retrieved_agent1_messages[1].content, "Agent 1 response")
174
+
175
+ # Verify agent2 only has its own messages
176
+ self.assertEqual(retrieved_agent2_messages[0].content, "Agent 2 question")
177
+ self.assertEqual(retrieved_agent2_messages[1].content, "Agent 2 response")
178
+
179
+ def test_serialization_preserves_session_id(self):
180
+ """Test that agent serialization preserves custom session_id"""
181
+ custom_session_id = "serialization-test-session"
182
+
183
+ # Create agent with custom session_id
184
+ original_agent = Agent(
185
+ tools=self.tools,
186
+ topic=self.topic,
187
+ custom_instructions=self.custom_instructions,
188
+ agent_config=self.config,
189
+ session_id=custom_session_id,
190
+ )
191
+
192
+ # Add some memory
193
+ test_messages = [
194
+ ChatMessage(role=MessageRole.USER, content="Test question"),
195
+ ChatMessage(role=MessageRole.ASSISTANT, content="Test answer"),
196
+ ]
197
+ original_agent.memory.put_messages(test_messages)
198
+
199
+ # Serialize the agent
200
+ serialized_data = original_agent.dumps()
201
+
202
+ # Deserialize the agent
203
+ restored_agent = Agent.loads(serialized_data)
204
+
205
+ # Verify session_id is preserved
206
+ self.assertEqual(restored_agent.session_id, custom_session_id)
207
+ self.assertEqual(restored_agent.memory.session_id, custom_session_id)
208
+
209
+ # Verify memory is preserved
210
+ restored_messages = restored_agent.memory.get()
211
+ self.assertEqual(len(restored_messages), 2)
212
+ self.assertEqual(restored_messages[0].content, "Test question")
213
+ self.assertEqual(restored_messages[1].content, "Test answer")
214
+
215
+ def test_chat_history_initialization_with_session_id(self):
216
+ """Test Agent initialization with chat_history and custom session_id"""
217
+ custom_session_id = "chat-history-session"
218
+ chat_history = [
219
+ ("Hello", "Hi there!"),
220
+ ("How are you?", "I'm doing well, thank you!"),
221
+ ]
222
+
223
+ agent = Agent(
224
+ tools=self.tools,
225
+ topic=self.topic,
226
+ custom_instructions=self.custom_instructions,
227
+ agent_config=self.config,
228
+ session_id=custom_session_id,
229
+ chat_history=chat_history,
230
+ )
231
+
232
+ # Verify session_id is correct
233
+ self.assertEqual(agent.session_id, custom_session_id)
234
+ self.assertEqual(agent.memory.session_id, custom_session_id)
235
+
236
+ # Verify chat history was loaded into memory
237
+ messages = agent.memory.get()
238
+ self.assertEqual(len(messages), 4) # 2 user + 2 assistant messages
239
+
240
+ # Verify message content and roles
241
+ self.assertEqual(messages[0].role, MessageRole.USER)
242
+ self.assertEqual(messages[0].content, "Hello")
243
+ self.assertEqual(messages[1].role, MessageRole.ASSISTANT)
244
+ self.assertEqual(messages[1].content, "Hi there!")
245
+ self.assertEqual(messages[2].role, MessageRole.USER)
246
+ self.assertEqual(messages[2].content, "How are you?")
247
+ self.assertEqual(messages[3].role, MessageRole.ASSISTANT)
248
+ self.assertEqual(messages[3].content, "I'm doing well, thank you!")
249
+
250
+
251
+ if __name__ == "__main__":
252
+ unittest.main()
tests/test_streaming.py CHANGED
@@ -1,77 +1,98 @@
1
1
  # Suppress external dependency warnings before any other imports
2
2
  import warnings
3
+
3
4
  warnings.simplefilter("ignore", DeprecationWarning)
4
5
 
5
6
  import unittest
6
7
  import asyncio
7
8
 
8
- from vectara_agentic.agent import Agent, AgentType
9
- from vectara_agentic.agent_config import AgentConfig
9
+ from vectara_agentic.agent import Agent
10
10
  from vectara_agentic.tools import ToolsFactory
11
- from vectara_agentic.types import ModelProvider
12
11
 
13
12
  import nest_asyncio
13
+
14
14
  nest_asyncio.apply()
15
15
 
16
- def mult(x: float, y: float) -> float:
17
- "Multiply two numbers"
18
- return x * y
16
+ from conftest import (
17
+ fc_config_openai,
18
+ fc_config_anthropic,
19
+ mult,
20
+ STANDARD_TEST_TOPIC,
21
+ STANDARD_TEST_INSTRUCTIONS,
22
+ )
19
23
 
20
24
 
21
- config_function_calling_openai = AgentConfig(
22
- agent_type=AgentType.FUNCTION_CALLING,
23
- main_llm_provider=ModelProvider.OPENAI,
24
- tool_llm_provider=ModelProvider.OPENAI,
25
- )
25
+ class TestAgentStreaming(unittest.IsolatedAsyncioTestCase):
26
26
 
27
- fc_config_anthropic = AgentConfig(
28
- agent_type=AgentType.FUNCTION_CALLING,
29
- main_llm_provider=ModelProvider.ANTHROPIC,
30
- tool_llm_provider=ModelProvider.ANTHROPIC,
31
- )
27
+ async def test_anthropic(self):
28
+ tools = [ToolsFactory().create_tool(mult)]
29
+ agent = Agent(
30
+ agent_config=fc_config_anthropic,
31
+ tools=tools,
32
+ topic=STANDARD_TEST_TOPIC,
33
+ custom_instructions=STANDARD_TEST_INSTRUCTIONS,
34
+ )
32
35
 
33
- fc_config_gemini = AgentConfig(
34
- agent_type=AgentType.FUNCTION_CALLING,
35
- main_llm_provider=ModelProvider.GEMINI,
36
- tool_llm_provider=ModelProvider.GEMINI,
37
- )
36
+ # First calculation: 5 * 10 = 50
37
+ stream1 = await agent.astream_chat(
38
+ "What is 5 times 10. Only give the answer, nothing else"
39
+ )
40
+ # Consume the stream
41
+ async for chunk in stream1.async_response_gen():
42
+ pass
43
+ _ = await stream1.aget_response()
38
44
 
39
- fc_config_together = AgentConfig(
40
- agent_type=AgentType.FUNCTION_CALLING,
41
- main_llm_provider=ModelProvider.TOGETHER,
42
- tool_llm_provider=ModelProvider.TOGETHER,
43
- )
45
+ # Second calculation: 3 * 7 = 21
46
+ stream2 = await agent.astream_chat(
47
+ "what is 3 times 7. Only give the answer, nothing else"
48
+ )
49
+ # Consume the stream
50
+ async for chunk in stream2.async_response_gen():
51
+ pass
52
+ _ = await stream2.aget_response()
44
53
 
54
+ # Final calculation: 50 * 21 = 1050
55
+ stream3 = await agent.astream_chat(
56
+ "multiply the results of the last two multiplications. Only give the answer, nothing else."
57
+ )
58
+ # Consume the stream
59
+ async for chunk in stream3.async_response_gen():
60
+ pass
61
+ response3 = await stream3.aget_response()
45
62
 
46
- class TestAgentStreaming(unittest.TestCase):
63
+ self.assertIn("1050", response3.response)
47
64
 
48
- async def test_anthropic(self):
65
+ async def test_openai(self):
49
66
  tools = [ToolsFactory().create_tool(mult)]
50
- topic = "AI topic"
51
- instructions = "Always do as your father tells you, if your mother agrees!"
52
67
  agent = Agent(
53
- agent_config=fc_config_anthropic, # Use function calling which has better streaming
68
+ agent_config=fc_config_openai,
54
69
  tools=tools,
55
- topic=topic,
56
- custom_instructions=instructions,
70
+ topic=STANDARD_TEST_TOPIC,
71
+ custom_instructions=STANDARD_TEST_INSTRUCTIONS,
57
72
  )
58
73
 
59
74
  # First calculation: 5 * 10 = 50
60
- stream1 = await agent.astream_chat("What is 5 times 10. Only give the answer, nothing else")
75
+ stream1 = await agent.astream_chat(
76
+ "What is 5 times 10. Only give the answer, nothing else"
77
+ )
61
78
  # Consume the stream
62
79
  async for chunk in stream1.async_response_gen():
63
80
  pass
64
81
  _ = await stream1.aget_response()
65
82
 
66
83
  # Second calculation: 3 * 7 = 21
67
- stream2 = await agent.astream_chat("what is 3 times 7. Only give the answer, nothing else")
84
+ stream2 = await agent.astream_chat(
85
+ "what is 3 times 7. Only give the answer, nothing else"
86
+ )
68
87
  # Consume the stream
69
88
  async for chunk in stream2.async_response_gen():
70
89
  pass
71
90
  _ = await stream2.aget_response()
72
91
 
73
92
  # Final calculation: 50 * 21 = 1050
74
- stream3 = await agent.astream_chat("multiply the results of the last two multiplications. Only give the answer, nothing else.")
93
+ stream3 = await agent.astream_chat(
94
+ "multiply the results of the last two multiplications. Only give the answer, nothing else."
95
+ )
75
96
  # Consume the stream
76
97
  async for chunk in stream3.async_response_gen():
77
98
  pass
@@ -81,7 +102,7 @@ class TestAgentStreaming(unittest.TestCase):
81
102
 
82
103
  def test_openai_sync(self):
83
104
  """Synchronous wrapper for the async test"""
84
- asyncio.run(self.test_anthropic())
105
+ asyncio.run(self.test_openai())
85
106
 
86
107
 
87
108
  if __name__ == "__main__":