vectara-agentic 0.3.3__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of vectara-agentic might be problematic. Click here for more details.

Files changed (56) hide show
  1. tests/__init__.py +7 -0
  2. tests/conftest.py +316 -0
  3. tests/endpoint.py +54 -17
  4. tests/run_tests.py +112 -0
  5. tests/test_agent.py +35 -33
  6. tests/test_agent_fallback_memory.py +270 -0
  7. tests/test_agent_memory_consistency.py +229 -0
  8. tests/test_agent_type.py +86 -143
  9. tests/test_api_endpoint.py +4 -0
  10. tests/test_bedrock.py +50 -31
  11. tests/test_fallback.py +4 -0
  12. tests/test_gemini.py +27 -59
  13. tests/test_groq.py +50 -31
  14. tests/test_private_llm.py +11 -2
  15. tests/test_return_direct.py +6 -2
  16. tests/test_serialization.py +7 -6
  17. tests/test_session_memory.py +252 -0
  18. tests/test_streaming.py +109 -0
  19. tests/test_together.py +62 -0
  20. tests/test_tools.py +10 -82
  21. tests/test_vectara_llms.py +4 -0
  22. tests/test_vhc.py +67 -0
  23. tests/test_workflow.py +13 -28
  24. vectara_agentic/__init__.py +27 -4
  25. vectara_agentic/_callback.py +65 -67
  26. vectara_agentic/_observability.py +30 -30
  27. vectara_agentic/_version.py +1 -1
  28. vectara_agentic/agent.py +565 -859
  29. vectara_agentic/agent_config.py +15 -14
  30. vectara_agentic/agent_core/__init__.py +22 -0
  31. vectara_agentic/agent_core/factory.py +383 -0
  32. vectara_agentic/{_prompts.py → agent_core/prompts.py} +21 -46
  33. vectara_agentic/agent_core/serialization.py +348 -0
  34. vectara_agentic/agent_core/streaming.py +483 -0
  35. vectara_agentic/agent_core/utils/__init__.py +29 -0
  36. vectara_agentic/agent_core/utils/hallucination.py +157 -0
  37. vectara_agentic/agent_core/utils/logging.py +52 -0
  38. vectara_agentic/agent_core/utils/schemas.py +87 -0
  39. vectara_agentic/agent_core/utils/tools.py +125 -0
  40. vectara_agentic/agent_endpoint.py +4 -6
  41. vectara_agentic/db_tools.py +37 -12
  42. vectara_agentic/llm_utils.py +42 -43
  43. vectara_agentic/sub_query_workflow.py +9 -14
  44. vectara_agentic/tool_utils.py +138 -83
  45. vectara_agentic/tools.py +36 -21
  46. vectara_agentic/tools_catalog.py +16 -16
  47. vectara_agentic/types.py +106 -8
  48. {vectara_agentic-0.3.3.dist-info → vectara_agentic-0.4.1.dist-info}/METADATA +111 -31
  49. vectara_agentic-0.4.1.dist-info/RECORD +53 -0
  50. tests/test_agent_planning.py +0 -64
  51. tests/test_hhem.py +0 -100
  52. vectara_agentic/hhem.py +0 -82
  53. vectara_agentic-0.3.3.dist-info/RECORD +0 -39
  54. {vectara_agentic-0.3.3.dist-info → vectara_agentic-0.4.1.dist-info}/WHEEL +0 -0
  55. {vectara_agentic-0.3.3.dist-info → vectara_agentic-0.4.1.dist-info}/licenses/LICENSE +0 -0
  56. {vectara_agentic-0.3.3.dist-info → vectara_agentic-0.4.1.dist-info}/top_level.txt +0 -0
