vectara-agentic 0.4.2__py3-none-any.whl → 0.4.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of vectara-agentic might be problematic. Click here for more details.

Files changed (43) hide show
  1. tests/__init__.py +1 -0
  2. tests/benchmark_models.py +547 -372
  3. tests/conftest.py +14 -12
  4. tests/endpoint.py +9 -5
  5. tests/run_tests.py +1 -0
  6. tests/test_agent.py +22 -9
  7. tests/test_agent_fallback_memory.py +4 -4
  8. tests/test_agent_memory_consistency.py +4 -4
  9. tests/test_agent_type.py +2 -0
  10. tests/test_api_endpoint.py +13 -13
  11. tests/test_bedrock.py +9 -1
  12. tests/test_fallback.py +18 -7
  13. tests/test_gemini.py +14 -40
  14. tests/test_groq.py +43 -1
  15. tests/test_openai.py +160 -0
  16. tests/test_private_llm.py +19 -6
  17. tests/test_react_error_handling.py +293 -0
  18. tests/test_react_memory.py +257 -0
  19. tests/test_react_streaming.py +135 -0
  20. tests/test_react_workflow_events.py +395 -0
  21. tests/test_return_direct.py +1 -0
  22. tests/test_serialization.py +58 -20
  23. tests/test_session_memory.py +11 -11
  24. tests/test_streaming.py +0 -44
  25. tests/test_together.py +75 -1
  26. tests/test_tools.py +3 -1
  27. tests/test_vectara_llms.py +2 -2
  28. tests/test_vhc.py +7 -2
  29. tests/test_workflow.py +17 -11
  30. vectara_agentic/_callback.py +79 -21
  31. vectara_agentic/_version.py +1 -1
  32. vectara_agentic/agent.py +65 -27
  33. vectara_agentic/agent_core/serialization.py +5 -9
  34. vectara_agentic/agent_core/streaming.py +245 -64
  35. vectara_agentic/agent_core/utils/schemas.py +2 -2
  36. vectara_agentic/llm_utils.py +64 -15
  37. vectara_agentic/tools.py +88 -31
  38. {vectara_agentic-0.4.2.dist-info → vectara_agentic-0.4.4.dist-info}/METADATA +133 -36
  39. vectara_agentic-0.4.4.dist-info/RECORD +59 -0
  40. vectara_agentic-0.4.2.dist-info/RECORD +0 -54
  41. {vectara_agentic-0.4.2.dist-info → vectara_agentic-0.4.4.dist-info}/WHEEL +0 -0
  42. {vectara_agentic-0.4.2.dist-info → vectara_agentic-0.4.4.dist-info}/licenses/LICENSE +0 -0
  43. {vectara_agentic-0.4.2.dist-info → vectara_agentic-0.4.4.dist-info}/top_level.txt +0 -0
