vectara-agentic 0.4.2__py3-none-any.whl → 0.4.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of vectara-agentic might be problematic. Click here for more details.
- tests/__init__.py +1 -0
- tests/benchmark_models.py +547 -372
- tests/conftest.py +14 -12
- tests/endpoint.py +9 -5
- tests/run_tests.py +1 -0
- tests/test_agent.py +22 -9
- tests/test_agent_fallback_memory.py +4 -4
- tests/test_agent_memory_consistency.py +4 -4
- tests/test_agent_type.py +2 -0
- tests/test_api_endpoint.py +13 -13
- tests/test_bedrock.py +9 -1
- tests/test_fallback.py +18 -7
- tests/test_gemini.py +14 -40
- tests/test_groq.py +43 -1
- tests/test_openai.py +160 -0
- tests/test_private_llm.py +19 -6
- tests/test_react_error_handling.py +293 -0
- tests/test_react_memory.py +257 -0
- tests/test_react_streaming.py +135 -0
- tests/test_react_workflow_events.py +395 -0
- tests/test_return_direct.py +1 -0
- tests/test_serialization.py +58 -20
- tests/test_session_memory.py +11 -11
- tests/test_streaming.py +0 -44
- tests/test_together.py +75 -1
- tests/test_tools.py +3 -1
- tests/test_vectara_llms.py +2 -2
- tests/test_vhc.py +7 -2
- tests/test_workflow.py +17 -11
- vectara_agentic/_callback.py +79 -21
- vectara_agentic/_version.py +1 -1
- vectara_agentic/agent.py +65 -27
- vectara_agentic/agent_core/serialization.py +5 -9
- vectara_agentic/agent_core/streaming.py +245 -64
- vectara_agentic/agent_core/utils/schemas.py +2 -2
- vectara_agentic/llm_utils.py +64 -15
- vectara_agentic/tools.py +88 -31
- {vectara_agentic-0.4.2.dist-info → vectara_agentic-0.4.4.dist-info}/METADATA +133 -36
- vectara_agentic-0.4.4.dist-info/RECORD +59 -0
- vectara_agentic-0.4.2.dist-info/RECORD +0 -54
- {vectara_agentic-0.4.2.dist-info → vectara_agentic-0.4.4.dist-info}/WHEEL +0 -0
- {vectara_agentic-0.4.2.dist-info → vectara_agentic-0.4.4.dist-info}/licenses/LICENSE +0 -0
- {vectara_agentic-0.4.2.dist-info → vectara_agentic-0.4.4.dist-info}/top_level.txt +0 -0
tests/conftest.py
CHANGED
|
@@ -101,9 +101,9 @@ react_config_anthropic = AgentConfig(
|
|
|
101
101
|
react_config_gemini = AgentConfig(
|
|
102
102
|
agent_type=AgentType.REACT,
|
|
103
103
|
main_llm_provider=ModelProvider.GEMINI,
|
|
104
|
-
main_llm_model_name="models/gemini-2.5-flash",
|
|
104
|
+
main_llm_model_name="models/gemini-2.5-flash-lite",
|
|
105
105
|
tool_llm_provider=ModelProvider.GEMINI,
|
|
106
|
-
tool_llm_model_name="models/gemini-2.5-flash",
|
|
106
|
+
tool_llm_model_name="models/gemini-2.5-flash-lite",
|
|
107
107
|
)
|
|
108
108
|
|
|
109
109
|
react_config_together = AgentConfig(
|
|
@@ -112,12 +112,19 @@ react_config_together = AgentConfig(
|
|
|
112
112
|
tool_llm_provider=ModelProvider.TOGETHER,
|
|
113
113
|
)
|
|
114
114
|
|
|
115
|
+
react_config_openai = AgentConfig(
|
|
116
|
+
agent_type=AgentType.REACT,
|
|
117
|
+
main_llm_provider=ModelProvider.OPENAI,
|
|
118
|
+
tool_llm_provider=ModelProvider.OPENAI,
|
|
119
|
+
)
|
|
120
|
+
|
|
115
121
|
react_config_groq = AgentConfig(
|
|
116
122
|
agent_type=AgentType.REACT,
|
|
117
123
|
main_llm_provider=ModelProvider.GROQ,
|
|
118
124
|
tool_llm_provider=ModelProvider.GROQ,
|
|
119
125
|
)
|
|
120
126
|
|
|
127
|
+
|
|
121
128
|
# Private LLM configurations
|
|
122
129
|
private_llm_react_config = AgentConfig(
|
|
123
130
|
agent_type=AgentType.REACT,
|
|
@@ -161,14 +168,6 @@ def is_rate_limited(response_text: str) -> bool:
|
|
|
161
168
|
"rate limit",
|
|
162
169
|
"quota exceeded",
|
|
163
170
|
"usage limit",
|
|
164
|
-
# GROQ-specific
|
|
165
|
-
"tokens per day",
|
|
166
|
-
"TPD",
|
|
167
|
-
"service tier",
|
|
168
|
-
"on_demand",
|
|
169
|
-
"deepseek-r1-distill-llama-70b",
|
|
170
|
-
"Upgrade to Dev Tier",
|
|
171
|
-
"console.groq.com/settings/billing",
|
|
172
171
|
# OpenAI-specific
|
|
173
172
|
"requests per minute",
|
|
174
173
|
"RPM",
|
|
@@ -188,6 +187,9 @@ def is_rate_limited(response_text: str) -> bool:
|
|
|
188
187
|
# Additional rate limit patterns
|
|
189
188
|
"Limit.*Used.*Requested",
|
|
190
189
|
"Need more tokens",
|
|
190
|
+
# Provider failure patterns
|
|
191
|
+
"failure can't be resolved after",
|
|
192
|
+
"Got empty message",
|
|
191
193
|
]
|
|
192
194
|
|
|
193
195
|
response_lower = response_text.lower()
|
|
@@ -279,10 +281,10 @@ class AgentTestMixin:
|
|
|
279
281
|
provider: Provider name for error messages
|
|
280
282
|
|
|
281
283
|
Usage:
|
|
282
|
-
with self.with_provider_fallback("
|
|
284
|
+
with self.with_provider_fallback("OpenAI"):
|
|
283
285
|
response = agent.chat("test")
|
|
284
286
|
|
|
285
|
-
with self.with_provider_fallback("
|
|
287
|
+
with self.with_provider_fallback("OpenAI"):
|
|
286
288
|
async for chunk in agent.astream_chat("test"):
|
|
287
289
|
pass
|
|
288
290
|
|
tests/endpoint.py
CHANGED
|
@@ -10,12 +10,13 @@ app = Flask(__name__)
|
|
|
10
10
|
# Configure logging
|
|
11
11
|
logging.basicConfig(level=logging.INFO)
|
|
12
12
|
app.logger.setLevel(logging.INFO)
|
|
13
|
-
werkzeug_log = logging.getLogger(
|
|
13
|
+
werkzeug_log = logging.getLogger("werkzeug")
|
|
14
14
|
werkzeug_log.setLevel(logging.ERROR)
|
|
15
15
|
|
|
16
16
|
# Load expected API key from environment (fallback for testing)
|
|
17
17
|
EXPECTED_API_KEY = "TEST_API_KEY"
|
|
18
18
|
|
|
19
|
+
|
|
19
20
|
# Authentication decorator
|
|
20
21
|
def require_api_key(f):
|
|
21
22
|
@wraps(f)
|
|
@@ -27,12 +28,15 @@ def require_api_key(f):
|
|
|
27
28
|
if api_key != EXPECTED_API_KEY:
|
|
28
29
|
return jsonify({"error": "Unauthorized"}), 401
|
|
29
30
|
return f(*args, **kwargs)
|
|
31
|
+
|
|
30
32
|
return decorated_function
|
|
31
33
|
|
|
34
|
+
|
|
32
35
|
@app.before_request
|
|
33
36
|
def log_request_info():
|
|
34
37
|
app.logger.info("%s %s", request.method, request.path)
|
|
35
38
|
|
|
39
|
+
|
|
36
40
|
@app.route("/v1/chat/completions", methods=["POST"])
|
|
37
41
|
@require_api_key
|
|
38
42
|
def chat_completions():
|
|
@@ -46,7 +50,7 @@ def chat_completions():
|
|
|
46
50
|
return jsonify({"error": "Invalid JSON payload"}), 400
|
|
47
51
|
|
|
48
52
|
client = OpenAI()
|
|
49
|
-
is_stream = data.get(
|
|
53
|
+
is_stream = data.get("stream", False)
|
|
50
54
|
|
|
51
55
|
if is_stream:
|
|
52
56
|
# Stream each chunk to the client as Server-Sent Events
|
|
@@ -62,9 +66,9 @@ def chat_completions():
|
|
|
62
66
|
yield f"data: {error_msg}\n\n"
|
|
63
67
|
|
|
64
68
|
headers = {
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
69
|
+
"Content-Type": "text/event-stream",
|
|
70
|
+
"Cache-Control": "no-cache",
|
|
71
|
+
"Connection": "keep-alive",
|
|
68
72
|
}
|
|
69
73
|
return Response(generate(), headers=headers)
|
|
70
74
|
|
tests/run_tests.py
CHANGED
tests/test_agent.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
# Suppress external dependency warnings before any other imports
|
|
2
2
|
import warnings
|
|
3
|
+
|
|
3
4
|
warnings.simplefilter("ignore", DeprecationWarning)
|
|
4
5
|
|
|
5
6
|
import unittest
|
|
@@ -18,6 +19,7 @@ from conftest import mult, STANDARD_TEST_TOPIC, STANDARD_TEST_INSTRUCTIONS
|
|
|
18
19
|
|
|
19
20
|
ARIZE_LOCK = threading.Lock()
|
|
20
21
|
|
|
22
|
+
|
|
21
23
|
class TestAgentPackage(unittest.TestCase):
|
|
22
24
|
def setUp(self):
|
|
23
25
|
self.agents_to_cleanup = []
|
|
@@ -27,7 +29,7 @@ class TestAgentPackage(unittest.TestCase):
|
|
|
27
29
|
import asyncio
|
|
28
30
|
|
|
29
31
|
for agent in self.agents_to_cleanup:
|
|
30
|
-
if hasattr(agent,
|
|
32
|
+
if hasattr(agent, "cleanup"):
|
|
31
33
|
agent.cleanup()
|
|
32
34
|
|
|
33
35
|
# Force garbage collection to clean up any remaining references
|
|
@@ -53,7 +55,10 @@ class TestAgentPackage(unittest.TestCase):
|
|
|
53
55
|
+ " with Always do as your mother tells you!"
|
|
54
56
|
)
|
|
55
57
|
self.assertEqual(
|
|
56
|
-
format_prompt(
|
|
58
|
+
format_prompt(
|
|
59
|
+
prompt_template, GENERAL_INSTRUCTIONS, topic, custom_instructions
|
|
60
|
+
),
|
|
61
|
+
expected_output,
|
|
57
62
|
)
|
|
58
63
|
|
|
59
64
|
def test_agent_init(self):
|
|
@@ -81,22 +86,26 @@ class TestAgentPackage(unittest.TestCase):
|
|
|
81
86
|
main_llm_model_name="claude-sonnet-4-20250514",
|
|
82
87
|
tool_llm_provider=ModelProvider.TOGETHER,
|
|
83
88
|
tool_llm_model_name="moonshotai/Kimi-K2-Instruct",
|
|
84
|
-
observer=ObserverType.ARIZE_PHOENIX
|
|
89
|
+
observer=ObserverType.ARIZE_PHOENIX,
|
|
85
90
|
)
|
|
86
91
|
|
|
87
92
|
agent = Agent(
|
|
88
93
|
tools=tools,
|
|
89
94
|
topic=STANDARD_TEST_TOPIC,
|
|
90
95
|
custom_instructions=STANDARD_TEST_INSTRUCTIONS,
|
|
91
|
-
agent_config=config
|
|
96
|
+
agent_config=config,
|
|
92
97
|
)
|
|
93
98
|
self.agents_to_cleanup.append(agent)
|
|
94
99
|
self.assertEqual(agent._topic, STANDARD_TEST_TOPIC)
|
|
95
100
|
self.assertEqual(agent._custom_instructions, STANDARD_TEST_INSTRUCTIONS)
|
|
96
101
|
self.assertEqual(agent.agent_type, AgentType.REACT)
|
|
97
102
|
self.assertEqual(agent.agent_config.observer, ObserverType.ARIZE_PHOENIX)
|
|
98
|
-
self.assertEqual(
|
|
99
|
-
|
|
103
|
+
self.assertEqual(
|
|
104
|
+
agent.agent_config.main_llm_provider, ModelProvider.ANTHROPIC
|
|
105
|
+
)
|
|
106
|
+
self.assertEqual(
|
|
107
|
+
agent.agent_config.tool_llm_provider, ModelProvider.TOGETHER
|
|
108
|
+
)
|
|
100
109
|
|
|
101
110
|
# To run this test, you must have ANTHROPIC_API_KEY and TOGETHER_API_KEY in your environment
|
|
102
111
|
self.assertEqual(
|
|
@@ -120,7 +129,9 @@ class TestAgentPackage(unittest.TestCase):
|
|
|
120
129
|
|
|
121
130
|
agent.chat("What is 5 times 10. Only give the answer, nothing else")
|
|
122
131
|
agent.chat("what is 3 times 7. Only give the answer, nothing else")
|
|
123
|
-
res = agent.chat(
|
|
132
|
+
res = agent.chat(
|
|
133
|
+
"multiply the results of the last two questions. Output only the answer."
|
|
134
|
+
)
|
|
124
135
|
self.assertEqual(res.response, "1050")
|
|
125
136
|
|
|
126
137
|
def test_from_corpus(self):
|
|
@@ -144,7 +155,7 @@ class TestAgentPackage(unittest.TestCase):
|
|
|
144
155
|
tools=tools,
|
|
145
156
|
topic=topic,
|
|
146
157
|
custom_instructions=instructions,
|
|
147
|
-
chat_history=[("What is 5 times 10", "50"), ("What is 3 times 7", "21")]
|
|
158
|
+
chat_history=[("What is 5 times 10", "50"), ("What is 3 times 7", "21")],
|
|
148
159
|
)
|
|
149
160
|
self.agents_to_cleanup.append(agent)
|
|
150
161
|
|
|
@@ -152,7 +163,9 @@ class TestAgentPackage(unittest.TestCase):
|
|
|
152
163
|
clone = Agent.loads(data)
|
|
153
164
|
assert clone.memory.get() == agent.memory.get()
|
|
154
165
|
|
|
155
|
-
res = agent.chat(
|
|
166
|
+
res = agent.chat(
|
|
167
|
+
"multiply the results of the last two questions. Output only the answer."
|
|
168
|
+
)
|
|
156
169
|
self.assertEqual(res.response, "1050")
|
|
157
170
|
|
|
158
171
|
def test_custom_general_instruction(self):
|
|
@@ -70,7 +70,7 @@ class TestAgentFallbackMemoryConsistency(unittest.TestCase):
|
|
|
70
70
|
|
|
71
71
|
# Verify session_id consistency
|
|
72
72
|
# Memory is managed by the main Agent class
|
|
73
|
-
self.assertEqual(agent.memory.
|
|
73
|
+
self.assertEqual(agent.memory.session_id, self.session_id)
|
|
74
74
|
|
|
75
75
|
def test_memory_sync_during_agent_switching(self):
|
|
76
76
|
"""Test that memory remains consistent when switching between main and fallback agents"""
|
|
@@ -219,13 +219,13 @@ class TestAgentFallbackMemoryConsistency(unittest.TestCase):
|
|
|
219
219
|
|
|
220
220
|
# Verify main agent session_id consistency
|
|
221
221
|
self.assertEqual(agent.session_id, self.session_id)
|
|
222
|
-
self.assertEqual(agent.memory.
|
|
222
|
+
self.assertEqual(agent.memory.session_id, self.session_id)
|
|
223
223
|
|
|
224
224
|
# Verify session_id consistency across all agents
|
|
225
225
|
# Memory is managed by the main Agent class
|
|
226
|
-
self.assertEqual(agent.memory.
|
|
226
|
+
self.assertEqual(agent.memory.session_id, self.session_id)
|
|
227
227
|
self.assertEqual(
|
|
228
|
-
agent.memory.
|
|
228
|
+
agent.memory.session_id, self.session_id
|
|
229
229
|
) # Both access same memory
|
|
230
230
|
|
|
231
231
|
def test_agent_recreation_on_switch(self):
|
|
@@ -172,21 +172,21 @@ class TestAgentMemoryConsistency(unittest.TestCase):
|
|
|
172
172
|
|
|
173
173
|
# Verify initial session_id
|
|
174
174
|
self.assertEqual(agent.session_id, self.session_id)
|
|
175
|
-
self.assertEqual(agent.memory.
|
|
175
|
+
self.assertEqual(agent.memory.session_id, self.session_id)
|
|
176
176
|
|
|
177
177
|
# Switch configurations multiple times
|
|
178
178
|
agent._switch_agent_config()
|
|
179
179
|
self.assertEqual(agent.session_id, self.session_id)
|
|
180
|
-
self.assertEqual(agent.memory.
|
|
180
|
+
self.assertEqual(agent.memory.session_id, self.session_id)
|
|
181
181
|
|
|
182
182
|
agent._switch_agent_config()
|
|
183
183
|
self.assertEqual(agent.session_id, self.session_id)
|
|
184
|
-
self.assertEqual(agent.memory.
|
|
184
|
+
self.assertEqual(agent.memory.session_id, self.session_id)
|
|
185
185
|
|
|
186
186
|
# Clear memory
|
|
187
187
|
agent.clear_memory()
|
|
188
188
|
self.assertEqual(agent.session_id, self.session_id)
|
|
189
|
-
self.assertEqual(agent.memory.
|
|
189
|
+
self.assertEqual(agent.memory.session_id, self.session_id)
|
|
190
190
|
|
|
191
191
|
def test_serialization_preserves_consistency(self):
|
|
192
192
|
"""Test that serialization/deserialization preserves memory consistency behavior"""
|
tests/test_agent_type.py
CHANGED
|
@@ -7,6 +7,7 @@ import unittest
|
|
|
7
7
|
|
|
8
8
|
import sys
|
|
9
9
|
import os
|
|
10
|
+
|
|
10
11
|
sys.path.insert(0, os.path.dirname(__file__))
|
|
11
12
|
|
|
12
13
|
from conftest import (
|
|
@@ -25,6 +26,7 @@ from conftest import (
|
|
|
25
26
|
from vectara_agentic.agent import Agent
|
|
26
27
|
from vectara_agentic.tools import ToolsFactory
|
|
27
28
|
|
|
29
|
+
|
|
28
30
|
class TestAgentType(unittest.TestCase, AgentTestMixin):
|
|
29
31
|
|
|
30
32
|
def setUp(self):
|
tests/test_api_endpoint.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
# Suppress external dependency warnings before any other imports
|
|
2
2
|
import warnings
|
|
3
|
+
|
|
3
4
|
warnings.simplefilter("ignore", DeprecationWarning)
|
|
4
5
|
|
|
5
6
|
import unittest
|
|
@@ -21,6 +22,7 @@ class DummyAgent(Agent):
|
|
|
21
22
|
def chat(self, message: str) -> str:
|
|
22
23
|
return f"Echo: {message}"
|
|
23
24
|
|
|
25
|
+
|
|
24
26
|
class APITestCase(unittest.TestCase):
|
|
25
27
|
@classmethod
|
|
26
28
|
def setUpClass(cls):
|
|
@@ -42,7 +44,9 @@ class APITestCase(unittest.TestCase):
|
|
|
42
44
|
self.assertIn("No message provided", r.json()["detail"])
|
|
43
45
|
|
|
44
46
|
def test_chat_unauthorized(self):
|
|
45
|
-
r = self.client.get(
|
|
47
|
+
r = self.client.get(
|
|
48
|
+
"/chat", params={"message": "hello"}, headers={"X-API-Key": "bad"}
|
|
49
|
+
)
|
|
46
50
|
self.assertEqual(r.status_code, 403)
|
|
47
51
|
|
|
48
52
|
def test_completions_success(self):
|
|
@@ -69,14 +73,13 @@ class APITestCase(unittest.TestCase):
|
|
|
69
73
|
|
|
70
74
|
def test_completions_unauthorized(self):
|
|
71
75
|
payload = {"model": "m1", "prompt": "hi"}
|
|
72
|
-
r = self.client.post(
|
|
76
|
+
r = self.client.post(
|
|
77
|
+
"/v1/completions", json=payload, headers={"X-API-Key": "bad"}
|
|
78
|
+
)
|
|
73
79
|
self.assertEqual(r.status_code, 403)
|
|
74
80
|
|
|
75
81
|
def test_chat_completion_success(self):
|
|
76
|
-
payload = {
|
|
77
|
-
"model": "m1",
|
|
78
|
-
"messages": [{"role": "user", "content": "hello"}]
|
|
79
|
-
}
|
|
82
|
+
payload = {"model": "m1", "messages": [{"role": "user", "content": "hello"}]}
|
|
80
83
|
r = self.client.post("/v1/chat", json=payload, headers=self.headers)
|
|
81
84
|
self.assertEqual(r.status_code, 200)
|
|
82
85
|
data = r.json()
|
|
@@ -99,8 +102,8 @@ class APITestCase(unittest.TestCase):
|
|
|
99
102
|
{"role": "system", "content": "ignore me"},
|
|
100
103
|
{"role": "user", "content": "foo"},
|
|
101
104
|
{"role": "assistant", "content": "pong"},
|
|
102
|
-
{"role": "user", "content": "bar"}
|
|
103
|
-
]
|
|
105
|
+
{"role": "user", "content": "bar"},
|
|
106
|
+
],
|
|
104
107
|
}
|
|
105
108
|
r = self.client.post("/v1/chat", json=payload, headers=self.headers)
|
|
106
109
|
self.assertEqual(r.status_code, 200)
|
|
@@ -108,7 +111,7 @@ class APITestCase(unittest.TestCase):
|
|
|
108
111
|
|
|
109
112
|
# Should concatenate only user messages: "foo bar"
|
|
110
113
|
self.assertEqual(data["choices"][0]["message"]["content"], "Echo: foo bar")
|
|
111
|
-
self.assertEqual(data["usage"]["prompt_tokens"], 2)
|
|
114
|
+
self.assertEqual(data["usage"]["prompt_tokens"], 2) # "foo","bar"
|
|
112
115
|
self.assertEqual(data["usage"]["completion_tokens"], 3) # "Echo:","foo","bar"
|
|
113
116
|
|
|
114
117
|
def test_chat_completion_no_messages(self):
|
|
@@ -118,10 +121,7 @@ class APITestCase(unittest.TestCase):
|
|
|
118
121
|
self.assertIn("`messages` is required", r.json()["detail"])
|
|
119
122
|
|
|
120
123
|
def test_chat_completion_unauthorized(self):
|
|
121
|
-
payload = {
|
|
122
|
-
"model": "m1",
|
|
123
|
-
"messages": [{"role": "user", "content": "oops"}]
|
|
124
|
-
}
|
|
124
|
+
payload = {"model": "m1", "messages": [{"role": "user", "content": "oops"}]}
|
|
125
125
|
r = self.client.post("/v1/chat", json=payload, headers={"X-API-Key": "bad"})
|
|
126
126
|
self.assertEqual(r.status_code, 403)
|
|
127
127
|
|
tests/test_bedrock.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
# Suppress external dependency warnings before any other imports
|
|
2
2
|
import warnings
|
|
3
|
+
|
|
3
4
|
warnings.simplefilter("ignore", DeprecationWarning)
|
|
4
5
|
|
|
5
6
|
import unittest
|
|
@@ -9,12 +10,19 @@ from vectara_agentic.agent import Agent
|
|
|
9
10
|
from vectara_agentic.tools import ToolsFactory
|
|
10
11
|
|
|
11
12
|
import nest_asyncio
|
|
13
|
+
|
|
12
14
|
nest_asyncio.apply()
|
|
13
15
|
|
|
14
|
-
from conftest import
|
|
16
|
+
from conftest import (
|
|
17
|
+
mult,
|
|
18
|
+
fc_config_bedrock,
|
|
19
|
+
STANDARD_TEST_TOPIC,
|
|
20
|
+
STANDARD_TEST_INSTRUCTIONS,
|
|
21
|
+
)
|
|
15
22
|
|
|
16
23
|
ARIZE_LOCK = threading.Lock()
|
|
17
24
|
|
|
25
|
+
|
|
18
26
|
class TestBedrock(unittest.IsolatedAsyncioTestCase):
|
|
19
27
|
|
|
20
28
|
async def test_multiturn(self):
|
tests/test_fallback.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
# Suppress external dependency warnings before any other imports
|
|
2
2
|
import warnings
|
|
3
|
+
|
|
3
4
|
warnings.simplefilter("ignore", DeprecationWarning)
|
|
4
5
|
|
|
5
6
|
import os
|
|
@@ -16,19 +17,25 @@ from vectara_agentic.tools import ToolsFactory
|
|
|
16
17
|
|
|
17
18
|
FLASK_PORT = 5002
|
|
18
19
|
|
|
20
|
+
|
|
19
21
|
class TestFallback(unittest.TestCase):
|
|
20
22
|
|
|
21
23
|
@classmethod
|
|
22
24
|
def setUp(cls):
|
|
23
25
|
# Start the Flask server as a subprocess
|
|
24
26
|
cls.flask_process = subprocess.Popen(
|
|
25
|
-
[
|
|
26
|
-
env={
|
|
27
|
-
|
|
27
|
+
["flask", "run", f"--port={FLASK_PORT}"],
|
|
28
|
+
env={
|
|
29
|
+
**os.environ,
|
|
30
|
+
"FLASK_APP": "tests.endpoint:app",
|
|
31
|
+
"FLASK_ENV": "development",
|
|
32
|
+
},
|
|
33
|
+
stdout=None,
|
|
34
|
+
stderr=None,
|
|
28
35
|
)
|
|
29
36
|
# Wait for the server to start
|
|
30
37
|
timeout = 10
|
|
31
|
-
url = f
|
|
38
|
+
url = f"http://127.0.0.1:{FLASK_PORT}/"
|
|
32
39
|
for _ in range(timeout):
|
|
33
40
|
try:
|
|
34
41
|
requests.get(url)
|
|
@@ -62,9 +69,13 @@ class TestFallback(unittest.TestCase):
|
|
|
62
69
|
# Set fallback agent config to OpenAI agent
|
|
63
70
|
fallback_config = AgentConfig()
|
|
64
71
|
|
|
65
|
-
agent = Agent(
|
|
66
|
-
|
|
67
|
-
|
|
72
|
+
agent = Agent(
|
|
73
|
+
agent_config=config,
|
|
74
|
+
tools=tools,
|
|
75
|
+
topic=topic,
|
|
76
|
+
custom_instructions=custom_instructions,
|
|
77
|
+
fallback_agent_config=fallback_config,
|
|
78
|
+
)
|
|
68
79
|
|
|
69
80
|
# To run this test, you must have OPENAI_API_KEY in your environment
|
|
70
81
|
res = agent.chat(
|
tests/test_gemini.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
# Suppress external dependency warnings before any other imports
|
|
2
2
|
import warnings
|
|
3
|
+
|
|
3
4
|
warnings.simplefilter("ignore", DeprecationWarning)
|
|
4
5
|
|
|
5
6
|
import unittest
|
|
@@ -9,46 +10,15 @@ from vectara_agentic.tools import ToolsFactory
|
|
|
9
10
|
|
|
10
11
|
|
|
11
12
|
import nest_asyncio
|
|
12
|
-
nest_asyncio.apply()
|
|
13
|
-
|
|
14
|
-
from conftest import mult, fc_config_gemini, STANDARD_TEST_TOPIC, STANDARD_TEST_INSTRUCTIONS
|
|
15
|
-
|
|
16
|
-
tickers = {
|
|
17
|
-
"C": "Citigroup",
|
|
18
|
-
"COF": "Capital One",
|
|
19
|
-
"JPM": "JPMorgan Chase",
|
|
20
|
-
"AAPL": "Apple Computer",
|
|
21
|
-
"GOOG": "Google",
|
|
22
|
-
"AMZN": "Amazon",
|
|
23
|
-
"SNOW": "Snowflake",
|
|
24
|
-
"TEAM": "Atlassian",
|
|
25
|
-
"TSLA": "Tesla",
|
|
26
|
-
"NVDA": "Nvidia",
|
|
27
|
-
"MSFT": "Microsoft",
|
|
28
|
-
"AMD": "Advanced Micro Devices",
|
|
29
|
-
"INTC": "Intel",
|
|
30
|
-
"NFLX": "Netflix",
|
|
31
|
-
"STT": "State Street",
|
|
32
|
-
"BK": "Bank of New York Mellon",
|
|
33
|
-
}
|
|
34
|
-
years = list(range(2015, 2025))
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
def get_company_info() -> list[str]:
|
|
38
|
-
"""
|
|
39
|
-
Returns a dictionary of companies you can query about. Always check this before using any other tool.
|
|
40
|
-
The output is a dictionary of valid ticker symbols mapped to company names.
|
|
41
|
-
You can use this to identify the companies you can query about, and their ticker information.
|
|
42
|
-
"""
|
|
43
|
-
return tickers
|
|
44
13
|
|
|
14
|
+
nest_asyncio.apply()
|
|
45
15
|
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
16
|
+
from conftest import (
|
|
17
|
+
mult,
|
|
18
|
+
fc_config_gemini,
|
|
19
|
+
STANDARD_TEST_TOPIC,
|
|
20
|
+
STANDARD_TEST_INSTRUCTIONS,
|
|
21
|
+
)
|
|
52
22
|
|
|
53
23
|
|
|
54
24
|
class TestGEMINI(unittest.TestCase):
|
|
@@ -63,7 +33,9 @@ class TestGEMINI(unittest.TestCase):
|
|
|
63
33
|
)
|
|
64
34
|
_ = agent.chat("What is 5 times 10. Only give the answer, nothing else")
|
|
65
35
|
_ = agent.chat("what is 3 times 7. Only give the answer, nothing else")
|
|
66
|
-
res = agent.chat(
|
|
36
|
+
res = agent.chat(
|
|
37
|
+
"what is the result of multiplying the results of the last two multiplications. Only give the answer, nothing else."
|
|
38
|
+
)
|
|
67
39
|
self.assertIn("1050", res.response)
|
|
68
40
|
|
|
69
41
|
def test_gemini_single_prompt(self):
|
|
@@ -75,7 +47,9 @@ class TestGEMINI(unittest.TestCase):
|
|
|
75
47
|
topic=STANDARD_TEST_TOPIC,
|
|
76
48
|
custom_instructions=STANDARD_TEST_INSTRUCTIONS,
|
|
77
49
|
)
|
|
78
|
-
res = agent.chat(
|
|
50
|
+
res = agent.chat(
|
|
51
|
+
"First, multiply 5 by 10. Then, multiply 3 by 7. Finally, multiply the results of the first two calculations."
|
|
52
|
+
)
|
|
79
53
|
self.assertIn("1050", res.response)
|
|
80
54
|
|
|
81
55
|
|
tests/test_groq.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
# Suppress external dependency warnings before any other imports
|
|
2
2
|
import warnings
|
|
3
|
+
|
|
3
4
|
warnings.simplefilter("ignore", DeprecationWarning)
|
|
4
5
|
|
|
5
6
|
import unittest
|
|
@@ -7,14 +8,23 @@ import threading
|
|
|
7
8
|
|
|
8
9
|
from vectara_agentic.agent import Agent
|
|
9
10
|
from vectara_agentic.tools import ToolsFactory
|
|
11
|
+
from vectara_agentic.agent_config import AgentConfig
|
|
12
|
+
from vectara_agentic.types import AgentType, ModelProvider
|
|
10
13
|
|
|
11
14
|
import nest_asyncio
|
|
15
|
+
|
|
12
16
|
nest_asyncio.apply()
|
|
13
17
|
|
|
14
|
-
from conftest import
|
|
18
|
+
from conftest import (
|
|
19
|
+
mult,
|
|
20
|
+
fc_config_groq,
|
|
21
|
+
STANDARD_TEST_TOPIC,
|
|
22
|
+
STANDARD_TEST_INSTRUCTIONS,
|
|
23
|
+
)
|
|
15
24
|
|
|
16
25
|
ARIZE_LOCK = threading.Lock()
|
|
17
26
|
|
|
27
|
+
|
|
18
28
|
class TestGROQ(unittest.IsolatedAsyncioTestCase):
|
|
19
29
|
|
|
20
30
|
async def test_multiturn(self):
|
|
@@ -56,6 +66,38 @@ class TestGROQ(unittest.IsolatedAsyncioTestCase):
|
|
|
56
66
|
|
|
57
67
|
self.assertEqual(response3.response, "1050")
|
|
58
68
|
|
|
69
|
+
async def test_gpt_oss_120b(self):
|
|
70
|
+
"""Test GPT-OSS-120B model with GROQ provider."""
|
|
71
|
+
with ARIZE_LOCK:
|
|
72
|
+
# Create config specifically for GPT-OSS-120B via GROQ
|
|
73
|
+
gpt_oss_config = AgentConfig(
|
|
74
|
+
agent_type=AgentType.FUNCTION_CALLING,
|
|
75
|
+
main_llm_provider=ModelProvider.GROQ,
|
|
76
|
+
main_llm_model_name="openai/gpt-oss-120b",
|
|
77
|
+
tool_llm_provider=ModelProvider.GROQ,
|
|
78
|
+
tool_llm_model_name="openai/gpt-oss-120b",
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
tools = [ToolsFactory().create_tool(mult)]
|
|
82
|
+
agent = Agent(
|
|
83
|
+
agent_config=gpt_oss_config,
|
|
84
|
+
tools=tools,
|
|
85
|
+
topic=STANDARD_TEST_TOPIC,
|
|
86
|
+
custom_instructions=STANDARD_TEST_INSTRUCTIONS,
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
# Test simple multiplication: 8 * 6 = 48
|
|
90
|
+
stream = await agent.astream_chat(
|
|
91
|
+
"What is 8 times 6? Only give the answer, nothing else"
|
|
92
|
+
)
|
|
93
|
+
# Consume the stream
|
|
94
|
+
async for chunk in stream.async_response_gen():
|
|
95
|
+
pass
|
|
96
|
+
response = await stream.aget_response()
|
|
97
|
+
|
|
98
|
+
# Verify the response contains the correct answer
|
|
99
|
+
self.assertIn("48", response.response)
|
|
100
|
+
|
|
59
101
|
|
|
60
102
|
if __name__ == "__main__":
|
|
61
103
|
unittest.main()
|