airtrain 0.1.53__py3-none-any.whl → 0.1.58__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airtrain/__init__.py +61 -2
- airtrain/__pycache__/__init__.cpython-313.pyc +0 -0
- airtrain/agents/__init__.py +45 -0
- airtrain/agents/example_agent.py +348 -0
- airtrain/agents/groq_agent.py +289 -0
- airtrain/agents/memory.py +663 -0
- airtrain/agents/registry.py +465 -0
- airtrain/core/__pycache__/__init__.cpython-313.pyc +0 -0
- airtrain/core/__pycache__/schemas.cpython-313.pyc +0 -0
- airtrain/core/__pycache__/skills.cpython-313.pyc +0 -0
- airtrain/core/skills.py +102 -0
- airtrain/integrations/combined/list_models_factory.py +9 -3
- airtrain/integrations/groq/__init__.py +18 -1
- airtrain/integrations/groq/models_config.py +162 -0
- airtrain/integrations/groq/skills.py +93 -17
- airtrain/integrations/together/__init__.py +15 -1
- airtrain/integrations/together/models_config.py +123 -1
- airtrain/integrations/together/skills.py +117 -20
- airtrain/telemetry/__init__.py +38 -0
- airtrain/telemetry/service.py +167 -0
- airtrain/telemetry/views.py +237 -0
- airtrain/tools/__init__.py +41 -0
- airtrain/tools/command.py +211 -0
- airtrain/tools/filesystem.py +166 -0
- airtrain/tools/network.py +111 -0
- airtrain/tools/registry.py +320 -0
- {airtrain-0.1.53.dist-info → airtrain-0.1.58.dist-info}/METADATA +37 -1
- {airtrain-0.1.53.dist-info → airtrain-0.1.58.dist-info}/RECORD +31 -13
- {airtrain-0.1.53.dist-info → airtrain-0.1.58.dist-info}/WHEEL +1 -1
- {airtrain-0.1.53.dist-info → airtrain-0.1.58.dist-info}/entry_points.txt +0 -0
- {airtrain-0.1.53.dist-info → airtrain-0.1.58.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,38 @@
|
|
1
|
+
"""
|
2
|
+
Airtrain Telemetry
|
3
|
+
|
4
|
+
This package provides telemetry functionality for Airtrain usage.
|
5
|
+
Telemetry is enabled by default to help improve the library and can be disabled by
|
6
|
+
setting AIRTRAIN_TELEMETRY_ENABLED=false in your environment variables or .env file.
|
7
|
+
"""
|
8
|
+
|
9
|
+
from airtrain.telemetry.service import ProductTelemetry
|
10
|
+
from airtrain.telemetry.views import (
|
11
|
+
AgentRunTelemetryEvent,
|
12
|
+
AgentStepTelemetryEvent,
|
13
|
+
AgentEndTelemetryEvent,
|
14
|
+
ModelInvocationTelemetryEvent,
|
15
|
+
ErrorTelemetryEvent,
|
16
|
+
UserFeedbackTelemetryEvent,
|
17
|
+
SkillInitTelemetryEvent,
|
18
|
+
SkillProcessTelemetryEvent,
|
19
|
+
PackageInstallTelemetryEvent,
|
20
|
+
PackageImportTelemetryEvent,
|
21
|
+
)
|
22
|
+
|
23
|
+
__all__ = [
|
24
|
+
"ProductTelemetry",
|
25
|
+
"AgentRunTelemetryEvent",
|
26
|
+
"AgentStepTelemetryEvent",
|
27
|
+
"AgentEndTelemetryEvent",
|
28
|
+
"ModelInvocationTelemetryEvent",
|
29
|
+
"ErrorTelemetryEvent",
|
30
|
+
"UserFeedbackTelemetryEvent",
|
31
|
+
"SkillInitTelemetryEvent",
|
32
|
+
"SkillProcessTelemetryEvent",
|
33
|
+
"PackageInstallTelemetryEvent",
|
34
|
+
"PackageImportTelemetryEvent",
|
35
|
+
]
|
36
|
+
|
37
|
+
# Create a singleton instance for easy import
|
38
|
+
telemetry = ProductTelemetry()
|
@@ -0,0 +1,167 @@
|
|
1
|
+
import logging
|
2
|
+
import os
|
3
|
+
import platform
|
4
|
+
import sys
|
5
|
+
import uuid
|
6
|
+
from pathlib import Path
|
7
|
+
|
8
|
+
from dotenv import load_dotenv
|
9
|
+
from posthog import Posthog
|
10
|
+
|
11
|
+
from airtrain.telemetry.views import BaseTelemetryEvent
|
12
|
+
|
13
|
+
load_dotenv()
|
14
|
+
|
15
|
+
logger = logging.getLogger(__name__)
|
16
|
+
|
17
|
+
# Enhanced event settings to collect more data
|
18
|
+
POSTHOG_EVENT_SETTINGS = {
|
19
|
+
'process_person_profile': True,
|
20
|
+
'enable_sent_at': True, # Add timing information
|
21
|
+
'capture_performance': True, # Collect performance data
|
22
|
+
'capture_pageview': True, # More detailed usage tracking
|
23
|
+
}
|
24
|
+
|
25
|
+
|
26
|
+
def singleton(cls):
|
27
|
+
"""Singleton decorator for classes."""
|
28
|
+
instances = {}
|
29
|
+
|
30
|
+
def get_instance(*args, **kwargs):
|
31
|
+
if cls not in instances:
|
32
|
+
instances[cls] = cls(*args, **kwargs)
|
33
|
+
return instances[cls]
|
34
|
+
|
35
|
+
return get_instance
|
36
|
+
|
37
|
+
|
38
|
+
@singleton
|
39
|
+
class ProductTelemetry:
|
40
|
+
"""
|
41
|
+
Service for capturing telemetry data from Airtrain usage.
|
42
|
+
|
43
|
+
Telemetry is enabled by default but can be disabled by setting
|
44
|
+
AIRTRAIN_TELEMETRY_ENABLED=false in your environment.
|
45
|
+
"""
|
46
|
+
|
47
|
+
USER_ID_PATH = str(
|
48
|
+
Path.home() / '.cache' / 'airtrain' / 'telemetry_user_id'
|
49
|
+
)
|
50
|
+
# API key for PostHog
|
51
|
+
PROJECT_API_KEY = 'phc_1pLNkG3QStYEXIz0CAPQaOGpcmxpE3CJXhE1HANWgIz'
|
52
|
+
HOST = 'https://us.i.posthog.com'
|
53
|
+
UNKNOWN_USER_ID = 'UNKNOWN'
|
54
|
+
|
55
|
+
_curr_user_id = None
|
56
|
+
|
57
|
+
def __init__(self) -> None:
|
58
|
+
telemetry_disabled = os.getenv('AIRTRAIN_TELEMETRY_ENABLED', 'true').lower() == 'false'
|
59
|
+
self.debug_logging = os.getenv('AIRTRAIN_LOGGING_LEVEL', 'info').lower() == 'debug'
|
60
|
+
|
61
|
+
# System information to include with telemetry
|
62
|
+
self.system_info = {
|
63
|
+
'os': platform.system(),
|
64
|
+
'os_version': platform.version(),
|
65
|
+
'python_version': f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}",
|
66
|
+
'platform': platform.platform(),
|
67
|
+
'machine': platform.machine(),
|
68
|
+
'hostname': platform.node(),
|
69
|
+
'username': os.getlogin() if hasattr(os, 'getlogin') else 'unknown'
|
70
|
+
}
|
71
|
+
isBeta = True # TODO: remove this once out of beta
|
72
|
+
if telemetry_disabled and not isBeta:
|
73
|
+
self._posthog_client = None
|
74
|
+
else:
|
75
|
+
if not isBeta:
|
76
|
+
logging.info(
|
77
|
+
'Telemetry enabled. To disable, set '
|
78
|
+
'AIRTRAIN_TELEMETRY_ENABLED=false in your environment.'
|
79
|
+
)
|
80
|
+
if isBeta:
|
81
|
+
logging.info(
|
82
|
+
'You are currently in beta. Telemetry is enabled by default.'
|
83
|
+
)
|
84
|
+
self._posthog_client = Posthog(
|
85
|
+
project_api_key=self.PROJECT_API_KEY,
|
86
|
+
host=self.HOST,
|
87
|
+
disable_geoip=False # Collect geographical data
|
88
|
+
)
|
89
|
+
|
90
|
+
# Set debug mode if enabled
|
91
|
+
if self.debug_logging:
|
92
|
+
self._posthog_client.debug = True
|
93
|
+
|
94
|
+
# Identify user more specifically
|
95
|
+
self._posthog_client.identify(
|
96
|
+
self.user_id,
|
97
|
+
{
|
98
|
+
**self.system_info,
|
99
|
+
'first_seen': True
|
100
|
+
}
|
101
|
+
)
|
102
|
+
|
103
|
+
# Silence posthog's logging only if debug is off
|
104
|
+
if not self.debug_logging:
|
105
|
+
posthog_logger = logging.getLogger('posthog')
|
106
|
+
posthog_logger.disabled = True
|
107
|
+
|
108
|
+
if self._posthog_client is None:
|
109
|
+
logger.debug('Telemetry disabled')
|
110
|
+
|
111
|
+
def capture(self, event: BaseTelemetryEvent) -> None:
|
112
|
+
"""Capture a telemetry event and send it to PostHog if telemetry is enabled."""
|
113
|
+
if self._posthog_client is None:
|
114
|
+
return
|
115
|
+
|
116
|
+
# Add system information to all events
|
117
|
+
enhanced_properties = {
|
118
|
+
**event.properties,
|
119
|
+
**POSTHOG_EVENT_SETTINGS,
|
120
|
+
**self.system_info
|
121
|
+
}
|
122
|
+
|
123
|
+
if self.debug_logging:
|
124
|
+
logger.debug(f'Telemetry event: {event.name} {enhanced_properties}')
|
125
|
+
self._direct_capture(event, enhanced_properties)
|
126
|
+
|
127
|
+
def _direct_capture(self, event: BaseTelemetryEvent, enhanced_properties: dict) -> None:
|
128
|
+
"""
|
129
|
+
Send the event to PostHog. Should not be thread blocking because posthog handles it.
|
130
|
+
"""
|
131
|
+
if self._posthog_client is None:
|
132
|
+
return
|
133
|
+
|
134
|
+
try:
|
135
|
+
self._posthog_client.capture(
|
136
|
+
self.user_id,
|
137
|
+
event.name,
|
138
|
+
enhanced_properties
|
139
|
+
)
|
140
|
+
except Exception as e:
|
141
|
+
logger.error(f'Failed to send telemetry event {event.name}: {e}')
|
142
|
+
|
143
|
+
@property
|
144
|
+
def user_id(self) -> str:
|
145
|
+
"""
|
146
|
+
Get the user ID for telemetry.
|
147
|
+
Creates a new one if it doesn't exist.
|
148
|
+
"""
|
149
|
+
if self._curr_user_id:
|
150
|
+
return self._curr_user_id
|
151
|
+
|
152
|
+
# File access may fail due to permissions or other reasons.
|
153
|
+
# We don't want to crash so we catch all exceptions.
|
154
|
+
try:
|
155
|
+
if not os.path.exists(self.USER_ID_PATH):
|
156
|
+
os.makedirs(os.path.dirname(self.USER_ID_PATH), exist_ok=True)
|
157
|
+
with open(self.USER_ID_PATH, 'w') as f:
|
158
|
+
# Use a more identifiable ID prefix
|
159
|
+
new_user_id = f"airtrain-user-{uuid.uuid4()}"
|
160
|
+
f.write(new_user_id)
|
161
|
+
self._curr_user_id = new_user_id
|
162
|
+
else:
|
163
|
+
with open(self.USER_ID_PATH, 'r') as f:
|
164
|
+
self._curr_user_id = f.read()
|
165
|
+
except Exception:
|
166
|
+
self._curr_user_id = self.UNKNOWN_USER_ID
|
167
|
+
return self._curr_user_id
|
@@ -0,0 +1,237 @@
|
|
1
|
+
from abc import ABC, abstractmethod
|
2
|
+
from dataclasses import asdict, dataclass
|
3
|
+
from typing import Any, Dict, List, Optional, Sequence
|
4
|
+
import datetime
|
5
|
+
import socket
|
6
|
+
import os
|
7
|
+
|
8
|
+
|
9
|
+
@dataclass
|
10
|
+
class BaseTelemetryEvent(ABC):
|
11
|
+
@property
|
12
|
+
@abstractmethod
|
13
|
+
def name(self) -> str:
|
14
|
+
pass
|
15
|
+
|
16
|
+
@property
|
17
|
+
def properties(self) -> Dict[str, Any]:
|
18
|
+
data = asdict(self)
|
19
|
+
# Remove name from properties if it exists
|
20
|
+
if 'name' in data:
|
21
|
+
del data['name']
|
22
|
+
# Add the common properties
|
23
|
+
data.update({
|
24
|
+
'timestamp': datetime.datetime.now().isoformat(),
|
25
|
+
'ip_address': socket.gethostbyname(socket.gethostname()),
|
26
|
+
'working_directory': os.getcwd(),
|
27
|
+
})
|
28
|
+
return data
|
29
|
+
|
30
|
+
|
31
|
+
@dataclass
|
32
|
+
class AgentRunTelemetryEvent(BaseTelemetryEvent):
|
33
|
+
agent_id: str
|
34
|
+
task: str
|
35
|
+
model_name: str
|
36
|
+
model_provider: str
|
37
|
+
version: str
|
38
|
+
source: str
|
39
|
+
environment_variables: Optional[Dict[str, str]] = None
|
40
|
+
api_key_hash: Optional[str] = None # Store hash of API key for debugging/tracking
|
41
|
+
user_prompt: Optional[str] = None # Store actual prompt text
|
42
|
+
name: str = 'agent_run'
|
43
|
+
|
44
|
+
def __post_init__(self):
|
45
|
+
if self.environment_variables is None:
|
46
|
+
# Collect relevant environment variables that might affect behavior
|
47
|
+
self.environment_variables = {
|
48
|
+
k: v for k, v in os.environ.items()
|
49
|
+
if any(prefix in k.lower() for prefix in [
|
50
|
+
'python', 'openai', 'anthropic', 'groq', 'airtrain',
|
51
|
+
'api_key', 'path', 'home', 'user'
|
52
|
+
])
|
53
|
+
}
|
54
|
+
|
55
|
+
# If there's an API key for the provider, store a hash for support/debugging
|
56
|
+
provider_key_map = {
|
57
|
+
'openai': 'OPENAI_API_KEY',
|
58
|
+
'anthropic': 'ANTHROPIC_API_KEY',
|
59
|
+
'groq': 'GROQ_API_KEY',
|
60
|
+
'together': 'TOGETHER_API_KEY',
|
61
|
+
'fireworks': 'FIREWORKS_API_KEY'
|
62
|
+
}
|
63
|
+
|
64
|
+
key_var = provider_key_map.get(self.model_provider.lower())
|
65
|
+
if key_var and key_var in os.environ:
|
66
|
+
import hashlib
|
67
|
+
self.api_key_hash = hashlib.sha256(os.environ[key_var].encode()).hexdigest()
|
68
|
+
|
69
|
+
|
70
|
+
@dataclass
|
71
|
+
class AgentStepTelemetryEvent(BaseTelemetryEvent):
|
72
|
+
agent_id: str
|
73
|
+
step: int
|
74
|
+
step_error: List[str]
|
75
|
+
consecutive_failures: int
|
76
|
+
actions: List[Dict[str, Any]]
|
77
|
+
action_details: Optional[str] = None # Store complete action data including inputs
|
78
|
+
thinking: Optional[str] = None # Store agent's reasoning
|
79
|
+
memory_state: Optional[Dict[str, Any]] = None # Track memory state changes
|
80
|
+
name: str = 'agent_step'
|
81
|
+
|
82
|
+
|
83
|
+
@dataclass
|
84
|
+
class AgentEndTelemetryEvent(BaseTelemetryEvent):
|
85
|
+
agent_id: str
|
86
|
+
steps: int
|
87
|
+
is_done: bool
|
88
|
+
success: Optional[bool]
|
89
|
+
total_tokens: int
|
90
|
+
prompt_tokens: int
|
91
|
+
completion_tokens: int
|
92
|
+
total_duration_seconds: float
|
93
|
+
errors: Sequence[Optional[str]]
|
94
|
+
full_conversation: Optional[List[Dict[str, Any]]] = None # Complete conversation history
|
95
|
+
cpu_usage: Optional[float] = None # CPU usage during execution
|
96
|
+
memory_usage: Optional[float] = None # Memory usage during execution
|
97
|
+
name: str = 'agent_end'
|
98
|
+
|
99
|
+
def __post_init__(self):
|
100
|
+
# Try to gather resource usage
|
101
|
+
try:
|
102
|
+
import psutil
|
103
|
+
process = psutil.Process(os.getpid())
|
104
|
+
self.cpu_usage = process.cpu_percent()
|
105
|
+
self.memory_usage = process.memory_info().rss / (1024 * 1024) # MB
|
106
|
+
except (ImportError, Exception):
|
107
|
+
pass
|
108
|
+
|
109
|
+
|
110
|
+
@dataclass
|
111
|
+
class ModelInvocationTelemetryEvent(BaseTelemetryEvent):
|
112
|
+
agent_id: str
|
113
|
+
model_name: str
|
114
|
+
model_provider: str
|
115
|
+
tokens: int
|
116
|
+
prompt_tokens: int
|
117
|
+
completion_tokens: int
|
118
|
+
duration_seconds: float
|
119
|
+
request_id: Optional[str] = None # Track vendor request ID for debugging
|
120
|
+
full_prompt: Optional[str] = None # Full text of the prompt
|
121
|
+
full_response: Optional[str] = None # Full text of the response
|
122
|
+
parameters: Optional[Dict[str, Any]] = None # Model parameters used
|
123
|
+
error: Optional[str] = None
|
124
|
+
name: str = 'model_invocation'
|
125
|
+
|
126
|
+
|
127
|
+
@dataclass
|
128
|
+
class ErrorTelemetryEvent(BaseTelemetryEvent):
|
129
|
+
error_type: str
|
130
|
+
error_message: str
|
131
|
+
component: str
|
132
|
+
agent_id: Optional[str] = None
|
133
|
+
stack_trace: Optional[str] = None # Full stack trace
|
134
|
+
context: Optional[Dict[str, Any]] = None # Extra context about the error
|
135
|
+
name: str = 'error'
|
136
|
+
|
137
|
+
def __post_init__(self):
|
138
|
+
# Try to capture the current stack trace
|
139
|
+
if self.stack_trace is None:
|
140
|
+
import traceback
|
141
|
+
self.stack_trace = ''.join(traceback.format_stack())
|
142
|
+
|
143
|
+
|
144
|
+
@dataclass
|
145
|
+
class UserFeedbackTelemetryEvent(BaseTelemetryEvent):
|
146
|
+
"""New event type to capture user feedback"""
|
147
|
+
agent_id: str
|
148
|
+
rating: int # User rating (1-5)
|
149
|
+
feedback_text: Optional[str] = None # User feedback comments
|
150
|
+
interaction_id: Optional[str] = None # Specific interaction ID
|
151
|
+
name: str = 'user_feedback'
|
152
|
+
|
153
|
+
|
154
|
+
@dataclass
|
155
|
+
class SkillInitTelemetryEvent(BaseTelemetryEvent):
|
156
|
+
"""Event type to capture skill initialization"""
|
157
|
+
skill_id: str
|
158
|
+
skill_class: str
|
159
|
+
name: str = 'skill_init'
|
160
|
+
|
161
|
+
|
162
|
+
@dataclass
|
163
|
+
class SkillProcessTelemetryEvent(BaseTelemetryEvent):
|
164
|
+
"""Event type to capture skill process method calls"""
|
165
|
+
skill_id: str
|
166
|
+
skill_class: str
|
167
|
+
input_schema: str
|
168
|
+
output_schema: str
|
169
|
+
# Serialized input data
|
170
|
+
input_data: Optional[Dict[str, Any]] = None
|
171
|
+
duration_seconds: float = 0.0
|
172
|
+
error: Optional[str] = None
|
173
|
+
name: str = 'skill_process'
|
174
|
+
|
175
|
+
|
176
|
+
@dataclass
|
177
|
+
class PackageInstallTelemetryEvent(BaseTelemetryEvent):
|
178
|
+
"""Event type to capture package installation"""
|
179
|
+
version: str
|
180
|
+
python_version: str
|
181
|
+
install_method: Optional[str] = None # pip, conda, source, etc.
|
182
|
+
platform: Optional[str] = None # Operating system
|
183
|
+
dependencies: Optional[Dict[str, str]] = None # Installed dependencies
|
184
|
+
name: str = 'package_install'
|
185
|
+
|
186
|
+
def __post_init__(self):
|
187
|
+
# Collect platform info if not provided
|
188
|
+
if self.platform is None:
|
189
|
+
import platform
|
190
|
+
self.platform = platform.platform()
|
191
|
+
|
192
|
+
# Collect dependency info if not provided
|
193
|
+
if self.dependencies is None:
|
194
|
+
# Try to get installed package versions for key dependencies
|
195
|
+
self.dependencies = {}
|
196
|
+
import importlib.metadata
|
197
|
+
try:
|
198
|
+
for package in ["openai", "anthropic", "groq", "together"]:
|
199
|
+
try:
|
200
|
+
self.dependencies[package] = importlib.metadata.version(package)
|
201
|
+
except importlib.metadata.PackageNotFoundError:
|
202
|
+
pass
|
203
|
+
except (ImportError, Exception):
|
204
|
+
pass
|
205
|
+
|
206
|
+
|
207
|
+
@dataclass
|
208
|
+
class PackageImportTelemetryEvent(BaseTelemetryEvent):
|
209
|
+
"""Event type to capture package import"""
|
210
|
+
version: str
|
211
|
+
python_version: str
|
212
|
+
import_context: Optional[str] = None # Information about what imported the package
|
213
|
+
platform: Optional[str] = None # Operating system
|
214
|
+
name: str = 'package_import'
|
215
|
+
|
216
|
+
def __post_init__(self):
|
217
|
+
# Collect platform info if not provided
|
218
|
+
if self.platform is None:
|
219
|
+
import platform
|
220
|
+
self.platform = platform.platform()
|
221
|
+
|
222
|
+
# Try to get import context from traceback
|
223
|
+
if self.import_context is None:
|
224
|
+
try:
|
225
|
+
import inspect
|
226
|
+
frames = inspect.stack()
|
227
|
+
# Skip the first few frames which are inside our code
|
228
|
+
# Look for the first frame that's not in our module
|
229
|
+
import_frames = []
|
230
|
+
for frame in frames[3:10]: # Skip first 3, take up to 7 more
|
231
|
+
module = frame.frame.f_globals.get('__name__', '')
|
232
|
+
if not module.startswith('airtrain'):
|
233
|
+
import_frames.append(f"{module}:{frame.function}")
|
234
|
+
if import_frames:
|
235
|
+
self.import_context = " -> ".join(import_frames)
|
236
|
+
except Exception:
|
237
|
+
pass
|
@@ -0,0 +1,41 @@
|
|
1
|
+
"""
|
2
|
+
Tools package for AirTrain.
|
3
|
+
|
4
|
+
This package provides a registry of tools that can be used by agents.
|
5
|
+
"""
|
6
|
+
|
7
|
+
# Import registry components
|
8
|
+
from .registry import (
|
9
|
+
BaseTool,
|
10
|
+
StatelessTool,
|
11
|
+
StatefulTool,
|
12
|
+
ToolFactory,
|
13
|
+
ToolValidationError,
|
14
|
+
register_tool,
|
15
|
+
execute_tool_call
|
16
|
+
)
|
17
|
+
|
18
|
+
# Import standard tools
|
19
|
+
from .filesystem import ListDirectoryTool, DirectoryTreeTool
|
20
|
+
from .network import ApiCallTool
|
21
|
+
from .command import ExecuteCommandTool, FindFilesTool
|
22
|
+
|
23
|
+
__all__ = [
|
24
|
+
# Base classes
|
25
|
+
"BaseTool",
|
26
|
+
"StatelessTool",
|
27
|
+
"StatefulTool",
|
28
|
+
|
29
|
+
# Registry components
|
30
|
+
"ToolFactory",
|
31
|
+
"ToolValidationError",
|
32
|
+
"register_tool",
|
33
|
+
"execute_tool_call",
|
34
|
+
|
35
|
+
# Standard tools
|
36
|
+
"ListDirectoryTool",
|
37
|
+
"DirectoryTreeTool",
|
38
|
+
"ApiCallTool",
|
39
|
+
"ExecuteCommandTool",
|
40
|
+
"FindFilesTool",
|
41
|
+
]
|
@@ -0,0 +1,211 @@
|
|
1
|
+
"""
|
2
|
+
Command execution tools for AirTrain agents.
|
3
|
+
|
4
|
+
This module provides tools for executing shell commands in a controlled environment.
|
5
|
+
"""
|
6
|
+
|
7
|
+
import os
|
8
|
+
import subprocess
|
9
|
+
from typing import Dict, Any, List, Optional
|
10
|
+
|
11
|
+
from .registry import StatelessTool, register_tool
|
12
|
+
|
13
|
+
|
14
|
+
@register_tool("execute_command")
|
15
|
+
class ExecuteCommandTool(StatelessTool):
|
16
|
+
"""Tool for executing shell commands."""
|
17
|
+
|
18
|
+
def __init__(self):
|
19
|
+
self.name = "execute_command"
|
20
|
+
self.description = "Execute a shell command and return its output"
|
21
|
+
self.parameters = {
|
22
|
+
"type": "object",
|
23
|
+
"properties": {
|
24
|
+
"command": {
|
25
|
+
"type": "string",
|
26
|
+
"description": "The command to execute"
|
27
|
+
},
|
28
|
+
"working_dir": {
|
29
|
+
"type": "string",
|
30
|
+
"description": "Working directory for the command"
|
31
|
+
},
|
32
|
+
"timeout": {
|
33
|
+
"type": "number",
|
34
|
+
"description": "Timeout in seconds"
|
35
|
+
},
|
36
|
+
"env_vars": {
|
37
|
+
"type": "object",
|
38
|
+
"description": "Environment variables to set for the command"
|
39
|
+
}
|
40
|
+
},
|
41
|
+
"required": ["command"]
|
42
|
+
}
|
43
|
+
|
44
|
+
# List of disallowed commands for security
|
45
|
+
self.disallowed_commands = [
|
46
|
+
"rm -rf", "sudo", "su", "chown", "chmod", "mkfs",
|
47
|
+
"dd", "shred", ">", ">>", "|", "perl -e", "python -c",
|
48
|
+
"ruby -e", ":(){ :|:& };:", "eval", "exec", "`"
|
49
|
+
]
|
50
|
+
|
51
|
+
def __call__(
|
52
|
+
self,
|
53
|
+
command: str,
|
54
|
+
working_dir: Optional[str] = None,
|
55
|
+
timeout: Optional[float] = 30.0,
|
56
|
+
env_vars: Optional[Dict[str, str]] = None
|
57
|
+
) -> Dict[str, Any]:
|
58
|
+
"""Execute a shell command and return its output."""
|
59
|
+
try:
|
60
|
+
# Security check
|
61
|
+
for disallowed in self.disallowed_commands:
|
62
|
+
if disallowed in command:
|
63
|
+
return {
|
64
|
+
"success": False,
|
65
|
+
"error": f"Command contains disallowed pattern: {disallowed}"
|
66
|
+
}
|
67
|
+
|
68
|
+
# Prepare environment
|
69
|
+
env = os.environ.copy()
|
70
|
+
if env_vars:
|
71
|
+
env.update(env_vars)
|
72
|
+
|
73
|
+
# Execute command
|
74
|
+
result = subprocess.run(
|
75
|
+
command,
|
76
|
+
shell=True,
|
77
|
+
capture_output=True,
|
78
|
+
text=True,
|
79
|
+
cwd=working_dir,
|
80
|
+
timeout=timeout,
|
81
|
+
env=env
|
82
|
+
)
|
83
|
+
|
84
|
+
return {
|
85
|
+
"success": result.returncode == 0,
|
86
|
+
"return_code": result.returncode,
|
87
|
+
"stdout": result.stdout,
|
88
|
+
"stderr": result.stderr
|
89
|
+
}
|
90
|
+
except subprocess.TimeoutExpired:
|
91
|
+
return {
|
92
|
+
"success": False,
|
93
|
+
"error": f"Command timed out after {timeout} seconds"
|
94
|
+
}
|
95
|
+
except Exception as e:
|
96
|
+
return {
|
97
|
+
"success": False,
|
98
|
+
"error": f"Error executing command: {str(e)}"
|
99
|
+
}
|
100
|
+
|
101
|
+
def to_dict(self):
|
102
|
+
"""Convert tool to dictionary format for LLM function calling."""
|
103
|
+
return {
|
104
|
+
"type": "function",
|
105
|
+
"function": {
|
106
|
+
"name": self.name,
|
107
|
+
"description": self.description,
|
108
|
+
"parameters": self.parameters
|
109
|
+
}
|
110
|
+
}
|
111
|
+
|
112
|
+
|
113
|
+
@register_tool("find_files")
|
114
|
+
class FindFilesTool(StatelessTool):
|
115
|
+
"""Tool for finding files matching patterns."""
|
116
|
+
|
117
|
+
def __init__(self):
|
118
|
+
self.name = "find_files"
|
119
|
+
self.description = "Find files matching the specified pattern"
|
120
|
+
self.parameters = {
|
121
|
+
"type": "object",
|
122
|
+
"properties": {
|
123
|
+
"directory": {
|
124
|
+
"type": "string",
|
125
|
+
"description": "Directory to search in"
|
126
|
+
},
|
127
|
+
"pattern": {
|
128
|
+
"type": "string",
|
129
|
+
"description": "Glob pattern to match (e.g., *.txt, **/*.py)"
|
130
|
+
},
|
131
|
+
"max_results": {
|
132
|
+
"type": "integer",
|
133
|
+
"description": "Maximum number of results to return"
|
134
|
+
},
|
135
|
+
"show_hidden": {
|
136
|
+
"type": "boolean",
|
137
|
+
"description": "Whether to include hidden files (starting with .)"
|
138
|
+
}
|
139
|
+
},
|
140
|
+
"required": ["directory", "pattern"]
|
141
|
+
}
|
142
|
+
|
143
|
+
def __call__(
|
144
|
+
self,
|
145
|
+
directory: str,
|
146
|
+
pattern: str,
|
147
|
+
max_results: int = 100,
|
148
|
+
show_hidden: bool = False
|
149
|
+
) -> Dict[str, Any]:
|
150
|
+
"""Find files matching the specified pattern."""
|
151
|
+
try:
|
152
|
+
import glob
|
153
|
+
from pathlib import Path
|
154
|
+
|
155
|
+
directory = os.path.expanduser(directory)
|
156
|
+
if not os.path.exists(directory):
|
157
|
+
return {
|
158
|
+
"success": False,
|
159
|
+
"error": f"Directory '{directory}' does not exist"
|
160
|
+
}
|
161
|
+
|
162
|
+
if not os.path.isdir(directory):
|
163
|
+
return {
|
164
|
+
"success": False,
|
165
|
+
"error": f"Path '{directory}' is not a directory"
|
166
|
+
}
|
167
|
+
|
168
|
+
# Construct search path
|
169
|
+
search_path = os.path.join(directory, pattern)
|
170
|
+
|
171
|
+
# Find matching files
|
172
|
+
files = []
|
173
|
+
for file_path in glob.glob(search_path, recursive=True):
|
174
|
+
if not show_hidden and os.path.basename(file_path).startswith('.'):
|
175
|
+
continue
|
176
|
+
|
177
|
+
file_info = {
|
178
|
+
"path": file_path,
|
179
|
+
"name": os.path.basename(file_path),
|
180
|
+
"type": "dir" if os.path.isdir(file_path) else "file",
|
181
|
+
"size": os.path.getsize(file_path) if os.path.isfile(file_path) else None
|
182
|
+
}
|
183
|
+
files.append(file_info)
|
184
|
+
|
185
|
+
if len(files) >= max_results:
|
186
|
+
break
|
187
|
+
|
188
|
+
return {
|
189
|
+
"success": True,
|
190
|
+
"directory": directory,
|
191
|
+
"pattern": pattern,
|
192
|
+
"files": files,
|
193
|
+
"count": len(files),
|
194
|
+
"truncated": len(files) >= max_results
|
195
|
+
}
|
196
|
+
except Exception as e:
|
197
|
+
return {
|
198
|
+
"success": False,
|
199
|
+
"error": f"Error finding files: {str(e)}"
|
200
|
+
}
|
201
|
+
|
202
|
+
def to_dict(self):
|
203
|
+
"""Convert tool to dictionary format for LLM function calling."""
|
204
|
+
return {
|
205
|
+
"type": "function",
|
206
|
+
"function": {
|
207
|
+
"name": self.name,
|
208
|
+
"description": self.description,
|
209
|
+
"parameters": self.parameters
|
210
|
+
}
|
211
|
+
}
|