tests/conftest.py CHANGED
@@ -101,9 +101,9 @@ react_config_anthropic = AgentConfig(
101
101
  react_config_gemini = AgentConfig(
102
102
  agent_type=AgentType.REACT,
103
103
  main_llm_provider=ModelProvider.GEMINI,
104
- main_llm_model_name="models/gemini-2.5-flash",
104
+ main_llm_model_name="models/gemini-2.5-flash-lite",
105
105
  tool_llm_provider=ModelProvider.GEMINI,
106
- tool_llm_model_name="models/gemini-2.5-flash",
106
+ tool_llm_model_name="models/gemini-2.5-flash-lite",
107
107
  )
108
108
 
109
109
  react_config_together = AgentConfig(
@@ -112,12 +112,19 @@ react_config_together = AgentConfig(
112
112
  tool_llm_provider=ModelProvider.TOGETHER,
113
113
  )
114
114
 
115
+ react_config_openai = AgentConfig(
116
+ agent_type=AgentType.REACT,
117
+ main_llm_provider=ModelProvider.OPENAI,
118
+ tool_llm_provider=ModelProvider.OPENAI,
119
+ )
120
+
115
121
  react_config_groq = AgentConfig(
116
122
  agent_type=AgentType.REACT,
117
123
  main_llm_provider=ModelProvider.GROQ,
118
124
  tool_llm_provider=ModelProvider.GROQ,
119
125
  )
120
126
 
127
+
121
128
  # Private LLM configurations
122
129
  private_llm_react_config = AgentConfig(
123
130
  agent_type=AgentType.REACT,
@@ -161,14 +168,6 @@ def is_rate_limited(response_text: str) -> bool:
161
168
  "rate limit",
162
169
  "quota exceeded",
163
170
  "usage limit",
164
- # GROQ-specific
165
- "tokens per day",
166
- "TPD",
167
- "service tier",
168
- "on_demand",
169
- "deepseek-r1-distill-llama-70b",
170
- "Upgrade to Dev Tier",
171
- "console.groq.com/settings/billing",
172
171
  # OpenAI-specific
173
172
  "requests per minute",
174
173
  "RPM",
@@ -188,6 +187,9 @@ def is_rate_limited(response_text: str) -> bool:
188
187
  # Additional rate limit patterns
189
188
  "Limit.*Used.*Requested",
190
189
  "Need more tokens",
190
+ # Provider failure patterns
191
+ "failure can't be resolved after",
192
+ "Got empty message",
191
193
  ]
192
194
 
193
195
  response_lower = response_text.lower()
@@ -279,10 +281,10 @@ class AgentTestMixin:
279
281
  provider: Provider name for error messages
280
282
 
281
283
  Usage:
282
- with self.with_provider_fallback("GROQ"):
284
+ with self.with_provider_fallback("OpenAI"):
283
285
  response = agent.chat("test")
284
286
 
285
- with self.with_provider_fallback("GROQ"):
287
+ with self.with_provider_fallback("OpenAI"):
286
288
  async for chunk in agent.astream_chat("test"):
287
289
  pass
288
290
 
tests/endpoint.py CHANGED
@@ -10,12 +10,13 @@ app = Flask(__name__)
10
10
  # Configure logging
11
11
  logging.basicConfig(level=logging.INFO)
12
12
  app.logger.setLevel(logging.INFO)
13
- werkzeug_log = logging.getLogger('werkzeug')
13
+ werkzeug_log = logging.getLogger("werkzeug")
14
14
  werkzeug_log.setLevel(logging.ERROR)
15
15
 
16
16
  # Load expected API key from environment (fallback for testing)
17
17
  EXPECTED_API_KEY = "TEST_API_KEY"
18
18
 
19
+
19
20
  # Authentication decorator
20
21
  def require_api_key(f):
21
22
  @wraps(f)
@@ -27,12 +28,15 @@ def require_api_key(f):
27
28
  if api_key != EXPECTED_API_KEY:
28
29
  return jsonify({"error": "Unauthorized"}), 401
29
30
  return f(*args, **kwargs)
31
+
30
32
  return decorated_function
31
33
 
34
+
32
35
  @app.before_request
33
36
  def log_request_info():
34
37
  app.logger.info("%s %s", request.method, request.path)
35
38
 
39
+
36
40
  @app.route("/v1/chat/completions", methods=["POST"])
37
41
  @require_api_key
38
42
  def chat_completions():
@@ -46,7 +50,7 @@ def chat_completions():
46
50
  return jsonify({"error": "Invalid JSON payload"}), 400
47
51
 
48
52
  client = OpenAI()
49
- is_stream = data.get('stream', False)
53
+ is_stream = data.get("stream", False)
50
54
 
51
55
  if is_stream:
52
56
  # Stream each chunk to the client as Server-Sent Events
@@ -62,9 +66,9 @@ def chat_completions():
62
66
  yield f"data: {error_msg}\n\n"
63
67
 
64
68
  headers = {
65
- 'Content-Type': 'text/event-stream',
66
- 'Cache-Control': 'no-cache',
67
- 'Connection': 'keep-alive'
69
+ "Content-Type": "text/event-stream",
70
+ "Cache-Control": "no-cache",
71
+ "Connection": "keep-alive",
68
72
  }
69
73
  return Response(generate(), headers=headers)
70
74
 
tests/run_tests.py CHANGED
@@ -66,6 +66,7 @@ def main():
66
66
 
67
67
  # Add tests directory to Python path for relative imports
68
68
  import os
69
+
69
70
  sys.path.insert(0, os.path.abspath("tests"))
70
71
 
71
72
  # Discover and run tests
tests/test_agent.py CHANGED
@@ -1,5 +1,6 @@
1
1
  # Suppress external dependency warnings before any other imports
2
2
  import warnings
3
+
3
4
  warnings.simplefilter("ignore", DeprecationWarning)
4
5
 
5
6
  import unittest
@@ -18,6 +19,7 @@ from conftest import mult, STANDARD_TEST_TOPIC, STANDARD_TEST_INSTRUCTIONS
18
19
 
19
20
  ARIZE_LOCK = threading.Lock()
20
21
 
22
+
21
23
  class TestAgentPackage(unittest.TestCase):
22
24
  def setUp(self):
23
25
  self.agents_to_cleanup = []
@@ -27,7 +29,7 @@ class TestAgentPackage(unittest.TestCase):
27
29
  import asyncio
28
30
 
29
31
  for agent in self.agents_to_cleanup:
30
- if hasattr(agent, 'cleanup'):
32
+ if hasattr(agent, "cleanup"):
31
33
  agent.cleanup()
32
34
 
33
35
  # Force garbage collection to clean up any remaining references
@@ -53,7 +55,10 @@ class TestAgentPackage(unittest.TestCase):
53
55
  + " with Always do as your mother tells you!"
54
56
  )
55
57
  self.assertEqual(
56
- format_prompt(prompt_template, GENERAL_INSTRUCTIONS, topic, custom_instructions), expected_output
58
+ format_prompt(
59
+ prompt_template, GENERAL_INSTRUCTIONS, topic, custom_instructions
60
+ ),
61
+ expected_output,
57
62
  )
58
63
 
59
64
  def test_agent_init(self):
@@ -81,22 +86,26 @@ class TestAgentPackage(unittest.TestCase):
81
86
  main_llm_model_name="claude-sonnet-4-20250514",
82
87
  tool_llm_provider=ModelProvider.TOGETHER,
83
88
  tool_llm_model_name="moonshotai/Kimi-K2-Instruct",
84
- observer=ObserverType.ARIZE_PHOENIX
89
+ observer=ObserverType.ARIZE_PHOENIX,
85
90
  )
