vectara-agentic 0.4.0__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of vectara-agentic might be problematic. Click here for more details.
- tests/conftest.py +5 -1
- tests/run_tests.py +1 -0
- tests/test_agent.py +26 -29
- tests/test_agent_fallback_memory.py +270 -0
- tests/test_agent_memory_consistency.py +229 -0
- tests/test_agent_type.py +4 -0
- tests/test_bedrock.py +46 -31
- tests/test_gemini.py +7 -22
- tests/test_groq.py +46 -31
- tests/test_serialization.py +3 -6
- tests/test_session_memory.py +252 -0
- tests/test_streaming.py +58 -37
- tests/test_together.py +62 -0
- tests/test_vhc.py +3 -2
- tests/test_workflow.py +9 -28
- vectara_agentic/_version.py +1 -1
- vectara_agentic/agent.py +212 -33
- vectara_agentic/agent_core/factory.py +30 -148
- vectara_agentic/agent_core/prompts.py +20 -13
- vectara_agentic/agent_core/serialization.py +3 -0
- vectara_agentic/agent_core/streaming.py +22 -34
- vectara_agentic/agent_core/utils/__init__.py +0 -5
- vectara_agentic/agent_core/utils/hallucination.py +54 -99
- vectara_agentic/llm_utils.py +1 -1
- vectara_agentic/types.py +9 -3
- {vectara_agentic-0.4.0.dist-info → vectara_agentic-0.4.1.dist-info}/METADATA +49 -8
- vectara_agentic-0.4.1.dist-info/RECORD +53 -0
- vectara_agentic/agent_core/utils/prompt_formatting.py +0 -56
- vectara_agentic-0.4.0.dist-info/RECORD +0 -50
- {vectara_agentic-0.4.0.dist-info → vectara_agentic-0.4.1.dist-info}/WHEEL +0 -0
- {vectara_agentic-0.4.0.dist-info → vectara_agentic-0.4.1.dist-info}/licenses/LICENSE +0 -0
- {vectara_agentic-0.4.0.dist-info → vectara_agentic-0.4.1.dist-info}/top_level.txt +0 -0
tests/test_bedrock.py
CHANGED
|
@@ -3,43 +3,58 @@ import warnings
|
|
|
3
3
|
warnings.simplefilter("ignore", DeprecationWarning)
|
|
4
4
|
|
|
5
5
|
import unittest
|
|
6
|
+
import threading
|
|
6
7
|
|
|
7
|
-
from vectara_agentic.agent import Agent
|
|
8
|
-
from vectara_agentic.agent_config import AgentConfig
|
|
8
|
+
from vectara_agentic.agent import Agent
|
|
9
9
|
from vectara_agentic.tools import ToolsFactory
|
|
10
|
-
from vectara_agentic.types import ModelProvider
|
|
11
10
|
|
|
12
11
|
import nest_asyncio
|
|
13
12
|
nest_asyncio.apply()
|
|
14
13
|
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
14
|
+
from conftest import mult, fc_config_bedrock, STANDARD_TEST_TOPIC, STANDARD_TEST_INSTRUCTIONS
|
|
15
|
+
|
|
16
|
+
ARIZE_LOCK = threading.Lock()
|
|
17
|
+
|
|
18
|
+
class TestBedrock(unittest.IsolatedAsyncioTestCase):
|
|
19
|
+
|
|
20
|
+
async def test_multiturn(self):
|
|
21
|
+
with ARIZE_LOCK:
|
|
22
|
+
tools = [ToolsFactory().create_tool(mult)]
|
|
23
|
+
agent = Agent(
|
|
24
|
+
tools=tools,
|
|
25
|
+
topic=STANDARD_TEST_TOPIC,
|
|
26
|
+
custom_instructions=STANDARD_TEST_INSTRUCTIONS,
|
|
27
|
+
agent_config=fc_config_bedrock,
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
# First calculation: 5 * 10 = 50
|
|
31
|
+
stream1 = await agent.astream_chat(
|
|
32
|
+
"What is 5 times 10. Only give the answer, nothing else"
|
|
33
|
+
)
|
|
34
|
+
# Consume the stream
|
|
35
|
+
async for chunk in stream1.async_response_gen():
|
|
36
|
+
pass
|
|
37
|
+
_ = await stream1.aget_response()
|
|
38
|
+
|
|
39
|
+
# Second calculation: 3 * 7 = 21
|
|
40
|
+
stream2 = await agent.astream_chat(
|
|
41
|
+
"what is 3 times 7. Only give the answer, nothing else"
|
|
42
|
+
)
|
|
43
|
+
# Consume the stream
|
|
44
|
+
async for chunk in stream2.async_response_gen():
|
|
45
|
+
pass
|
|
46
|
+
_ = await stream2.aget_response()
|
|
47
|
+
|
|
48
|
+
# Final calculation: 50 * 21 = 1050
|
|
49
|
+
stream3 = await agent.astream_chat(
|
|
50
|
+
"multiply the results of the last two questions. Output only the answer."
|
|
51
|
+
)
|
|
52
|
+
# Consume the stream
|
|
53
|
+
async for chunk in stream3.async_response_gen():
|
|
54
|
+
pass
|
|
55
|
+
response3 = await stream3.aget_response()
|
|
56
|
+
|
|
57
|
+
self.assertEqual(response3.response, "1050")
|
|
43
58
|
|
|
44
59
|
|
|
45
60
|
if __name__ == "__main__":
|
tests/test_gemini.py
CHANGED
|
@@ -4,15 +4,15 @@ warnings.simplefilter("ignore", DeprecationWarning)
|
|
|
4
4
|
|
|
5
5
|
import unittest
|
|
6
6
|
|
|
7
|
-
from vectara_agentic.agent import Agent
|
|
8
|
-
from vectara_agentic.agent_config import AgentConfig
|
|
9
|
-
from vectara_agentic.types import ModelProvider
|
|
7
|
+
from vectara_agentic.agent import Agent
|
|
10
8
|
from vectara_agentic.tools import ToolsFactory
|
|
11
9
|
|
|
12
10
|
|
|
13
11
|
import nest_asyncio
|
|
14
12
|
nest_asyncio.apply()
|
|
15
13
|
|
|
14
|
+
from conftest import mult, fc_config_gemini, STANDARD_TEST_TOPIC, STANDARD_TEST_INSTRUCTIONS
|
|
15
|
+
|
|
16
16
|
tickers = {
|
|
17
17
|
"C": "Citigroup",
|
|
18
18
|
"COF": "Capital One",
|
|
@@ -34,11 +34,6 @@ tickers = {
|
|
|
34
34
|
years = list(range(2015, 2025))
|
|
35
35
|
|
|
36
36
|
|
|
37
|
-
def mult(x: float, y: float) -> float:
|
|
38
|
-
"Multiply two numbers"
|
|
39
|
-
return x * y
|
|
40
|
-
|
|
41
|
-
|
|
42
37
|
def get_company_info() -> list[str]:
|
|
43
38
|
"""
|
|
44
39
|
Returns a dictionary of companies you can query about. Always check this before using any other tool.
|
|
@@ -56,23 +51,15 @@ def get_valid_years() -> list[str]:
|
|
|
56
51
|
return years
|
|
57
52
|
|
|
58
53
|
|
|
59
|
-
fc_config_gemini = AgentConfig(
|
|
60
|
-
agent_type=AgentType.FUNCTION_CALLING,
|
|
61
|
-
main_llm_provider=ModelProvider.GEMINI,
|
|
62
|
-
tool_llm_provider=ModelProvider.GEMINI,
|
|
63
|
-
)
|
|
64
|
-
|
|
65
54
|
class TestGEMINI(unittest.TestCase):
|
|
66
55
|
def test_gemini(self):
|
|
67
56
|
tools = [ToolsFactory().create_tool(mult)]
|
|
68
|
-
topic = "AI topic"
|
|
69
|
-
instructions = "Always do as your father tells you, if your mother agrees!"
|
|
70
57
|
|
|
71
58
|
agent = Agent(
|
|
72
59
|
agent_config=fc_config_gemini,
|
|
73
60
|
tools=tools,
|
|
74
|
-
topic=
|
|
75
|
-
custom_instructions=
|
|
61
|
+
topic=STANDARD_TEST_TOPIC,
|
|
62
|
+
custom_instructions=STANDARD_TEST_INSTRUCTIONS,
|
|
76
63
|
)
|
|
77
64
|
_ = agent.chat("What is 5 times 10. Only give the answer, nothing else")
|
|
78
65
|
_ = agent.chat("what is 3 times 7. Only give the answer, nothing else")
|
|
@@ -81,14 +68,12 @@ class TestGEMINI(unittest.TestCase):
|
|
|
81
68
|
|
|
82
69
|
def test_gemini_single_prompt(self):
|
|
83
70
|
tools = [ToolsFactory().create_tool(mult)]
|
|
84
|
-
topic = "AI topic"
|
|
85
|
-
instructions = "Always do as your father tells you, if your mother agrees!"
|
|
86
71
|
|
|
87
72
|
agent = Agent(
|
|
88
73
|
agent_config=fc_config_gemini,
|
|
89
74
|
tools=tools,
|
|
90
|
-
topic=
|
|
91
|
-
custom_instructions=
|
|
75
|
+
topic=STANDARD_TEST_TOPIC,
|
|
76
|
+
custom_instructions=STANDARD_TEST_INSTRUCTIONS,
|
|
92
77
|
)
|
|
93
78
|
res = agent.chat("First, multiply 5 by 10. Then, multiply 3 by 7. Finally, multiply the results of the first two calculations.")
|
|
94
79
|
self.assertIn("1050", res.response)
|
tests/test_groq.py
CHANGED
|
@@ -3,43 +3,58 @@ import warnings
|
|
|
3
3
|
warnings.simplefilter("ignore", DeprecationWarning)
|
|
4
4
|
|
|
5
5
|
import unittest
|
|
6
|
+
import threading
|
|
6
7
|
|
|
7
|
-
from vectara_agentic.agent import Agent
|
|
8
|
-
from vectara_agentic.agent_config import AgentConfig
|
|
8
|
+
from vectara_agentic.agent import Agent
|
|
9
9
|
from vectara_agentic.tools import ToolsFactory
|
|
10
|
-
from vectara_agentic.types import ModelProvider
|
|
11
10
|
|
|
12
11
|
import nest_asyncio
|
|
13
12
|
nest_asyncio.apply()
|
|
14
13
|
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
14
|
+
from conftest import mult, fc_config_groq, STANDARD_TEST_TOPIC, STANDARD_TEST_INSTRUCTIONS
|
|
15
|
+
|
|
16
|
+
ARIZE_LOCK = threading.Lock()
|
|
17
|
+
|
|
18
|
+
class TestGROQ(unittest.IsolatedAsyncioTestCase):
|
|
19
|
+
|
|
20
|
+
async def test_multiturn(self):
|
|
21
|
+
with ARIZE_LOCK:
|
|
22
|
+
tools = [ToolsFactory().create_tool(mult)]
|
|
23
|
+
agent = Agent(
|
|
24
|
+
tools=tools,
|
|
25
|
+
topic=STANDARD_TEST_TOPIC,
|
|
26
|
+
custom_instructions=STANDARD_TEST_INSTRUCTIONS,
|
|
27
|
+
agent_config=fc_config_groq,
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
# First calculation: 5 * 10 = 50
|
|
31
|
+
stream1 = await agent.astream_chat(
|
|
32
|
+
"What is 5 times 10. Only give the answer, nothing else"
|
|
33
|
+
)
|
|
34
|
+
# Consume the stream
|
|
35
|
+
async for chunk in stream1.async_response_gen():
|
|
36
|
+
pass
|
|
37
|
+
_ = await stream1.aget_response()
|
|
38
|
+
|
|
39
|
+
# Second calculation: 3 * 7 = 21
|
|
40
|
+
stream2 = await agent.astream_chat(
|
|
41
|
+
"what is 3 times 7. Only give the answer, nothing else"
|
|
42
|
+
)
|
|
43
|
+
# Consume the stream
|
|
44
|
+
async for chunk in stream2.async_response_gen():
|
|
45
|
+
pass
|
|
46
|
+
_ = await stream2.aget_response()
|
|
47
|
+
|
|
48
|
+
# Final calculation: 50 * 21 = 1050
|
|
49
|
+
stream3 = await agent.astream_chat(
|
|
50
|
+
"multiply the results of the last two questions. Output only the answer."
|
|
51
|
+
)
|
|
52
|
+
# Consume the stream
|
|
53
|
+
async for chunk in stream3.async_response_gen():
|
|
54
|
+
pass
|
|
55
|
+
response3 = await stream3.aget_response()
|
|
56
|
+
|
|
57
|
+
self.assertEqual(response3.response, "1050")
|
|
43
58
|
|
|
44
59
|
|
|
45
60
|
if __name__ == "__main__":
|
tests/test_serialization.py
CHANGED
|
@@ -14,8 +14,7 @@ from vectara_agentic.tools import ToolsFactory
|
|
|
14
14
|
from llama_index.core.utilities.sql_wrapper import SQLDatabase
|
|
15
15
|
from sqlalchemy import create_engine
|
|
16
16
|
|
|
17
|
-
|
|
18
|
-
return x * y
|
|
17
|
+
from conftest import mult, STANDARD_TEST_TOPIC, STANDARD_TEST_INSTRUCTIONS
|
|
19
18
|
|
|
20
19
|
|
|
21
20
|
ARIZE_LOCK = threading.Lock()
|
|
@@ -44,12 +43,10 @@ class TestAgentSerialization(unittest.TestCase):
|
|
|
44
43
|
)
|
|
45
44
|
|
|
46
45
|
tools = [ToolsFactory().create_tool(mult)] + ToolsFactory().standard_tools() + db_tools
|
|
47
|
-
topic = "AI topic"
|
|
48
|
-
instructions = "Always do as your father tells you, if your mother agrees!"
|
|
49
46
|
agent = Agent(
|
|
50
47
|
tools=tools,
|
|
51
|
-
topic=
|
|
52
|
-
custom_instructions=
|
|
48
|
+
topic=STANDARD_TEST_TOPIC,
|
|
49
|
+
custom_instructions=STANDARD_TEST_INSTRUCTIONS,
|
|
53
50
|
agent_config=config
|
|
54
51
|
)
|
|
55
52
|
|
|
@@ -0,0 +1,252 @@
|
|
|
1
|
+
# Suppress external dependency warnings before any other imports
|
|
2
|
+
import warnings
|
|
3
|
+
|
|
4
|
+
warnings.simplefilter("ignore", DeprecationWarning)
|
|
5
|
+
|
|
6
|
+
import unittest
|
|
7
|
+
import threading
|
|
8
|
+
from datetime import date
|
|
9
|
+
|
|
10
|
+
from vectara_agentic.agent import Agent
|
|
11
|
+
from vectara_agentic.agent_config import AgentConfig
|
|
12
|
+
from vectara_agentic.types import ModelProvider
|
|
13
|
+
from vectara_agentic.tools import ToolsFactory
|
|
14
|
+
from llama_index.core.llms import ChatMessage, MessageRole
|
|
15
|
+
from conftest import mult, add
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
ARIZE_LOCK = threading.Lock()
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class TestSessionMemoryManagement(unittest.TestCase):
|
|
22
|
+
"""Test session_id parameter and memory management functionality"""
|
|
23
|
+
|
|
24
|
+
def setUp(self):
|
|
25
|
+
"""Set up test fixtures"""
|
|
26
|
+
self.tools = [ToolsFactory().create_tool(mult), ToolsFactory().create_tool(add)]
|
|
27
|
+
self.topic = "Mathematics"
|
|
28
|
+
self.custom_instructions = "You are a helpful math assistant."
|
|
29
|
+
self.config = AgentConfig(main_llm_provider=ModelProvider.ANTHROPIC)
|
|
30
|
+
|
|
31
|
+
def test_agent_init_with_session_id(self):
|
|
32
|
+
"""Test Agent initialization with custom session_id"""
|
|
33
|
+
custom_session_id = "test-session-123"
|
|
34
|
+
|
|
35
|
+
agent = Agent(
|
|
36
|
+
tools=self.tools,
|
|
37
|
+
topic=self.topic,
|
|
38
|
+
custom_instructions=self.custom_instructions,
|
|
39
|
+
agent_config=self.config,
|
|
40
|
+
session_id=custom_session_id,
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
# Verify the agent uses the provided session_id
|
|
44
|
+
self.assertEqual(agent.session_id, custom_session_id)
|
|
45
|
+
|
|
46
|
+
# Verify memory uses the same session_id
|
|
47
|
+
self.assertEqual(agent.memory.session_id, custom_session_id)
|
|
48
|
+
|
|
49
|
+
def test_agent_init_without_session_id(self):
|
|
50
|
+
"""Test Agent initialization without session_id (auto-generation)"""
|
|
51
|
+
agent = Agent(
|
|
52
|
+
tools=self.tools,
|
|
53
|
+
topic=self.topic,
|
|
54
|
+
custom_instructions=self.custom_instructions,
|
|
55
|
+
agent_config=self.config,
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
# Verify auto-generated session_id follows expected pattern
|
|
59
|
+
expected_pattern = f"{self.topic}:{date.today().isoformat()}"
|
|
60
|
+
self.assertEqual(agent.session_id, expected_pattern)
|
|
61
|
+
|
|
62
|
+
# Verify memory uses the same session_id
|
|
63
|
+
self.assertEqual(agent.memory.session_id, expected_pattern)
|
|
64
|
+
|
|
65
|
+
def test_from_tools_with_session_id(self):
|
|
66
|
+
"""Test Agent.from_tools() with custom session_id"""
|
|
67
|
+
custom_session_id = "from-tools-session-456"
|
|
68
|
+
|
|
69
|
+
agent = Agent.from_tools(
|
|
70
|
+
tools=self.tools,
|
|
71
|
+
topic=self.topic,
|
|
72
|
+
custom_instructions=self.custom_instructions,
|
|
73
|
+
agent_config=self.config,
|
|
74
|
+
session_id=custom_session_id,
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
# Verify the agent uses the provided session_id
|
|
78
|
+
self.assertEqual(agent.session_id, custom_session_id)
|
|
79
|
+
self.assertEqual(agent.memory.session_id, custom_session_id)
|
|
80
|
+
|
|
81
|
+
def test_from_tools_without_session_id(self):
|
|
82
|
+
"""Test Agent.from_tools() without session_id (auto-generation)"""
|
|
83
|
+
agent = Agent.from_tools(
|
|
84
|
+
tools=self.tools,
|
|
85
|
+
topic=self.topic,
|
|
86
|
+
custom_instructions=self.custom_instructions,
|
|
87
|
+
agent_config=self.config,
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
# Verify auto-generated session_id
|
|
91
|
+
expected_pattern = f"{self.topic}:{date.today().isoformat()}"
|
|
92
|
+
self.assertEqual(agent.session_id, expected_pattern)
|
|
93
|
+
self.assertEqual(agent.memory.session_id, expected_pattern)
|
|
94
|
+
|
|
95
|
+
def test_session_id_consistency_across_agents(self):
|
|
96
|
+
"""Test that agents with same session_id have consistent session_id attributes"""
|
|
97
|
+
shared_session_id = "shared-session-id-test"
|
|
98
|
+
|
|
99
|
+
# Create two agents with the same session_id
|
|
100
|
+
agent1 = Agent(
|
|
101
|
+
tools=self.tools,
|
|
102
|
+
topic=self.topic,
|
|
103
|
+
custom_instructions=self.custom_instructions,
|
|
104
|
+
agent_config=self.config,
|
|
105
|
+
session_id=shared_session_id,
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
agent2 = Agent(
|
|
109
|
+
tools=self.tools,
|
|
110
|
+
topic=self.topic,
|
|
111
|
+
custom_instructions=self.custom_instructions,
|
|
112
|
+
agent_config=self.config,
|
|
113
|
+
session_id=shared_session_id,
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
# Verify both agents have the same session_id
|
|
117
|
+
self.assertEqual(agent1.session_id, shared_session_id)
|
|
118
|
+
self.assertEqual(agent2.session_id, shared_session_id)
|
|
119
|
+
self.assertEqual(agent1.session_id, agent2.session_id)
|
|
120
|
+
|
|
121
|
+
# Verify their memory instances also have the correct session_id
|
|
122
|
+
self.assertEqual(agent1.memory.session_id, shared_session_id)
|
|
123
|
+
self.assertEqual(agent2.memory.session_id, shared_session_id)
|
|
124
|
+
|
|
125
|
+
# Note: Each agent gets its own Memory instance (this is expected behavior)
|
|
126
|
+
# In production, memory persistence happens through serialization/deserialization
|
|
127
|
+
|
|
128
|
+
def test_memory_isolation_different_sessions(self):
|
|
129
|
+
"""Test that agents with different session_ids have isolated memory"""
|
|
130
|
+
session_id_1 = "isolated-session-1"
|
|
131
|
+
session_id_2 = "isolated-session-2"
|
|
132
|
+
|
|
133
|
+
# Create two agents with different session_ids
|
|
134
|
+
agent1 = Agent(
|
|
135
|
+
tools=self.tools,
|
|
136
|
+
topic=self.topic,
|
|
137
|
+
custom_instructions=self.custom_instructions,
|
|
138
|
+
agent_config=self.config,
|
|
139
|
+
session_id=session_id_1,
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
agent2 = Agent(
|
|
143
|
+
tools=self.tools,
|
|
144
|
+
topic=self.topic,
|
|
145
|
+
custom_instructions=self.custom_instructions,
|
|
146
|
+
agent_config=self.config,
|
|
147
|
+
session_id=session_id_2,
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
# Add messages to agent1's memory
|
|
151
|
+
agent1_messages = [
|
|
152
|
+
ChatMessage(role=MessageRole.USER, content="Agent 1 question"),
|
|
153
|
+
ChatMessage(role=MessageRole.ASSISTANT, content="Agent 1 response"),
|
|
154
|
+
]
|
|
155
|
+
agent1.memory.put_messages(agent1_messages)
|
|
156
|
+
|
|
157
|
+
# Add different messages to agent2's memory
|
|
158
|
+
agent2_messages = [
|
|
159
|
+
ChatMessage(role=MessageRole.USER, content="Agent 2 question"),
|
|
160
|
+
ChatMessage(role=MessageRole.ASSISTANT, content="Agent 2 response"),
|
|
161
|
+
]
|
|
162
|
+
agent2.memory.put_messages(agent2_messages)
|
|
163
|
+
|
|
164
|
+
# Verify memory isolation
|
|
165
|
+
retrieved_agent1_messages = agent1.memory.get()
|
|
166
|
+
retrieved_agent2_messages = agent2.memory.get()
|
|
167
|
+
|
|
168
|
+
self.assertEqual(len(retrieved_agent1_messages), 2)
|
|
169
|
+
self.assertEqual(len(retrieved_agent2_messages), 2)
|
|
170
|
+
|
|
171
|
+
# Verify agent1 only has its own messages
|
|
172
|
+
self.assertEqual(retrieved_agent1_messages[0].content, "Agent 1 question")
|
|
173
|
+
self.assertEqual(retrieved_agent1_messages[1].content, "Agent 1 response")
|
|
174
|
+
|
|
175
|
+
# Verify agent2 only has its own messages
|
|
176
|
+
self.assertEqual(retrieved_agent2_messages[0].content, "Agent 2 question")
|
|
177
|
+
self.assertEqual(retrieved_agent2_messages[1].content, "Agent 2 response")
|
|
178
|
+
|
|
179
|
+
def test_serialization_preserves_session_id(self):
|
|
180
|
+
"""Test that agent serialization preserves custom session_id"""
|
|
181
|
+
custom_session_id = "serialization-test-session"
|
|
182
|
+
|
|
183
|
+
# Create agent with custom session_id
|
|
184
|
+
original_agent = Agent(
|
|
185
|
+
tools=self.tools,
|
|
186
|
+
topic=self.topic,
|
|
187
|
+
custom_instructions=self.custom_instructions,
|
|
188
|
+
agent_config=self.config,
|
|
189
|
+
session_id=custom_session_id,
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
# Add some memory
|
|
193
|
+
test_messages = [
|
|
194
|
+
ChatMessage(role=MessageRole.USER, content="Test question"),
|
|
195
|
+
ChatMessage(role=MessageRole.ASSISTANT, content="Test answer"),
|
|
196
|
+
]
|
|
197
|
+
original_agent.memory.put_messages(test_messages)
|
|
198
|
+
|
|
199
|
+
# Serialize the agent
|
|
200
|
+
serialized_data = original_agent.dumps()
|
|
201
|
+
|
|
202
|
+
# Deserialize the agent
|
|
203
|
+
restored_agent = Agent.loads(serialized_data)
|
|
204
|
+
|
|
205
|
+
# Verify session_id is preserved
|
|
206
|
+
self.assertEqual(restored_agent.session_id, custom_session_id)
|
|
207
|
+
self.assertEqual(restored_agent.memory.session_id, custom_session_id)
|
|
208
|
+
|
|
209
|
+
# Verify memory is preserved
|
|
210
|
+
restored_messages = restored_agent.memory.get()
|
|
211
|
+
self.assertEqual(len(restored_messages), 2)
|
|
212
|
+
self.assertEqual(restored_messages[0].content, "Test question")
|
|
213
|
+
self.assertEqual(restored_messages[1].content, "Test answer")
|
|
214
|
+
|
|
215
|
+
def test_chat_history_initialization_with_session_id(self):
|
|
216
|
+
"""Test Agent initialization with chat_history and custom session_id"""
|
|
217
|
+
custom_session_id = "chat-history-session"
|
|
218
|
+
chat_history = [
|
|
219
|
+
("Hello", "Hi there!"),
|
|
220
|
+
("How are you?", "I'm doing well, thank you!"),
|
|
221
|
+
]
|
|
222
|
+
|
|
223
|
+
agent = Agent(
|
|
224
|
+
tools=self.tools,
|
|
225
|
+
topic=self.topic,
|
|
226
|
+
custom_instructions=self.custom_instructions,
|
|
227
|
+
agent_config=self.config,
|
|
228
|
+
session_id=custom_session_id,
|
|
229
|
+
chat_history=chat_history,
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
# Verify session_id is correct
|
|
233
|
+
self.assertEqual(agent.session_id, custom_session_id)
|
|
234
|
+
self.assertEqual(agent.memory.session_id, custom_session_id)
|
|
235
|
+
|
|
236
|
+
# Verify chat history was loaded into memory
|
|
237
|
+
messages = agent.memory.get()
|
|
238
|
+
self.assertEqual(len(messages), 4) # 2 user + 2 assistant messages
|
|
239
|
+
|
|
240
|
+
# Verify message content and roles
|
|
241
|
+
self.assertEqual(messages[0].role, MessageRole.USER)
|
|
242
|
+
self.assertEqual(messages[0].content, "Hello")
|
|
243
|
+
self.assertEqual(messages[1].role, MessageRole.ASSISTANT)
|
|
244
|
+
self.assertEqual(messages[1].content, "Hi there!")
|
|
245
|
+
self.assertEqual(messages[2].role, MessageRole.USER)
|
|
246
|
+
self.assertEqual(messages[2].content, "How are you?")
|
|
247
|
+
self.assertEqual(messages[3].role, MessageRole.ASSISTANT)
|
|
248
|
+
self.assertEqual(messages[3].content, "I'm doing well, thank you!")
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
if __name__ == "__main__":
|
|
252
|
+
unittest.main()
|
tests/test_streaming.py
CHANGED
|
@@ -1,77 +1,98 @@
|
|
|
1
1
|
# Suppress external dependency warnings before any other imports
|
|
2
2
|
import warnings
|
|
3
|
+
|
|
3
4
|
warnings.simplefilter("ignore", DeprecationWarning)
|
|
4
5
|
|
|
5
6
|
import unittest
|
|
6
7
|
import asyncio
|
|
7
8
|
|
|
8
|
-
from vectara_agentic.agent import Agent
|
|
9
|
-
from vectara_agentic.agent_config import AgentConfig
|
|
9
|
+
from vectara_agentic.agent import Agent
|
|
10
10
|
from vectara_agentic.tools import ToolsFactory
|
|
11
|
-
from vectara_agentic.types import ModelProvider
|
|
12
11
|
|
|
13
12
|
import nest_asyncio
|
|
13
|
+
|
|
14
14
|
nest_asyncio.apply()
|
|
15
15
|
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
16
|
+
from conftest import (
|
|
17
|
+
fc_config_openai,
|
|
18
|
+
fc_config_anthropic,
|
|
19
|
+
mult,
|
|
20
|
+
STANDARD_TEST_TOPIC,
|
|
21
|
+
STANDARD_TEST_INSTRUCTIONS,
|
|
22
|
+
)
|
|
19
23
|
|
|
20
24
|
|
|
21
|
-
|
|
22
|
-
agent_type=AgentType.FUNCTION_CALLING,
|
|
23
|
-
main_llm_provider=ModelProvider.OPENAI,
|
|
24
|
-
tool_llm_provider=ModelProvider.OPENAI,
|
|
25
|
-
)
|
|
25
|
+
class TestAgentStreaming(unittest.IsolatedAsyncioTestCase):
|
|
26
26
|
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
27
|
+
async def test_anthropic(self):
|
|
28
|
+
tools = [ToolsFactory().create_tool(mult)]
|
|
29
|
+
agent = Agent(
|
|
30
|
+
agent_config=fc_config_anthropic,
|
|
31
|
+
tools=tools,
|
|
32
|
+
topic=STANDARD_TEST_TOPIC,
|
|
33
|
+
custom_instructions=STANDARD_TEST_INSTRUCTIONS,
|
|
34
|
+
)
|
|
32
35
|
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
36
|
+
# First calculation: 5 * 10 = 50
|
|
37
|
+
stream1 = await agent.astream_chat(
|
|
38
|
+
"What is 5 times 10. Only give the answer, nothing else"
|
|
39
|
+
)
|
|
40
|
+
# Consume the stream
|
|
41
|
+
async for chunk in stream1.async_response_gen():
|
|
42
|
+
pass
|
|
43
|
+
_ = await stream1.aget_response()
|
|
38
44
|
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
45
|
+
# Second calculation: 3 * 7 = 21
|
|
46
|
+
stream2 = await agent.astream_chat(
|
|
47
|
+
"what is 3 times 7. Only give the answer, nothing else"
|
|
48
|
+
)
|
|
49
|
+
# Consume the stream
|
|
50
|
+
async for chunk in stream2.async_response_gen():
|
|
51
|
+
pass
|
|
52
|
+
_ = await stream2.aget_response()
|
|
44
53
|
|
|
54
|
+
# Final calculation: 50 * 21 = 1050
|
|
55
|
+
stream3 = await agent.astream_chat(
|
|
56
|
+
"multiply the results of the last two multiplications. Only give the answer, nothing else."
|
|
57
|
+
)
|
|
58
|
+
# Consume the stream
|
|
59
|
+
async for chunk in stream3.async_response_gen():
|
|
60
|
+
pass
|
|
61
|
+
response3 = await stream3.aget_response()
|
|
45
62
|
|
|
46
|
-
|
|
63
|
+
self.assertIn("1050", response3.response)
|
|
47
64
|
|
|
48
|
-
async def
|
|
65
|
+
async def test_openai(self):
|
|
49
66
|
tools = [ToolsFactory().create_tool(mult)]
|
|
50
|
-
topic = "AI topic"
|
|
51
|
-
instructions = "Always do as your father tells you, if your mother agrees!"
|
|
52
67
|
agent = Agent(
|
|
53
|
-
agent_config=
|
|
68
|
+
agent_config=fc_config_openai,
|
|
54
69
|
tools=tools,
|
|
55
|
-
topic=
|
|
56
|
-
custom_instructions=
|
|
70
|
+
topic=STANDARD_TEST_TOPIC,
|
|
71
|
+
custom_instructions=STANDARD_TEST_INSTRUCTIONS,
|
|
57
72
|
)
|
|
58
73
|
|
|
59
74
|
# First calculation: 5 * 10 = 50
|
|
60
|
-
stream1 = await agent.astream_chat(
|
|
75
|
+
stream1 = await agent.astream_chat(
|
|
76
|
+
"What is 5 times 10. Only give the answer, nothing else"
|
|
77
|
+
)
|
|
61
78
|
# Consume the stream
|
|
62
79
|
async for chunk in stream1.async_response_gen():
|
|
63
80
|
pass
|
|
64
81
|
_ = await stream1.aget_response()
|
|
65
82
|
|
|
66
83
|
# Second calculation: 3 * 7 = 21
|
|
67
|
-
stream2 = await agent.astream_chat(
|
|
84
|
+
stream2 = await agent.astream_chat(
|
|
85
|
+
"what is 3 times 7. Only give the answer, nothing else"
|
|
86
|
+
)
|
|
68
87
|
# Consume the stream
|
|
69
88
|
async for chunk in stream2.async_response_gen():
|
|
70
89
|
pass
|
|
71
90
|
_ = await stream2.aget_response()
|
|
72
91
|
|
|
73
92
|
# Final calculation: 50 * 21 = 1050
|
|
74
|
-
stream3 = await agent.astream_chat(
|
|
93
|
+
stream3 = await agent.astream_chat(
|
|
94
|
+
"multiply the results of the last two multiplications. Only give the answer, nothing else."
|
|
95
|
+
)
|
|
75
96
|
# Consume the stream
|
|
76
97
|
async for chunk in stream3.async_response_gen():
|
|
77
98
|
pass
|
|
@@ -81,7 +102,7 @@ class TestAgentStreaming(unittest.TestCase):
|
|
|
81
102
|
|
|
82
103
|
def test_openai_sync(self):
|
|
83
104
|
"""Synchronous wrapper for the async test"""
|
|
84
|
-
asyncio.run(self.
|
|
105
|
+
asyncio.run(self.test_openai())
|
|
85
106
|
|
|
86
107
|
|
|
87
108
|
if __name__ == "__main__":
|