vectara-agentic 0.3.3__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of vectara-agentic might be problematic. Click here for more details.
- tests/__init__.py +7 -0
- tests/conftest.py +316 -0
- tests/endpoint.py +54 -17
- tests/run_tests.py +112 -0
- tests/test_agent.py +35 -33
- tests/test_agent_fallback_memory.py +270 -0
- tests/test_agent_memory_consistency.py +229 -0
- tests/test_agent_type.py +86 -143
- tests/test_api_endpoint.py +4 -0
- tests/test_bedrock.py +50 -31
- tests/test_fallback.py +4 -0
- tests/test_gemini.py +27 -59
- tests/test_groq.py +50 -31
- tests/test_private_llm.py +11 -2
- tests/test_return_direct.py +6 -2
- tests/test_serialization.py +7 -6
- tests/test_session_memory.py +252 -0
- tests/test_streaming.py +109 -0
- tests/test_together.py +62 -0
- tests/test_tools.py +10 -82
- tests/test_vectara_llms.py +4 -0
- tests/test_vhc.py +67 -0
- tests/test_workflow.py +13 -28
- vectara_agentic/__init__.py +27 -4
- vectara_agentic/_callback.py +65 -67
- vectara_agentic/_observability.py +30 -30
- vectara_agentic/_version.py +1 -1
- vectara_agentic/agent.py +565 -859
- vectara_agentic/agent_config.py +15 -14
- vectara_agentic/agent_core/__init__.py +22 -0
- vectara_agentic/agent_core/factory.py +383 -0
- vectara_agentic/{_prompts.py → agent_core/prompts.py} +21 -46
- vectara_agentic/agent_core/serialization.py +348 -0
- vectara_agentic/agent_core/streaming.py +483 -0
- vectara_agentic/agent_core/utils/__init__.py +29 -0
- vectara_agentic/agent_core/utils/hallucination.py +157 -0
- vectara_agentic/agent_core/utils/logging.py +52 -0
- vectara_agentic/agent_core/utils/schemas.py +87 -0
- vectara_agentic/agent_core/utils/tools.py +125 -0
- vectara_agentic/agent_endpoint.py +4 -6
- vectara_agentic/db_tools.py +37 -12
- vectara_agentic/llm_utils.py +42 -43
- vectara_agentic/sub_query_workflow.py +9 -14
- vectara_agentic/tool_utils.py +138 -83
- vectara_agentic/tools.py +36 -21
- vectara_agentic/tools_catalog.py +16 -16
- vectara_agentic/types.py +106 -8
- {vectara_agentic-0.3.3.dist-info → vectara_agentic-0.4.1.dist-info}/METADATA +111 -31
- vectara_agentic-0.4.1.dist-info/RECORD +53 -0
- tests/test_agent_planning.py +0 -64
- tests/test_hhem.py +0 -100
- vectara_agentic/hhem.py +0 -82
- vectara_agentic-0.3.3.dist-info/RECORD +0 -39
- {vectara_agentic-0.3.3.dist-info → vectara_agentic-0.4.1.dist-info}/WHEEL +0 -0
- {vectara_agentic-0.3.3.dist-info → vectara_agentic-0.4.1.dist-info}/licenses/LICENSE +0 -0
- {vectara_agentic-0.3.3.dist-info → vectara_agentic-0.4.1.dist-info}/top_level.txt +0 -0
tests/__init__.py
CHANGED
tests/conftest.py
ADDED
|
@@ -0,0 +1,316 @@
|
|
|
1
|
+
# Suppress external dependency warnings before any other imports
|
|
2
|
+
import warnings
|
|
3
|
+
|
|
4
|
+
warnings.simplefilter("ignore", DeprecationWarning)
|
|
5
|
+
|
|
6
|
+
"""
|
|
7
|
+
Common test utilities, configurations, and fixtures for the vectara-agentic test suite.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import unittest
|
|
11
|
+
from contextlib import contextmanager
|
|
12
|
+
from typing import Any
|
|
13
|
+
|
|
14
|
+
from vectara_agentic.agent_config import AgentConfig
|
|
15
|
+
from vectara_agentic.types import AgentType, ModelProvider
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
# ========================================
|
|
19
|
+
# Common Test Functions
|
|
20
|
+
# ========================================
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def mult(x: float, y: float) -> float:
|
|
24
|
+
"""Multiply two numbers - common test function used across multiple test files."""
|
|
25
|
+
return x * y
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def add(x: float, y: float) -> float:
|
|
29
|
+
"""Add two numbers - common test function used in workflow tests."""
|
|
30
|
+
return x + y
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
# ========================================
|
|
34
|
+
# Common Test Data
|
|
35
|
+
# ========================================
|
|
36
|
+
|
|
37
|
+
# Standard test topic used across most tests
|
|
38
|
+
STANDARD_TEST_TOPIC = "AI topic"
|
|
39
|
+
|
|
40
|
+
# Standard test instructions used across most tests
|
|
41
|
+
STANDARD_TEST_INSTRUCTIONS = (
|
|
42
|
+
"Always do as your father tells you, if your mother agrees!"
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
# Alternative instructions for specific tests
|
|
46
|
+
WORKFLOW_TEST_INSTRUCTIONS = "You are a helpful AI assistant."
|
|
47
|
+
MATH_AGENT_INSTRUCTIONS = "you are an agent specializing in math, assisting a user."
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
# ========================================
|
|
51
|
+
# Agent Configuration Objects
|
|
52
|
+
# ========================================
|
|
53
|
+
|
|
54
|
+
# Default configurations
|
|
55
|
+
default_config = AgentConfig()
|
|
56
|
+
|
|
57
|
+
# Function Calling configurations for all providers
|
|
58
|
+
fc_config_anthropic = AgentConfig(
|
|
59
|
+
agent_type=AgentType.FUNCTION_CALLING,
|
|
60
|
+
main_llm_provider=ModelProvider.ANTHROPIC,
|
|
61
|
+
tool_llm_provider=ModelProvider.ANTHROPIC,
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
fc_config_gemini = AgentConfig(
|
|
65
|
+
agent_type=AgentType.FUNCTION_CALLING,
|
|
66
|
+
main_llm_provider=ModelProvider.GEMINI,
|
|
67
|
+
tool_llm_provider=ModelProvider.GEMINI,
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
fc_config_together = AgentConfig(
|
|
71
|
+
agent_type=AgentType.FUNCTION_CALLING,
|
|
72
|
+
main_llm_provider=ModelProvider.TOGETHER,
|
|
73
|
+
tool_llm_provider=ModelProvider.TOGETHER,
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
fc_config_openai = AgentConfig(
|
|
77
|
+
agent_type=AgentType.FUNCTION_CALLING,
|
|
78
|
+
main_llm_provider=ModelProvider.OPENAI,
|
|
79
|
+
tool_llm_provider=ModelProvider.OPENAI,
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
fc_config_groq = AgentConfig(
|
|
83
|
+
agent_type=AgentType.FUNCTION_CALLING,
|
|
84
|
+
main_llm_provider=ModelProvider.GROQ,
|
|
85
|
+
tool_llm_provider=ModelProvider.GROQ,
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
fc_config_bedrock = AgentConfig(
|
|
89
|
+
agent_type=AgentType.FUNCTION_CALLING,
|
|
90
|
+
main_llm_provider=ModelProvider.BEDROCK,
|
|
91
|
+
tool_llm_provider=ModelProvider.BEDROCK,
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
# ReAct configurations for all providers
|
|
95
|
+
react_config_anthropic = AgentConfig(
|
|
96
|
+
agent_type=AgentType.REACT,
|
|
97
|
+
main_llm_provider=ModelProvider.ANTHROPIC,
|
|
98
|
+
tool_llm_provider=ModelProvider.ANTHROPIC,
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
react_config_gemini = AgentConfig(
|
|
102
|
+
agent_type=AgentType.REACT,
|
|
103
|
+
main_llm_provider=ModelProvider.GEMINI,
|
|
104
|
+
main_llm_model_name="models/gemini-2.5-flash",
|
|
105
|
+
tool_llm_provider=ModelProvider.GEMINI,
|
|
106
|
+
tool_llm_model_name="models/gemini-2.5-flash",
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
react_config_together = AgentConfig(
|
|
110
|
+
agent_type=AgentType.REACT,
|
|
111
|
+
main_llm_provider=ModelProvider.TOGETHER,
|
|
112
|
+
tool_llm_provider=ModelProvider.TOGETHER,
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
react_config_groq = AgentConfig(
|
|
116
|
+
agent_type=AgentType.REACT,
|
|
117
|
+
main_llm_provider=ModelProvider.GROQ,
|
|
118
|
+
tool_llm_provider=ModelProvider.GROQ,
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
# Private LLM configurations
|
|
122
|
+
private_llm_react_config = AgentConfig(
|
|
123
|
+
agent_type=AgentType.REACT,
|
|
124
|
+
main_llm_provider=ModelProvider.PRIVATE,
|
|
125
|
+
main_llm_model_name="gpt-4o",
|
|
126
|
+
private_llm_api_base="http://localhost:8000/v1",
|
|
127
|
+
tool_llm_provider=ModelProvider.PRIVATE,
|
|
128
|
+
tool_llm_model_name="gpt-4o",
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
private_llm_fc_config = AgentConfig(
|
|
132
|
+
agent_type=AgentType.FUNCTION_CALLING,
|
|
133
|
+
main_llm_provider=ModelProvider.PRIVATE,
|
|
134
|
+
main_llm_model_name="gpt-4.1",
|
|
135
|
+
private_llm_api_base="http://localhost:8000/v1",
|
|
136
|
+
tool_llm_provider=ModelProvider.PRIVATE,
|
|
137
|
+
tool_llm_model_name="gpt-4.1",
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
# ========================================
|
|
142
|
+
# Error Detection and Testing Utilities
|
|
143
|
+
# ========================================
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def is_rate_limited(response_text: str) -> bool:
|
|
147
|
+
"""
|
|
148
|
+
Check if a response indicates a rate limit error from any LLM provider.
|
|
149
|
+
|
|
150
|
+
Args:
|
|
151
|
+
response_text: The response text from the agent
|
|
152
|
+
|
|
153
|
+
Returns:
|
|
154
|
+
bool: True if the response indicates rate limiting
|
|
155
|
+
"""
|
|
156
|
+
rate_limit_indicators = [
|
|
157
|
+
# Generic indicators
|
|
158
|
+
"Error code: 429",
|
|
159
|
+
"rate_limit_exceeded",
|
|
160
|
+
"Rate limit reached",
|
|
161
|
+
"rate limit",
|
|
162
|
+
"quota exceeded",
|
|
163
|
+
"usage limit",
|
|
164
|
+
# GROQ-specific
|
|
165
|
+
"tokens per day",
|
|
166
|
+
"TPD",
|
|
167
|
+
"service tier",
|
|
168
|
+
"on_demand",
|
|
169
|
+
"deepseek-r1-distill-llama-70b",
|
|
170
|
+
"Upgrade to Dev Tier",
|
|
171
|
+
"console.groq.com/settings/billing",
|
|
172
|
+
# OpenAI-specific
|
|
173
|
+
"requests per minute",
|
|
174
|
+
"RPM",
|
|
175
|
+
"tokens per minute",
|
|
176
|
+
"TPM",
|
|
177
|
+
# Anthropic-specific
|
|
178
|
+
"overloaded_error",
|
|
179
|
+
"Overloaded",
|
|
180
|
+
"APIStatusError",
|
|
181
|
+
"anthropic.APIStatusError",
|
|
182
|
+
"usage_limit_exceeded",
|
|
183
|
+
# General API limit indicators
|
|
184
|
+
"try again in",
|
|
185
|
+
"Please wait",
|
|
186
|
+
"Too many requests",
|
|
187
|
+
"throttled",
|
|
188
|
+
# Additional rate limit patterns
|
|
189
|
+
"Limit.*Used.*Requested",
|
|
190
|
+
"Need more tokens",
|
|
191
|
+
]
|
|
192
|
+
|
|
193
|
+
response_lower = response_text.lower()
|
|
194
|
+
return any(
|
|
195
|
+
indicator.lower() in response_lower for indicator in rate_limit_indicators
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
def is_api_key_error(response_text: str) -> bool:
|
|
200
|
+
"""
|
|
201
|
+
Check if a response indicates an API key authentication error.
|
|
202
|
+
|
|
203
|
+
Args:
|
|
204
|
+
response_text: The response text from the agent
|
|
205
|
+
|
|
206
|
+
Returns:
|
|
207
|
+
bool: True if the response indicates API key issues
|
|
208
|
+
"""
|
|
209
|
+
api_key_indicators = [
|
|
210
|
+
"Error code: 401",
|
|
211
|
+
"Invalid API Key",
|
|
212
|
+
"authentication",
|
|
213
|
+
"unauthorized",
|
|
214
|
+
"invalid_api_key",
|
|
215
|
+
"missing api key",
|
|
216
|
+
"api key not found",
|
|
217
|
+
]
|
|
218
|
+
|
|
219
|
+
response_lower = response_text.lower()
|
|
220
|
+
return any(indicator.lower() in response_lower for indicator in api_key_indicators)
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
def skip_if_rate_limited(
|
|
224
|
+
test_instance: unittest.TestCase, response_text: str, provider: str = "LLM"
|
|
225
|
+
) -> None:
|
|
226
|
+
"""
|
|
227
|
+
Skip a test if the response indicates rate limiting.
|
|
228
|
+
|
|
229
|
+
Args:
|
|
230
|
+
test_instance: The test case instance
|
|
231
|
+
response_text: The response text to check
|
|
232
|
+
provider: The name of the provider (for clearer skip messages)
|
|
233
|
+
"""
|
|
234
|
+
if is_rate_limited(response_text):
|
|
235
|
+
test_instance.skipTest(f"{provider} rate limit reached - skipping test")
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
def skip_if_api_key_error(
|
|
239
|
+
test_instance: unittest.TestCase, response_text: str, provider: str = "LLM"
|
|
240
|
+
) -> None:
|
|
241
|
+
"""
|
|
242
|
+
Skip a test if the response indicates API key issues.
|
|
243
|
+
|
|
244
|
+
Args:
|
|
245
|
+
test_instance: The test case instance
|
|
246
|
+
response_text: The response text to check
|
|
247
|
+
provider: The name of the provider (for clearer skip messages)
|
|
248
|
+
"""
|
|
249
|
+
if is_api_key_error(response_text):
|
|
250
|
+
test_instance.skipTest(f"{provider} API key invalid/missing - skipping test")
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
def skip_if_provider_error(
|
|
254
|
+
test_instance: unittest.TestCase, response_text: str, provider: str = "LLM"
|
|
255
|
+
) -> None:
|
|
256
|
+
"""
|
|
257
|
+
Skip a test if the response indicates common provider errors (rate limiting or API key issues).
|
|
258
|
+
|
|
259
|
+
Args:
|
|
260
|
+
test_instance: The test case instance
|
|
261
|
+
response_text: The response text to check
|
|
262
|
+
provider: The name of the provider (for clearer skip messages)
|
|
263
|
+
"""
|
|
264
|
+
skip_if_rate_limited(test_instance, response_text, provider)
|
|
265
|
+
skip_if_api_key_error(test_instance, response_text, provider)
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
class AgentTestMixin:
|
|
269
|
+
"""
|
|
270
|
+
Mixin class providing utility methods for agent testing.
|
|
271
|
+
"""
|
|
272
|
+
|
|
273
|
+
@contextmanager
|
|
274
|
+
def with_provider_fallback(self, provider: str = "LLM"):
|
|
275
|
+
"""
|
|
276
|
+
Context manager that catches and handles provider errors from any agent method.
|
|
277
|
+
|
|
278
|
+
Args:
|
|
279
|
+
provider: Provider name for error messages
|
|
280
|
+
|
|
281
|
+
Usage:
|
|
282
|
+
with self.with_provider_fallback("GROQ"):
|
|
283
|
+
response = agent.chat("test")
|
|
284
|
+
|
|
285
|
+
with self.with_provider_fallback("GROQ"):
|
|
286
|
+
async for chunk in agent.astream_chat("test"):
|
|
287
|
+
pass
|
|
288
|
+
|
|
289
|
+
Raises:
|
|
290
|
+
unittest.SkipTest: If rate limiting or API key errors occur
|
|
291
|
+
"""
|
|
292
|
+
try:
|
|
293
|
+
yield
|
|
294
|
+
except Exception as e:
|
|
295
|
+
error_text = str(e)
|
|
296
|
+
if is_rate_limited(error_text) or is_api_key_error(error_text):
|
|
297
|
+
self.skipTest(f"{provider} error: {error_text}")
|
|
298
|
+
raise
|
|
299
|
+
|
|
300
|
+
def check_response_and_skip(self, response: Any, provider: str = "LLM") -> Any:
|
|
301
|
+
"""
|
|
302
|
+
Check response content and skip test if provider errors are detected.
|
|
303
|
+
|
|
304
|
+
Args:
|
|
305
|
+
response: The response object from agent method
|
|
306
|
+
provider: Provider name for error messages
|
|
307
|
+
|
|
308
|
+
Returns:
|
|
309
|
+
The response object if no errors detected
|
|
310
|
+
|
|
311
|
+
Raises:
|
|
312
|
+
unittest.SkipTest: If rate limiting or API key errors detected in response
|
|
313
|
+
"""
|
|
314
|
+
response_text = getattr(response, "response", str(response))
|
|
315
|
+
skip_if_provider_error(self, response_text, provider)
|
|
316
|
+
return response
|
tests/endpoint.py
CHANGED
|
@@ -1,47 +1,84 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
import json
|
|
3
3
|
import logging
|
|
4
4
|
from functools import wraps
|
|
5
|
+
from flask import Flask, request, Response, jsonify
|
|
6
|
+
from openai import OpenAI
|
|
5
7
|
|
|
6
8
|
app = Flask(__name__)
|
|
7
|
-
app.config['TESTING'] = True
|
|
8
9
|
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
10
|
+
# Configure logging
|
|
11
|
+
logging.basicConfig(level=logging.INFO)
|
|
12
|
+
app.logger.setLevel(logging.INFO)
|
|
13
|
+
werkzeug_log = logging.getLogger('werkzeug')
|
|
14
|
+
werkzeug_log.setLevel(logging.ERROR)
|
|
13
15
|
|
|
16
|
+
# Load expected API key from environment (fallback for testing)
|
|
14
17
|
EXPECTED_API_KEY = "TEST_API_KEY"
|
|
15
18
|
|
|
19
|
+
# Authentication decorator
|
|
16
20
|
def require_api_key(f):
|
|
17
21
|
@wraps(f)
|
|
18
22
|
def decorated_function(*args, **kwargs):
|
|
19
|
-
|
|
20
|
-
if not
|
|
23
|
+
auth_header = request.headers.get("Authorization", "")
|
|
24
|
+
if not auth_header.startswith("Bearer "):
|
|
25
|
+
return jsonify({"error": "Unauthorized"}), 401
|
|
26
|
+
api_key = auth_header.split(" ", 1)[1]
|
|
27
|
+
if api_key != EXPECTED_API_KEY:
|
|
21
28
|
return jsonify({"error": "Unauthorized"}), 401
|
|
22
29
|
return f(*args, **kwargs)
|
|
23
30
|
return decorated_function
|
|
24
31
|
|
|
25
32
|
@app.before_request
|
|
26
33
|
def log_request_info():
|
|
27
|
-
app.logger.info("
|
|
34
|
+
app.logger.info("%s %s", request.method, request.path)
|
|
28
35
|
|
|
29
36
|
@app.route("/v1/chat/completions", methods=["POST"])
|
|
30
37
|
@require_api_key
|
|
31
38
|
def chat_completions():
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
39
|
+
"""
|
|
40
|
+
Proxy endpoint for OpenAI Chat Completions.
|
|
41
|
+
Supports both streaming and non-streaming modes.
|
|
42
|
+
"""
|
|
43
|
+
try:
|
|
44
|
+
data = request.get_json(force=True)
|
|
45
|
+
except Exception:
|
|
35
46
|
return jsonify({"error": "Invalid JSON payload"}), 400
|
|
36
47
|
|
|
37
48
|
client = OpenAI()
|
|
49
|
+
is_stream = data.get('stream', False)
|
|
50
|
+
|
|
51
|
+
if is_stream:
|
|
52
|
+
# Stream each chunk to the client as Server-Sent Events
|
|
53
|
+
def generate():
|
|
54
|
+
try:
|
|
55
|
+
for chunk in client.chat.completions.create(**data):
|
|
56
|
+
# Convert chunk to dict and then JSON
|
|
57
|
+
event = chunk.model_dump()
|
|
58
|
+
yield f"data: {json.dumps(event)}\n\n"
|
|
59
|
+
except Exception as e:
|
|
60
|
+
# On error, send an SSE event with error info
|
|
61
|
+
error_msg = json.dumps({"error": str(e)})
|
|
62
|
+
yield f"data: {error_msg}\n\n"
|
|
63
|
+
|
|
64
|
+
headers = {
|
|
65
|
+
'Content-Type': 'text/event-stream',
|
|
66
|
+
'Cache-Control': 'no-cache',
|
|
67
|
+
'Connection': 'keep-alive'
|
|
68
|
+
}
|
|
69
|
+
return Response(generate(), headers=headers)
|
|
70
|
+
|
|
71
|
+
# Non-streaming path
|
|
38
72
|
try:
|
|
39
73
|
completion = client.chat.completions.create(**data)
|
|
40
|
-
|
|
74
|
+
result = completion.model_dump()
|
|
75
|
+
app.logger.info(f"Non-stream response: {result}")
|
|
76
|
+
return jsonify(result), 200
|
|
41
77
|
except Exception as e:
|
|
42
|
-
|
|
78
|
+
app.logger.error(f"Error during completion: {e}")
|
|
79
|
+
return jsonify({"error": str(e)}), 500
|
|
43
80
|
|
|
44
81
|
|
|
45
82
|
if __name__ == "__main__":
|
|
46
|
-
#
|
|
47
|
-
app.run(
|
|
83
|
+
# Bind to all interfaces on port 5000
|
|
84
|
+
app.run(host="0.0.0.0", port=5000, debug=False, use_reloader=False)
|
tests/run_tests.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""
|
|
3
|
+
Custom test runner that suppresses Pydantic deprecation warnings.
|
|
4
|
+
Usage: python run_tests.py [test_pattern]
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import sys
|
|
8
|
+
import warnings
|
|
9
|
+
import unittest
|
|
10
|
+
import argparse
|
|
11
|
+
import asyncio
|
|
12
|
+
import gc
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def suppress_pydantic_warnings():
|
|
16
|
+
"""Comprehensive warning suppression before unittest starts."""
|
|
17
|
+
# Multiple layers of suppression
|
|
18
|
+
warnings.resetwarnings()
|
|
19
|
+
warnings.simplefilter("ignore", DeprecationWarning)
|
|
20
|
+
|
|
21
|
+
# Specific Pydantic patterns
|
|
22
|
+
pydantic_patterns = [
|
|
23
|
+
".*PydanticDeprecatedSince.*",
|
|
24
|
+
".*__fields__.*deprecated.*",
|
|
25
|
+
".*__fields_set__.*deprecated.*",
|
|
26
|
+
".*model_fields.*deprecated.*",
|
|
27
|
+
".*model_computed_fields.*deprecated.*",
|
|
28
|
+
".*use.*model_fields.*instead.*",
|
|
29
|
+
".*use.*model_fields_set.*instead.*",
|
|
30
|
+
".*Accessing.*model_.*attribute.*deprecated.*",
|
|
31
|
+
]
|
|
32
|
+
|
|
33
|
+
# Resource warning patterns (reduce noise, not critical)
|
|
34
|
+
resource_patterns = [
|
|
35
|
+
".*unclosed transport.*",
|
|
36
|
+
".*unclosed <socket\\.socket.*",
|
|
37
|
+
".*unclosed event loop.*",
|
|
38
|
+
]
|
|
39
|
+
|
|
40
|
+
for pattern in pydantic_patterns:
|
|
41
|
+
warnings.filterwarnings("ignore", category=DeprecationWarning, message=pattern)
|
|
42
|
+
|
|
43
|
+
for pattern in resource_patterns:
|
|
44
|
+
warnings.filterwarnings("ignore", category=ResourceWarning, message=pattern)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def main():
|
|
48
|
+
parser = argparse.ArgumentParser(description="Run tests with warning suppression")
|
|
49
|
+
parser.add_argument(
|
|
50
|
+
"pattern",
|
|
51
|
+
nargs="?",
|
|
52
|
+
default="test_*.py",
|
|
53
|
+
help="Test file pattern (default: test_*.py)",
|
|
54
|
+
)
|
|
55
|
+
parser.add_argument("--verbose", "-v", action="store_true", help="Verbose output")
|
|
56
|
+
|
|
57
|
+
args = parser.parse_args()
|
|
58
|
+
|
|
59
|
+
# Apply comprehensive warning suppression BEFORE unittest starts
|
|
60
|
+
suppress_pydantic_warnings()
|
|
61
|
+
|
|
62
|
+
print(f"🧪 Running tests with pattern: {args.pattern}")
|
|
63
|
+
print("🔇 Pydantic deprecation warnings suppressed")
|
|
64
|
+
|
|
65
|
+
# Add tests directory to Python path for relative imports
|
|
66
|
+
import os
|
|
67
|
+
sys.path.insert(0, os.path.abspath("tests"))
|
|
68
|
+
|
|
69
|
+
# Discover and run tests
|
|
70
|
+
loader = unittest.TestLoader()
|
|
71
|
+
start_dir = "tests"
|
|
72
|
+
suite = loader.discover(start_dir, pattern=args.pattern)
|
|
73
|
+
|
|
74
|
+
# Run tests
|
|
75
|
+
verbosity = 2 if args.verbose else 1
|
|
76
|
+
runner = unittest.TextTestRunner(verbosity=verbosity)
|
|
77
|
+
result = runner.run(suite)
|
|
78
|
+
|
|
79
|
+
# Cleanup to reduce resource warnings
|
|
80
|
+
try:
|
|
81
|
+
# Close any remaining event loops
|
|
82
|
+
loop = None
|
|
83
|
+
try:
|
|
84
|
+
loop = asyncio.get_running_loop()
|
|
85
|
+
except RuntimeError:
|
|
86
|
+
pass
|
|
87
|
+
|
|
88
|
+
if loop and not loop.is_closed():
|
|
89
|
+
# Cancel all pending tasks
|
|
90
|
+
pending = asyncio.all_tasks(loop)
|
|
91
|
+
for task in pending:
|
|
92
|
+
task.cancel()
|
|
93
|
+
|
|
94
|
+
# Give tasks a chance to complete cancellation
|
|
95
|
+
if pending:
|
|
96
|
+
loop.run_until_complete(
|
|
97
|
+
asyncio.gather(*pending, return_exceptions=True)
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
# Force garbage collection
|
|
101
|
+
gc.collect()
|
|
102
|
+
|
|
103
|
+
except Exception:
|
|
104
|
+
# Don't let cleanup errors affect test results
|
|
105
|
+
pass
|
|
106
|
+
|
|
107
|
+
# Exit with proper code
|
|
108
|
+
sys.exit(0 if result.wasSuccessful() else 1)
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
if __name__ == "__main__":
|
|
112
|
+
main()
|
tests/test_agent.py
CHANGED
|
@@ -1,18 +1,19 @@
|
|
|
1
|
+
# Suppress external dependency warnings before any other imports
|
|
2
|
+
import warnings
|
|
3
|
+
warnings.simplefilter("ignore", DeprecationWarning)
|
|
4
|
+
|
|
1
5
|
import unittest
|
|
2
6
|
import threading
|
|
3
7
|
from datetime import date
|
|
4
8
|
|
|
5
|
-
from vectara_agentic.agent import
|
|
9
|
+
from vectara_agentic.agent import Agent, AgentType
|
|
10
|
+
from vectara_agentic.agent_core.factory import format_prompt
|
|
6
11
|
from vectara_agentic.agent_config import AgentConfig
|
|
7
12
|
from vectara_agentic.types import ModelProvider, ObserverType
|
|
8
13
|
from vectara_agentic.tools import ToolsFactory
|
|
9
14
|
|
|
10
|
-
from vectara_agentic.
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
def mult(x: float, y: float) -> float:
|
|
14
|
-
"Multiply two numbers"
|
|
15
|
-
return x * y
|
|
15
|
+
from vectara_agentic.agent_core.prompts import GENERAL_INSTRUCTIONS
|
|
16
|
+
from conftest import mult, STANDARD_TEST_TOPIC, STANDARD_TEST_INSTRUCTIONS
|
|
16
17
|
|
|
17
18
|
|
|
18
19
|
ARIZE_LOCK = threading.Lock()
|
|
@@ -28,19 +29,17 @@ class TestAgentPackage(unittest.TestCase):
|
|
|
28
29
|
+ " with Always do as your mother tells you!"
|
|
29
30
|
)
|
|
30
31
|
self.assertEqual(
|
|
31
|
-
|
|
32
|
+
format_prompt(prompt_template, GENERAL_INSTRUCTIONS, topic, custom_instructions), expected_output
|
|
32
33
|
)
|
|
33
34
|
|
|
34
35
|
def test_agent_init(self):
|
|
35
36
|
tools = [ToolsFactory().create_tool(mult)]
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
agent
|
|
39
|
-
self.assertEqual(agent.
|
|
40
|
-
self.assertEqual(agent._topic, topic)
|
|
41
|
-
self.assertEqual(agent._custom_instructions, custom_instructions)
|
|
37
|
+
agent = Agent(tools, STANDARD_TEST_TOPIC, STANDARD_TEST_INSTRUCTIONS)
|
|
38
|
+
self.assertEqual(agent.agent_type, AgentType.FUNCTION_CALLING)
|
|
39
|
+
self.assertEqual(agent._topic, STANDARD_TEST_TOPIC)
|
|
40
|
+
self.assertEqual(agent._custom_instructions, STANDARD_TEST_INSTRUCTIONS)
|
|
42
41
|
|
|
43
|
-
# To run this test, you must have
|
|
42
|
+
# To run this test, you must have appropriate API key in your environment
|
|
44
43
|
self.assertEqual(
|
|
45
44
|
agent.chat(
|
|
46
45
|
"What is 5 times 10. Only give the answer, nothing else"
|
|
@@ -51,8 +50,6 @@ class TestAgentPackage(unittest.TestCase):
|
|
|
51
50
|
def test_agent_config(self):
|
|
52
51
|
with ARIZE_LOCK:
|
|
53
52
|
tools = [ToolsFactory().create_tool(mult)]
|
|
54
|
-
topic = "AI topic"
|
|
55
|
-
instructions = "Always do as your father tells you, if your mother agrees!"
|
|
56
53
|
config = AgentConfig(
|
|
57
54
|
agent_type=AgentType.REACT,
|
|
58
55
|
main_llm_provider=ModelProvider.ANTHROPIC,
|
|
@@ -64,12 +61,12 @@ class TestAgentPackage(unittest.TestCase):
|
|
|
64
61
|
|
|
65
62
|
agent = Agent(
|
|
66
63
|
tools=tools,
|
|
67
|
-
topic=
|
|
68
|
-
custom_instructions=
|
|
64
|
+
topic=STANDARD_TEST_TOPIC,
|
|
65
|
+
custom_instructions=STANDARD_TEST_INSTRUCTIONS,
|
|
69
66
|
agent_config=config
|
|
70
67
|
)
|
|
71
|
-
self.assertEqual(agent._topic,
|
|
72
|
-
self.assertEqual(agent._custom_instructions,
|
|
68
|
+
self.assertEqual(agent._topic, STANDARD_TEST_TOPIC)
|
|
69
|
+
self.assertEqual(agent._custom_instructions, STANDARD_TEST_INSTRUCTIONS)
|
|
73
70
|
self.assertEqual(agent.agent_type, AgentType.REACT)
|
|
74
71
|
self.assertEqual(agent.agent_config.observer, ObserverType.ARIZE_PHOENIX)
|
|
75
72
|
self.assertEqual(agent.agent_config.main_llm_provider, ModelProvider.ANTHROPIC)
|
|
@@ -84,19 +81,20 @@ class TestAgentPackage(unittest.TestCase):
|
|
|
84
81
|
)
|
|
85
82
|
|
|
86
83
|
def test_multiturn(self):
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
84
|
+
with ARIZE_LOCK:
|
|
85
|
+
tools = [ToolsFactory().create_tool(mult)]
|
|
86
|
+
topic = "AI topic"
|
|
87
|
+
instructions = "Always do as your father tells you, if your mother agrees!"
|
|
88
|
+
agent = Agent(
|
|
89
|
+
tools=tools,
|
|
90
|
+
topic=topic,
|
|
91
|
+
custom_instructions=instructions,
|
|
92
|
+
)
|
|
95
93
|
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
94
|
+
agent.chat("What is 5 times 10. Only give the answer, nothing else")
|
|
95
|
+
agent.chat("what is 3 times 7. Only give the answer, nothing else")
|
|
96
|
+
res = agent.chat("multiply the results of the last two questions. Output only the answer.")
|
|
97
|
+
self.assertEqual(res.response, "1050")
|
|
100
98
|
|
|
101
99
|
def test_from_corpus(self):
|
|
102
100
|
agent = Agent.from_corpus(
|
|
@@ -121,6 +119,10 @@ class TestAgentPackage(unittest.TestCase):
|
|
|
121
119
|
chat_history=[("What is 5 times 10", "50"), ("What is 3 times 7", "21")]
|
|
122
120
|
)
|
|
123
121
|
|
|
122
|
+
data = agent.dumps()
|
|
123
|
+
clone = Agent.loads(data)
|
|
124
|
+
assert clone.memory.get() == agent.memory.get()
|
|
125
|
+
|
|
124
126
|
res = agent.chat("multiply the results of the last two questions. Output only the answer.")
|
|
125
127
|
self.assertEqual(res.response, "1050")
|
|
126
128
|
|