airtrain 0.1.2__py3-none-any.whl → 0.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airtrain/__init__.py +148 -2
- airtrain/__main__.py +4 -0
- airtrain/__pycache__/__init__.cpython-313.pyc +0 -0
- airtrain/agents/__init__.py +45 -0
- airtrain/agents/example_agent.py +348 -0
- airtrain/agents/groq_agent.py +289 -0
- airtrain/agents/memory.py +663 -0
- airtrain/agents/registry.py +465 -0
- airtrain/builder/__init__.py +3 -0
- airtrain/builder/agent_builder.py +122 -0
- airtrain/cli/__init__.py +0 -0
- airtrain/cli/builder.py +23 -0
- airtrain/cli/main.py +120 -0
- airtrain/contrib/__init__.py +29 -0
- airtrain/contrib/travel/__init__.py +35 -0
- airtrain/contrib/travel/agents.py +243 -0
- airtrain/contrib/travel/models.py +59 -0
- airtrain/core/__init__.py +7 -0
- airtrain/core/__pycache__/__init__.cpython-313.pyc +0 -0
- airtrain/core/__pycache__/schemas.cpython-313.pyc +0 -0
- airtrain/core/__pycache__/skills.cpython-313.pyc +0 -0
- airtrain/core/credentials.py +171 -0
- airtrain/core/schemas.py +237 -0
- airtrain/core/skills.py +269 -0
- airtrain/integrations/__init__.py +74 -0
- airtrain/integrations/anthropic/__init__.py +33 -0
- airtrain/integrations/anthropic/credentials.py +32 -0
- airtrain/integrations/anthropic/list_models.py +110 -0
- airtrain/integrations/anthropic/models_config.py +100 -0
- airtrain/integrations/anthropic/skills.py +155 -0
- airtrain/integrations/aws/__init__.py +6 -0
- airtrain/integrations/aws/credentials.py +36 -0
- airtrain/integrations/aws/skills.py +98 -0
- airtrain/integrations/cerebras/__init__.py +6 -0
- airtrain/integrations/cerebras/credentials.py +19 -0
- airtrain/integrations/cerebras/skills.py +127 -0
- airtrain/integrations/combined/__init__.py +21 -0
- airtrain/integrations/combined/groq_fireworks_skills.py +126 -0
- airtrain/integrations/combined/list_models_factory.py +210 -0
- airtrain/integrations/fireworks/__init__.py +21 -0
- airtrain/integrations/fireworks/completion_skills.py +147 -0
- airtrain/integrations/fireworks/conversation_manager.py +109 -0
- airtrain/integrations/fireworks/credentials.py +26 -0
- airtrain/integrations/fireworks/list_models.py +128 -0
- airtrain/integrations/fireworks/models.py +139 -0
- airtrain/integrations/fireworks/requests_skills.py +207 -0
- airtrain/integrations/fireworks/skills.py +181 -0
- airtrain/integrations/fireworks/structured_completion_skills.py +175 -0
- airtrain/integrations/fireworks/structured_requests_skills.py +291 -0
- airtrain/integrations/fireworks/structured_skills.py +102 -0
- airtrain/integrations/google/__init__.py +7 -0
- airtrain/integrations/google/credentials.py +58 -0
- airtrain/integrations/google/skills.py +122 -0
- airtrain/integrations/groq/__init__.py +23 -0
- airtrain/integrations/groq/credentials.py +24 -0
- airtrain/integrations/groq/models_config.py +162 -0
- airtrain/integrations/groq/skills.py +201 -0
- airtrain/integrations/ollama/__init__.py +6 -0
- airtrain/integrations/ollama/credentials.py +26 -0
- airtrain/integrations/ollama/skills.py +41 -0
- airtrain/integrations/openai/__init__.py +37 -0
- airtrain/integrations/openai/chinese_assistant.py +42 -0
- airtrain/integrations/openai/credentials.py +39 -0
- airtrain/integrations/openai/list_models.py +112 -0
- airtrain/integrations/openai/models_config.py +224 -0
- airtrain/integrations/openai/skills.py +342 -0
- airtrain/integrations/perplexity/__init__.py +49 -0
- airtrain/integrations/perplexity/credentials.py +43 -0
- airtrain/integrations/perplexity/list_models.py +112 -0
- airtrain/integrations/perplexity/models_config.py +128 -0
- airtrain/integrations/perplexity/skills.py +279 -0
- airtrain/integrations/sambanova/__init__.py +6 -0
- airtrain/integrations/sambanova/credentials.py +20 -0
- airtrain/integrations/sambanova/skills.py +129 -0
- airtrain/integrations/search/__init__.py +21 -0
- airtrain/integrations/search/exa/__init__.py +23 -0
- airtrain/integrations/search/exa/credentials.py +30 -0
- airtrain/integrations/search/exa/schemas.py +114 -0
- airtrain/integrations/search/exa/skills.py +115 -0
- airtrain/integrations/together/__init__.py +33 -0
- airtrain/integrations/together/audio_models_config.py +34 -0
- airtrain/integrations/together/credentials.py +22 -0
- airtrain/integrations/together/embedding_models_config.py +92 -0
- airtrain/integrations/together/image_models_config.py +69 -0
- airtrain/integrations/together/image_skill.py +143 -0
- airtrain/integrations/together/list_models.py +76 -0
- airtrain/integrations/together/models.py +95 -0
- airtrain/integrations/together/models_config.py +399 -0
- airtrain/integrations/together/rerank_models_config.py +43 -0
- airtrain/integrations/together/rerank_skill.py +49 -0
- airtrain/integrations/together/schemas.py +33 -0
- airtrain/integrations/together/skills.py +305 -0
- airtrain/integrations/together/vision_models_config.py +49 -0
- airtrain/telemetry/__init__.py +38 -0
- airtrain/telemetry/service.py +167 -0
- airtrain/telemetry/views.py +237 -0
- airtrain/tools/__init__.py +45 -0
- airtrain/tools/command.py +398 -0
- airtrain/tools/filesystem.py +166 -0
- airtrain/tools/network.py +111 -0
- airtrain/tools/registry.py +320 -0
- airtrain/tools/search.py +450 -0
- airtrain/tools/testing.py +135 -0
- airtrain-0.1.4.dist-info/METADATA +222 -0
- airtrain-0.1.4.dist-info/RECORD +108 -0
- {airtrain-0.1.2.dist-info → airtrain-0.1.4.dist-info}/WHEEL +1 -1
- airtrain-0.1.4.dist-info/entry_points.txt +2 -0
- airtrain-0.1.2.dist-info/METADATA +0 -106
- airtrain-0.1.2.dist-info/RECORD +0 -5
- {airtrain-0.1.2.dist-info → airtrain-0.1.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,166 @@
|
|
1
|
+
import os
|
2
|
+
import json
|
3
|
+
import subprocess
|
4
|
+
from pathlib import Path
|
5
|
+
|
6
|
+
from .registry import StatelessTool, register_tool
|
7
|
+
|
8
|
+
|
9
|
+
@register_tool("list_directory")
|
10
|
+
class ListDirectoryTool(StatelessTool):
|
11
|
+
"""Tool for listing contents of a directory."""
|
12
|
+
|
13
|
+
def __init__(self):
|
14
|
+
self.name = "list_directory"
|
15
|
+
self.description = "List the contents of a directory, showing files and subdirectories"
|
16
|
+
self.parameters = {
|
17
|
+
"type": "object",
|
18
|
+
"properties": {
|
19
|
+
"path": {
|
20
|
+
"type": "string",
|
21
|
+
"description": "Path to the directory to list. "
|
22
|
+
"Defaults to current directory if not provided."
|
23
|
+
},
|
24
|
+
"show_hidden": {
|
25
|
+
"type": "boolean",
|
26
|
+
"description": "Whether to show hidden files (starting with .)"
|
27
|
+
}
|
28
|
+
},
|
29
|
+
"required": []
|
30
|
+
}
|
31
|
+
|
32
|
+
def __call__(self, path: str = ".", show_hidden: bool = False) -> str:
|
33
|
+
"""List the contents of a directory."""
|
34
|
+
try:
|
35
|
+
path = os.path.expanduser(path)
|
36
|
+
if not os.path.exists(path):
|
37
|
+
return f"Error: Path '{path}' does not exist"
|
38
|
+
|
39
|
+
if not os.path.isdir(path):
|
40
|
+
return f"Error: Path '{path}' is not a directory"
|
41
|
+
|
42
|
+
items = []
|
43
|
+
for item in os.listdir(path):
|
44
|
+
if not show_hidden and item.startswith('.'):
|
45
|
+
continue
|
46
|
+
|
47
|
+
item_path = os.path.join(path, item)
|
48
|
+
item_type = "directory" if os.path.isdir(item_path) else "file"
|
49
|
+
size = os.path.getsize(item_path) if os.path.isfile(item_path) else None
|
50
|
+
|
51
|
+
items.append({
|
52
|
+
"name": item,
|
53
|
+
"type": item_type,
|
54
|
+
"size": size
|
55
|
+
})
|
56
|
+
|
57
|
+
return json.dumps({"path": path, "items": items}, indent=2)
|
58
|
+
except Exception as e:
|
59
|
+
return f"Error listing directory: {str(e)}"
|
60
|
+
|
61
|
+
def to_dict(self):
|
62
|
+
"""Convert tool to dictionary format for LLM function calling."""
|
63
|
+
return {
|
64
|
+
"type": "function",
|
65
|
+
"function": {
|
66
|
+
"name": self.name,
|
67
|
+
"description": self.description,
|
68
|
+
"parameters": self.parameters
|
69
|
+
}
|
70
|
+
}
|
71
|
+
|
72
|
+
|
73
|
+
@register_tool("directory_tree")
|
74
|
+
class DirectoryTreeTool(StatelessTool):
|
75
|
+
"""Tool for displaying directory structure as a tree."""
|
76
|
+
|
77
|
+
def __init__(self):
|
78
|
+
self.name = "directory_tree"
|
79
|
+
self.description = "Display the directory structure as a tree, " \
|
80
|
+
"showing the hierarchy of files and directories"
|
81
|
+
self.parameters = {
|
82
|
+
"type": "object",
|
83
|
+
"properties": {
|
84
|
+
"path": {
|
85
|
+
"type": "string",
|
86
|
+
"description": "Path to the root directory. " \
|
87
|
+
"Defaults to current directory if not provided."
|
88
|
+
},
|
89
|
+
"max_depth": {
|
90
|
+
"type": "integer",
|
91
|
+
"description": "Maximum depth of subdirectories to display"
|
92
|
+
},
|
93
|
+
"show_hidden": {
|
94
|
+
"type": "boolean",
|
95
|
+
"description": "Whether to show hidden files (starting with .)"
|
96
|
+
}
|
97
|
+
},
|
98
|
+
"required": []
|
99
|
+
}
|
100
|
+
|
101
|
+
def __call__(self, path: str = ".", max_depth: int = 3, show_hidden: bool = False) -> str:
|
102
|
+
"""Display the directory structure as a tree."""
|
103
|
+
try:
|
104
|
+
path = os.path.expanduser(path)
|
105
|
+
if not os.path.exists(path):
|
106
|
+
return f"Error: Path '{path}' does not exist"
|
107
|
+
|
108
|
+
if not os.path.isdir(path):
|
109
|
+
return f"Error: Path '{path}' is not a directory"
|
110
|
+
|
111
|
+
# Try to use external 'tree' command if available
|
112
|
+
try:
|
113
|
+
cmd = ["tree", path, "-L", str(max_depth)]
|
114
|
+
if not show_hidden:
|
115
|
+
cmd.append("-I")
|
116
|
+
cmd.append(".*")
|
117
|
+
|
118
|
+
result = subprocess.run(cmd, capture_output=True, text=True)
|
119
|
+
if result.returncode == 0:
|
120
|
+
return result.stdout
|
121
|
+
except (subprocess.SubprocessError, FileNotFoundError):
|
122
|
+
# If 'tree' command fails or is not available, fall back to custom implementation
|
123
|
+
pass
|
124
|
+
|
125
|
+
# Custom tree implementation
|
126
|
+
result = [f"Directory tree for {path}:"]
|
127
|
+
|
128
|
+
def add_to_tree(directory: Path, prefix: str = "", depth: int = 0):
|
129
|
+
if depth > max_depth:
|
130
|
+
return
|
131
|
+
|
132
|
+
try:
|
133
|
+
entries = sorted(directory.iterdir(),
|
134
|
+
key=lambda x: (x.is_file(), x.name))
|
135
|
+
|
136
|
+
for i, entry in enumerate(entries):
|
137
|
+
if not show_hidden and entry.name.startswith('.'):
|
138
|
+
continue
|
139
|
+
|
140
|
+
is_last = i == len(entries) - 1
|
141
|
+
result.append(f"{prefix}{'└── ' if is_last else '├── '}{entry.name}")
|
142
|
+
|
143
|
+
if entry.is_dir():
|
144
|
+
add_to_tree(
|
145
|
+
entry,
|
146
|
+
prefix + (' ' if is_last else '│ '),
|
147
|
+
depth + 1
|
148
|
+
)
|
149
|
+
except PermissionError:
|
150
|
+
result.append(f"{prefix}└── [Permission denied]")
|
151
|
+
|
152
|
+
add_to_tree(Path(path))
|
153
|
+
return "\n".join(result)
|
154
|
+
except Exception as e:
|
155
|
+
return f"Error creating directory tree: {str(e)}"
|
156
|
+
|
157
|
+
def to_dict(self):
|
158
|
+
"""Convert tool to dictionary format for LLM function calling."""
|
159
|
+
return {
|
160
|
+
"type": "function",
|
161
|
+
"function": {
|
162
|
+
"name": self.name,
|
163
|
+
"description": self.description,
|
164
|
+
"parameters": self.parameters
|
165
|
+
}
|
166
|
+
}
|
@@ -0,0 +1,111 @@
|
|
1
|
+
import requests
|
2
|
+
from typing import Dict, Any, Optional, Union
|
3
|
+
from urllib.parse import urlparse
|
4
|
+
|
5
|
+
from .registry import StatelessTool, register_tool
|
6
|
+
|
7
|
+
|
8
|
+
@register_tool("api_call")
|
9
|
+
class ApiCallTool(StatelessTool):
|
10
|
+
"""Tool for making HTTP API calls."""
|
11
|
+
|
12
|
+
def __init__(self):
|
13
|
+
self.name = "api_call"
|
14
|
+
self.description = "Make HTTP requests to external APIs"
|
15
|
+
self.parameters = {
|
16
|
+
"type": "object",
|
17
|
+
"properties": {
|
18
|
+
"url": {
|
19
|
+
"type": "string",
|
20
|
+
"description": "URL to make the request to"
|
21
|
+
},
|
22
|
+
"method": {
|
23
|
+
"type": "string",
|
24
|
+
"enum": ["GET", "POST", "PUT", "DELETE", "PATCH"],
|
25
|
+
"description": "HTTP method to use for the request"
|
26
|
+
},
|
27
|
+
"headers": {
|
28
|
+
"type": "object",
|
29
|
+
"description": "HTTP headers to include in the request"
|
30
|
+
},
|
31
|
+
"params": {
|
32
|
+
"type": "object",
|
33
|
+
"description": "URL parameters for the request"
|
34
|
+
},
|
35
|
+
"data": {
|
36
|
+
"type": "object",
|
37
|
+
"description": "Data to send in the request body (for POST, PUT, PATCH)"
|
38
|
+
},
|
39
|
+
"json_data": {
|
40
|
+
"type": "object",
|
41
|
+
"description": "JSON data to send in the request body (for POST, PUT, PATCH)"
|
42
|
+
},
|
43
|
+
"timeout": {
|
44
|
+
"type": "number",
|
45
|
+
"description": "Request timeout in seconds"
|
46
|
+
}
|
47
|
+
},
|
48
|
+
"required": ["url", "method"]
|
49
|
+
}
|
50
|
+
|
51
|
+
def __call__(
|
52
|
+
self,
|
53
|
+
url: str,
|
54
|
+
method: str = "GET",
|
55
|
+
headers: Optional[Dict[str, str]] = None,
|
56
|
+
params: Optional[Dict[str, str]] = None,
|
57
|
+
data: Optional[Union[Dict[str, Any], str]] = None,
|
58
|
+
json_data: Optional[Dict[str, Any]] = None,
|
59
|
+
timeout: float = 10.0
|
60
|
+
) -> Dict[str, Any]:
|
61
|
+
"""Make an HTTP request to the specified URL."""
|
62
|
+
try:
|
63
|
+
# Validate URL
|
64
|
+
parsed_url = urlparse(url)
|
65
|
+
if not parsed_url.scheme or not parsed_url.netloc:
|
66
|
+
return {"error": f"Invalid URL '{url}'"}
|
67
|
+
|
68
|
+
# Prepare request
|
69
|
+
method = method.upper()
|
70
|
+
if method not in ["GET", "POST", "PUT", "DELETE", "PATCH"]:
|
71
|
+
return {"error": f"Unsupported HTTP method '{method}'"}
|
72
|
+
|
73
|
+
# Make request
|
74
|
+
response = requests.request(
|
75
|
+
method=method,
|
76
|
+
url=url,
|
77
|
+
headers=headers,
|
78
|
+
params=params,
|
79
|
+
data=data,
|
80
|
+
json=json_data,
|
81
|
+
timeout=timeout
|
82
|
+
)
|
83
|
+
|
84
|
+
# Try to parse response as JSON
|
85
|
+
try:
|
86
|
+
json_result = response.json()
|
87
|
+
response_data = json_result
|
88
|
+
except ValueError:
|
89
|
+
# Not JSON, return text
|
90
|
+
response_data = response.text
|
91
|
+
|
92
|
+
return {
|
93
|
+
"status_code": response.status_code,
|
94
|
+
"headers": dict(response.headers),
|
95
|
+
"content": response_data
|
96
|
+
}
|
97
|
+
except requests.exceptions.RequestException as e:
|
98
|
+
return {"error": f"Error making API request: {str(e)}"}
|
99
|
+
except Exception as e:
|
100
|
+
return {"error": f"Error: {str(e)}"}
|
101
|
+
|
102
|
+
def to_dict(self):
|
103
|
+
"""Convert tool to dictionary format for LLM function calling."""
|
104
|
+
return {
|
105
|
+
"type": "function",
|
106
|
+
"function": {
|
107
|
+
"name": self.name,
|
108
|
+
"description": self.description,
|
109
|
+
"parameters": self.parameters
|
110
|
+
}
|
111
|
+
}
|
@@ -0,0 +1,320 @@
|
|
1
|
+
"""
|
2
|
+
Tool Registry System for AirTrain
|
3
|
+
|
4
|
+
This module provides a registry system for tools that can be used by AI agents.
|
5
|
+
It supports both stateful tools (requiring fresh instances) and stateless tools
|
6
|
+
(which can be shared/reused).
|
7
|
+
|
8
|
+
The registry system includes:
|
9
|
+
- Validation mechanisms for tools
|
10
|
+
- Registration decorators
|
11
|
+
- Factory methods for tool creation
|
12
|
+
- Discovery utilities for finding available tools
|
13
|
+
"""
|
14
|
+
|
15
|
+
from abc import ABC, abstractmethod
|
16
|
+
from typing import Dict, List, Type, Any, Optional, TypeVar
|
17
|
+
|
18
|
+
|
19
|
+
# Type variable for tool classes
|
20
|
+
T = TypeVar('T', bound='BaseTool')
|
21
|
+
|
22
|
+
# Registry structure: {"stateful": {}, "stateless": {}}
|
23
|
+
TOOL_REGISTRY = {
|
24
|
+
"stateful": {},
|
25
|
+
"stateless": {}
|
26
|
+
}
|
27
|
+
|
28
|
+
|
29
|
+
class ToolValidationError(Exception):
|
30
|
+
"""Exception raised when a tool fails validation checks."""
|
31
|
+
pass
|
32
|
+
|
33
|
+
|
34
|
+
class BaseTool(ABC):
|
35
|
+
"""Base class for all tools."""
|
36
|
+
|
37
|
+
# These will be set by the registration decorator
|
38
|
+
tool_name: str = None
|
39
|
+
tool_type: str = None
|
40
|
+
|
41
|
+
@abstractmethod
|
42
|
+
def __call__(self, **kwargs) -> Any:
|
43
|
+
"""Execute the tool with the given parameters."""
|
44
|
+
pass
|
45
|
+
|
46
|
+
@abstractmethod
|
47
|
+
def to_dict(self) -> Dict[str, Any]:
|
48
|
+
"""Convert tool to dictionary format for LLM function calling."""
|
49
|
+
pass
|
50
|
+
|
51
|
+
|
52
|
+
class StatefulTool(BaseTool):
|
53
|
+
"""Base class for tools that maintain state and require fresh instances."""
|
54
|
+
|
55
|
+
@classmethod
|
56
|
+
@abstractmethod
|
57
|
+
def create_instance(cls: Type[T]) -> T:
|
58
|
+
"""Create a new instance of the tool with fresh state."""
|
59
|
+
pass
|
60
|
+
|
61
|
+
@abstractmethod
|
62
|
+
def reset(self) -> None:
|
63
|
+
"""Reset the tool's internal state."""
|
64
|
+
pass
|
65
|
+
|
66
|
+
|
67
|
+
class StatelessTool(BaseTool):
|
68
|
+
"""Base class for stateless tools that can be reused."""
|
69
|
+
pass
|
70
|
+
|
71
|
+
|
72
|
+
def validate_tool(cls: Type[BaseTool], tool_type: str) -> Type[BaseTool]:
|
73
|
+
"""
|
74
|
+
Validate that a tool class meets the requirements for its type.
|
75
|
+
|
76
|
+
Args:
|
77
|
+
cls: The tool class to validate
|
78
|
+
tool_type: Either "stateful" or "stateless"
|
79
|
+
|
80
|
+
Returns:
|
81
|
+
The validated tool class
|
82
|
+
|
83
|
+
Raises:
|
84
|
+
ToolValidationError: If the tool does not meet requirements
|
85
|
+
"""
|
86
|
+
# Check that the class implements the required methods
|
87
|
+
if not hasattr(cls, '__call__') or not callable(getattr(cls, '__call__')):
|
88
|
+
raise ToolValidationError(f"Tool {cls.__name__} must implement __call__ method")
|
89
|
+
|
90
|
+
if not hasattr(cls, 'to_dict') or not callable(getattr(cls, 'to_dict')):
|
91
|
+
raise ToolValidationError(f"Tool {cls.__name__} must implement to_dict method")
|
92
|
+
|
93
|
+
# Validate stateful tool specific requirements
|
94
|
+
if tool_type == "stateful":
|
95
|
+
if not issubclass(cls, StatefulTool):
|
96
|
+
raise ToolValidationError(
|
97
|
+
f"Stateful tool {cls.__name__} must inherit from StatefulTool"
|
98
|
+
)
|
99
|
+
|
100
|
+
create_instance_attr = hasattr(cls, 'create_instance')
|
101
|
+
create_instance_callable = callable(getattr(cls, 'create_instance', None))
|
102
|
+
|
103
|
+
if not create_instance_attr or not create_instance_callable:
|
104
|
+
raise ToolValidationError(
|
105
|
+
f"Stateful tool {cls.__name__} must implement create_instance class method"
|
106
|
+
)
|
107
|
+
|
108
|
+
if not hasattr(cls, 'reset') or not callable(getattr(cls, 'reset')):
|
109
|
+
raise ToolValidationError(
|
110
|
+
f"Stateful tool {cls.__name__} must implement reset method"
|
111
|
+
)
|
112
|
+
|
113
|
+
# Validate stateless tool specific requirements
|
114
|
+
if tool_type == "stateless":
|
115
|
+
if not issubclass(cls, StatelessTool):
|
116
|
+
raise ToolValidationError(
|
117
|
+
f"Stateless tool {cls.__name__} must inherit from StatelessTool"
|
118
|
+
)
|
119
|
+
|
120
|
+
return cls
|
121
|
+
|
122
|
+
|
123
|
+
def register_tool(name: str, tool_type: str = "stateless"):
|
124
|
+
"""
|
125
|
+
Decorator for registering a tool with the registry.
|
126
|
+
|
127
|
+
Args:
|
128
|
+
name: The name of the tool
|
129
|
+
tool_type: Either "stateful" or "stateless"
|
130
|
+
|
131
|
+
Returns:
|
132
|
+
A decorator function that registers the tool
|
133
|
+
|
134
|
+
Raises:
|
135
|
+
ValueError: If the tool name is already registered or the tool type is invalid
|
136
|
+
"""
|
137
|
+
if tool_type not in TOOL_REGISTRY:
|
138
|
+
raise ValueError(
|
139
|
+
f"Invalid tool type: {tool_type}. Must be either 'stateful' or 'stateless'"
|
140
|
+
)
|
141
|
+
|
142
|
+
def decorator(cls: Type[BaseTool]) -> Type[BaseTool]:
|
143
|
+
if name in TOOL_REGISTRY[tool_type]:
|
144
|
+
raise ValueError(f"Tool name '{name}' already registered in {tool_type} registry")
|
145
|
+
|
146
|
+
# Validate the tool
|
147
|
+
validated_cls = validate_tool(cls, tool_type)
|
148
|
+
|
149
|
+
# Register the tool
|
150
|
+
TOOL_REGISTRY[tool_type][name] = validated_cls
|
151
|
+
|
152
|
+
# Add metadata to the class
|
153
|
+
validated_cls.tool_name = name
|
154
|
+
validated_cls.tool_type = tool_type
|
155
|
+
|
156
|
+
return validated_cls
|
157
|
+
|
158
|
+
return decorator
|
159
|
+
|
160
|
+
|
161
|
+
class ToolFactory:
|
162
|
+
"""Factory class for creating and managing tools."""
|
163
|
+
|
164
|
+
@staticmethod
|
165
|
+
def get_tool(name: str, tool_type: str = "stateless") -> BaseTool:
|
166
|
+
"""
|
167
|
+
Get a tool instance by name and type.
|
168
|
+
|
169
|
+
For stateful tools, this returns a fresh instance.
|
170
|
+
For stateless tools, this returns a singleton instance.
|
171
|
+
|
172
|
+
Args:
|
173
|
+
name: The name of the tool
|
174
|
+
tool_type: Either "stateful" or "stateless"
|
175
|
+
|
176
|
+
Returns:
|
177
|
+
An instance of the requested tool
|
178
|
+
|
179
|
+
Raises:
|
180
|
+
ValueError: If the tool or tool type is not found
|
181
|
+
"""
|
182
|
+
if tool_type not in TOOL_REGISTRY:
|
183
|
+
raise ValueError(f"Invalid tool type: {tool_type}")
|
184
|
+
|
185
|
+
tool_cls = TOOL_REGISTRY[tool_type].get(name)
|
186
|
+
if not tool_cls:
|
187
|
+
raise ValueError(f"Tool '{name}' not found in {tool_type} registry")
|
188
|
+
|
189
|
+
# Handle stateful tools - always create a fresh instance
|
190
|
+
if tool_type == "stateful":
|
191
|
+
instance = tool_cls.create_instance()
|
192
|
+
instance.reset() # Ensure the instance is in a clean state
|
193
|
+
return instance
|
194
|
+
|
195
|
+
# Handle stateless tools - reuse the same instance
|
196
|
+
# We use a singleton pattern here with lazy initialization
|
197
|
+
if not hasattr(tool_cls, '_instance'):
|
198
|
+
tool_cls._instance = tool_cls()
|
199
|
+
|
200
|
+
return tool_cls._instance
|
201
|
+
|
202
|
+
@staticmethod
|
203
|
+
def list_tools(tool_type: Optional[str] = None) -> Dict[str, List[str]]:
|
204
|
+
"""
|
205
|
+
List all registered tools, optionally filtered by type.
|
206
|
+
|
207
|
+
Args:
|
208
|
+
tool_type: Optional filter for tool type
|
209
|
+
|
210
|
+
Returns:
|
211
|
+
A dictionary mapping tool types to lists of tool names
|
212
|
+
"""
|
213
|
+
if tool_type:
|
214
|
+
if tool_type not in TOOL_REGISTRY:
|
215
|
+
raise ValueError(f"Invalid tool type: {tool_type}")
|
216
|
+
return {tool_type: list(TOOL_REGISTRY[tool_type].keys())}
|
217
|
+
|
218
|
+
return {t_type: list(tools.keys()) for t_type, tools in TOOL_REGISTRY.items()}
|
219
|
+
|
220
|
+
@staticmethod
|
221
|
+
def get_tool_definitions(tool_type: Optional[str] = None) -> List[Dict[str, Any]]:
|
222
|
+
"""
|
223
|
+
Get tool definitions for all registered tools or tools of a specific type.
|
224
|
+
|
225
|
+
This is useful for preparing tools for LLM function calling.
|
226
|
+
|
227
|
+
Args:
|
228
|
+
tool_type: Optional filter for tool type
|
229
|
+
|
230
|
+
Returns:
|
231
|
+
A list of tool definitions in dictionary format
|
232
|
+
"""
|
233
|
+
tool_defs = []
|
234
|
+
|
235
|
+
if tool_type:
|
236
|
+
if tool_type not in TOOL_REGISTRY:
|
237
|
+
raise ValueError(f"Invalid tool type: {tool_type}")
|
238
|
+
registry = {tool_type: TOOL_REGISTRY[tool_type]}
|
239
|
+
else:
|
240
|
+
registry = TOOL_REGISTRY
|
241
|
+
|
242
|
+
for t_type, tools in registry.items():
|
243
|
+
for name, cls in tools.items():
|
244
|
+
# For stateless tools, we can use the singleton instance
|
245
|
+
if t_type == "stateless":
|
246
|
+
if not hasattr(cls, '_instance'):
|
247
|
+
cls._instance = cls()
|
248
|
+
tool_defs.append(cls._instance.to_dict())
|
249
|
+
# For stateful tools, we need to create a temporary instance
|
250
|
+
else:
|
251
|
+
instance = cls.create_instance()
|
252
|
+
tool_defs.append(instance.to_dict())
|
253
|
+
|
254
|
+
return tool_defs
|
255
|
+
|
256
|
+
|
257
|
+
def get_default_tools(tool_type: Optional[str] = None) -> List[Dict[str, Any]]:
|
258
|
+
"""
|
259
|
+
Get tool definitions for all registered tools.
|
260
|
+
|
261
|
+
This is a convenience function that delegates to ToolFactory.get_tool_definitions.
|
262
|
+
|
263
|
+
Args:
|
264
|
+
tool_type: Optional filter for tool type
|
265
|
+
|
266
|
+
Returns:
|
267
|
+
A list of tool definitions in dictionary format
|
268
|
+
"""
|
269
|
+
return ToolFactory.get_tool_definitions(tool_type)
|
270
|
+
|
271
|
+
|
272
|
+
def execute_tool_call(tool_call: Dict[str, Any]) -> Any:
|
273
|
+
"""
|
274
|
+
Execute a tool call based on LLM function calling format.
|
275
|
+
|
276
|
+
Args:
|
277
|
+
tool_call: A dictionary containing the tool call details
|
278
|
+
|
279
|
+
Returns:
|
280
|
+
The result of executing the tool call
|
281
|
+
|
282
|
+
Raises:
|
283
|
+
ValueError: If the tool is not found or the call format is invalid
|
284
|
+
"""
|
285
|
+
import json
|
286
|
+
|
287
|
+
# Extract tool details from the call
|
288
|
+
function_details = tool_call.get("function", {})
|
289
|
+
function_name = function_details.get("name")
|
290
|
+
|
291
|
+
if not function_name:
|
292
|
+
raise ValueError("Invalid tool call format: Missing function name")
|
293
|
+
|
294
|
+
# Try to find the tool in both registries
|
295
|
+
tool = None
|
296
|
+
tool_type = None
|
297
|
+
|
298
|
+
for t_type in TOOL_REGISTRY:
|
299
|
+
if function_name in TOOL_REGISTRY[t_type]:
|
300
|
+
tool_type = t_type
|
301
|
+
break
|
302
|
+
|
303
|
+
if not tool_type:
|
304
|
+
raise ValueError(f"Tool '{function_name}' not found in any registry")
|
305
|
+
|
306
|
+
# Get a tool instance
|
307
|
+
tool = ToolFactory.get_tool(function_name, tool_type)
|
308
|
+
|
309
|
+
# Parse arguments
|
310
|
+
try:
|
311
|
+
arguments = json.loads(function_details.get("arguments", "{}"))
|
312
|
+
except json.JSONDecodeError:
|
313
|
+
raise ValueError(f"Invalid arguments format for tool '{function_name}'")
|
314
|
+
|
315
|
+
# Execute the tool
|
316
|
+
try:
|
317
|
+
result = tool(**arguments)
|
318
|
+
return result
|
319
|
+
except Exception as e:
|
320
|
+
return f"Error executing tool '{function_name}': {str(e)}"
|