aixtools 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of aixtools might be problematic. Click here for more details.

Files changed (58) hide show
  1. aixtools/__init__.py +5 -0
  2. aixtools/a2a/__init__.py +5 -0
  3. aixtools/a2a/app.py +126 -0
  4. aixtools/a2a/utils.py +115 -0
  5. aixtools/agents/__init__.py +12 -0
  6. aixtools/agents/agent.py +164 -0
  7. aixtools/agents/agent_batch.py +74 -0
  8. aixtools/app.py +143 -0
  9. aixtools/context.py +12 -0
  10. aixtools/db/__init__.py +17 -0
  11. aixtools/db/database.py +110 -0
  12. aixtools/db/vector_db.py +115 -0
  13. aixtools/log_view/__init__.py +17 -0
  14. aixtools/log_view/app.py +195 -0
  15. aixtools/log_view/display.py +285 -0
  16. aixtools/log_view/export.py +51 -0
  17. aixtools/log_view/filters.py +41 -0
  18. aixtools/log_view/log_utils.py +26 -0
  19. aixtools/log_view/node_summary.py +229 -0
  20. aixtools/logfilters/__init__.py +7 -0
  21. aixtools/logfilters/context_filter.py +67 -0
  22. aixtools/logging/__init__.py +30 -0
  23. aixtools/logging/log_objects.py +227 -0
  24. aixtools/logging/logging_config.py +116 -0
  25. aixtools/logging/mcp_log_models.py +102 -0
  26. aixtools/logging/mcp_logger.py +172 -0
  27. aixtools/logging/model_patch_logging.py +87 -0
  28. aixtools/logging/open_telemetry.py +36 -0
  29. aixtools/mcp/__init__.py +9 -0
  30. aixtools/mcp/example_client.py +30 -0
  31. aixtools/mcp/example_server.py +22 -0
  32. aixtools/mcp/fast_mcp_log.py +31 -0
  33. aixtools/mcp/faulty_mcp.py +320 -0
  34. aixtools/model_patch/model_patch.py +65 -0
  35. aixtools/server/__init__.py +23 -0
  36. aixtools/server/app_mounter.py +90 -0
  37. aixtools/server/path.py +72 -0
  38. aixtools/server/utils.py +70 -0
  39. aixtools/testing/__init__.py +9 -0
  40. aixtools/testing/aix_test_model.py +147 -0
  41. aixtools/testing/mock_tool.py +66 -0
  42. aixtools/testing/model_patch_cache.py +279 -0
  43. aixtools/tools/doctor/__init__.py +3 -0
  44. aixtools/tools/doctor/tool_doctor.py +61 -0
  45. aixtools/tools/doctor/tool_recommendation.py +44 -0
  46. aixtools/utils/__init__.py +35 -0
  47. aixtools/utils/chainlit/cl_agent_show.py +82 -0
  48. aixtools/utils/chainlit/cl_utils.py +168 -0
  49. aixtools/utils/config.py +118 -0
  50. aixtools/utils/config_util.py +69 -0
  51. aixtools/utils/enum_with_description.py +37 -0
  52. aixtools/utils/persisted_dict.py +99 -0
  53. aixtools/utils/utils.py +160 -0
  54. aixtools-0.1.0.dist-info/METADATA +355 -0
  55. aixtools-0.1.0.dist-info/RECORD +58 -0
  56. aixtools-0.1.0.dist-info/WHEEL +5 -0
  57. aixtools-0.1.0.dist-info/entry_points.txt +2 -0
  58. aixtools-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,147 @@