tests/__init__.py CHANGED
@@ -0,0 +1,7 @@
1
+ """
2
+ Tests package for vectara_agentic.
3
+ """
4
+
5
+ # Suppress external dependency warnings globally for all tests
6
+ import warnings
7
+ warnings.simplefilter("ignore", DeprecationWarning)
tests/conftest.py ADDED
@@ -0,0 +1,316 @@
1
+ # Suppress external dependency warnings before any other imports
2
+ import warnings
3
+
4
+ warnings.simplefilter("ignore", DeprecationWarning)
5
+
6
+ """
7
+ Common test utilities, configurations, and fixtures for the vectara-agentic test suite.
8
+ """
9
+
10
+ import unittest
11
+ from contextlib import contextmanager
12
+ from typing import Any
13
+
14
+ from vectara_agentic.agent_config import AgentConfig
15
+ from vectara_agentic.types import AgentType, ModelProvider
16
+
17
+
18
+ # ========================================
19
+ # Common Test Functions
20
+ # ========================================
21
+
22
+
23
+ def mult(x: float, y: float) -> float:
24
+ """Multiply two numbers - common test function used across multiple test files."""
25
+ return x * y
26
+
27
+
28
+ def add(x: float, y: float) -> float:
29
+ """Add two numbers - common test function used in workflow tests."""
30
+ return x + y
31
+
32
+
33
+ # ========================================
34
+ # Common Test Data
35
+ # ========================================
36
+
37
+ # Standard test topic used across most tests
38
+ STANDARD_TEST_TOPIC = "AI topic"
39
+
40
+ # Standard test instructions used across most tests
41
+ STANDARD_TEST_INSTRUCTIONS = (
42
+ "Always do as your father tells you, if your mother agrees!"
43
+ )
44
+
45
+ # Alternative instructions for specific tests
46
+ WORKFLOW_TEST_INSTRUCTIONS = "You are a helpful AI assistant."
47
+ MATH_AGENT_INSTRUCTIONS = "you are an agent specializing in math, assisting a user."
48
+
49
+
50
+ # ========================================
51
+ # Agent Configuration Objects
52
+ # ========================================
53
+
54
+ # Default configurations
55
+ default_config = AgentConfig()
56
+
57
+ # Function Calling configurations for all providers
58
+ fc_config_anthropic = AgentConfig(
59
+ agent_type=AgentType.FUNCTION_CALLING,
60
+ main_llm_provider=ModelProvider.ANTHROPIC,
61
+ tool_llm_provider=ModelProvider.ANTHROPIC,
62
+ )
63
+
64
+ fc_config_gemini = AgentConfig(
65
+ agent_type=AgentType.FUNCTION_CALLING,
66
+ main_llm_provider=ModelProvider.GEMINI,
67
+ tool_llm_provider=ModelProvider.GEMINI,
68
+ )
69
+
70
+ fc_config_together = AgentConfig(
71
+ agent_type=AgentType.FUNCTION_CALLING,
72
+ main_llm_provider=ModelProvider.TOGETHER,
73
+ tool_llm_provider=ModelProvider.TOGETHER,
74
+ )
75
+
76
+ fc_config_openai = AgentConfig(
77
+ agent_type=AgentType.FUNCTION_CALLING,
78
+ main_llm_provider=ModelProvider.OPENAI,
79
+ tool_llm_provider=ModelProvider.OPENAI,
80
+ )
81
+
82
+ fc_config_groq = AgentConfig(
83
+ agent_type=AgentType.FUNCTION_CALLING,
84
+ main_llm_provider=ModelProvider.GROQ,
85
+ tool_llm_provider=ModelProvider.GROQ,
86
+ )
87
+
88
+ fc_config_bedrock = AgentConfig(
89
+ agent_type=AgentType.FUNCTION_CALLING,
90
+ main_llm_provider=ModelProvider.BEDROCK,
91
+ tool_llm_provider=ModelProvider.BEDROCK,
92
+ )
93
+
94
+ # ReAct configurations for all providers
95
+ react_config_anthropic = AgentConfig(
96
+ agent_type=AgentType.REACT,
97
+ main_llm_provider=ModelProvider.ANTHROPIC,
98
+ tool_llm_provider=ModelProvider.ANTHROPIC,
99
+ )
100
+
101
+ react_config_gemini = AgentConfig(
102
+ agent_type=AgentType.REACT,
103
+ main_llm_provider=ModelProvider.GEMINI,
104
+ main_llm_model_name="models/gemini-2.5-flash",
105
+ tool_llm_provider=ModelProvider.GEMINI,
106
+ tool_llm_model_name="models/gemini-2.5-flash",
107
+ )
108
+
109
+ react_config_together = AgentConfig(
110
+ agent_type=AgentType.REACT,
111
+ main_llm_provider=ModelProvider.TOGETHER,
112
+ tool_llm_provider=ModelProvider.TOGETHER,
113
+ )
114
+
115
+ react_config_groq = AgentConfig(
116
+ agent_type=AgentType.REACT,
117
+ main_llm_provider=ModelProvider.GROQ,
118
+ tool_llm_provider=ModelProvider.GROQ,
119
+ )
120
+
121
+ # Private LLM configurations
122
+ private_llm_react_config = AgentConfig(
123
+ agent_type=AgentType.REACT,
124
+ main_llm_provider=ModelProvider.PRIVATE,
125
+ main_llm_model_name="gpt-4o",
126
+ private_llm_api_base="http://localhost:8000/v1",
127
+ tool_llm_provider=ModelProvider.PRIVATE,
128
+ tool_llm_model_name="gpt-4o",
129
+ )
130
+
131
+ private_llm_fc_config = AgentConfig(
132
+ agent_type=AgentType.FUNCTION_CALLING,
133
+ main_llm_provider=ModelProvider.PRIVATE,
134
+ main_llm_model_name="gpt-4.1",
135
+ private_llm_api_base="http://localhost:8000/v1",
136
+ tool_llm_provider=ModelProvider.PRIVATE,
137
+ tool_llm_model_name="gpt-4.1",
138
+ )
139
+
140
+
141
+ # ========================================
142
+ # Error Detection and Testing Utilities
143
+ # ========================================
144
+
145
+
146
+ def is_rate_limited(response_text: str) -> bool:
147
+ """
148
+ Check if a response indicates a rate limit error from any LLM provider.
149
+
150
+ Args:
151
+ response_text: The response text from the agent
152
+
153
+ Returns:
154
+ bool: True if the response indicates rate limiting
155
+ """
156
+ rate_limit_indicators = [
157
+ # Generic indicators
158
+ "Error code: 429",
159
+ "rate_limit_exceeded",
160
+ "Rate limit reached",
161
+ "rate limit",
162
+ "quota exceeded",
163
+ "usage limit",
164
+ # GROQ-specific
165
+ "tokens per day",
166
+ "TPD",
167
+ "service tier",
168
+ "on_demand",
169
+ "deepseek-r1-distill-llama-70b",
170
+ "Upgrade to Dev Tier",
171
+ "console.groq.com/settings/billing",
172
+ # OpenAI-specific
173
+ "requests per minute",
174
+ "RPM",
175
+ "tokens per minute",
176
+ "TPM",
177
+ # Anthropic-specific
178
+ "overloaded_error",
179
+ "Overloaded",
180
+ "APIStatusError",
181
+ "anthropic.APIStatusError",
182
+ "usage_limit_exceeded",
183
+ # General API limit indicators
184
+ "try again in",
185
+ "Please wait",
186
+ "Too many requests",
187
+ "throttled",
188
+ # Additional rate limit patterns
189
+ "Limit.*Used.*Requested",
190
+ "Need more tokens",
191
+ ]
192
+
193
+ response_lower = response_text.lower()
194
+ return any(
195
+ indicator.lower() in response_lower for indicator in rate_limit_indicators
196
+ )
197
+
198
+
199
+ def is_api_key_error(response_text: str) -> bool:
200
+ """
201
+ Check if a response indicates an API key authentication error.
202
+
203
+ Args:
204
+ response_text: The response text from the agent
205
+
206
+ Returns:
207
+ bool: True if the response indicates API key issues
208
+ """
209
+ api_key_indicators = [
210
+ "Error code: 401",
211
+ "Invalid API Key",
212
+ "authentication",
213
+ "unauthorized",
214
+ "invalid_api_key",
215
+ "missing api key",
216
+ "api key not found",
217
+ ]
218
+
219
+ response_lower = response_text.lower()
220
+ return any(indicator.lower() in response_lower for indicator in api_key_indicators)
221
+
222
+
223
+ def skip_if_rate_limited(
224
+ test_instance: unittest.TestCase, response_text: str, provider: str = "LLM"
225
+ ) -> None:
226
+ """
227
+ Skip a test if the response indicates rate limiting.
228
+
229
+ Args:
230
+ test_instance: The test case instance
231
+ response_text: The response text to check
232
+ provider: The name of the provider (for clearer skip messages)
233
+ """
234
+ if is_rate_limited(response_text):
235
+ test_instance.skipTest(f"{provider} rate limit reached - skipping test")
236
+
237
+
238
+ def skip_if_api_key_error(
239
+ test_instance: unittest.TestCase, response_text: str, provider: str = "LLM"
240
+ ) -> None:
241
+ """
242
+ Skip a test if the response indicates API key issues.
243
+
244
+ Args:
245
+ test_instance: The test case instance
246
+ response_text: The response text to check
247
+ provider: The name of the provider (for clearer skip messages)
248
+ """
249
+ if is_api_key_error(response_text):
250
+ test_instance.skipTest(f"{provider} API key invalid/missing - skipping test")
251
+
252
+
253
+ def skip_if_provider_error(
254
+ test_instance: unittest.TestCase, response_text: str, provider: str = "LLM"
255
+ ) -> None:
256
+ """
257
+ Skip a test if the response indicates common provider errors (rate limiting or API key issues).
258
+
259
+ Args:
260
+ test_instance: The test case instance
261
+ response_text: The response text to check
262
+ provider: The name of the provider (for clearer skip messages)
263
+ """
264
+ skip_if_rate_limited(test_instance, response_text, provider)
265
+ skip_if_api_key_error(test_instance, response_text, provider)
266
+
267
+
268
+ class AgentTestMixin:
269
+ """
270
+ Mixin class providing utility methods for agent testing.
271
+ """
272
+
273
+ @contextmanager
274
+ def with_provider_fallback(self, provider: str = "LLM"):
275
+ """
276
+ Context manager that catches and handles provider errors from any agent method.
277
+
278
+ Args:
279
+ provider: Provider name for error messages
280
+
281
+ Usage:
282
+ with self.with_provider_fallback("GROQ"):
283
+ response = agent.chat("test")
284
+
285
+ with self.with_provider_fallback("GROQ"):
286
+ async for chunk in agent.astream_chat("test"):
287
+ pass
288
+
289
+ Raises:
290
+ unittest.SkipTest: If rate limiting or API key errors occur
291
+ """
292
+ try:
293
+ yield
294
+ except Exception as e:
295
+ error_text = str(e)
296
+ if is_rate_limited(error_text) or is_api_key_error(error_text):
297
+ self.skipTest(f"{provider} error: {error_text}")
298
+ raise
299
+
300
+ def check_response_and_skip(self, response: Any, provider: str = "LLM") -> Any:
301
+ """
302
+ Check response content and skip test if provider errors are detected.
303
+
304
+ Args:
305
+ response: The response object from agent method
306
+ provider: Provider name for error messages
307
+
308
+ Returns:
309
+ The response object if no errors detected
310
+
311
+ Raises:
312
+ unittest.SkipTest: If rate limiting or API key errors detected in response
313
+ """
314
+ response_text = getattr(response, "response", str(response))
315
+ skip_if_provider_error(self, response_text, provider)
316
+ return response
tests/endpoint.py CHANGED
@@ -1,47 +1,84 @@
1
- from openai import OpenAI
2
- from flask import Flask, request, jsonify
1
+ #!/usr/bin/env python3
2
+ import json
3
3
  import logging
