vectara-agentic 0.3.2__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of vectara-agentic might be problematic. Click here for more details.

Files changed (53) hide show
  1. tests/__init__.py +7 -0
  2. tests/conftest.py +312 -0
  3. tests/endpoint.py +54 -17
  4. tests/run_tests.py +111 -0
  5. tests/test_agent.py +10 -5
  6. tests/test_agent_type.py +82 -143
  7. tests/test_api_endpoint.py +4 -0
  8. tests/test_bedrock.py +4 -0
  9. tests/test_fallback.py +4 -0
  10. tests/test_gemini.py +28 -45
  11. tests/test_groq.py +4 -0
  12. tests/test_private_llm.py +11 -2
  13. tests/test_return_direct.py +6 -2
  14. tests/test_serialization.py +4 -0
  15. tests/test_streaming.py +88 -0
  16. tests/test_tools.py +10 -82
  17. tests/test_vectara_llms.py +4 -0
  18. tests/test_vhc.py +66 -0
  19. tests/test_workflow.py +4 -0
  20. vectara_agentic/__init__.py +27 -4
  21. vectara_agentic/_callback.py +65 -67
  22. vectara_agentic/_observability.py +30 -30
  23. vectara_agentic/_version.py +1 -1
  24. vectara_agentic/agent.py +375 -848
  25. vectara_agentic/agent_config.py +15 -14
  26. vectara_agentic/agent_core/__init__.py +22 -0
  27. vectara_agentic/agent_core/factory.py +501 -0
  28. vectara_agentic/{_prompts.py → agent_core/prompts.py} +3 -35
  29. vectara_agentic/agent_core/serialization.py +345 -0
  30. vectara_agentic/agent_core/streaming.py +495 -0
  31. vectara_agentic/agent_core/utils/__init__.py +34 -0
  32. vectara_agentic/agent_core/utils/hallucination.py +202 -0
  33. vectara_agentic/agent_core/utils/logging.py +52 -0
  34. vectara_agentic/agent_core/utils/prompt_formatting.py +56 -0
  35. vectara_agentic/agent_core/utils/schemas.py +87 -0
  36. vectara_agentic/agent_core/utils/tools.py +125 -0
  37. vectara_agentic/agent_endpoint.py +4 -6
  38. vectara_agentic/db_tools.py +37 -12
  39. vectara_agentic/llm_utils.py +41 -42
  40. vectara_agentic/sub_query_workflow.py +9 -14
  41. vectara_agentic/tool_utils.py +138 -83
  42. vectara_agentic/tools.py +43 -21
  43. vectara_agentic/tools_catalog.py +16 -16
  44. vectara_agentic/types.py +98 -6
  45. {vectara_agentic-0.3.2.dist-info → vectara_agentic-0.4.0.dist-info}/METADATA +69 -30
  46. vectara_agentic-0.4.0.dist-info/RECORD +50 -0
  47. tests/test_agent_planning.py +0 -64
  48. tests/test_hhem.py +0 -100
  49. vectara_agentic/hhem.py +0 -82
  50. vectara_agentic-0.3.2.dist-info/RECORD +0 -39
  51. {vectara_agentic-0.3.2.dist-info → vectara_agentic-0.4.0.dist-info}/WHEEL +0 -0
  52. {vectara_agentic-0.3.2.dist-info → vectara_agentic-0.4.0.dist-info}/licenses/LICENSE +0 -0
  53. {vectara_agentic-0.3.2.dist-info → vectara_agentic-0.4.0.dist-info}/top_level.txt +0 -0