1
+ """
2
+ Test model implementation for AI agent testing with predefined responses.
3
+ """
4
+
5
+ import asyncio
6
+ from collections.abc import AsyncIterator
7
+ from contextlib import asynccontextmanager
8
+ from types import AsyncGeneratorType
9
+
10
+ from pydantic import BaseModel
11
+ from pydantic_ai.messages import ModelMessage, ModelResponse, TextPart, ToolCallPart
12
+ from pydantic_ai.models import Model, ModelRequestParameters, StreamedResponse
13
+ from pydantic_ai.models.function import _estimate_usage
14
+ from pydantic_ai.models.test import TestStreamedResponse
15
+ from pydantic_ai.settings import ModelSettings
16
+
17
+ from ..utils.utils import async_iter
18
+
19
+ FINAL_RESULT_TOOL_NAME = "final_result"
20
+
21
+
22
+ def final_result_tool(result: BaseModel | dict) -> ToolCallPart:
23
+ """Create a ToolCallPart for the final result."""
24
+ if isinstance(result, BaseModel):
25
+ args = result.model_dump()
26
+ else:
27
+ args = result
28
+ return ToolCallPart(tool_name=FINAL_RESULT_TOOL_NAME, args=args)
29
+
30
+
31
+ class AixTestModel(Model):
32
+ """
33
+ Test model, returns a specified list of answers, including messages or tool calls
34
+ This is used for testing the agent and model interaction with the rest of the system.
35
+
36
+ responses: Is a list of strings or ToolCallPart objects that the model will return in order.
37
+
38
+ Note: The agent will continue to invoke 'request()' until it returns a txt (i.e. not a ToolCallPart).
39
+
40
+ Example: Unstructured output (text)
41
+ ```
42
+ model = AixTestModel(
43
+ responses=[
44
+ ToolCallPart(tool_name='send_message_to_user', args={'message': 'First, let me say hi...'}),
45
+ "Please invoke the agent again to continue the conversation....",
46
+ # ---------- The first time you invoke the agent, it will stop here ----------
47
+
48
+ ToolCallPart(tool_name='send_message_to_user', args={'message': 'Hi there again!'}),
49
+ "The 10th prime number is 29.",
50
+ # ---------- The second time you invoke the agent, it will stop here ----------
51
+
52
+ # If you invoke the agent again, it will continue raising an exception
53
+ # because there are no more responses
54
+ ]
55
+ )
56
+ ```
57
+
58
+ Example structured output:
59
+ ```
60
+ # Define a model for the final result
61
+ class MyResult(BaseModel):
62
+ text: str
63
+
64
+ model = AixTestModel(
65
+ responses=[
66
+ ToolCallPart(tool_name='send_message_to_user', args={'message': 'First, let me say hi...'}),
67
+ ToolCallPart(tool_name='send_message_to_user', args={'message': 'Hi there again!'}),
68
+ final_result_tool(MyResult(text='The 10th prime is 29')),
69
+ ]
70
+ )
71
+ ```
72
+ """
73
+
74
+ def __init__( # pylint: disable=super-init-not-called
75
+ self,
76
+ responses: list[str | TextPart | ToolCallPart] | AsyncGeneratorType,
77
+ sleep_time: float | None = None,
78
+ ):
79
+ self.response_iter = responses(self) if callable(responses) else async_iter(responses)
80
+ self.messages: list[ModelMessage] = None
81
+ self.sleep_time = sleep_time
82
+ self.last_model_request_parameters: ModelRequestParameters | None = None
83
+
84
+ @property
85
+ def last_message(self) -> ModelResponse | None:
86
+ """Return the last response."""
87
+ return self.messages[-1] if self.messages else None
88
+
89
+ @property
90
+ def last_message_part(self) -> TextPart | ToolCallPart | None:
91
+ """Return the last part of the response."""
92
+ return self.last_message.parts[-1] if self.last_message and self.last_message.parts else None
93
+
94
+ async def request(
95
+ self,
96
+ messages: list[ModelMessage],
97
+ model_settings: ModelSettings | None,
98
+ model_request_parameters: ModelRequestParameters,
99
+ ) -> ModelResponse:
100
+ self.last_model_request_parameters = model_request_parameters
101
+ model_response = await self._request(messages, model_settings, model_request_parameters)
102
+ model_response.usage = _estimate_usage([*messages, model_response])
103
+ return model_response
104
+
105
+ @asynccontextmanager
106
+ async def request_stream(
107
+ self,
108
+ messages: list[ModelMessage],
109
+ model_settings: ModelSettings | None,
110
+ model_request_parameters: ModelRequestParameters,
111
+ ) -> AsyncIterator[StreamedResponse]:
112
+ model_response = await self._request(messages, model_settings, model_request_parameters)
113
+ yield TestStreamedResponse(_model_name=self.model_name, _structured_response=model_response, _messages=messages)
114
+
115
+ @property
116
+ def model_name(self) -> str:
117
+ return self.__class__.__name__.lower()
118
+
119
+ @property
120
+ def system(self) -> str:
121
+ """The system / model provider."""
122
+ return "test_system"
123
+
124
+ async def _request(
125
+ self,
126
+ messages: list[ModelMessage],
127
+ model_settings: ModelSettings | None, # pylint: disable=unused-argument
128
+ model_request_parameters: ModelRequestParameters, # pylint: disable=unused-argument
129
+ ) -> ModelResponse:
130
+ self.messages = messages
131
+ res = await anext(self.response_iter, None)
132
+ assert res, "No more responses available."
133
+ if callable(res):
134
+ res = res(self, messages)
135
+ match res:
136
+ case str():
137
+ return ModelResponse(parts=[TextPart(res)], model_name=self.model_name)
138
+ case ToolCallPart():
139
+ return ModelResponse(parts=[res], model_name=self.model_name)
140
+ case TextPart():
141
+ return ModelResponse(parts=[res], model_name=self.model_name)
142
+ case Exception():
143
+ raise res
144
+ case _:
145
+ raise ValueError(f"Invalid response type: {type(res)}, response: {res}")
146
+ if self.sleep_time:
147
+ await asyncio.sleep(self.sleep_time)
@@ -0,0 +1,66 @@
1
+ """
2
+ Mock tool implementation for testing agent interactions.
3
+ """
4
+
5
+ import functools
6
+ import inspect
7
+ from typing import Any, Awaitable, Callable, List, TypeVar, Union
8
+
9
+ T = TypeVar("T")
10
+
11
+
12
+ def mock_tool(func: Callable[..., Union[T, Awaitable[T]]], return_values: List[Any]) -> Callable[..., Any]:
13
+ """
14
+ Creates a mock version of the provided function that returns values from a predefined list.
15
+ Supports both synchronous and asynchronous functions.
16
+
17
+ Args:
18
+ func: The original function to mock (can be sync or async)
19
+ return_values: A list of values to be returned sequentially on each call
20
+
21
+ Returns:
22
+ A new function with the same signature and docstring as the original function,
23
+ but returns values from the provided list sequentially.
24
+ If the original function is async, the mock will also be async.
25
+
26
+ Raises:
27
+ IndexError: When the mock function is called more times than there are return values
28
+ """
29
+ # Make a copy of the return values to avoid modifying the original list
30
+ values = return_values.copy()
31
+
32
+ # Check if the function is asynchronous
33
+ is_async = inspect.iscoroutinefunction(func)
34
+
35
+ if is_async:
36
+ # Create an async wrapper for async functions
37
+ @functools.wraps(func)
38
+ async def async_mock_wrapper(*args, **kwargs): # pylint: disable=unused-argument
39
+ # Check if we have any return values left
40
+ if not values:
41
+ raise IndexError(
42
+ f"No more mock return values available for {func.__name__}. "
43
+ f"The function has been called more times than the number of provided return values."
44
+ )
45
+
46
+ # Return and remove the first value from our list
47
+ return values.pop(0)
48
+
49
+ # Return the async wrapper function
50
+ return async_mock_wrapper
51
+
52
+ # Create a sync wrapper for sync functions
53
+ @functools.wraps(func)
54
+ def sync_mock_wrapper(*args, **kwargs): # pylint: disable=unused-argument
55
+ # Check if we have any return values left
56
+ if not values:
57
+ raise IndexError(
58
+ f"No more mock return values available for {func.__name__}. "
59
+ f"The function has been called more times than the number of provided return values."
60
+ )
61
+
62
+ # Return and remove the first value from our list
63
+ return values.pop(0)
64
+
65
+ # Return the sync wrapper function
66
+ return sync_mock_wrapper
@@ -0,0 +1,279 @@
1
+ """
2
+ This module is useful for testing, so that we can run tests without having to call the real function.
3
+ The cached values are stored in `cache_file`.
4
+
5
+ In `learn=False` mode:
6
+ - The model_request_cache_wrapper first tries to answer from the cached value.
7
+ - If the cached value does not exist, it raises an exception.
8
+
9
+ In `learn=True` mode:
10
+ - The model_request_cache_wrapper first tries to answer from the cached value.
11
+ - If the cached value does not exist, it invokes the real function
12
+ and adds the new result to `cache_file` for future use.
13
+
14
+ Values are saved to pickle files, but since objects could be non-picklable, we use `safe_deepcopy_for_cache`.
15
+ """
16
+
17
+ import datetime
18
+ import functools
19
+ import hashlib
20
+ import json
21
+ import pickle
22
+ import uuid
23
+ from contextlib import asynccontextmanager
24
+ from pathlib import Path
25
+ from typing import Any, Dict, Mapping, Sequence, Set, Tuple, Type
26
+
27
+ from aixtools.logging.log_objects import safe_deepcopy
28
+ from aixtools.logging.logging_config import get_logger
29
+ from aixtools.model_patch.model_patch import get_request_fn, get_request_stream_fn, model_patch
30
+
31
+ logger = get_logger(__name__)
32
+
33
+
34
+ class CacheKeyError(Exception):
35
+ """Exception raised when a key is not found in the cache."""
36
+
37
+
38
+ def safe_deepcopy_for_cache_key(obj: Any, normalize_types: Set[Type] = None) -> Any:
39
+ """
40
+ A modified version of safe_deepcopy that normalizes or skips fields based on their types.
41
+ This is useful for generating cache keys where some fields (like timestamps, UUIDs)
42
+ should be normalized to ensure consistent cache hits.
43
+
44
+ Args:
45
+ obj: The object to copy
46
+ normalize_types: Set of types to normalize (replace with placeholders)
47
+
48
+ Returns:
49
+ A deep copy of the object with normalized values for specified types
50
+ """
51
+ # Default types to normalize if none provided
52
+ if normalize_types is None:
53
+ normalize_types = {datetime.datetime, datetime.date, datetime.time, uuid.UUID}
54
+
55
+ # Check if the object is of a type that should be normalized
56
+ if any(isinstance(obj, t) for t in normalize_types):
57
+ # Replace with a placeholder indicating the type
58
+ return f"<{type(obj).__name__}>"
59
+
60
+ # Check if the object is primitive (int, float, str, bool)
61
+ if isinstance(obj, (int, float, str, bool)):
62
+ return obj
63
+
64
+ # Handle mappings (dict and other mapping types)
65
+ if isinstance(obj, Mapping):
66
+ return {k: safe_deepcopy_for_cache_key(v, normalize_types) for k, v in obj.items()}
67
+
68
+ # Handle sequences (list, tuple, and other sequence types) but not strings
69
+ if isinstance(obj, Sequence) and not isinstance(obj, str):
70
+ return [safe_deepcopy_for_cache_key(item, normalize_types) for item in obj]
71
+
72
+ # Handle objects with __dict__ attribute (custom classes)
73
+ if hasattr(obj, "__dict__"):
74
+ # For objects, we create a dictionary representation
75
+ result = {}
76
+ for attr, value in vars(obj).items():
77
+ result[attr] = safe_deepcopy_for_cache_key(value, normalize_types)
78
+ return result
79
+
80
+ # For other types, return a string representation
81
+ return f"<{type(obj).__name__}>"
82
+
83
+
84
+ def generate_cache_key(method_name: str, args: Tuple, kwargs: Dict) -> str:
85
+ """
86
+ Generate a unique cache key based on the method name and its arguments.
87
+ Uses safe_deepcopy_for_cache to normalize values that change frequently.
88
+
89
+ Args:
90
+ method_name: Name of the method being called
91
+ args: Positional arguments to the method
92
+ kwargs: Keyword arguments to the method
93
+
94
+ Returns:
95
+ A string hash that uniquely identifies this method call
96
+ """
97
+ # Normalize the arguments and kwargs
98
+ normalized_args = safe_deepcopy_for_cache_key(args)
99
+ normalized_kwargs = safe_deepcopy_for_cache_key(kwargs)
100
+
101
+ # Create a dictionary with the normalized information
102
+ key_dict = {"method_name": method_name, "args": normalized_args, "kwargs": normalized_kwargs}
103
+
104
+ # Convert to a stable string representation and hash it
105
+ key_str = json.dumps(key_dict, sort_keys=True, default=str)
106
+ return hashlib.md5(key_str.encode()).hexdigest()
107
+
108
+
109
+ def save_to_cache(cache_file: Path, key: str, value: Any) -> None:
110
+ """
111
+ Save a value to the cache file using the given key.
112
+
113
+ Args:
114
+ cache_file: Path to the cache file
115
+ key: The cache key
116
+ value: The value to cache (will be deep-copied to ensure it's pickable)
117
+ """
118
+ # Create parent directories if they don't exist
119
+ cache_file.parent.mkdir(parents=True, exist_ok=True)
120
+
121
+ # Load existing cache
122
+ cache = {}
123
+ if cache_file.exists():
124
+ try:
125
+ with open(cache_file, "rb") as f:
126
+ cache = pickle.load(f)
127
+ except (pickle.PickleError, EOFError):
128
+ # If the file is corrupted, start with an empty cache
129
+ cache = {}
130
+
131
+ # Make a safe copy of the value and add it to the cache
132
+ safe_value = safe_deepcopy(value)
133
+ cache[key] = safe_value
134
+ logger.debug("Cache updated: %s -> %r", key, safe_value)
135
+
136
+ # Save the updated cache
137
+ with open(cache_file, "wb") as f:
138
+ pickle.dump(cache, f)
139
+
140
+
141
+ def get_from_cache(cache_file: Path, key: str) -> Any:
142
+ """
143
+ Retrieve a value from the cache file using the given key.
144
+
145
+ Args:
146
+ cache_file: Path to the cache file
147
+ key: The cache key
148
+
149
+ Returns:
150
+ The cached value
151
+
152
+ Raises:
153
+ CacheKeyError: If the key is not found in the cache
154
+ """
155
+ if not cache_file.exists():
156
+ raise CacheKeyError(f"Cache file {cache_file} does not exist")
157
+ try:
158
+ with open(cache_file, "rb") as f:
159
+ cache = pickle.load(f)
160
+ except (pickle.PickleError, EOFError) as e:
161
+ raise CacheKeyError(f"Cache file {cache_file} is corrupted") from e
162
+ if key not in cache:
163
+ raise CacheKeyError(f"Key {key} not found in cache")
164
+ return cache[key]
165
+
166
+
167
+ def request_cache(fn, cache_file: Path, learn=False):
168
+ """
169
+ model_request_cache_wrapper for async method calls that uses a cache.
170
+
171
+ In learn=False mode:
172
+ - Tries to answer from cached value
173
+ - If the cached value does not exist, raises an exception
174
+
175
+ In learn=True mode:
176
+ - Tries to answer from cached value
177
+ - If the cached value does not exist, invokes the real function and adds the result to cache
178
+
179
+ Args:
180
+ fn: The async function to wrap
181
+ cache_file: Path to the cache file
182
+ learn: Whether to learn new responses (True) or only use cached ones (False)
183
+ """
184
+
185
+ @functools.wraps(fn)
186
+ async def model_request_cache_wrapper(*args, **kwargs):
187
+ # Generate a unique cache key for this request
188
+ cache_key = generate_cache_key(fn.__name__, args, kwargs)
189
+ try:
190
+ # Try to get the result from cache
191
+ result = get_from_cache(cache_file, cache_key)
192
+ logger.debug("Cache hit for %s with key %s", fn.__name__, cache_key)
193
+ return result
194
+ except CacheKeyError as e:
195
+ # If not in cache and learn=False, raise an exception
196
+ if not learn:
197
+ raise CacheKeyError(f"No cached response for {fn.__name__} with key {cache_key} and learn=False") from e
198
+ # If learn=True, invoke the original method
199
+ result = await fn(*args, **kwargs)
200
+ # Save the result to cache
201
+ save_to_cache(cache_file, cache_key, result)
202
+ return result
203
+
204
+ return model_request_cache_wrapper
205
+
206
+
207
+ def request_stream_cache(fn, cache_file: Path, learn=False):
208
+ """
209
+ model_request_cache_wrapper for async streaming method calls that uses a cache.
210
+
211
+ Similar to request_cache, but handles streaming responses by caching all items
212
+ and then replaying them when retrieved from cache.
213
+
214
+ Args:
215
+ fn: The async context manager function to wrap
216
+ cache_file: Path to the cache file
217
+ learn: Whether to learn new responses (True) or only use cached ones (False)
218
+ """
219
+
220
+ @functools.wraps(fn)
221
+ @asynccontextmanager
222
+ async def model_request_stream_cache_wrapper(*args, **kwargs):
223
+ # Generate a unique cache key for this request
224
+ cache_key = generate_cache_key(fn.__name__, args, kwargs)
225
+
226
+ try:
227
+ # Try to get the cached items
228
+ cached_items = get_from_cache(cache_file, cache_key)
229
+
230
+ # Define a generator that yields the cached items
231
+ async def cached_gen():
232
+ for item in cached_items:
233
+ yield item
234
+
235
+ yield cached_gen()
236
+
237
+ except CacheKeyError as e:
238
+ # If not in cache and learn=False, raise an exception
239
+ if not learn:
240
+ raise CacheKeyError(f"No cached stream for {fn.__name__} with key {cache_key} and learn=False") from e
241
+
242
+ # If learn=True, invoke the original method
243
+ async with fn(*args, **kwargs) as stream:
244
+ # Collect all items to save to cache
245
+ all_items = []
246
+
247
+ async def gen():
248
+ item_num = 0
249
+ async for item in stream:
250
+ all_items.append(item)
251
+ item_num += 1
252
+ yield item
253
+
254
+ # Yield the generator
255
+ gen_instance = gen()
256
+ yield gen_instance
257
+
258
+ # After the context manager exits, save all items to cache
259
+ # We need to make sure all items have been consumed
260
+ try:
261
+ async for _ in gen_instance:
262
+ pass
263
+ except StopAsyncIteration:
264
+ pass
265
+
266
+ # Save the collected items to cache
267
+ save_to_cache(cache_file, cache_key, all_items)
268
+
269
+ return model_request_stream_cache_wrapper
270
+
271
+
272
+ def model_patch_cache(model, cache_file: Path, learn=False):
273
+ """Patch model with methods for caching requests and responses"""
274
+ logger.debug("Using cache file: %s", cache_file)
275
+ return model_patch(
276
+ model,
277
+ request_method=request_cache(get_request_fn(model), cache_file, learn),
278
+ request_stream_method=request_stream_cache(get_request_stream_fn(model), cache_file, learn),
279
+ )
@@ -0,0 +1,3 @@
1
+ from .tool_doctor import tool_doctor
2
+
3
+ __all__ = ["tool_doctor"]
@@ -0,0 +1,61 @@
1
+ from aixtools.agents import get_agent, run_agent
2
+ from aixtools.tools.doctor.tool_recommendation import ToolRecommendation
3
+
4
+ TOOL_DOCTOR_PROMPT = """
5
+ ## Tool doctor
6
+ You are helping to debug common errors in tools and tool definitions.
7
+
8
+ Given the tools, for each tool you will give feedback about the tool
9
+ definition.
10
+
11
+ 1. Name: Check if the tool name is descriptive and follows naming conventions.
12
+ 2. Description: Ensure the tool's description is clear and provides enough
13
+ detail
14
+ so that users understand its purpose and functionality.
15
+ 3. Return type: Check if the return type is specified and matches the tool's
16
+ functionality.
17
+ 4. Arguments: Verify that the tool arguments are well-defined and include types
18
+ and descriptions. Ensure that argument names are descriptive and follow
19
+ naming conventions.
20
+ 5. Look for any missing or redundant information in the tool definition.
21
+
22
+ Some rules:
23
+ - Ignore a tool called 'final_result'.
24
+ - Do not suggest change if the tool is already well-defined.
25
+ - Don't be nitpicking, focus on significant improvements that can be made to
26
+ the tool definition.
27
+ - Don't suggest trivial improvements or changes for things that are
28
+ self-evident or already well-defined.
29
+ """
30
+
31
+
32
+ async def tool_doctor_single(tool) -> ToolRecommendation:
33
+ """Run the tool doctor agent to analyze a single tool"""
34
+ agent = get_agent(tools=[tool], output_type=ToolRecommendation)
35
+ ret = await run_agent(agent, TOOL_DOCTOR_PROMPT, log_model_requests=True, verbose=True, debug=True)
36
+ return ret[0] # type: ignore
37
+
38
+
39
+ async def tool_doctor_multiple(tools: list) -> list[ToolRecommendation]:
40
+ """Run the tool doctor agent to analyze the tools"""
41
+ agent = get_agent(tools=tools, output_type=list[ToolRecommendation])
42
+ ret = await run_agent(agent, TOOL_DOCTOR_PROMPT)
43
+ return ret[0] # type: ignore
44
+
45
+
46
+ async def tool_doctor(tools: list, max_tools_per_batch=5, verbose=True) -> list[ToolRecommendation]:
47
+ """Run the tool doctor agent to analyze tools and give recommendations."""
48
+ # Split tools into batches if they exceed the max_tools_per_batch limit
49
+ results = []
50
+ for i in range(0, len(tools), max_tools_per_batch):
51
+ batch = tools[i : i + max_tools_per_batch]
52
+ batch_num = i // max_tools_per_batch + 1
53
+ tool_names = [t.__name__ for t in batch]
54
+ print(f"Processing batch {batch_num} with {len(batch)} tools: {tool_names}")
55
+ ret = await tool_doctor_multiple(batch)
56
+ results.extend(ret)
57
+ # Print results if verbose
58
+ if verbose:
59
+ for r in results:
60
+ print(r)
61
+ return results
@@ -0,0 +1,44 @@
1
+ from pydantic import BaseModel
2
+
3
+
4
+ class ArgumentRecommendation(BaseModel):
5
+ """A recommendation for an argument"""
6
+
7
+ arg_name: str
8
+ arg_type: str
9
+ arg_description_improvement: str
10
+
11
+ def __str__(self):
12
+ return f"{self.arg_name} ({self.arg_type}): {self.arg_description_improvement}"
13
+
14
+
15
+ class ToolRecommendation(BaseModel):
16
+ """A recomendation for a tool"""
17
+
18
+ name: str
19
+ needs_improvement: bool
20
+ description_improvement: str
21
+ arguments: list[ArgumentRecommendation]
22
+ return_description_improvement: str
23
+
24
+ def signature(self):
25
+ """Generate the function signature for the tool"""
26
+ signature = f"def {self.name}("
27
+ if self.arguments:
28
+ signature += ", ".join([f"{arg.arg_name}: {arg.arg_type}" for arg in self.arguments])
29
+ signature += ")"
30
+ return signature
31
+
32
+ def __str__(self):
33
+ out = f"Tool name: '{self.name}'\n"
34
+ out += f" - Signature: {self.signature()}\n"
35
+ out += f" - Needs improvement: {self.needs_improvement}\n"
36
+ if self.description_improvement:
37
+ out += f" - Description improvement: {self.description_improvement}\n"
38
+ if self.return_description_improvement:
39
+ out += f" - Return description improvement: {self.return_description_improvement}\n"
40
+ if self.arguments:
41
+ out += " - Arguments:\n"
42
+ for arg in self.arguments:
43
+ out += f" - {arg}\n"
44
+ return out
@@ -0,0 +1,35 @@
1
+ """
2
+ Utils package initialization.
3
+ """
4
+
5
+ from aixtools.logging.logging_config import get_logger # pylint: disable=import-error
6
+ from aixtools.utils import config
7
+ from aixtools.utils.enum_with_description import EnumWithDescription
8
+ from aixtools.utils.persisted_dict import PersistedDict
9
+ from aixtools.utils.utils import (
10
+ escape_backticks,
11
+ escape_newline,
12
+ find_file,
13
+ prepend_all_lines,
14
+ remove_quotes,
15
+ tabit,
16
+ to_str,
17
+ tripple_quote_strip,
18
+ truncate,
19
+ )
20
+
21
+ __all__ = [
22
+ "config",
23
+ "PersistedDict",
24
+ "EnumWithDescription",
25
+ "escape_newline",
26
+ "escape_backticks",
27
+ "find_file",
28
+ "get_logger",
29
+ "prepend_all_lines",
30
+ "remove_quotes",
31
+ "tabit",
32
+ "to_str",
33
+ "truncate",
34
+ "tripple_quote_strip",
35
+ ]