4
4
  from functools import wraps
5
+ from flask import Flask, request, Response, jsonify
6
+ from openai import OpenAI
5
7
 
6
8
  app = Flask(__name__)
7
- app.config['TESTING'] = True
8
9
 
9
- log = logging.getLogger('werkzeug')
10
- log.setLevel(logging.ERROR)
11
-
12
- # Set your OpenAI API key (ensure you've set this in your environment)
10
+ # Configure logging
11
+ logging.basicConfig(level=logging.INFO)
12
+ app.logger.setLevel(logging.INFO)
13
+ werkzeug_log = logging.getLogger('werkzeug')
14
+ werkzeug_log.setLevel(logging.ERROR)
13
15
 
16
+ # Load expected API key from environment (fallback for testing)
14
17
  EXPECTED_API_KEY = "TEST_API_KEY"
15
18
 
19
+ # Authentication decorator
16
20
  def require_api_key(f):
17
21
  @wraps(f)
18
22
  def decorated_function(*args, **kwargs):
19
- api_key = request.headers.get("Authorization").split("Bearer ")[-1]
20
- if not api_key or api_key != EXPECTED_API_KEY:
23
+ auth_header = request.headers.get("Authorization", "")
24
+ if not auth_header.startswith("Bearer "):
25
+ return jsonify({"error": "Unauthorized"}), 401
26
+ api_key = auth_header.split(" ", 1)[1]
27
+ if api_key != EXPECTED_API_KEY:
21
28
  return jsonify({"error": "Unauthorized"}), 401