86
91
 
87
92
  agent = Agent(
88
93
  tools=tools,
89
94
  topic=STANDARD_TEST_TOPIC,
90
95
  custom_instructions=STANDARD_TEST_INSTRUCTIONS,
91
- agent_config=config
96
+ agent_config=config,
92
97
  )
93
98
  self.agents_to_cleanup.append(agent)
94
99
  self.assertEqual(agent._topic, STANDARD_TEST_TOPIC)
95
100
  self.assertEqual(agent._custom_instructions, STANDARD_TEST_INSTRUCTIONS)
96
101
  self.assertEqual(agent.agent_type, AgentType.REACT)
97
102
  self.assertEqual(agent.agent_config.observer, ObserverType.ARIZE_PHOENIX)
98
- self.assertEqual(agent.agent_config.main_llm_provider, ModelProvider.ANTHROPIC)
99
- self.assertEqual(agent.agent_config.tool_llm_provider, ModelProvider.TOGETHER)
103
+ self.assertEqual(
104
+ agent.agent_config.main_llm_provider, ModelProvider.ANTHROPIC
105
+ )
106
+ self.assertEqual(
107
+ agent.agent_config.tool_llm_provider, ModelProvider.TOGETHER
108
+ )
100
109
 
101
110
  # To run this test, you must have ANTHROPIC_API_KEY and TOGETHER_API_KEY in your environment