tests/__init__.py CHANGED
@@ -0,0 +1,7 @@
1
+ """
2
+ Tests package for vectara_agentic.
3
+ """
4
+
5
+ # Suppress external dependency warnings globally for all tests
6
+ import warnings
7
+ warnings.simplefilter("ignore", DeprecationWarning)
tests/conftest.py ADDED
@@ -0,0 +1,312 @@
1
+ # Suppress external dependency warnings before any other imports
2
+ import warnings
3
+
4
+ warnings.simplefilter("ignore", DeprecationWarning)
5
+
6
+ """
7
+ Common test utilities, configurations, and fixtures for the vectara-agentic test suite.
8
+ """
9
+
10
+ import unittest
11
+ from contextlib import contextmanager
12
+ from typing import Any
13
+
14
+ from vectara_agentic.agent_config import AgentConfig
15
+ from vectara_agentic.types import AgentType, ModelProvider
16
+
17
+
18
+ # ========================================
19
+ # Common Test Functions
20
+ # ========================================
21
+
22
+ def mult(x: float, y: float) -> float:
23
+ """Multiply two numbers - common test function used across multiple test files."""
24
+ return x * y
25
+
26
+
27
+ def add(x: float, y: float) -> float:
28
+ """Add two numbers - common test function used in workflow tests."""
29
+ return x + y
30
+
31
+
32
+ # ========================================
33
+ # Common Test Data
34
+ # ========================================
35
+
36
+ # Standard test topic used across most tests
37
+ STANDARD_TEST_TOPIC = "AI topic"
38
+
39
+ # Standard test instructions used across most tests
40
+ STANDARD_TEST_INSTRUCTIONS = "Always do as your father tells you, if your mother agrees!"
41
+
42
+ # Alternative instructions for specific tests
43
+ WORKFLOW_TEST_INSTRUCTIONS = "You are a helpful AI assistant."
44
+ MATH_AGENT_INSTRUCTIONS = "you are an agent specializing in math, assisting a user."
45
+
46
+
47
+ # ========================================
48
+ # Agent Configuration Objects
49
+ # ========================================
50
+
51
+ # Default configurations
52
+ default_config = AgentConfig()
53
+
54
+ # Function Calling configurations for all providers
55
+ fc_config_anthropic = AgentConfig(
56
+ agent_type=AgentType.FUNCTION_CALLING,
57
+ main_llm_provider=ModelProvider.ANTHROPIC,
58
+ tool_llm_provider=ModelProvider.ANTHROPIC,
59
+ )
60
+
61
+ fc_config_gemini = AgentConfig(
62
+ agent_type=AgentType.FUNCTION_CALLING,
63
+ main_llm_provider=ModelProvider.GEMINI,
64
+ tool_llm_provider=ModelProvider.GEMINI,
65
+ )
66
+
67
+ fc_config_together = AgentConfig(
68
+ agent_type=AgentType.FUNCTION_CALLING,
69
+ main_llm_provider=ModelProvider.TOGETHER,
70
+ tool_llm_provider=ModelProvider.TOGETHER,
71
+ )
72
+
73
+ fc_config_openai = AgentConfig(
74
+ agent_type=AgentType.FUNCTION_CALLING,
75
+ main_llm_provider=ModelProvider.OPENAI,
76
+ tool_llm_provider=ModelProvider.OPENAI,
77
+ )
78
+
79
+ fc_config_groq = AgentConfig(
80
+ agent_type=AgentType.FUNCTION_CALLING,
81
+ main_llm_provider=ModelProvider.GROQ,
82
+ tool_llm_provider=ModelProvider.GROQ,
83
+ )
84
+
85
+ fc_config_bedrock = AgentConfig(
86
+ agent_type=AgentType.FUNCTION_CALLING,
87
+ main_llm_provider=ModelProvider.BEDROCK,
88
+ tool_llm_provider=ModelProvider.BEDROCK,
89
+ )
90
+
91
+ # ReAct configurations for all providers
92
+ react_config_anthropic = AgentConfig(
93
+ agent_type=AgentType.REACT,
94
+ main_llm_provider=ModelProvider.ANTHROPIC,
95
+ tool_llm_provider=ModelProvider.ANTHROPIC,
96
+ )
97
+
98
+ react_config_gemini = AgentConfig(
99
+ agent_type=AgentType.REACT,
100
+ main_llm_provider=ModelProvider.GEMINI,
101
+ main_llm_model_name="models/gemini-2.5-flash",
102
+ tool_llm_provider=ModelProvider.GEMINI,
103
+ tool_llm_model_name="models/gemini-2.5-flash",
104
+ )
105
+
106
+ react_config_together = AgentConfig(
107
+ agent_type=AgentType.REACT,
108
+ main_llm_provider=ModelProvider.TOGETHER,
109
+ tool_llm_provider=ModelProvider.TOGETHER,
110
+ )
111
+
112
+ react_config_groq = AgentConfig(
113
+ agent_type=AgentType.REACT,
114
+ main_llm_provider=ModelProvider.GROQ,
115
+ tool_llm_provider=ModelProvider.GROQ,
116
+ )
117
+
118
+ # Private LLM configurations
119
+ private_llm_react_config = AgentConfig(
120
+ agent_type=AgentType.REACT,
121
+ main_llm_provider=ModelProvider.PRIVATE,
122
+ main_llm_model_name="gpt-4o",
123
+ private_llm_api_base="http://localhost:8000/v1",
124
+ tool_llm_provider=ModelProvider.PRIVATE,
125
+ tool_llm_model_name="gpt-4o",
126
+ )
127
+
128
+ private_llm_fc_config = AgentConfig(
129
+ agent_type=AgentType.FUNCTION_CALLING,
130
+ main_llm_provider=ModelProvider.PRIVATE,
131
+ main_llm_model_name="gpt-4.1",
132
+ private_llm_api_base="http://localhost:8000/v1",
133
+ tool_llm_provider=ModelProvider.PRIVATE,
134
+ tool_llm_model_name="gpt-4.1",
135
+ )
136
+
137
+
138
+ # ========================================
139
+ # Error Detection and Testing Utilities
140
+ # ========================================
141
+
142
+ def is_rate_limited(response_text: str) -> bool:
143
+ """
144
+ Check if a response indicates a rate limit error from any LLM provider.
145
+
146
+ Args:
147
+ response_text: The response text from the agent
148
+
149
+ Returns:
150
+ bool: True if the response indicates rate limiting
151
+ """
152
+ rate_limit_indicators = [
153
+ # Generic indicators
154
+ "Error code: 429",
155
+ "rate_limit_exceeded",
156
+ "Rate limit reached",
157
+ "rate limit",
158
+ "quota exceeded",
159
+ "usage limit",
160
+ # GROQ-specific
161
+ "tokens per day",
162
+ "TPD",
163
+ "service tier",
164
+ "on_demand",
165
+ "deepseek-r1-distill-llama-70b",
166
+ "Upgrade to Dev Tier",
167
+ "console.groq.com/settings/billing",
168
+ # OpenAI-specific
169
+ "requests per minute",
170
+ "RPM",
171
+ "tokens per minute",
172
+ "TPM",
173
+ # Anthropic-specific
174
+ "overloaded_error",
175
+ "Overloaded",
176
+ "APIStatusError",
177
+ "anthropic.APIStatusError",
178
+ "usage_limit_exceeded",
179
+ # General API limit indicators
180
+ "try again in",
181
+ "Please wait",
182
+ "Too many requests",
183
+ "throttled",
184
+ # Additional rate limit patterns
185
+ "Limit.*Used.*Requested",
186
+ "Need more tokens",
187
+ ]
188
+
189
+ response_lower = response_text.lower()
190
+ return any(
191
+ indicator.lower() in response_lower for indicator in rate_limit_indicators
192
+ )
193
+
194
+
195
+ def is_api_key_error(response_text: str) -> bool:
196
+ """
197
+ Check if a response indicates an API key authentication error.
198
+
199
+ Args:
200
+ response_text: The response text from the agent
201
+
202
+ Returns:
203
+ bool: True if the response indicates API key issues
204
+ """
205
+ api_key_indicators = [
206
+ "Error code: 401",
207
+ "Invalid API Key",
208
+ "authentication",
209
+ "unauthorized",
210
+ "invalid_api_key",
211
+ "missing api key",
212
+ "api key not found",
213
+ ]
214
+
215
+ response_lower = response_text.lower()
216
+ return any(indicator.lower() in response_lower for indicator in api_key_indicators)
217
+
218
+
219
+ def skip_if_rate_limited(
220
+ test_instance: unittest.TestCase, response_text: str, provider: str = "LLM"
221
+ ) -> None:
222
+ """
223
+ Skip a test if the response indicates rate limiting.
224
+
225
+ Args:
226
+ test_instance: The test case instance
227
+ response_text: The response text to check
228
+ provider: The name of the provider (for clearer skip messages)
229
+ """
230
+ if is_rate_limited(response_text):
231
+ test_instance.skipTest(f"{provider} rate limit reached - skipping test")
232
+
233
+
234
+ def skip_if_api_key_error(
235
+ test_instance: unittest.TestCase, response_text: str, provider: str = "LLM"
236
+ ) -> None:
237
+ """
238
+ Skip a test if the response indicates API key issues.
239
+
240
+ Args:
241
+ test_instance: The test case instance
242
+ response_text: The response text to check
243
+ provider: The name of the provider (for clearer skip messages)
244
+ """
245
+ if is_api_key_error(response_text):
246
+ test_instance.skipTest(f"{provider} API key invalid/missing - skipping test")
247
+
248
+
249
+ def skip_if_provider_error(
250
+ test_instance: unittest.TestCase, response_text: str, provider: str = "LLM"
251
+ ) -> None:
252
+ """
253
+ Skip a test if the response indicates common provider errors (rate limiting or API key issues).
254
+
255
+ Args:
256
+ test_instance: The test case instance
257
+ response_text: The response text to check
258
+ provider: The name of the provider (for clearer skip messages)
259
+ """
260
+ skip_if_rate_limited(test_instance, response_text, provider)
261
+ skip_if_api_key_error(test_instance, response_text, provider)
262
+
263
+
264
+ class AgentTestMixin:
265
+ """
266
+ Mixin class providing utility methods for agent testing.
267
+ """
268
+
269
+ @contextmanager
270
+ def with_provider_fallback(self, provider: str = "LLM"):
271
+ """
272
+ Context manager that catches and handles provider errors from any agent method.
273
+
274
+ Args:
275
+ provider: Provider name for error messages
276
+
277
+ Usage:
278
+ with self.with_provider_fallback("GROQ"):
279
+ response = agent.chat("test")
280
+
281
+ with self.with_provider_fallback("GROQ"):
282
+ async for chunk in agent.astream_chat("test"):
283
+ pass
284
+
285
+ Raises:
286
+ unittest.SkipTest: If rate limiting or API key errors occur
287
+ """
288
+ try:
289
+ yield
290
+ except Exception as e:
291
+ error_text = str(e)
292
+ if is_rate_limited(error_text) or is_api_key_error(error_text):
293
+ self.skipTest(f"{provider} error: {error_text}")
294
+ raise
295
+
296
+ def check_response_and_skip(self, response: Any, provider: str = "LLM") -> Any:
297
+ """
298
+ Check response content and skip test if provider errors are detected.
299
+
300
+ Args:
301
+ response: The response object from agent method
302
+ provider: Provider name for error messages
303
+
304
+ Returns:
305
+ The response object if no errors detected
306
+
307
+ Raises:
308
+ unittest.SkipTest: If rate limiting or API key errors detected in response
309
+ """
310
+ response_text = getattr(response, "response", str(response))
311
+ skip_if_provider_error(self, response_text, provider)
312
+ return response
tests/endpoint.py CHANGED
@@ -1,47 +1,84 @@
1
- from openai import OpenAI
2
- from flask import Flask, request, jsonify
1
+ #!/usr/bin/env python3
2
+ import json
3
3
  import logging