22
29
  return f(*args, **kwargs)
23
30
  return decorated_function
24
31
 
25
32
  @app.before_request
26
33
  def log_request_info():
27
- app.logger.info("Request received: %s %s", request.method, request.path)
34
+ app.logger.info("%s %s", request.method, request.path)
28
35
 
29
36
  @app.route("/v1/chat/completions", methods=["POST"])
30
37
  @require_api_key
31
38
  def chat_completions():
32
- app.logger.info("Received request on /v1/chat/completions")
33
- data = request.get_json()
34
- if not data:
39
+ """
40
+ Proxy endpoint for OpenAI Chat Completions.
41
+ Supports both streaming and non-streaming modes.
42
+ """
43
+ try:
44
+ data = request.get_json(force=True)
45
+ except Exception:
35
46
  return jsonify({"error": "Invalid JSON payload"}), 400
36
47
 
37
48
  client = OpenAI()
49
+ is_stream = data.get('stream', False)
50
+
51
+ if is_stream:
52
+ # Stream each chunk to the client as Server-Sent Events
53
+ def generate():
54
+ try:
55
+ for chunk in client.chat.completions.create(**data):
56
+ # Convert chunk to dict and then JSON
57
+ event = chunk.model_dump()
58
+ yield f"data: {json.dumps(event)}\n\n"
59
+ except Exception as e:
60
+ # On error, send an SSE event with error info
61
+ error_msg = json.dumps({"error": str(e)})
62
+ yield f"data: {error_msg}\n\n"
63
+
64
+ headers = {
65
+ 'Content-Type': 'text/event-stream',
66
+ 'Cache-Control': 'no-cache',
67
+ 'Connection': 'keep-alive'
68
+ }
69
+ return Response(generate(), headers=headers)
70
+
71
+ # Non-streaming path
38
72
  try:
39
73
  completion = client.chat.completions.create(**data)
40
- return jsonify(completion.model_dump()), 200
74
+ result = completion.model_dump()
75
+ app.logger.info(f"Non-stream response: {result}")
76
+ return jsonify(result), 200
41
77
  except Exception as e:
42
- return jsonify({"error": str(e)}), 400
78
+ app.logger.error(f"Error during completion: {e}")
79
+ return jsonify({"error": str(e)}), 500
43
80
 
44
81
 
45
82
  if __name__ == "__main__":
46
- # Run on port 5000 by default; adjust as needed.
47
- app.run(debug=True, port=5000, use_reloader=False)
83
+ # Bind to all interfaces on port 5000
84
+ app.run(host="0.0.0.0", port=5000, debug=False, use_reloader=False)
tests/run_tests.py ADDED
@@ -0,0 +1,112 @@
1
+ #!/usr/bin/env python3
2
+ """
3
+ Custom test runner that suppresses Pydantic deprecation warnings.
4
+ Usage: python run_tests.py [test_pattern]
5
+ """
6
+
7
+ import sys
8
+ import warnings
9
+ import unittest
10
+ import argparse
11
+ import asyncio
12
+ import gc
13
+
14
+
15
+ def suppress_pydantic_warnings():
16
+ """Comprehensive warning suppression before unittest starts."""
17
+ # Multiple layers of suppression
18
+ warnings.resetwarnings()
19
+ warnings.simplefilter("ignore", DeprecationWarning)
20
+
21
+ # Specific Pydantic patterns
22
+ pydantic_patterns = [
23
+ ".*PydanticDeprecatedSince.*",
24
+ ".*__fields__.*deprecated.*",
25
+ ".*__fields_set__.*deprecated.*",
26
+ ".*model_fields.*deprecated.*",
27
+ ".*model_computed_fields.*deprecated.*",
28
+ ".*use.*model_fields.*instead.*",
29
+ ".*use.*model_fields_set.*instead.*",
30
+ ".*Accessing.*model_.*attribute.*deprecated.*",
31
+ ]
32
+
33
+ # Resource warning patterns (reduce noise, not critical)
34
+ resource_patterns = [
35
+ ".*unclosed transport.*",
36
+ ".*unclosed <socket\\.socket.*",
37
+ ".*unclosed event loop.*",
38
+ ]
39
+
40
+ for pattern in pydantic_patterns:
41
+ warnings.filterwarnings("ignore", category=DeprecationWarning, message=pattern)
42
+
43
+ for pattern in resource_patterns:
44
+ warnings.filterwarnings("ignore", category=ResourceWarning, message=pattern)
45
+
46
+
47
+ def main():
48
+ parser = argparse.ArgumentParser(description="Run tests with warning suppression")
49
+ parser.add_argument(
50
+ "pattern",
51
+ nargs="?",
52
+ default="test_*.py",
53
+ help="Test file pattern (default: test_*.py)",
54
+ )
55
+ parser.add_argument("--verbose", "-v", action="store_true", help="Verbose output")
56
+
57
+ args = parser.parse_args()
58
+
59
+ # Apply comprehensive warning suppression BEFORE unittest starts
60
+ suppress_pydantic_warnings()
61
+
62
+ print(f"🧪 Running tests with pattern: {args.pattern}")
63
+ print("🔇 Pydantic deprecation warnings suppressed")
64
+
65
+ # Add tests directory to Python path for relative imports
66
+ import os
67
+ sys.path.insert(0, os.path.abspath("tests"))
68
+
69
+ # Discover and run tests
70
+ loader = unittest.TestLoader()
71
+ start_dir = "tests"
72
+ suite = loader.discover(start_dir, pattern=args.pattern)
73
+
74
+ # Run tests
75
+ verbosity = 2 if args.verbose else 1
76
+ runner = unittest.TextTestRunner(verbosity=verbosity)
77
+ result = runner.run(suite)
78
+
79
+ # Cleanup to reduce resource warnings
80
+ try:
81
+ # Close any remaining event loops
82
+ loop = None
83
+ try:
84
+ loop = asyncio.get_running_loop()
85
+ except RuntimeError:
86
+ pass
87
+
88
+ if loop and not loop.is_closed():
89
+ # Cancel all pending tasks
90
+ pending = asyncio.all_tasks(loop)
91
+ for task in pending:
92
+ task.cancel()
93
+
94
+ # Give tasks a chance to complete cancellation
95
+ if pending:
96
+ loop.run_until_complete(
97
+ asyncio.gather(*pending, return_exceptions=True)
98
+ )
99
+
100
+ # Force garbage collection
101
+ gc.collect()
102
+
103
+ except Exception:
104
+ # Don't let cleanup errors affect test results
105
+ pass
106
+
107
+ # Exit with proper code
108
+ sys.exit(0 if result.wasSuccessful() else 1)
109
+
110
+
111
+ if __name__ == "__main__":
112
+ main()
tests/test_agent.py CHANGED
@@ -1,18 +1,19 @@
1
+ # Suppress external dependency warnings before any other imports
2
+ import warnings
3
+ warnings.simplefilter("ignore", DeprecationWarning)
4
+
1
5
  import unittest