102
111
  self.assertEqual(
@@ -120,7 +129,9 @@ class TestAgentPackage(unittest.TestCase):
120
129
 
121
130
  agent.chat("What is 5 times 10. Only give the answer, nothing else")
122
131
  agent.chat("what is 3 times 7. Only give the answer, nothing else")
123
- res = agent.chat("multiply the results of the last two questions. Output only the answer.")
132
+ res = agent.chat(
133
+ "multiply the results of the last two questions. Output only the answer."
134
+ )
124
135
  self.assertEqual(res.response, "1050")
125
136
 
126
137
  def test_from_corpus(self):
@@ -144,7 +155,7 @@ class TestAgentPackage(unittest.TestCase):
144
155
  tools=tools,
145
156
  topic=topic,
146
157
  custom_instructions=instructions,
147
- chat_history=[("What is 5 times 10", "50"), ("What is 3 times 7", "21")]
158
+ chat_history=[("What is 5 times 10", "50"), ("What is 3 times 7", "21")],
148
159
  )
149
160
  self.agents_to_cleanup.append(agent)
150
161
 
@@ -152,7 +163,9 @@ class TestAgentPackage(unittest.TestCase):
152
163
  clone = Agent.loads(data)
153
164
  assert clone.memory.get() == agent.memory.get()
154
165
 
155
- res = agent.chat("multiply the results of the last two questions. Output only the answer.")
166
+ res = agent.chat(
167
+ "multiply the results of the last two questions. Output only the answer."
168
+ )
156
169
  self.assertEqual(res.response, "1050")
157
170
 
158
171
  def test_custom_general_instruction(self):
@@ -70,7 +70,7 @@ class TestAgentFallbackMemoryConsistency(unittest.TestCase):
70
70
 
71
71
  # Verify session_id consistency
72
72
  # Memory is managed by the main Agent class
73
- self.assertEqual(agent.memory.chat_store_key, self.session_id)
73
+ self.assertEqual(agent.memory.session_id, self.session_id)
74
74
 
75
75
  def test_memory_sync_during_agent_switching(self):
76
76
  """Test that memory remains consistent when switching between main and fallback agents"""
@@ -219,13 +219,13 @@ class TestAgentFallbackMemoryConsistency(unittest.TestCase):
219
219
 
220
220
  # Verify main agent session_id consistency
221
221
  self.assertEqual(agent.session_id, self.session_id)
222
- self.assertEqual(agent.memory.chat_store_key, self.session_id)
222
+ self.assertEqual(agent.memory.session_id, self.session_id)
223
223
 
224
224
  # Verify session_id consistency across all agents
225
225
  # Memory is managed by the main Agent class
226
- self.assertEqual(agent.memory.chat_store_key, self.session_id)
226
+ self.assertEqual(agent.memory.session_id, self.session_id)
227
227
  self.assertEqual(
228
- agent.memory.chat_store_key, self.session_id
228
+ agent.memory.session_id, self.session_id
229
229
  ) # Both access same memory
230
230
 
231
231
  def test_agent_recreation_on_switch(self):
@@ -172,21 +172,21 @@ class TestAgentMemoryConsistency(unittest.TestCase):
172
172
 
173
173
  # Verify initial session_id
174
174
  self.assertEqual(agent.session_id, self.session_id)
175
- self.assertEqual(agent.memory.chat_store_key, self.session_id)
175
+ self.assertEqual(agent.memory.session_id, self.session_id)
176
176
 
177
177
  # Switch configurations multiple times
178
178
  agent._switch_agent_config()