4
4
  from functools import wraps
5
+ from flask import Flask, request, Response, jsonify
6
+ from openai import OpenAI
5
7
 
6
8
  app = Flask(__name__)
7
- app.config['TESTING'] = True
8
9
 
9
- log = logging.getLogger('werkzeug')
10
- log.setLevel(logging.ERROR)
11
-
12
- # Set your OpenAI API key (ensure you've set this in your environment)
10
+ # Configure logging
11
+ logging.basicConfig(level=logging.INFO)
12
+ app.logger.setLevel(logging.INFO)
13
+ werkzeug_log = logging.getLogger('werkzeug')
14
+ werkzeug_log.setLevel(logging.ERROR)
13
15
 
16
+ # Load expected API key from environment (fallback for testing)
14
17
  EXPECTED_API_KEY = "TEST_API_KEY"
15
18
 
19
+ # Authentication decorator
16
20
  def require_api_key(f):
17
21
  @wraps(f)
18
22
  def decorated_function(*args, **kwargs):
19
- api_key = request.headers.get("Authorization").split("Bearer ")[-1]
20
- if not api_key or api_key != EXPECTED_API_KEY:
23
+ auth_header = request.headers.get("Authorization", "")
24
+ if not auth_header.startswith("Bearer "):
25
+ return jsonify({"error": "Unauthorized"}), 401
26
+ api_key = auth_header.split(" ", 1)[1]
27
+ if api_key != EXPECTED_API_KEY:
21
28
  return jsonify({"error": "Unauthorized"}), 401