2
6
  import threading
3
7
  from datetime import date
4
8
 
5
- from vectara_agentic.agent import _get_prompt, Agent, AgentType
9
+ from vectara_agentic.agent import Agent, AgentType
10
+ from vectara_agentic.agent_core.factory import format_prompt
6
11
  from vectara_agentic.agent_config import AgentConfig
7
12
  from vectara_agentic.types import ModelProvider, ObserverType
8
13
  from vectara_agentic.tools import ToolsFactory
9
14
 
10
- from vectara_agentic._prompts import GENERAL_INSTRUCTIONS
11
-
12
-
13
- def mult(x: float, y: float) -> float:
14
- "Multiply two numbers"
15
- return x * y
15
+ from vectara_agentic.agent_core.prompts import GENERAL_INSTRUCTIONS
16
+ from conftest import mult, STANDARD_TEST_TOPIC, STANDARD_TEST_INSTRUCTIONS
16
17
 
17
18
 
18
19
  ARIZE_LOCK = threading.Lock()
@@ -28,19 +29,17 @@ class TestAgentPackage(unittest.TestCase):
28
29
  + " with Always do as your mother tells you!"
29
30
  )
30
31
  self.assertEqual(
31
- _get_prompt(prompt_template, GENERAL_INSTRUCTIONS, topic, custom_instructions), expected_output
32
+ format_prompt(prompt_template, GENERAL_INSTRUCTIONS, topic, custom_instructions), expected_output
32
33
  )
33
34
 
34
35
  def test_agent_init(self):
35
36
  tools = [ToolsFactory().create_tool(mult)]