179
179
  self.assertEqual(agent.session_id, self.session_id)
180
- self.assertEqual(agent.memory.chat_store_key, self.session_id)
180
+ self.assertEqual(agent.memory.session_id, self.session_id)
181
181
 
182
182
  agent._switch_agent_config()
183
183
  self.assertEqual(agent.session_id, self.session_id)
184
- self.assertEqual(agent.memory.chat_store_key, self.session_id)
184
+ self.assertEqual(agent.memory.session_id, self.session_id)
185
185
 
186
186
  # Clear memory
187
187
  agent.clear_memory()
188
188
  self.assertEqual(agent.session_id, self.session_id)
189
- self.assertEqual(agent.memory.chat_store_key, self.session_id)
189
+ self.assertEqual(agent.memory.session_id, self.session_id)
190
190
 
191
191
  def test_serialization_preserves_consistency(self):
192
192
  """Test that serialization/deserialization preserves memory consistency behavior"""
tests/test_agent_type.py CHANGED
@@ -7,6 +7,7 @@ import unittest
7
7
 
8
8
  import sys
9
9
  import os
10
+
10
11
  sys.path.insert(0, os.path.dirname(__file__))
11
12
 
12
13
  from conftest import (
@@ -25,6 +26,7 @@ from conftest import (
25
26
  from vectara_agentic.agent import Agent
26
27
  from vectara_agentic.tools import ToolsFactory
27
28
 
29
+
28
30
  class TestAgentType(unittest.TestCase, AgentTestMixin):
29
31
 
30
32
  def setUp(self):
@@ -1,5 +1,6 @@
1
1
  # Suppress external dependency warnings before any other imports
2
2
  import warnings
3
+
3
4
  warnings.simplefilter("ignore", DeprecationWarning)
4
5
 
5
6
  import unittest
@@ -21,6 +22,7 @@ class DummyAgent(Agent):
21
22
  def chat(self, message: str) -> str:
22
23
  return f"Echo: {message}"
23
24
 
25
+
24
26
  class APITestCase(unittest.TestCase):
25
27
  @classmethod
26
28
  def setUpClass(cls):
@@ -42,7 +44,9 @@ class APITestCase(unittest.TestCase):
42
44
  self.assertIn("No message provided", r.json()["detail"])
43
45
 
44
46
  def test_chat_unauthorized(self):
45
- r = self.client.get("/chat", params={"message": "hello"}, headers={"X-API-Key": "bad"})
47
+ r = self.client.get(
48
+ "/chat", params={"message": "hello"}, headers={"X-API-Key": "bad"}
49
+ )
46
50
  self.assertEqual(r.status_code, 403)
47
51
 
48
52
  def test_completions_success(self):
@@ -69,14 +73,13 @@ class APITestCase(unittest.TestCase):
69
73
 
70
74
  def test_completions_unauthorized(self):
71
75
  payload = {"model": "m1", "prompt": "hi"}
72
- r = self.client.post("/v1/completions", json=payload, headers={"X-API-Key": "bad"})
76
+ r = self.client.post(
77
+ "/v1/completions", json=payload, headers={"X-API-Key": "bad"}
78
+ )
73
79
  self.assertEqual(r.status_code, 403)
74
80
 
75
81
  def test_chat_completion_success(self):
76
- payload = {
77
- "model": "m1",
78
- "messages": [{"role": "user", "content": "hello"}]
79
- }
82
+ payload = {"model": "m1", "messages": [{"role": "user", "content": "hello"}]}
80
83
  r = self.client.post("/v1/chat", json=payload, headers=self.headers)
81
84
  self.assertEqual(r.status_code, 200)
82
85
  data = r.json()
@@ -99,8 +102,8 @@ class APITestCase(unittest.TestCase):
99
102
  {"role": "system", "content": "ignore me"},
100
103
  {"role": "user", "content": "foo"},
101
104
  {"role": "assistant", "content": "pong"},
102
- {"role": "user", "content": "bar"}
103
- ]
105
+ {"role": "user", "content": "bar"},
106
+ ],
104
107
  }