22
29
  return f(*args, **kwargs)
23
30
  return decorated_function
24
31
 
25
32
  @app.before_request
26
33
  def log_request_info():
27
- app.logger.info("Request received: %s %s", request.method, request.path)
34
+ app.logger.info("%s %s", request.method, request.path)
28
35
 
29
36
  @app.route("/v1/chat/completions", methods=["POST"])
30
37
  @require_api_key
31
38
  def chat_completions():
32
- app.logger.info("Received request on /v1/chat/completions")
33
- data = request.get_json()
34
- if not data:
39
+ """
40
+ Proxy endpoint for OpenAI Chat Completions.
41
+ Supports both streaming and non-streaming modes.
42
+ """
43
+ try:
44
+ data = request.get_json(force=True)
45
+ except Exception:
35
46
  return jsonify({"error": "Invalid JSON payload"}), 400
36
47
 
37
48
  client = OpenAI()
49
+ is_stream = data.get('stream', False)
50
+
51
+ if is_stream:
52
+ # Stream each chunk to the client as Server-Sent Events
53
+ def generate():
54
+ try:
55
+ for chunk in client.chat.completions.create(**data):
56
+ # Convert chunk to dict and then JSON
57
+ event = chunk.model_dump()
58
+ yield f"data: {json.dumps(event)}\n\n"
59
+ except Exception as e:
60
+ # On error, send an SSE event with error info
61
+ error_msg = json.dumps({"error": str(e)})
62
+ yield f"data: {error_msg}\n\n"
63
+
64
+ headers = {
65
+ 'Content-Type': 'text/event-stream',
66
+ 'Cache-Control': 'no-cache',
67
+ 'Connection': 'keep-alive'
68
+ }
69
+ return Response(generate(), headers=headers)
70
+
71
+ # Non-streaming path
38
72
  try:
39
73
  completion = client.chat.completions.create(**data)
40
- return jsonify(completion.model_dump()), 200
74
+ result = completion.model_dump()
75
+ app.logger.info(f"Non-stream response: {result}")
76
+ return jsonify(result), 200
41
77
  except Exception as e:
42
- return jsonify({"error": str(e)}), 400
78
+ app.logger.error(f"Error during completion: {e}")
79
+ return jsonify({"error": str(e)}), 500
43
80
 
44
81
 
45
82
  if __name__ == "__main__":
46
- # Run on port 5000 by default; adjust as needed.
47
- app.run(debug=True, port=5000, use_reloader=False)
83
+ # Bind to all interfaces on port 5000
84
+ app.run(host="0.0.0.0", port=5000, debug=False, use_reloader=False)
tests/run_tests.py ADDED
@@ -0,0 +1,111 @@
1
+ #!/usr/bin/env python3
2
+ """
3
+ Custom test runner that suppresses Pydantic deprecation warnings.
4
+ Usage: python run_tests.py [test_pattern]
5
+ """
6
+
7
+ import sys
8
+ import warnings
9
+ import unittest
10
+ import argparse
11
+ import asyncio
12
+ import gc
13
+
14
+
15
+ def suppress_pydantic_warnings():
16
+ """Comprehensive warning suppression before unittest starts."""
17
+ # Multiple layers of suppression
18
+ warnings.resetwarnings()
19
+ warnings.simplefilter("ignore", DeprecationWarning)
20
+
21
+ # Specific Pydantic patterns
22
+ pydantic_patterns = [
23
+ ".*PydanticDeprecatedSince.*",
24
+ ".*__fields__.*deprecated.*",
25
+ ".*__fields_set__.*deprecated.*",
26
+ ".*model_fields.*deprecated.*",
27
+ ".*model_computed_fields.*deprecated.*",
28
+ ".*use.*model_fields.*instead.*",
29
+ ".*use.*model_fields_set.*instead.*",
30
+ ".*Accessing.*model_.*attribute.*deprecated.*",
31
+ ]
32
+
33
+ # Resource warning patterns (reduce noise, not critical)
34
+ resource_patterns = [
35
+ ".*unclosed transport.*",
36
+ ".*unclosed <socket\\.socket.*",
37
+ ]
38
+
39
+ for pattern in pydantic_patterns:
40
+ warnings.filterwarnings("ignore", category=DeprecationWarning, message=pattern)
41
+
42
+ for pattern in resource_patterns:
43
+ warnings.filterwarnings("ignore", category=ResourceWarning, message=pattern)
44
+
45
+
46
+ def main():
47
+ parser = argparse.ArgumentParser(description="Run tests with warning suppression")
48
+ parser.add_argument(
49
+ "pattern",
50
+ nargs="?",
51
+ default="test_*.py",
52
+ help="Test file pattern (default: test_*.py)",
53
+ )
54
+ parser.add_argument("--verbose", "-v", action="store_true", help="Verbose output")
55
+
56
+ args = parser.parse_args()
57
+
58
+ # Apply comprehensive warning suppression BEFORE unittest starts
59
+ suppress_pydantic_warnings()
60
+
61
+ print(f"🧪 Running tests with pattern: {args.pattern}")
62
+ print("🔇 Pydantic deprecation warnings suppressed")
63
+
64
+ # Add tests directory to Python path for relative imports
65
+ import os
66
+ sys.path.insert(0, os.path.abspath("tests"))
67
+
68
+ # Discover and run tests
69
+ loader = unittest.TestLoader()
70
+ start_dir = "tests"
71
+ suite = loader.discover(start_dir, pattern=args.pattern)
72
+
73
+ # Run tests
74
+ verbosity = 2 if args.verbose else 1
75
+ runner = unittest.TextTestRunner(verbosity=verbosity)
76
+ result = runner.run(suite)
77
+
78
+ # Cleanup to reduce resource warnings
79
+ try:
80
+ # Close any remaining event loops
81
+ loop = None
82
+ try:
83
+ loop = asyncio.get_running_loop()
84
+ except RuntimeError:
85
+ pass
86
+
87
+ if loop and not loop.is_closed():
88
+ # Cancel all pending tasks
89
+ pending = asyncio.all_tasks(loop)
90
+ for task in pending:
91
+ task.cancel()
92
+
93
+ # Give tasks a chance to complete cancellation
94
+ if pending:
95
+ loop.run_until_complete(
96
+ asyncio.gather(*pending, return_exceptions=True)
97
+ )
98
+
99
+ # Force garbage collection
100
+ gc.collect()
101
+
102
+ except Exception:
103
+ # Don't let cleanup errors affect test results
104
+ pass
105
+
106
+ # Exit with proper code
107
+ sys.exit(0 if result.wasSuccessful() else 1)
108
+
109
+
110
+ if __name__ == "__main__":
111
+ main()
tests/test_agent.py CHANGED
@@ -1,13 +1,18 @@
1
+ # Suppress external dependency warnings before any other imports
2
+ import warnings
3
+ warnings.simplefilter("ignore", DeprecationWarning)
4
+
1
5
  import unittest