36
- topic = "AI"
37
- custom_instructions = "Always do as your mother tells you!"
38
- agent = Agent(tools, topic, custom_instructions)
39
- self.assertEqual(agent.agent_type, AgentType.OPENAI)
40
- self.assertEqual(agent._topic, topic)
41
- self.assertEqual(agent._custom_instructions, custom_instructions)
37
+ agent = Agent(tools, STANDARD_TEST_TOPIC, STANDARD_TEST_INSTRUCTIONS)
38
+ self.assertEqual(agent.agent_type, AgentType.FUNCTION_CALLING)
39
+ self.assertEqual(agent._topic, STANDARD_TEST_TOPIC)
40
+ self.assertEqual(agent._custom_instructions, STANDARD_TEST_INSTRUCTIONS)
42
41
 
43
- # To run this test, you must have OPENAI_API_KEY in your environment
42
+ # To run this test, you must have appropriate API key in your environment
44
43
  self.assertEqual(
45
44
  agent.chat(
46
45
  "What is 5 times 10. Only give the answer, nothing else"
@@ -51,8 +50,6 @@ class TestAgentPackage(unittest.TestCase):
51
50
  def test_agent_config(self):
52
51
  with ARIZE_LOCK:
53
52
  tools = [ToolsFactory().create_tool(mult)]
54
- topic = "AI topic"
55
- instructions = "Always do as your father tells you, if your mother agrees!"
56
53
  config = AgentConfig(
57
54
  agent_type=AgentType.REACT,
58
55
  main_llm_provider=ModelProvider.ANTHROPIC,
@@ -64,12 +61,12 @@ class TestAgentPackage(unittest.TestCase):
64
61
 
65
62
  agent = Agent(
66
63
  tools=tools,
67
- topic=topic,
68
- custom_instructions=instructions,
64
+ topic=STANDARD_TEST_TOPIC,
65
+ custom_instructions=STANDARD_TEST_INSTRUCTIONS,
69
66
  agent_config=config
70
67
  )
71
- self.assertEqual(agent._topic, topic)
72
- self.assertEqual(agent._custom_instructions, instructions)
68
+ self.assertEqual(agent._topic, STANDARD_TEST_TOPIC)
69
+ self.assertEqual(agent._custom_instructions, STANDARD_TEST_INSTRUCTIONS)
73
70
  self.assertEqual(agent.agent_type, AgentType.REACT)
74
71
  self.assertEqual(agent.agent_config.observer, ObserverType.ARIZE_PHOENIX)
75
72
  self.assertEqual(agent.agent_config.main_llm_provider, ModelProvider.ANTHROPIC)
@@ -84,19 +81,20 @@ class TestAgentPackage(unittest.TestCase):
84
81
  )
85
82
 
86
83
  def test_multiturn(self):
87
- tools = [ToolsFactory().create_tool(mult)]
88
- topic = "AI topic"
89
- instructions = "Always do as your father tells you, if your mother agrees!"
90
- agent = Agent(
91
- tools=tools,
92
- topic=topic,
93
- custom_instructions=instructions,
94
- )
84
+ with ARIZE_LOCK:
85
+ tools = [ToolsFactory().create_tool(mult)]
86
+ topic = "AI topic"
87
+ instructions = "Always do as your father tells you, if your mother agrees!"
88
+ agent = Agent(
89
+ tools=tools,
90
+ topic=topic,
91
+ custom_instructions=instructions,
92
+ )
95
93
 
96
- agent.chat("What is 5 times 10. Only give the answer, nothing else")
97
- agent.chat("what is 3 times 7. Only give the answer, nothing else")
98
- res = agent.chat("multiply the results of the last two questions. Output only the answer.")
99
- self.assertEqual(res.response, "1050")
94
+ agent.chat("What is 5 times 10. Only give the answer, nothing else")
95
+ agent.chat("what is 3 times 7. Only give the answer, nothing else")
96
+ res = agent.chat("multiply the results of the last two questions. Output only the answer.")
97
+ self.assertEqual(res.response, "1050")
100
98
 
101
99
  def test_from_corpus(self):
102
100
  agent = Agent.from_corpus(
@@ -121,6 +119,10 @@ class TestAgentPackage(unittest.TestCase):
121
119
  chat_history=[("What is 5 times 10", "50"), ("What is 3 times 7", "21")]
122
120
  )
123
121
 
122
+ data = agent.dumps()
123
+ clone = Agent.loads(data)
124
+ assert clone.memory.get() == agent.memory.get()
125
+
124
126
  res = agent.chat("multiply the results of the last two questions. Output only the answer.")
125
127
  self.assertEqual(res.response, "1050")
126
128