105
108
  r = self.client.post("/v1/chat", json=payload, headers=self.headers)
106
109
  self.assertEqual(r.status_code, 200)
@@ -108,7 +111,7 @@ class APITestCase(unittest.TestCase):
108
111
 
109
112
  # Should concatenate only user messages: "foo bar"
110
113
  self.assertEqual(data["choices"][0]["message"]["content"], "Echo: foo bar")
111
- self.assertEqual(data["usage"]["prompt_tokens"], 2) # "foo","bar"
114
+ self.assertEqual(data["usage"]["prompt_tokens"], 2) # "foo","bar"
112
115
  self.assertEqual(data["usage"]["completion_tokens"], 3) # "Echo:","foo","bar"
113
116
 
114
117
  def test_chat_completion_no_messages(self):
@@ -118,10 +121,7 @@ class APITestCase(unittest.TestCase):
118
121
  self.assertIn("`messages` is required", r.json()["detail"])
119
122
 
120
123
  def test_chat_completion_unauthorized(self):
121
- payload = {
122
- "model": "m1",
123
- "messages": [{"role": "user", "content": "oops"}]
124
- }
124
+ payload = {"model": "m1", "messages": [{"role": "user", "content": "oops"}]}
125
125
  r = self.client.post("/v1/chat", json=payload, headers={"X-API-Key": "bad"})
126
126
  self.assertEqual(r.status_code, 403)
127
127
 
tests/test_bedrock.py CHANGED
@@ -1,5 +1,6 @@
1
1
  # Suppress external dependency warnings before any other imports
2
2
  import warnings
3
+
3
4
  warnings.simplefilter("ignore", DeprecationWarning)
4
5
 
5
6
  import unittest
@@ -9,12 +10,19 @@ from vectara_agentic.agent import Agent
9
10
  from vectara_agentic.tools import ToolsFactory
10
11
 
11
12
  import nest_asyncio
13
+
12
14
  nest_asyncio.apply()
13
15
 
14
- from conftest import mult, fc_config_bedrock, STANDARD_TEST_TOPIC, STANDARD_TEST_INSTRUCTIONS
16
+ from conftest import (
17
+ mult,
18
+ fc_config_bedrock,
19
+ STANDARD_TEST_TOPIC,
20
+ STANDARD_TEST_INSTRUCTIONS,
21
+ )
15
22
 
16
23
  ARIZE_LOCK = threading.Lock()
17
24
 
25
+
18
26
  class TestBedrock(unittest.IsolatedAsyncioTestCase):
19
27
 
20
28
  async def test_multiturn(self):
tests/test_fallback.py CHANGED
@@ -1,5 +1,6 @@
1
1
  # Suppress external dependency warnings before any other imports
2
2
  import warnings
3
+
3
4
  warnings.simplefilter("ignore", DeprecationWarning)
4
5
 
5
6
  import os
@@ -16,19 +17,25 @@ from vectara_agentic.tools import ToolsFactory
16
17
 
17
18
  FLASK_PORT = 5002
18
19
 
20
+
19
21
  class TestFallback(unittest.TestCase):
20
22
 
21
23
  @classmethod
22
24
  def setUp(cls):
23
25
  # Start the Flask server as a subprocess
24
26
  cls.flask_process = subprocess.Popen(
25
- ['flask', 'run', f'--port={FLASK_PORT}'],
26
- env={**os.environ, 'FLASK_APP': 'tests.endpoint:app', 'FLASK_ENV': 'development'},
27
- stdout=None, stderr=None,
27
+ ["flask", "run", f"--port={FLASK_PORT}"],
28
+ env={
29
+ **os.environ,
30
+ "FLASK_APP": "tests.endpoint:app",
31
+ "FLASK_ENV": "development",
32
+ },
33
+ stdout=None,
34
+ stderr=None,
28
35
  )