2
6
  import threading
3
7
  from datetime import date
4
8
 
5
- from vectara_agentic.agent import _get_prompt, Agent, AgentType
9
+ from vectara_agentic.agent import Agent, AgentType
10
+ from vectara_agentic.agent_core.utils.prompt_formatting import format_prompt
6
11
  from vectara_agentic.agent_config import AgentConfig
7
12
  from vectara_agentic.types import ModelProvider, ObserverType
8
13
  from vectara_agentic.tools import ToolsFactory
9
14
 
10
- from vectara_agentic._prompts import GENERAL_INSTRUCTIONS
15
+ from vectara_agentic.agent_core.prompts import GENERAL_INSTRUCTIONS
11
16
 
12
17
 
13
18
  def mult(x: float, y: float) -> float:
@@ -28,7 +33,7 @@ class TestAgentPackage(unittest.TestCase):
28
33
  + " with Always do as your mother tells you!"
29
34
  )
30
35
  self.assertEqual(
31
- _get_prompt(prompt_template, GENERAL_INSTRUCTIONS, topic, custom_instructions), expected_output
36
+ format_prompt(prompt_template, GENERAL_INSTRUCTIONS, topic, custom_instructions), expected_output
32
37
  )
33
38
 
34
39
  def test_agent_init(self):
@@ -36,11 +41,11 @@ class TestAgentPackage(unittest.TestCase):
36
41
  topic = "AI"
37
42
  custom_instructions = "Always do as your mother tells you!"
38
43
  agent = Agent(tools, topic, custom_instructions)
39
- self.assertEqual(agent.agent_type, AgentType.OPENAI)
44
+ self.assertEqual(agent.agent_type, AgentType.FUNCTION_CALLING)
40
45
  self.assertEqual(agent._topic, topic)
41
46
  self.assertEqual(agent._custom_instructions, custom_instructions)
42
47
 
43
- # To run this test, you must have OPENAI_API_KEY in your environment
48
+ # To run this test, you must have appropriate API key in your environment
44
49
  self.assertEqual(
45
50
  agent.chat(
46
51
  "What is 5 times 10. Only give the answer, nothing else"