aixtools 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of aixtools might be problematic. Click here for more details.
- aixtools/__init__.py +5 -0
- aixtools/a2a/__init__.py +5 -0
- aixtools/a2a/app.py +126 -0
- aixtools/a2a/utils.py +115 -0
- aixtools/agents/__init__.py +12 -0
- aixtools/agents/agent.py +164 -0
- aixtools/agents/agent_batch.py +74 -0
- aixtools/app.py +143 -0
- aixtools/context.py +12 -0
- aixtools/db/__init__.py +17 -0
- aixtools/db/database.py +110 -0
- aixtools/db/vector_db.py +115 -0
- aixtools/log_view/__init__.py +17 -0
- aixtools/log_view/app.py +195 -0
- aixtools/log_view/display.py +285 -0
- aixtools/log_view/export.py +51 -0
- aixtools/log_view/filters.py +41 -0
- aixtools/log_view/log_utils.py +26 -0
- aixtools/log_view/node_summary.py +229 -0
- aixtools/logfilters/__init__.py +7 -0
- aixtools/logfilters/context_filter.py +67 -0
- aixtools/logging/__init__.py +30 -0
- aixtools/logging/log_objects.py +227 -0
- aixtools/logging/logging_config.py +116 -0
- aixtools/logging/mcp_log_models.py +102 -0
- aixtools/logging/mcp_logger.py +172 -0
- aixtools/logging/model_patch_logging.py +87 -0
- aixtools/logging/open_telemetry.py +36 -0
- aixtools/mcp/__init__.py +9 -0
- aixtools/mcp/example_client.py +30 -0
- aixtools/mcp/example_server.py +22 -0
- aixtools/mcp/fast_mcp_log.py +31 -0
- aixtools/mcp/faulty_mcp.py +320 -0
- aixtools/model_patch/model_patch.py +65 -0
- aixtools/server/__init__.py +23 -0
- aixtools/server/app_mounter.py +90 -0
- aixtools/server/path.py +72 -0
- aixtools/server/utils.py +70 -0
- aixtools/testing/__init__.py +9 -0
- aixtools/testing/aix_test_model.py +147 -0
- aixtools/testing/mock_tool.py +66 -0
- aixtools/testing/model_patch_cache.py +279 -0
- aixtools/tools/doctor/__init__.py +3 -0
- aixtools/tools/doctor/tool_doctor.py +61 -0
- aixtools/tools/doctor/tool_recommendation.py +44 -0
- aixtools/utils/__init__.py +35 -0
- aixtools/utils/chainlit/cl_agent_show.py +82 -0
- aixtools/utils/chainlit/cl_utils.py +168 -0
- aixtools/utils/config.py +118 -0
- aixtools/utils/config_util.py +69 -0
- aixtools/utils/enum_with_description.py +37 -0
- aixtools/utils/persisted_dict.py +99 -0
- aixtools/utils/utils.py +160 -0
- aixtools-0.1.0.dist-info/METADATA +355 -0
- aixtools-0.1.0.dist-info/RECORD +58 -0
- aixtools-0.1.0.dist-info/WHEEL +5 -0
- aixtools-0.1.0.dist-info/entry_points.txt +2 -0
- aixtools-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,147 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Test model implementation for AI agent testing with predefined responses.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
from collections.abc import AsyncIterator
|
|
7
|
+
from contextlib import asynccontextmanager
|
|
8
|
+
from types import AsyncGeneratorType
|
|
9
|
+
|
|
10
|
+
from pydantic import BaseModel
|
|
11
|
+
from pydantic_ai.messages import ModelMessage, ModelResponse, TextPart, ToolCallPart
|
|
12
|
+
from pydantic_ai.models import Model, ModelRequestParameters, StreamedResponse
|
|
13
|
+
from pydantic_ai.models.function import _estimate_usage
|
|
14
|
+
from pydantic_ai.models.test import TestStreamedResponse
|
|
15
|
+
from pydantic_ai.settings import ModelSettings
|
|
16
|
+
|
|
17
|
+
from ..utils.utils import async_iter
|
|
18
|
+
|
|
19
|
+
FINAL_RESULT_TOOL_NAME = "final_result"
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def final_result_tool(result: BaseModel | dict) -> ToolCallPart:
|
|
23
|
+
"""Create a ToolCallPart for the final result."""
|
|
24
|
+
if isinstance(result, BaseModel):
|
|
25
|
+
args = result.model_dump()
|
|
26
|
+
else:
|
|
27
|
+
args = result
|
|
28
|
+
return ToolCallPart(tool_name=FINAL_RESULT_TOOL_NAME, args=args)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class AixTestModel(Model):
|
|
32
|
+
"""
|
|
33
|
+
Test model, returns a specified list of answers, including messages or tool calls
|
|
34
|
+
This is used for testing the agent and model interaction with the rest of the system.
|
|
35
|
+
|
|
36
|
+
responses: Is a list of strings or ToolCallPart objects that the model will return in order.
|
|
37
|
+
|
|
38
|
+
Note: The agent will continue to invoke 'request()' until it returns a txt (i.e. not a ToolCallPart).
|
|
39
|
+
|
|
40
|
+
Example: Unstructured output (text)
|
|
41
|
+
```
|
|
42
|
+
model = AixTestModel(
|
|
43
|
+
responses=[
|
|
44
|
+
ToolCallPart(tool_name='send_message_to_user', args={'message': 'First, let me say hi...'}),
|
|
45
|
+
"Please invoke the agent again to continue the conversation....",
|
|
46
|
+
# ---------- The first time you invoke the agent, it will stop here ----------
|
|
47
|
+
|
|
48
|
+
ToolCallPart(tool_name='send_message_to_user', args={'message': 'Hi there again!'}),
|
|
49
|
+
"The 10th prime number is 29.",
|
|
50
|
+
# ---------- The second time you invoke the agent, it will stop here ----------
|
|
51
|
+
|
|
52
|
+
# If you invoke the agent again, it will continue raising an exception
|
|
53
|
+
# because there are no more responses
|
|
54
|
+
]
|
|
55
|
+
)
|
|
56
|
+
```
|
|
57
|
+
|
|
58
|
+
Example structured output:
|
|
59
|
+
```
|
|
60
|
+
# Define a model for the final result
|
|
61
|
+
class MyResult(BaseModel):
|
|
62
|
+
text: str
|
|
63
|
+
|
|
64
|
+
model = AixTestModel(
|
|
65
|
+
responses=[
|
|
66
|
+
ToolCallPart(tool_name='send_message_to_user', args={'message': 'First, let me say hi...'}),
|
|
67
|
+
ToolCallPart(tool_name='send_message_to_user', args={'message': 'Hi there again!'}),
|
|
68
|
+
final_result_tool(MyResult(text='The 10th prime is 29')),
|
|
69
|
+
]
|
|
70
|
+
)
|
|
71
|
+
```
|
|
72
|
+
"""
|
|
73
|
+
|
|
74
|
+
def __init__( # pylint: disable=super-init-not-called
|
|
75
|
+
self,
|
|
76
|
+
responses: list[str | TextPart | ToolCallPart] | AsyncGeneratorType,
|
|
77
|
+
sleep_time: float | None = None,
|
|
78
|
+
):
|
|
79
|
+
self.response_iter = responses(self) if callable(responses) else async_iter(responses)
|
|
80
|
+
self.messages: list[ModelMessage] = None
|
|
81
|
+
self.sleep_time = sleep_time
|
|
82
|
+
self.last_model_request_parameters: ModelRequestParameters | None = None
|
|
83
|
+
|
|
84
|
+
@property
|
|
85
|
+
def last_message(self) -> ModelResponse | None:
|
|
86
|
+
"""Return the last response."""
|
|
87
|
+
return self.messages[-1] if self.messages else None
|
|
88
|
+
|
|
89
|
+
@property
|
|
90
|
+
def last_message_part(self) -> TextPart | ToolCallPart | None:
|
|
91
|
+
"""Return the last part of the response."""
|
|
92
|
+
return self.last_message.parts[-1] if self.last_message and self.last_message.parts else None
|
|
93
|
+
|
|
94
|
+
async def request(
|
|
95
|
+
self,
|
|
96
|
+
messages: list[ModelMessage],
|
|
97
|
+
model_settings: ModelSettings | None,
|
|
98
|
+
model_request_parameters: ModelRequestParameters,
|
|
99
|
+
) -> ModelResponse:
|
|
100
|
+
self.last_model_request_parameters = model_request_parameters
|
|
101
|
+
model_response = await self._request(messages, model_settings, model_request_parameters)
|
|
102
|
+
model_response.usage = _estimate_usage([*messages, model_response])
|
|
103
|
+
return model_response
|
|
104
|
+
|
|
105
|
+
@asynccontextmanager
|
|
106
|
+
async def request_stream(
|
|
107
|
+
self,
|
|
108
|
+
messages: list[ModelMessage],
|
|
109
|
+
model_settings: ModelSettings | None,
|
|
110
|
+
model_request_parameters: ModelRequestParameters,
|
|
111
|
+
) -> AsyncIterator[StreamedResponse]:
|
|
112
|
+
model_response = await self._request(messages, model_settings, model_request_parameters)
|
|
113
|
+
yield TestStreamedResponse(_model_name=self.model_name, _structured_response=model_response, _messages=messages)
|
|
114
|
+
|
|
115
|
+
@property
|
|
116
|
+
def model_name(self) -> str:
|
|
117
|
+
return self.__class__.__name__.lower()
|
|
118
|
+
|
|
119
|
+
@property
|
|
120
|
+
def system(self) -> str:
|
|
121
|
+
"""The system / model provider."""
|
|
122
|
+
return "test_system"
|
|
123
|
+
|
|
124
|
+
async def _request(
|
|
125
|
+
self,
|
|
126
|
+
messages: list[ModelMessage],
|
|
127
|
+
model_settings: ModelSettings | None, # pylint: disable=unused-argument
|
|
128
|
+
model_request_parameters: ModelRequestParameters, # pylint: disable=unused-argument
|
|
129
|
+
) -> ModelResponse:
|
|
130
|
+
self.messages = messages
|
|
131
|
+
res = await anext(self.response_iter, None)
|
|
132
|
+
assert res, "No more responses available."
|
|
133
|
+
if callable(res):
|
|
134
|
+
res = res(self, messages)
|
|
135
|
+
match res:
|
|
136
|
+
case str():
|
|
137
|
+
return ModelResponse(parts=[TextPart(res)], model_name=self.model_name)
|
|
138
|
+
case ToolCallPart():
|
|
139
|
+
return ModelResponse(parts=[res], model_name=self.model_name)
|
|
140
|
+
case TextPart():
|
|
141
|
+
return ModelResponse(parts=[res], model_name=self.model_name)
|
|
142
|
+
case Exception():
|
|
143
|
+
raise res
|
|
144
|
+
case _:
|
|
145
|
+
raise ValueError(f"Invalid response type: {type(res)}, response: {res}")
|
|
146
|
+
if self.sleep_time:
|
|
147
|
+
await asyncio.sleep(self.sleep_time)
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Mock tool implementation for testing agent interactions.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import functools
|
|
6
|
+
import inspect
|
|
7
|
+
from typing import Any, Awaitable, Callable, List, TypeVar, Union
|
|
8
|
+
|
|
9
|
+
T = TypeVar("T")
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def mock_tool(func: Callable[..., Union[T, Awaitable[T]]], return_values: List[Any]) -> Callable[..., Any]:
|
|
13
|
+
"""
|
|
14
|
+
Creates a mock version of the provided function that returns values from a predefined list.
|
|
15
|
+
Supports both synchronous and asynchronous functions.
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
func: The original function to mock (can be sync or async)
|
|
19
|
+
return_values: A list of values to be returned sequentially on each call
|
|
20
|
+
|
|
21
|
+
Returns:
|
|
22
|
+
A new function with the same signature and docstring as the original function,
|
|
23
|
+
but returns values from the provided list sequentially.
|
|
24
|
+
If the original function is async, the mock will also be async.
|
|
25
|
+
|
|
26
|
+
Raises:
|
|
27
|
+
IndexError: When the mock function is called more times than there are return values
|
|
28
|
+
"""
|
|
29
|
+
# Make a copy of the return values to avoid modifying the original list
|
|
30
|
+
values = return_values.copy()
|
|
31
|
+
|
|
32
|
+
# Check if the function is asynchronous
|
|
33
|
+
is_async = inspect.iscoroutinefunction(func)
|
|
34
|
+
|
|
35
|
+
if is_async:
|
|
36
|
+
# Create an async wrapper for async functions
|
|
37
|
+
@functools.wraps(func)
|
|
38
|
+
async def async_mock_wrapper(*args, **kwargs): # pylint: disable=unused-argument
|
|
39
|
+
# Check if we have any return values left
|
|
40
|
+
if not values:
|
|
41
|
+
raise IndexError(
|
|
42
|
+
f"No more mock return values available for {func.__name__}. "
|
|
43
|
+
f"The function has been called more times than the number of provided return values."
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
# Return and remove the first value from our list
|
|
47
|
+
return values.pop(0)
|
|
48
|
+
|
|
49
|
+
# Return the async wrapper function
|
|
50
|
+
return async_mock_wrapper
|
|
51
|
+
|
|
52
|
+
# Create a sync wrapper for sync functions
|
|
53
|
+
@functools.wraps(func)
|
|
54
|
+
def sync_mock_wrapper(*args, **kwargs): # pylint: disable=unused-argument
|
|
55
|
+
# Check if we have any return values left
|
|
56
|
+
if not values:
|
|
57
|
+
raise IndexError(
|
|
58
|
+
f"No more mock return values available for {func.__name__}. "
|
|
59
|
+
f"The function has been called more times than the number of provided return values."
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
# Return and remove the first value from our list
|
|
63
|
+
return values.pop(0)
|
|
64
|
+
|
|
65
|
+
# Return the sync wrapper function
|
|
66
|
+
return sync_mock_wrapper
|
|
@@ -0,0 +1,279 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This module is useful for testing, so that we can run tests without having to call the real function.
|
|
3
|
+
The cached values are stored in `cache_file`.
|
|
4
|
+
|
|
5
|
+
In `learn=False` mode:
|
|
6
|
+
- The model_request_cache_wrapper first tries to answer from the cached value.
|
|
7
|
+
- If the cached value does not exist, it raises an exception.
|
|
8
|
+
|
|
9
|
+
In `learn=True` mode:
|
|
10
|
+
- The model_request_cache_wrapper first tries to answer from the cached value.
|
|
11
|
+
- If the cached value does not exist, it invokes the real function
|
|
12
|
+
and adds the new result to `cache_file` for future use.
|
|
13
|
+
|
|
14
|
+
Values are saved to pickle files, but since objects could be non-picklable, we use `safe_deepcopy_for_cache`.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
import datetime
|
|
18
|
+
import functools
|
|
19
|
+
import hashlib
|
|
20
|
+
import json
|
|
21
|
+
import pickle
|
|
22
|
+
import uuid
|
|
23
|
+
from contextlib import asynccontextmanager
|
|
24
|
+
from pathlib import Path
|
|
25
|
+
from typing import Any, Dict, Mapping, Sequence, Set, Tuple, Type
|
|
26
|
+
|
|
27
|
+
from aixtools.logging.log_objects import safe_deepcopy
|
|
28
|
+
from aixtools.logging.logging_config import get_logger
|
|
29
|
+
from aixtools.model_patch.model_patch import get_request_fn, get_request_stream_fn, model_patch
|
|
30
|
+
|
|
31
|
+
logger = get_logger(__name__)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class CacheKeyError(Exception):
|
|
35
|
+
"""Exception raised when a key is not found in the cache."""
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def safe_deepcopy_for_cache_key(obj: Any, normalize_types: Set[Type] = None) -> Any:
|
|
39
|
+
"""
|
|
40
|
+
A modified version of safe_deepcopy that normalizes or skips fields based on their types.
|
|
41
|
+
This is useful for generating cache keys where some fields (like timestamps, UUIDs)
|
|
42
|
+
should be normalized to ensure consistent cache hits.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
obj: The object to copy
|
|
46
|
+
normalize_types: Set of types to normalize (replace with placeholders)
|
|
47
|
+
|
|
48
|
+
Returns:
|
|
49
|
+
A deep copy of the object with normalized values for specified types
|
|
50
|
+
"""
|
|
51
|
+
# Default types to normalize if none provided
|
|
52
|
+
if normalize_types is None:
|
|
53
|
+
normalize_types = {datetime.datetime, datetime.date, datetime.time, uuid.UUID}
|
|
54
|
+
|
|
55
|
+
# Check if the object is of a type that should be normalized
|
|
56
|
+
if any(isinstance(obj, t) for t in normalize_types):
|
|
57
|
+
# Replace with a placeholder indicating the type
|
|
58
|
+
return f"<{type(obj).__name__}>"
|
|
59
|
+
|
|
60
|
+
# Check if the object is primitive (int, float, str, bool)
|
|
61
|
+
if isinstance(obj, (int, float, str, bool)):
|
|
62
|
+
return obj
|
|
63
|
+
|
|
64
|
+
# Handle mappings (dict and other mapping types)
|
|
65
|
+
if isinstance(obj, Mapping):
|
|
66
|
+
return {k: safe_deepcopy_for_cache_key(v, normalize_types) for k, v in obj.items()}
|
|
67
|
+
|
|
68
|
+
# Handle sequences (list, tuple, and other sequence types) but not strings
|
|
69
|
+
if isinstance(obj, Sequence) and not isinstance(obj, str):
|
|
70
|
+
return [safe_deepcopy_for_cache_key(item, normalize_types) for item in obj]
|
|
71
|
+
|
|
72
|
+
# Handle objects with __dict__ attribute (custom classes)
|
|
73
|
+
if hasattr(obj, "__dict__"):
|
|
74
|
+
# For objects, we create a dictionary representation
|
|
75
|
+
result = {}
|
|
76
|
+
for attr, value in vars(obj).items():
|
|
77
|
+
result[attr] = safe_deepcopy_for_cache_key(value, normalize_types)
|
|
78
|
+
return result
|
|
79
|
+
|
|
80
|
+
# For other types, return a string representation
|
|
81
|
+
return f"<{type(obj).__name__}>"
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def generate_cache_key(method_name: str, args: Tuple, kwargs: Dict) -> str:
|
|
85
|
+
"""
|
|
86
|
+
Generate a unique cache key based on the method name and its arguments.
|
|
87
|
+
Uses safe_deepcopy_for_cache to normalize values that change frequently.
|
|
88
|
+
|
|
89
|
+
Args:
|
|
90
|
+
method_name: Name of the method being called
|
|
91
|
+
args: Positional arguments to the method
|
|
92
|
+
kwargs: Keyword arguments to the method
|
|
93
|
+
|
|
94
|
+
Returns:
|
|
95
|
+
A string hash that uniquely identifies this method call
|
|
96
|
+
"""
|
|
97
|
+
# Normalize the arguments and kwargs
|
|
98
|
+
normalized_args = safe_deepcopy_for_cache_key(args)
|
|
99
|
+
normalized_kwargs = safe_deepcopy_for_cache_key(kwargs)
|
|
100
|
+
|
|
101
|
+
# Create a dictionary with the normalized information
|
|
102
|
+
key_dict = {"method_name": method_name, "args": normalized_args, "kwargs": normalized_kwargs}
|
|
103
|
+
|
|
104
|
+
# Convert to a stable string representation and hash it
|
|
105
|
+
key_str = json.dumps(key_dict, sort_keys=True, default=str)
|
|
106
|
+
return hashlib.md5(key_str.encode()).hexdigest()
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def save_to_cache(cache_file: Path, key: str, value: Any) -> None:
|
|
110
|
+
"""
|
|
111
|
+
Save a value to the cache file using the given key.
|
|
112
|
+
|
|
113
|
+
Args:
|
|
114
|
+
cache_file: Path to the cache file
|
|
115
|
+
key: The cache key
|
|
116
|
+
value: The value to cache (will be deep-copied to ensure it's pickable)
|
|
117
|
+
"""
|
|
118
|
+
# Create parent directories if they don't exist
|
|
119
|
+
cache_file.parent.mkdir(parents=True, exist_ok=True)
|
|
120
|
+
|
|
121
|
+
# Load existing cache
|
|
122
|
+
cache = {}
|
|
123
|
+
if cache_file.exists():
|
|
124
|
+
try:
|
|
125
|
+
with open(cache_file, "rb") as f:
|
|
126
|
+
cache = pickle.load(f)
|
|
127
|
+
except (pickle.PickleError, EOFError):
|
|
128
|
+
# If the file is corrupted, start with an empty cache
|
|
129
|
+
cache = {}
|
|
130
|
+
|
|
131
|
+
# Make a safe copy of the value and add it to the cache
|
|
132
|
+
safe_value = safe_deepcopy(value)
|
|
133
|
+
cache[key] = safe_value
|
|
134
|
+
logger.debug("Cache updated: %s -> %r", key, safe_value)
|
|
135
|
+
|
|
136
|
+
# Save the updated cache
|
|
137
|
+
with open(cache_file, "wb") as f:
|
|
138
|
+
pickle.dump(cache, f)
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def get_from_cache(cache_file: Path, key: str) -> Any:
|
|
142
|
+
"""
|
|
143
|
+
Retrieve a value from the cache file using the given key.
|
|
144
|
+
|
|
145
|
+
Args:
|
|
146
|
+
cache_file: Path to the cache file
|
|
147
|
+
key: The cache key
|
|
148
|
+
|
|
149
|
+
Returns:
|
|
150
|
+
The cached value
|
|
151
|
+
|
|
152
|
+
Raises:
|
|
153
|
+
CacheKeyError: If the key is not found in the cache
|
|
154
|
+
"""
|
|
155
|
+
if not cache_file.exists():
|
|
156
|
+
raise CacheKeyError(f"Cache file {cache_file} does not exist")
|
|
157
|
+
try:
|
|
158
|
+
with open(cache_file, "rb") as f:
|
|
159
|
+
cache = pickle.load(f)
|
|
160
|
+
except (pickle.PickleError, EOFError) as e:
|
|
161
|
+
raise CacheKeyError(f"Cache file {cache_file} is corrupted") from e
|
|
162
|
+
if key not in cache:
|
|
163
|
+
raise CacheKeyError(f"Key {key} not found in cache")
|
|
164
|
+
return cache[key]
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def request_cache(fn, cache_file: Path, learn=False):
|
|
168
|
+
"""
|
|
169
|
+
model_request_cache_wrapper for async method calls that uses a cache.
|
|
170
|
+
|
|
171
|
+
In learn=False mode:
|
|
172
|
+
- Tries to answer from cached value
|
|
173
|
+
- If the cached value does not exist, raises an exception
|
|
174
|
+
|
|
175
|
+
In learn=True mode:
|
|
176
|
+
- Tries to answer from cached value
|
|
177
|
+
- If the cached value does not exist, invokes the real function and adds the result to cache
|
|
178
|
+
|
|
179
|
+
Args:
|
|
180
|
+
fn: The async function to wrap
|
|
181
|
+
cache_file: Path to the cache file
|
|
182
|
+
learn: Whether to learn new responses (True) or only use cached ones (False)
|
|
183
|
+
"""
|
|
184
|
+
|
|
185
|
+
@functools.wraps(fn)
|
|
186
|
+
async def model_request_cache_wrapper(*args, **kwargs):
|
|
187
|
+
# Generate a unique cache key for this request
|
|
188
|
+
cache_key = generate_cache_key(fn.__name__, args, kwargs)
|
|
189
|
+
try:
|
|
190
|
+
# Try to get the result from cache
|
|
191
|
+
result = get_from_cache(cache_file, cache_key)
|
|
192
|
+
logger.debug("Cache hit for %s with key %s", fn.__name__, cache_key)
|
|
193
|
+
return result
|
|
194
|
+
except CacheKeyError as e:
|
|
195
|
+
# If not in cache and learn=False, raise an exception
|
|
196
|
+
if not learn:
|
|
197
|
+
raise CacheKeyError(f"No cached response for {fn.__name__} with key {cache_key} and learn=False") from e
|
|
198
|
+
# If learn=True, invoke the original method
|
|
199
|
+
result = await fn(*args, **kwargs)
|
|
200
|
+
# Save the result to cache
|
|
201
|
+
save_to_cache(cache_file, cache_key, result)
|
|
202
|
+
return result
|
|
203
|
+
|
|
204
|
+
return model_request_cache_wrapper
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
def request_stream_cache(fn, cache_file: Path, learn=False):
|
|
208
|
+
"""
|
|
209
|
+
model_request_cache_wrapper for async streaming method calls that uses a cache.
|
|
210
|
+
|
|
211
|
+
Similar to request_cache, but handles streaming responses by caching all items
|
|
212
|
+
and then replaying them when retrieved from cache.
|
|
213
|
+
|
|
214
|
+
Args:
|
|
215
|
+
fn: The async context manager function to wrap
|
|
216
|
+
cache_file: Path to the cache file
|
|
217
|
+
learn: Whether to learn new responses (True) or only use cached ones (False)
|
|
218
|
+
"""
|
|
219
|
+
|
|
220
|
+
@functools.wraps(fn)
|
|
221
|
+
@asynccontextmanager
|
|
222
|
+
async def model_request_stream_cache_wrapper(*args, **kwargs):
|
|
223
|
+
# Generate a unique cache key for this request
|
|
224
|
+
cache_key = generate_cache_key(fn.__name__, args, kwargs)
|
|
225
|
+
|
|
226
|
+
try:
|
|
227
|
+
# Try to get the cached items
|
|
228
|
+
cached_items = get_from_cache(cache_file, cache_key)
|
|
229
|
+
|
|
230
|
+
# Define a generator that yields the cached items
|
|
231
|
+
async def cached_gen():
|
|
232
|
+
for item in cached_items:
|
|
233
|
+
yield item
|
|
234
|
+
|
|
235
|
+
yield cached_gen()
|
|
236
|
+
|
|
237
|
+
except CacheKeyError as e:
|
|
238
|
+
# If not in cache and learn=False, raise an exception
|
|
239
|
+
if not learn:
|
|
240
|
+
raise CacheKeyError(f"No cached stream for {fn.__name__} with key {cache_key} and learn=False") from e
|
|
241
|
+
|
|
242
|
+
# If learn=True, invoke the original method
|
|
243
|
+
async with fn(*args, **kwargs) as stream:
|
|
244
|
+
# Collect all items to save to cache
|
|
245
|
+
all_items = []
|
|
246
|
+
|
|
247
|
+
async def gen():
|
|
248
|
+
item_num = 0
|
|
249
|
+
async for item in stream:
|
|
250
|
+
all_items.append(item)
|
|
251
|
+
item_num += 1
|
|
252
|
+
yield item
|
|
253
|
+
|
|
254
|
+
# Yield the generator
|
|
255
|
+
gen_instance = gen()
|
|
256
|
+
yield gen_instance
|
|
257
|
+
|
|
258
|
+
# After the context manager exits, save all items to cache
|
|
259
|
+
# We need to make sure all items have been consumed
|
|
260
|
+
try:
|
|
261
|
+
async for _ in gen_instance:
|
|
262
|
+
pass
|
|
263
|
+
except StopAsyncIteration:
|
|
264
|
+
pass
|
|
265
|
+
|
|
266
|
+
# Save the collected items to cache
|
|
267
|
+
save_to_cache(cache_file, cache_key, all_items)
|
|
268
|
+
|
|
269
|
+
return model_request_stream_cache_wrapper
|
|
270
|
+
|
|
271
|
+
|
|
272
|
+
def model_patch_cache(model, cache_file: Path, learn=False):
|
|
273
|
+
"""Patch model with methods for caching requests and responses"""
|
|
274
|
+
logger.debug("Using cache file: %s", cache_file)
|
|
275
|
+
return model_patch(
|
|
276
|
+
model,
|
|
277
|
+
request_method=request_cache(get_request_fn(model), cache_file, learn),
|
|
278
|
+
request_stream_method=request_stream_cache(get_request_stream_fn(model), cache_file, learn),
|
|
279
|
+
)
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
from aixtools.agents import get_agent, run_agent
|
|
2
|
+
from aixtools.tools.doctor.tool_recommendation import ToolRecommendation
|
|
3
|
+
|
|
4
|
+
TOOL_DOCTOR_PROMPT = """
|
|
5
|
+
## Tool doctor
|
|
6
|
+
You are helping to debug common errors in tools and tool definitions.
|
|
7
|
+
|
|
8
|
+
Given the tools, for each tool you will give feedback about the tool
|
|
9
|
+
definition.
|
|
10
|
+
|
|
11
|
+
1. Name: Check if the tool name is descriptive and follows naming conventions.
|
|
12
|
+
2. Description: Ensure the tool's description is clear and provides enough
|
|
13
|
+
detail
|
|
14
|
+
so that users understand its purpose and functionality.
|
|
15
|
+
3. Return type: Check if the return type is specified and matches the tool's
|
|
16
|
+
functionality.
|
|
17
|
+
4. Arguments: Verify that the tool arguments are well-defined and include types
|
|
18
|
+
and descriptions. Ensure that argument names are descriptive and follow
|
|
19
|
+
naming conventions.
|
|
20
|
+
5. Look for any missing or redundant information in the tool definition.
|
|
21
|
+
|
|
22
|
+
Some rules:
|
|
23
|
+
- Ignore a tool called 'final_result'.
|
|
24
|
+
- Do not suggest change if the tool is already well-defined.
|
|
25
|
+
- Don't be nitpicking, focus on significant improvements that can be made to
|
|
26
|
+
the tool definition.
|
|
27
|
+
- Don't suggest trivial improvements or changes for things that are
|
|
28
|
+
self-evident or already well-defined.
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
async def tool_doctor_single(tool) -> ToolRecommendation:
|
|
33
|
+
"""Run the tool doctor agent to analyze a single tool"""
|
|
34
|
+
agent = get_agent(tools=[tool], output_type=ToolRecommendation)
|
|
35
|
+
ret = await run_agent(agent, TOOL_DOCTOR_PROMPT, log_model_requests=True, verbose=True, debug=True)
|
|
36
|
+
return ret[0] # type: ignore
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
async def tool_doctor_multiple(tools: list) -> list[ToolRecommendation]:
|
|
40
|
+
"""Run the tool doctor agent to analyze the tools"""
|
|
41
|
+
agent = get_agent(tools=tools, output_type=list[ToolRecommendation])
|
|
42
|
+
ret = await run_agent(agent, TOOL_DOCTOR_PROMPT)
|
|
43
|
+
return ret[0] # type: ignore
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
async def tool_doctor(tools: list, max_tools_per_batch=5, verbose=True) -> list[ToolRecommendation]:
|
|
47
|
+
"""Run the tool doctor agent to analyze tools and give recommendations."""
|
|
48
|
+
# Split tools into batches if they exceed the max_tools_per_batch limit
|
|
49
|
+
results = []
|
|
50
|
+
for i in range(0, len(tools), max_tools_per_batch):
|
|
51
|
+
batch = tools[i : i + max_tools_per_batch]
|
|
52
|
+
batch_num = i // max_tools_per_batch + 1
|
|
53
|
+
tool_names = [t.__name__ for t in batch]
|
|
54
|
+
print(f"Processing batch {batch_num} with {len(batch)} tools: {tool_names}")
|
|
55
|
+
ret = await tool_doctor_multiple(batch)
|
|
56
|
+
results.extend(ret)
|
|
57
|
+
# Print results if verbose
|
|
58
|
+
if verbose:
|
|
59
|
+
for r in results:
|
|
60
|
+
print(r)
|
|
61
|
+
return results
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
from pydantic import BaseModel
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class ArgumentRecommendation(BaseModel):
|
|
5
|
+
"""A recommendation for an argument"""
|
|
6
|
+
|
|
7
|
+
arg_name: str
|
|
8
|
+
arg_type: str
|
|
9
|
+
arg_description_improvement: str
|
|
10
|
+
|
|
11
|
+
def __str__(self):
|
|
12
|
+
return f"{self.arg_name} ({self.arg_type}): {self.arg_description_improvement}"
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class ToolRecommendation(BaseModel):
|
|
16
|
+
"""A recomendation for a tool"""
|
|
17
|
+
|
|
18
|
+
name: str
|
|
19
|
+
needs_improvement: bool
|
|
20
|
+
description_improvement: str
|
|
21
|
+
arguments: list[ArgumentRecommendation]
|
|
22
|
+
return_description_improvement: str
|
|
23
|
+
|
|
24
|
+
def signature(self):
|
|
25
|
+
"""Generate the function signature for the tool"""
|
|
26
|
+
signature = f"def {self.name}("
|
|
27
|
+
if self.arguments:
|
|
28
|
+
signature += ", ".join([f"{arg.arg_name}: {arg.arg_type}" for arg in self.arguments])
|
|
29
|
+
signature += ")"
|
|
30
|
+
return signature
|
|
31
|
+
|
|
32
|
+
def __str__(self):
|
|
33
|
+
out = f"Tool name: '{self.name}'\n"
|
|
34
|
+
out += f" - Signature: {self.signature()}\n"
|
|
35
|
+
out += f" - Needs improvement: {self.needs_improvement}\n"
|
|
36
|
+
if self.description_improvement:
|
|
37
|
+
out += f" - Description improvement: {self.description_improvement}\n"
|
|
38
|
+
if self.return_description_improvement:
|
|
39
|
+
out += f" - Return description improvement: {self.return_description_improvement}\n"
|
|
40
|
+
if self.arguments:
|
|
41
|
+
out += " - Arguments:\n"
|
|
42
|
+
for arg in self.arguments:
|
|
43
|
+
out += f" - {arg}\n"
|
|
44
|
+
return out
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Utils package initialization.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from aixtools.logging.logging_config import get_logger # pylint: disable=import-error
|
|
6
|
+
from aixtools.utils import config
|
|
7
|
+
from aixtools.utils.enum_with_description import EnumWithDescription
|
|
8
|
+
from aixtools.utils.persisted_dict import PersistedDict
|
|
9
|
+
from aixtools.utils.utils import (
|
|
10
|
+
escape_backticks,
|
|
11
|
+
escape_newline,
|
|
12
|
+
find_file,
|
|
13
|
+
prepend_all_lines,
|
|
14
|
+
remove_quotes,
|
|
15
|
+
tabit,
|
|
16
|
+
to_str,
|
|
17
|
+
tripple_quote_strip,
|
|
18
|
+
truncate,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
__all__ = [
|
|
22
|
+
"config",
|
|
23
|
+
"PersistedDict",
|
|
24
|
+
"EnumWithDescription",
|
|
25
|
+
"escape_newline",
|
|
26
|
+
"escape_backticks",
|
|
27
|
+
"find_file",
|
|
28
|
+
"get_logger",
|
|
29
|
+
"prepend_all_lines",
|
|
30
|
+
"remove_quotes",
|
|
31
|
+
"tabit",
|
|
32
|
+
"to_str",
|
|
33
|
+
"truncate",
|
|
34
|
+
"tripple_quote_strip",
|
|
35
|
+
]
|