29
36
  # Wait for the server to start
30
37
  timeout = 10
31
- url = f'http://127.0.0.1:{FLASK_PORT}/'
38
+ url = f"http://127.0.0.1:{FLASK_PORT}/"
32
39
  for _ in range(timeout):
33
40
  try:
34
41
  requests.get(url)
@@ -62,9 +69,13 @@ class TestFallback(unittest.TestCase):
62
69
  # Set fallback agent config to OpenAI agent
63
70
  fallback_config = AgentConfig()
64
71
 
65
- agent = Agent(agent_config=config, tools=tools, topic=topic,
66
- custom_instructions=custom_instructions,
67
- fallback_agent_config=fallback_config)
72
+ agent = Agent(
73
+ agent_config=config,
74
+ tools=tools,
75
+ topic=topic,
76
+ custom_instructions=custom_instructions,
77
+ fallback_agent_config=fallback_config,
78
+ )
68
79
 
69
80
  # To run this test, you must have OPENAI_API_KEY in your environment
70
81
  res = agent.chat(
tests/test_gemini.py CHANGED
@@ -1,5 +1,6 @@
1
1
  # Suppress external dependency warnings before any other imports
2
2
  import warnings
3
+
3
4
  warnings.simplefilter("ignore", DeprecationWarning)
4
5
 
5
6
  import unittest
@@ -9,46 +10,15 @@ from vectara_agentic.tools import ToolsFactory
9
10
 
10
11
 
11
12
  import nest_asyncio
12
- nest_asyncio.apply()
13
-
14
- from conftest import mult, fc_config_gemini, STANDARD_TEST_TOPIC, STANDARD_TEST_INSTRUCTIONS
15
-
16
- tickers = {
17
- "C": "Citigroup",
18
- "COF": "Capital One",
19
- "JPM": "JPMorgan Chase",
20
- "AAPL": "Apple Computer",
21
- "GOOG": "Google",
22
- "AMZN": "Amazon",
23
- "SNOW": "Snowflake",
24
- "TEAM": "Atlassian",
25
- "TSLA": "Tesla",
26
- "NVDA": "Nvidia",
27
- "MSFT": "Microsoft",
28
- "AMD": "Advanced Micro Devices",
29
- "INTC": "Intel",
30
- "NFLX": "Netflix",
31
- "STT": "State Street",
32
- "BK": "Bank of New York Mellon",
33
- }
34
- years = list(range(2015, 2025))
35
-
36
-
37
- def get_company_info() -> list[str]:
38
- """
39
- Returns a dictionary of companies you can query about. Always check this before using any other tool.
40
- The output is a dictionary of valid ticker symbols mapped to company names.
41
- You can use this to identify the companies you can query about, and their ticker information.
42
- """
43
- return tickers
44
13
 
14
+ nest_asyncio.apply()
45
15
 
46
- def get_valid_years() -> list[str]:
47
- """
48
- Returns a list of the years for which financial reports are available.
49
- Always check this before using any other tool.
50
- """
51
- return years
16
+ from conftest import (
17
+ mult,
18
+ fc_config_gemini,
19
+ STANDARD_TEST_TOPIC,
20
+ STANDARD_TEST_INSTRUCTIONS,
21
+ )
52
22
 
53
23
 
54
24
  class TestGEMINI(unittest.TestCase):
@@ -63,7 +33,9 @@ class TestGEMINI(unittest.TestCase):
63
33
  )
64
34
  _ = agent.chat("What is 5 times 10. Only give the answer, nothing else")
65
35
  _ = agent.chat("what is 3 times 7. Only give the answer, nothing else")
66
- res = agent.chat("what is the result of multiplying the results of the last two multiplications. Only give the answer, nothing else.")
36
+ res = agent.chat(
37
+ "what is the result of multiplying the results of the last two multiplications. Only give the answer, nothing else."
38
+ )
67
39
  self.assertIn("1050", res.response)
