airtrain 0.1.2__py3-none-any.whl → 0.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airtrain/__init__.py +148 -2
- airtrain/__main__.py +4 -0
- airtrain/__pycache__/__init__.cpython-313.pyc +0 -0
- airtrain/agents/__init__.py +45 -0
- airtrain/agents/example_agent.py +348 -0
- airtrain/agents/groq_agent.py +289 -0
- airtrain/agents/memory.py +663 -0
- airtrain/agents/registry.py +465 -0
- airtrain/builder/__init__.py +3 -0
- airtrain/builder/agent_builder.py +122 -0
- airtrain/cli/__init__.py +0 -0
- airtrain/cli/builder.py +23 -0
- airtrain/cli/main.py +120 -0
- airtrain/contrib/__init__.py +29 -0
- airtrain/contrib/travel/__init__.py +35 -0
- airtrain/contrib/travel/agents.py +243 -0
- airtrain/contrib/travel/models.py +59 -0
- airtrain/core/__init__.py +7 -0
- airtrain/core/__pycache__/__init__.cpython-313.pyc +0 -0
- airtrain/core/__pycache__/schemas.cpython-313.pyc +0 -0
- airtrain/core/__pycache__/skills.cpython-313.pyc +0 -0
- airtrain/core/credentials.py +171 -0
- airtrain/core/schemas.py +237 -0
- airtrain/core/skills.py +269 -0
- airtrain/integrations/__init__.py +74 -0
- airtrain/integrations/anthropic/__init__.py +33 -0
- airtrain/integrations/anthropic/credentials.py +32 -0
- airtrain/integrations/anthropic/list_models.py +110 -0
- airtrain/integrations/anthropic/models_config.py +100 -0
- airtrain/integrations/anthropic/skills.py +155 -0
- airtrain/integrations/aws/__init__.py +6 -0
- airtrain/integrations/aws/credentials.py +36 -0
- airtrain/integrations/aws/skills.py +98 -0
- airtrain/integrations/cerebras/__init__.py +6 -0
- airtrain/integrations/cerebras/credentials.py +19 -0
- airtrain/integrations/cerebras/skills.py +127 -0
- airtrain/integrations/combined/__init__.py +21 -0
- airtrain/integrations/combined/groq_fireworks_skills.py +126 -0
- airtrain/integrations/combined/list_models_factory.py +210 -0
- airtrain/integrations/fireworks/__init__.py +21 -0
- airtrain/integrations/fireworks/completion_skills.py +147 -0
- airtrain/integrations/fireworks/conversation_manager.py +109 -0
- airtrain/integrations/fireworks/credentials.py +26 -0
- airtrain/integrations/fireworks/list_models.py +128 -0
- airtrain/integrations/fireworks/models.py +139 -0
- airtrain/integrations/fireworks/requests_skills.py +207 -0
- airtrain/integrations/fireworks/skills.py +181 -0
- airtrain/integrations/fireworks/structured_completion_skills.py +175 -0
- airtrain/integrations/fireworks/structured_requests_skills.py +291 -0
- airtrain/integrations/fireworks/structured_skills.py +102 -0
- airtrain/integrations/google/__init__.py +7 -0
- airtrain/integrations/google/credentials.py +58 -0
- airtrain/integrations/google/skills.py +122 -0
- airtrain/integrations/groq/__init__.py +23 -0
- airtrain/integrations/groq/credentials.py +24 -0
- airtrain/integrations/groq/models_config.py +162 -0
- airtrain/integrations/groq/skills.py +201 -0
- airtrain/integrations/ollama/__init__.py +6 -0
- airtrain/integrations/ollama/credentials.py +26 -0
- airtrain/integrations/ollama/skills.py +41 -0
- airtrain/integrations/openai/__init__.py +37 -0
- airtrain/integrations/openai/chinese_assistant.py +42 -0
- airtrain/integrations/openai/credentials.py +39 -0
- airtrain/integrations/openai/list_models.py +112 -0
- airtrain/integrations/openai/models_config.py +224 -0
- airtrain/integrations/openai/skills.py +342 -0
- airtrain/integrations/perplexity/__init__.py +49 -0
- airtrain/integrations/perplexity/credentials.py +43 -0
- airtrain/integrations/perplexity/list_models.py +112 -0
- airtrain/integrations/perplexity/models_config.py +128 -0
- airtrain/integrations/perplexity/skills.py +279 -0
- airtrain/integrations/sambanova/__init__.py +6 -0
- airtrain/integrations/sambanova/credentials.py +20 -0
- airtrain/integrations/sambanova/skills.py +129 -0
- airtrain/integrations/search/__init__.py +21 -0
- airtrain/integrations/search/exa/__init__.py +23 -0
- airtrain/integrations/search/exa/credentials.py +30 -0
- airtrain/integrations/search/exa/schemas.py +114 -0
- airtrain/integrations/search/exa/skills.py +115 -0
- airtrain/integrations/together/__init__.py +33 -0
- airtrain/integrations/together/audio_models_config.py +34 -0
- airtrain/integrations/together/credentials.py +22 -0
- airtrain/integrations/together/embedding_models_config.py +92 -0
- airtrain/integrations/together/image_models_config.py +69 -0
- airtrain/integrations/together/image_skill.py +143 -0
- airtrain/integrations/together/list_models.py +76 -0
- airtrain/integrations/together/models.py +95 -0
- airtrain/integrations/together/models_config.py +399 -0
- airtrain/integrations/together/rerank_models_config.py +43 -0
- airtrain/integrations/together/rerank_skill.py +49 -0
- airtrain/integrations/together/schemas.py +33 -0
- airtrain/integrations/together/skills.py +305 -0
- airtrain/integrations/together/vision_models_config.py +49 -0
- airtrain/telemetry/__init__.py +38 -0
- airtrain/telemetry/service.py +167 -0
- airtrain/telemetry/views.py +237 -0
- airtrain/tools/__init__.py +45 -0
- airtrain/tools/command.py +398 -0
- airtrain/tools/filesystem.py +166 -0
- airtrain/tools/network.py +111 -0
- airtrain/tools/registry.py +320 -0
- airtrain/tools/search.py +450 -0
- airtrain/tools/testing.py +135 -0
- airtrain-0.1.4.dist-info/METADATA +222 -0
- airtrain-0.1.4.dist-info/RECORD +108 -0
- {airtrain-0.1.2.dist-info → airtrain-0.1.4.dist-info}/WHEEL +1 -1
- airtrain-0.1.4.dist-info/entry_points.txt +2 -0
- airtrain-0.1.2.dist-info/METADATA +0 -106
- airtrain-0.1.2.dist-info/RECORD +0 -5
- {airtrain-0.1.2.dist-info → airtrain-0.1.4.dist-info}/top_level.txt +0 -0
airtrain/tools/search.py
ADDED
@@ -0,0 +1,450 @@
|
|
1
|
+
"""
|
2
|
+
Search tools for AirTrain agents.
|
3
|
+
|
4
|
+
This module provides tools for searching for content within files and directories.
|
5
|
+
"""
|
6
|
+
|
7
|
+
import os
|
8
|
+
import re
|
9
|
+
import subprocess
|
10
|
+
from typing import Dict, Any, List, Optional, cast
|
11
|
+
|
12
|
+
from .registry import StatelessTool, register_tool
|
13
|
+
from airtrain.integrations.search.exa import (
|
14
|
+
ExaCredentials,
|
15
|
+
ExaSearchSkill,
|
16
|
+
ExaSearchInputSchema,
|
17
|
+
ExaSearchOutputSchema,
|
18
|
+
ExaContentConfig,
|
19
|
+
)
|
20
|
+
|
21
|
+
|
22
|
+
@register_tool("search_term")
|
23
|
+
class SearchTermTool(StatelessTool):
|
24
|
+
"""Tool for searching for specific terms within files."""
|
25
|
+
|
26
|
+
def __init__(self):
|
27
|
+
self.name = "search_term"
|
28
|
+
self.description = "Search for a specific term or pattern within files"
|
29
|
+
self.parameters = {
|
30
|
+
"type": "object",
|
31
|
+
"properties": {
|
32
|
+
"term": {
|
33
|
+
"type": "string",
|
34
|
+
"description": "The term or pattern to search for",
|
35
|
+
},
|
36
|
+
"directory": {
|
37
|
+
"type": "string",
|
38
|
+
"description": "Directory to search in (defaults to current directory)",
|
39
|
+
},
|
40
|
+
"file_pattern": {
|
41
|
+
"type": "string",
|
42
|
+
"description": "Pattern to filter files (e.g., *.py, *.txt)",
|
43
|
+
},
|
44
|
+
"case_sensitive": {
|
45
|
+
"type": "boolean",
|
46
|
+
"description": "Whether the search should be case-sensitive",
|
47
|
+
},
|
48
|
+
"regex": {
|
49
|
+
"type": "boolean",
|
50
|
+
"description": "Whether to treat the term as a regular expression",
|
51
|
+
},
|
52
|
+
"max_results": {
|
53
|
+
"type": "integer",
|
54
|
+
"description": "Maximum number of results to return",
|
55
|
+
},
|
56
|
+
"max_context_lines": {
|
57
|
+
"type": "integer",
|
58
|
+
"description": "Number of context lines to show before and after matches",
|
59
|
+
},
|
60
|
+
},
|
61
|
+
"required": ["term"],
|
62
|
+
}
|
63
|
+
|
64
|
+
def __call__(
|
65
|
+
self,
|
66
|
+
term: str,
|
67
|
+
directory: str = ".",
|
68
|
+
file_pattern: str = "*",
|
69
|
+
case_sensitive: bool = False,
|
70
|
+
regex: bool = False,
|
71
|
+
max_results: int = 100,
|
72
|
+
max_context_lines: int = 2,
|
73
|
+
) -> Dict[str, Any]:
|
74
|
+
"""Search for a specific term within files."""
|
75
|
+
try:
|
76
|
+
# Try to use grep if available (more efficient than pure Python)
|
77
|
+
try:
|
78
|
+
return self._search_with_grep(
|
79
|
+
term,
|
80
|
+
directory,
|
81
|
+
file_pattern,
|
82
|
+
case_sensitive,
|
83
|
+
regex,
|
84
|
+
max_results,
|
85
|
+
max_context_lines,
|
86
|
+
)
|
87
|
+
except (subprocess.SubprocessError, FileNotFoundError):
|
88
|
+
# Fall back to Python implementation if grep is not available
|
89
|
+
return self._search_with_python(
|
90
|
+
term,
|
91
|
+
directory,
|
92
|
+
file_pattern,
|
93
|
+
case_sensitive,
|
94
|
+
regex,
|
95
|
+
max_results,
|
96
|
+
max_context_lines,
|
97
|
+
)
|
98
|
+
except Exception as e:
|
99
|
+
return {"success": False, "error": f"Error searching for term: {str(e)}"}
|
100
|
+
|
101
|
+
def _search_with_grep(
|
102
|
+
self,
|
103
|
+
term: str,
|
104
|
+
directory: str,
|
105
|
+
file_pattern: str,
|
106
|
+
case_sensitive: bool,
|
107
|
+
regex: bool,
|
108
|
+
max_results: int,
|
109
|
+
max_context_lines: int,
|
110
|
+
) -> Dict[str, Any]:
|
111
|
+
"""Use grep to search for terms (more efficient)."""
|
112
|
+
# Prepare grep command
|
113
|
+
cmd = ["grep"]
|
114
|
+
|
115
|
+
# Add grep options
|
116
|
+
if not case_sensitive:
|
117
|
+
cmd.append("-i") # Case insensitive
|
118
|
+
|
119
|
+
if not regex:
|
120
|
+
cmd.append("-F") # Fixed string (not regex)
|
121
|
+
|
122
|
+
# Add context lines
|
123
|
+
if max_context_lines > 0:
|
124
|
+
cmd.append(f"-C{max_context_lines}") # Context lines
|
125
|
+
|
126
|
+
# Add recursive search
|
127
|
+
cmd.append("-r")
|
128
|
+
|
129
|
+
# Add line numbers
|
130
|
+
cmd.append("-n")
|
131
|
+
|
132
|
+
# Add max count if specified
|
133
|
+
if max_results > 0:
|
134
|
+
cmd.append(f"--max-count={max_results}")
|
135
|
+
|
136
|
+
# Add search term
|
137
|
+
cmd.append(term)
|
138
|
+
|
139
|
+
# Add directory
|
140
|
+
cmd.append(directory)
|
141
|
+
|
142
|
+
# Add file pattern if specified
|
143
|
+
if file_pattern != "*":
|
144
|
+
cmd.append("--include")
|
145
|
+
cmd.append(file_pattern)
|
146
|
+
|
147
|
+
# Execute grep command
|
148
|
+
result = subprocess.run(cmd, capture_output=True, text=True)
|
149
|
+
|
150
|
+
# Process results
|
151
|
+
if result.returncode != 0 and result.returncode != 1: # 1 means no matches
|
152
|
+
raise subprocess.SubprocessError(f"Grep error: {result.stderr}")
|
153
|
+
|
154
|
+
# Parse output
|
155
|
+
matches = []
|
156
|
+
for line in result.stdout.splitlines():
|
157
|
+
if not line.strip():
|
158
|
+
continue
|
159
|
+
|
160
|
+
# Parse grep output (filename:line_number:content)
|
161
|
+
parts = line.split(":", 2)
|
162
|
+
if len(parts) >= 3:
|
163
|
+
filename = parts[0]
|
164
|
+
line_number = int(parts[1])
|
165
|
+
content = parts[2]
|
166
|
+
|
167
|
+
matches.append(
|
168
|
+
{"file": filename, "line": line_number, "content": content}
|
169
|
+
)
|
170
|
+
|
171
|
+
return {
|
172
|
+
"success": True,
|
173
|
+
"term": term,
|
174
|
+
"directory": directory,
|
175
|
+
"file_pattern": file_pattern,
|
176
|
+
"matches": matches,
|
177
|
+
"match_count": len(matches),
|
178
|
+
"truncated": result.stdout.count("\n") >= max_results,
|
179
|
+
}
|
180
|
+
|
181
|
+
def _search_with_python(
|
182
|
+
self,
|
183
|
+
term: str,
|
184
|
+
directory: str,
|
185
|
+
file_pattern: str,
|
186
|
+
case_sensitive: bool,
|
187
|
+
regex: bool,
|
188
|
+
max_results: int,
|
189
|
+
max_context_lines: int,
|
190
|
+
) -> Dict[str, Any]:
|
191
|
+
"""Use Python to search for terms (fallback method)."""
|
192
|
+
import fnmatch
|
193
|
+
import glob
|
194
|
+
|
195
|
+
# Expand directory path
|
196
|
+
directory = os.path.expanduser(directory)
|
197
|
+
|
198
|
+
if not os.path.exists(directory):
|
199
|
+
return {
|
200
|
+
"success": False,
|
201
|
+
"error": f"Directory '{directory}' does not exist",
|
202
|
+
}
|
203
|
+
|
204
|
+
if not os.path.isdir(directory):
|
205
|
+
return {"success": False, "error": f"Path '{directory}' is not a directory"}
|
206
|
+
|
207
|
+
# Compile regex if needed
|
208
|
+
if regex:
|
209
|
+
if case_sensitive:
|
210
|
+
pattern = re.compile(term)
|
211
|
+
else:
|
212
|
+
pattern = re.compile(term, re.IGNORECASE)
|
213
|
+
else:
|
214
|
+
if case_sensitive:
|
215
|
+
pattern = re.compile(re.escape(term))
|
216
|
+
else:
|
217
|
+
pattern = re.compile(re.escape(term), re.IGNORECASE)
|
218
|
+
|
219
|
+
# Find all files matching the pattern
|
220
|
+
matches = []
|
221
|
+
match_count = 0
|
222
|
+
truncated = False
|
223
|
+
|
224
|
+
for root, _, files in os.walk(directory):
|
225
|
+
for filename in fnmatch.filter(files, file_pattern):
|
226
|
+
file_path = os.path.join(root, filename)
|
227
|
+
|
228
|
+
try:
|
229
|
+
with open(file_path, "r", encoding="utf-8", errors="ignore") as f:
|
230
|
+
lines = f.readlines()
|
231
|
+
|
232
|
+
# Search for matches in the file
|
233
|
+
for i, line in enumerate(lines):
|
234
|
+
if pattern.search(line):
|
235
|
+
line_number = i + 1
|
236
|
+
|
237
|
+
# Extract context lines
|
238
|
+
context_start = max(0, i - max_context_lines)
|
239
|
+
context_end = min(len(lines), i + max_context_lines + 1)
|
240
|
+
context_lines = lines[context_start:context_end]
|
241
|
+
|
242
|
+
matches.append(
|
243
|
+
{
|
244
|
+
"file": file_path,
|
245
|
+
"line": line_number,
|
246
|
+
"content": line.rstrip("\n"),
|
247
|
+
"context": {
|
248
|
+
"start_line": context_start + 1,
|
249
|
+
"end_line": context_end,
|
250
|
+
"lines": [
|
251
|
+
l.rstrip("\n") for l in context_lines
|
252
|
+
],
|
253
|
+
},
|
254
|
+
}
|
255
|
+
)
|
256
|
+
|
257
|
+
match_count += 1
|
258
|
+
if match_count >= max_results:
|
259
|
+
truncated = True
|
260
|
+
break
|
261
|
+
|
262
|
+
if truncated:
|
263
|
+
break
|
264
|
+
|
265
|
+
except Exception as e:
|
266
|
+
# Skip files that can't be read
|
267
|
+
continue
|
268
|
+
|
269
|
+
if truncated:
|
270
|
+
break
|
271
|
+
|
272
|
+
return {
|
273
|
+
"success": True,
|
274
|
+
"term": term,
|
275
|
+
"directory": directory,
|
276
|
+
"file_pattern": file_pattern,
|
277
|
+
"matches": matches,
|
278
|
+
"match_count": match_count,
|
279
|
+
"truncated": truncated,
|
280
|
+
}
|
281
|
+
|
282
|
+
def to_dict(self):
|
283
|
+
"""Convert tool to dictionary format for LLM function calling."""
|
284
|
+
return {
|
285
|
+
"type": "function",
|
286
|
+
"function": {
|
287
|
+
"name": self.name,
|
288
|
+
"description": self.description,
|
289
|
+
"parameters": self.parameters,
|
290
|
+
},
|
291
|
+
}
|
292
|
+
|
293
|
+
|
294
|
+
@register_tool("web_search")
|
295
|
+
class WebSearchTool(StatelessTool):
|
296
|
+
"""Tool for searching the web using the Exa API."""
|
297
|
+
|
298
|
+
def __init__(self):
|
299
|
+
self.name = "web_search"
|
300
|
+
self.description = "Search the web for information using the Exa search API"
|
301
|
+
self.parameters = {
|
302
|
+
"type": "object",
|
303
|
+
"properties": {
|
304
|
+
"query": {
|
305
|
+
"type": "string",
|
306
|
+
"description": "The search query to execute",
|
307
|
+
},
|
308
|
+
"num_results": {
|
309
|
+
"type": "integer",
|
310
|
+
"description": "Number of results to return (default: 5)",
|
311
|
+
},
|
312
|
+
"include_domains": {
|
313
|
+
"type": "array",
|
314
|
+
"items": {"type": "string"},
|
315
|
+
"description": "List of domains to include in the search",
|
316
|
+
},
|
317
|
+
"exclude_domains": {
|
318
|
+
"type": "array",
|
319
|
+
"items": {"type": "string"},
|
320
|
+
"description": "List of domains to exclude from the search",
|
321
|
+
},
|
322
|
+
"use_autoprompt": {
|
323
|
+
"type": "boolean",
|
324
|
+
"description": "Whether to use Exa's autoprompt feature for better results",
|
325
|
+
},
|
326
|
+
},
|
327
|
+
"required": ["query"],
|
328
|
+
}
|
329
|
+
|
330
|
+
# Exa API key from environment variable
|
331
|
+
api_key = os.environ.get("EXA_API_KEY", "")
|
332
|
+
if api_key:
|
333
|
+
self.credentials = ExaCredentials(api_key=api_key)
|
334
|
+
self.skill = ExaSearchSkill(credentials=self.credentials)
|
335
|
+
else:
|
336
|
+
self.credentials = None
|
337
|
+
self.skill = None
|
338
|
+
|
339
|
+
async def _async_search(
|
340
|
+
self,
|
341
|
+
query: str,
|
342
|
+
num_results: int = 5,
|
343
|
+
include_domains: Optional[List[str]] = None,
|
344
|
+
exclude_domains: Optional[List[str]] = None,
|
345
|
+
use_autoprompt: bool = False,
|
346
|
+
) -> Dict[str, Any]:
|
347
|
+
"""Execute the search asynchronously."""
|
348
|
+
if not self.credentials or not self.skill:
|
349
|
+
return {
|
350
|
+
"success": False,
|
351
|
+
"error": "Exa API key not configured. Set the EXA_API_KEY environment variable.",
|
352
|
+
}
|
353
|
+
|
354
|
+
try:
|
355
|
+
# Create input for the search
|
356
|
+
search_input = ExaSearchInputSchema(
|
357
|
+
query=query,
|
358
|
+
numResults=num_results,
|
359
|
+
includeDomains=include_domains,
|
360
|
+
excludeDomains=exclude_domains,
|
361
|
+
useAutoprompt=use_autoprompt,
|
362
|
+
contents=ExaContentConfig(text=True),
|
363
|
+
)
|
364
|
+
|
365
|
+
# Execute search
|
366
|
+
result = await self.skill.process(search_input)
|
367
|
+
|
368
|
+
# Process results into a simplified format
|
369
|
+
search_results = []
|
370
|
+
for item in result.results:
|
371
|
+
search_results.append(
|
372
|
+
{
|
373
|
+
"title": item.title or "No title",
|
374
|
+
"url": item.url,
|
375
|
+
"content": (
|
376
|
+
item.text[:1000] if item.text else "No content available"
|
377
|
+
),
|
378
|
+
"score": item.score,
|
379
|
+
"published": item.published,
|
380
|
+
}
|
381
|
+
)
|
382
|
+
|
383
|
+
return {
|
384
|
+
"success": True,
|
385
|
+
"query": query,
|
386
|
+
"results": search_results,
|
387
|
+
"result_count": len(search_results),
|
388
|
+
"autoprompt": result.autopromptString,
|
389
|
+
}
|
390
|
+
|
391
|
+
except Exception as e:
|
392
|
+
return {"success": False, "error": f"Error performing web search: {str(e)}"}
|
393
|
+
|
394
|
+
def __call__(
|
395
|
+
self,
|
396
|
+
query: str,
|
397
|
+
num_results: int = 5,
|
398
|
+
include_domains: Optional[List[str]] = None,
|
399
|
+
exclude_domains: Optional[List[str]] = None,
|
400
|
+
use_autoprompt: bool = False,
|
401
|
+
) -> Dict[str, Any]:
|
402
|
+
"""
|
403
|
+
Search the web for information.
|
404
|
+
|
405
|
+
Args:
|
406
|
+
query: The search query to execute
|
407
|
+
num_results: Number of results to return
|
408
|
+
include_domains: List of domains to include in search results
|
409
|
+
exclude_domains: List of domains to exclude from search results
|
410
|
+
use_autoprompt: Whether to use Exa's autoprompt feature
|
411
|
+
|
412
|
+
Returns:
|
413
|
+
Dictionary containing search results or error information
|
414
|
+
"""
|
415
|
+
import asyncio
|
416
|
+
|
417
|
+
if not self.credentials or not self.skill:
|
418
|
+
return {
|
419
|
+
"success": False,
|
420
|
+
"error": "Exa API key not configured. Set the EXA_API_KEY environment variable.",
|
421
|
+
}
|
422
|
+
|
423
|
+
# Run the async search in a new event loop
|
424
|
+
try:
|
425
|
+
loop = asyncio.new_event_loop()
|
426
|
+
asyncio.set_event_loop(loop)
|
427
|
+
result = loop.run_until_complete(
|
428
|
+
self._async_search(
|
429
|
+
query=query,
|
430
|
+
num_results=num_results,
|
431
|
+
include_domains=include_domains,
|
432
|
+
exclude_domains=exclude_domains,
|
433
|
+
use_autoprompt=use_autoprompt,
|
434
|
+
)
|
435
|
+
)
|
436
|
+
loop.close()
|
437
|
+
return result
|
438
|
+
except Exception as e:
|
439
|
+
return {"success": False, "error": f"Error executing search: {str(e)}"}
|
440
|
+
|
441
|
+
def to_dict(self):
|
442
|
+
"""Convert tool to dictionary format for LLM function calling."""
|
443
|
+
return {
|
444
|
+
"type": "function",
|
445
|
+
"function": {
|
446
|
+
"name": self.name,
|
447
|
+
"description": self.description,
|
448
|
+
"parameters": self.parameters,
|
449
|
+
},
|
450
|
+
}
|
@@ -0,0 +1,135 @@
|
|
1
|
+
"""
|
2
|
+
Testing tools for AirTrain agents.
|
3
|
+
|
4
|
+
This module provides tools for running tests and test frameworks.
|
5
|
+
"""
|
6
|
+
|
7
|
+
import os
|
8
|
+
import subprocess
|
9
|
+
from typing import Dict, Any, Optional, List
|
10
|
+
|
11
|
+
from .registry import StatelessTool, register_tool
|
12
|
+
|
13
|
+
|
14
|
+
@register_tool("run_pytest")
|
15
|
+
class RunPytestTool(StatelessTool):
|
16
|
+
"""Tool for running Python pytest on test files."""
|
17
|
+
|
18
|
+
def __init__(self):
|
19
|
+
self.name = "run_pytest"
|
20
|
+
self.description = "Run pytest on a specific test file or directory"
|
21
|
+
self.parameters = {
|
22
|
+
"type": "object",
|
23
|
+
"properties": {
|
24
|
+
"test_path": {
|
25
|
+
"type": "string",
|
26
|
+
"description": "Path to the test file or directory to run",
|
27
|
+
},
|
28
|
+
"args": {
|
29
|
+
"type": "array",
|
30
|
+
"items": {"type": "string"},
|
31
|
+
"description": "Additional pytest arguments",
|
32
|
+
},
|
33
|
+
"working_dir": {
|
34
|
+
"type": "string",
|
35
|
+
"description": "Working directory to run pytest from",
|
36
|
+
},
|
37
|
+
"verbose": {
|
38
|
+
"type": "boolean",
|
39
|
+
"description": "Run tests in verbose mode",
|
40
|
+
},
|
41
|
+
"capture_output": {
|
42
|
+
"type": "boolean",
|
43
|
+
"description": "Capture stdout/stderr or show directly",
|
44
|
+
},
|
45
|
+
"timeout": {
|
46
|
+
"type": "number",
|
47
|
+
"description": "Timeout in seconds for the test run",
|
48
|
+
},
|
49
|
+
},
|
50
|
+
"required": ["test_path"],
|
51
|
+
}
|
52
|
+
|
53
|
+
def __call__(
|
54
|
+
self,
|
55
|
+
test_path: str,
|
56
|
+
args: Optional[List[str]] = None,
|
57
|
+
working_dir: Optional[str] = None,
|
58
|
+
verbose: bool = False,
|
59
|
+
capture_output: bool = True,
|
60
|
+
timeout: float = 60.0,
|
61
|
+
) -> Dict[str, Any]:
|
62
|
+
"""Run pytest on a specific test file or directory."""
|
63
|
+
try:
|
64
|
+
# Expand user path if present
|
65
|
+
test_path = os.path.expanduser(test_path)
|
66
|
+
|
67
|
+
# Validate test path
|
68
|
+
if not os.path.exists(test_path):
|
69
|
+
return {
|
70
|
+
"success": False,
|
71
|
+
"error": f"Test path '{test_path}' does not exist",
|
72
|
+
}
|
73
|
+
|
74
|
+
# Prepare pytest command
|
75
|
+
cmd = ["pytest", test_path]
|
76
|
+
|
77
|
+
# Add verbosity flag if requested
|
78
|
+
if verbose:
|
79
|
+
cmd.append("-v")
|
80
|
+
|
81
|
+
# Add any additional arguments
|
82
|
+
if args:
|
83
|
+
cmd.extend(args)
|
84
|
+
|
85
|
+
# Run pytest
|
86
|
+
process = subprocess.run(
|
87
|
+
cmd,
|
88
|
+
cwd=working_dir,
|
89
|
+
capture_output=capture_output,
|
90
|
+
text=True,
|
91
|
+
timeout=timeout,
|
92
|
+
)
|
93
|
+
|
94
|
+
result = {
|
95
|
+
"success": process.returncode == 0,
|
96
|
+
"return_code": process.returncode,
|
97
|
+
"test_path": test_path,
|
98
|
+
"command": " ".join(cmd),
|
99
|
+
}
|
100
|
+
|
101
|
+
if capture_output:
|
102
|
+
result["stdout"] = process.stdout
|
103
|
+
result["stderr"] = process.stderr
|
104
|
+
|
105
|
+
# Parse test summary from output
|
106
|
+
if "failed" in process.stdout or "passed" in process.stdout:
|
107
|
+
summary_lines = []
|
108
|
+
for line in process.stdout.splitlines():
|
109
|
+
if "failed" in line or "passed" in line or "skipped" in line:
|
110
|
+
if "===" in line and "===" in line:
|
111
|
+
summary_lines.append(line.strip())
|
112
|
+
|
113
|
+
if summary_lines:
|
114
|
+
result["summary"] = summary_lines
|
115
|
+
|
116
|
+
return result
|
117
|
+
|
118
|
+
except subprocess.TimeoutExpired:
|
119
|
+
return {
|
120
|
+
"success": False,
|
121
|
+
"error": f"Pytest run timed out after {timeout} seconds",
|
122
|
+
}
|
123
|
+
except Exception as e:
|
124
|
+
return {"success": False, "error": f"Error running pytest: {str(e)}"}
|
125
|
+
|
126
|
+
def to_dict(self):
|
127
|
+
"""Convert tool to dictionary format for LLM function calling."""
|
128
|
+
return {
|
129
|
+
"type": "function",
|
130
|
+
"function": {
|
131
|
+
"name": self.name,
|
132
|
+
"description": self.description,
|
133
|
+
"parameters": self.parameters,
|
134
|
+
},
|
135
|
+
}
|