68
40
 
69
41
  def test_gemini_single_prompt(self):
@@ -75,7 +47,9 @@ class TestGEMINI(unittest.TestCase):
75
47
  topic=STANDARD_TEST_TOPIC,
76
48
  custom_instructions=STANDARD_TEST_INSTRUCTIONS,
77
49
  )
78
- res = agent.chat("First, multiply 5 by 10. Then, multiply 3 by 7. Finally, multiply the results of the first two calculations.")
50
+ res = agent.chat(
51
+ "First, multiply 5 by 10. Then, multiply 3 by 7. Finally, multiply the results of the first two calculations."
52
+ )
79
53
  self.assertIn("1050", res.response)
80
54
 
81
55
 
tests/test_groq.py CHANGED
@@ -1,5 +1,6 @@
1
1
  # Suppress external dependency warnings before any other imports
2
2
  import warnings
3
+
3
4
  warnings.simplefilter("ignore", DeprecationWarning)
4
5
 
5
6
  import unittest
@@ -7,14 +8,23 @@ import threading
7
8
 
8
9
  from vectara_agentic.agent import Agent
9
10
  from vectara_agentic.tools import ToolsFactory
11
+ from vectara_agentic.agent_config import AgentConfig
12
+ from vectara_agentic.types import AgentType, ModelProvider
10
13
 
11
14
  import nest_asyncio
15
+
12
16
  nest_asyncio.apply()
13
17
 
14
- from conftest import mult, fc_config_groq, STANDARD_TEST_TOPIC, STANDARD_TEST_INSTRUCTIONS
18
+ from conftest import (
19
+ mult,
20
+ fc_config_groq,
21
+ STANDARD_TEST_TOPIC,
22
+ STANDARD_TEST_INSTRUCTIONS,
23
+ )
15
24
 
16
25
  ARIZE_LOCK = threading.Lock()
17
26
 
27
+
18
28
  class TestGROQ(unittest.IsolatedAsyncioTestCase):
19
29
 
20
30
  async def test_multiturn(self):
@@ -56,6 +66,38 @@ class TestGROQ(unittest.IsolatedAsyncioTestCase):
56
66
 
57
67
  self.assertEqual(response3.response, "1050")
58
68
 
69
+ async def test_gpt_oss_120b(self):
70
+ """Test GPT-OSS-120B model with GROQ provider."""
71
+ with ARIZE_LOCK:
72
+ # Create config specifically for GPT-OSS-120B via GROQ
73
+ gpt_oss_config = AgentConfig(
74
+ agent_type=AgentType.FUNCTION_CALLING,
75
+ main_llm_provider=ModelProvider.GROQ,
76
+ main_llm_model_name="openai/gpt-oss-120b",
77
+ tool_llm_provider=ModelProvider.GROQ,
78
+ tool_llm_model_name="openai/gpt-oss-120b",
79
+ )
80
+
81
+ tools = [ToolsFactory().create_tool(mult)]
82
+ agent = Agent(
83
+ agent_config=gpt_oss_config,
84
+ tools=tools,
85
+ topic=STANDARD_TEST_TOPIC,
86
+ custom_instructions=STANDARD_TEST_INSTRUCTIONS,
87
+ )
88
+
89
+ # Test simple multiplication: 8 * 6 = 48
90
+ stream = await agent.astream_chat(
91
+ "What is 8 times 6? Only give the answer, nothing else"
92
+ )
93
+ # Consume the stream
94
+ async for chunk in stream.async_response_gen():
95
+ pass
96
+ response = await stream.aget_response()
97
+
98
+ # Verify the response contains the correct answer
99
+ self.assertIn("48", response.response)
100
+
59
101
 
60
102
  if __name__ == "__main__":
61
103
  unittest.main()