airtrain 0.1.58__py3-none-any.whl → 0.1.62__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airtrain/__init__.py +72 -44
- airtrain/__pycache__/__init__.cpython-313.pyc +0 -0
- airtrain/core/__pycache__/__init__.cpython-313.pyc +0 -0
- airtrain/core/__pycache__/schemas.cpython-313.pyc +0 -0
- airtrain/core/__pycache__/skills.cpython-313.pyc +0 -0
- airtrain/core/credentials.py +59 -13
- airtrain/integrations/__init__.py +21 -2
- airtrain/integrations/combined/list_models_factory.py +80 -41
- airtrain/integrations/perplexity/__init__.py +49 -0
- airtrain/integrations/perplexity/credentials.py +43 -0
- airtrain/integrations/perplexity/list_models.py +112 -0
- airtrain/integrations/perplexity/models_config.py +128 -0
- airtrain/integrations/perplexity/skills.py +279 -0
- airtrain/integrations/search/__init__.py +21 -0
- airtrain/integrations/search/exa/__init__.py +23 -0
- airtrain/integrations/search/exa/credentials.py +30 -0
- airtrain/integrations/search/exa/schemas.py +114 -0
- airtrain/integrations/search/exa/skills.py +114 -0
- airtrain/tools/__init__.py +9 -5
- airtrain/tools/command.py +248 -61
- airtrain/tools/search.py +450 -0
- airtrain/tools/testing.py +135 -0
- {airtrain-0.1.58.dist-info → airtrain-0.1.62.dist-info}/METADATA +1 -1
- {airtrain-0.1.58.dist-info → airtrain-0.1.62.dist-info}/RECORD +27 -15
- {airtrain-0.1.58.dist-info → airtrain-0.1.62.dist-info}/WHEEL +1 -1
- {airtrain-0.1.58.dist-info → airtrain-0.1.62.dist-info}/entry_points.txt +0 -0
- {airtrain-0.1.58.dist-info → airtrain-0.1.62.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,114 @@
|
|
1
|
+
"""
|
2
|
+
Schemas for Exa Search API.
|
3
|
+
|
4
|
+
This module defines the input and output schemas for the Exa search API.
|
5
|
+
"""
|
6
|
+
|
7
|
+
from typing import Dict, List, Optional, Any, Union
|
8
|
+
from pydantic import BaseModel, Field, HttpUrl
|
9
|
+
|
10
|
+
from airtrain.core.schemas import InputSchema, OutputSchema
|
11
|
+
|
12
|
+
|
13
|
+
class ExaContentConfig(BaseModel):
|
14
|
+
"""Configuration for the content to be returned by Exa search."""
|
15
|
+
|
16
|
+
text: bool = Field(default=True, description="Whether to return text content.")
|
17
|
+
extractedText: Optional[bool] = Field(
|
18
|
+
default=None, description="Whether to return extracted text content."
|
19
|
+
)
|
20
|
+
embedded: Optional[bool] = Field(
|
21
|
+
default=None, description="Whether to return embedded content."
|
22
|
+
)
|
23
|
+
links: Optional[bool] = Field(
|
24
|
+
default=None, description="Whether to return links from the content."
|
25
|
+
)
|
26
|
+
screenshot: Optional[bool] = Field(
|
27
|
+
default=None, description="Whether to return screenshots of the content."
|
28
|
+
)
|
29
|
+
highlighted: Optional[bool] = Field(
|
30
|
+
default=None, description="Whether to return highlighted text."
|
31
|
+
)
|
32
|
+
|
33
|
+
|
34
|
+
class ExaSearchInputSchema(InputSchema):
|
35
|
+
"""Input schema for Exa search API."""
|
36
|
+
|
37
|
+
query: str = Field(description="The search query to execute.")
|
38
|
+
numResults: Optional[int] = Field(
|
39
|
+
default=None, description="Number of results to return."
|
40
|
+
)
|
41
|
+
contents: Optional[ExaContentConfig] = Field(
|
42
|
+
default_factory=ExaContentConfig,
|
43
|
+
description="Configuration for the content to be returned.",
|
44
|
+
)
|
45
|
+
highlights: Optional[dict] = Field(
|
46
|
+
default=None, description="Highlighting configuration for search results."
|
47
|
+
)
|
48
|
+
useAutoprompt: Optional[bool] = Field(
|
49
|
+
default=None, description="Whether to use autoprompt for the search."
|
50
|
+
)
|
51
|
+
type: Optional[str] = Field(default=None, description="Type of search to perform.")
|
52
|
+
includeDomains: Optional[List[str]] = Field(
|
53
|
+
default=None, description="List of domains to include in the search."
|
54
|
+
)
|
55
|
+
excludeDomains: Optional[List[str]] = Field(
|
56
|
+
default=None, description="List of domains to exclude from the search."
|
57
|
+
)
|
58
|
+
|
59
|
+
|
60
|
+
class ExaModerationConfig(BaseModel):
|
61
|
+
"""Moderation configuration returned in search results."""
|
62
|
+
|
63
|
+
llamaguardS1: Optional[bool] = None
|
64
|
+
llamaguardS3: Optional[bool] = None
|
65
|
+
llamaguardS4: Optional[bool] = None
|
66
|
+
llamaguardS12: Optional[bool] = None
|
67
|
+
domainBlacklisted: Optional[bool] = None
|
68
|
+
domainBlacklistedMedia: Optional[bool] = None
|
69
|
+
|
70
|
+
|
71
|
+
class ExaHighlight(BaseModel):
|
72
|
+
"""Highlight information for a search result."""
|
73
|
+
|
74
|
+
text: str
|
75
|
+
score: float
|
76
|
+
|
77
|
+
|
78
|
+
class ExaSearchResult(BaseModel):
|
79
|
+
"""Individual search result from Exa."""
|
80
|
+
|
81
|
+
id: str
|
82
|
+
url: str
|
83
|
+
title: Optional[str] = None
|
84
|
+
text: Optional[str] = None
|
85
|
+
extractedText: Optional[str] = None
|
86
|
+
embedded: Optional[Dict[str, Any]] = None
|
87
|
+
score: float
|
88
|
+
published: Optional[str] = None
|
89
|
+
author: Optional[str] = None
|
90
|
+
highlights: Optional[List[ExaHighlight]] = None
|
91
|
+
robotsAllowed: Optional[bool] = None
|
92
|
+
moderationConfig: Optional[ExaModerationConfig] = None
|
93
|
+
urls: Optional[List[str]] = None
|
94
|
+
|
95
|
+
|
96
|
+
class ExaCostDetails(BaseModel):
|
97
|
+
"""Cost details for an Exa search request."""
|
98
|
+
|
99
|
+
total: float
|
100
|
+
search: Dict[str, float]
|
101
|
+
contents: Dict[str, float]
|
102
|
+
|
103
|
+
|
104
|
+
class ExaSearchOutputSchema(OutputSchema):
|
105
|
+
"""Output schema for Exa search API."""
|
106
|
+
|
107
|
+
results: List[ExaSearchResult] = Field(description="List of search results.")
|
108
|
+
query: str = Field(description="The original search query.")
|
109
|
+
autopromptString: Optional[str] = Field(
|
110
|
+
default=None, description="Autoprompt string used for the search if enabled."
|
111
|
+
)
|
112
|
+
costDollars: Optional[ExaCostDetails] = Field(
|
113
|
+
default=None, description="Cost details for the search request."
|
114
|
+
)
|
@@ -0,0 +1,114 @@
|
|
1
|
+
"""
|
2
|
+
Skills for Exa Search API.
|
3
|
+
|
4
|
+
This module provides skills for using the Exa search API.
|
5
|
+
"""
|
6
|
+
|
7
|
+
import json
|
8
|
+
import logging
|
9
|
+
import httpx
|
10
|
+
from typing import Optional, Dict, Any, List, cast
|
11
|
+
|
12
|
+
from pydantic import ValidationError
|
13
|
+
|
14
|
+
from airtrain.core.skills import Skill, ProcessingError
|
15
|
+
from .credentials import ExaCredentials
|
16
|
+
from .schemas import ExaSearchInputSchema, ExaSearchOutputSchema, ExaSearchResult
|
17
|
+
|
18
|
+
|
19
|
+
logger = logging.getLogger(__name__)
|
20
|
+
|
21
|
+
|
22
|
+
class ExaSearchSkill(Skill[ExaSearchInputSchema, ExaSearchOutputSchema]):
|
23
|
+
"""Skill for searching the web using the Exa search API."""
|
24
|
+
|
25
|
+
input_schema = ExaSearchInputSchema
|
26
|
+
output_schema = ExaSearchOutputSchema
|
27
|
+
|
28
|
+
EXA_API_ENDPOINT = "https://api.exa.ai/search"
|
29
|
+
|
30
|
+
def __init__(
|
31
|
+
self,
|
32
|
+
credentials: ExaCredentials,
|
33
|
+
timeout: float = 60.0,
|
34
|
+
max_retries: int = 3,
|
35
|
+
**kwargs,
|
36
|
+
):
|
37
|
+
"""
|
38
|
+
Initialize the Exa search skill.
|
39
|
+
|
40
|
+
Args:
|
41
|
+
credentials: Credentials for accessing the Exa API
|
42
|
+
timeout: Timeout for API requests in seconds
|
43
|
+
max_retries: Maximum number of retries for failed requests
|
44
|
+
"""
|
45
|
+
super().__init__(**kwargs)
|
46
|
+
self.credentials = credentials
|
47
|
+
self.timeout = timeout
|
48
|
+
self.max_retries = max_retries
|
49
|
+
|
50
|
+
async def process(self, input_data: ExaSearchInputSchema) -> ExaSearchOutputSchema:
|
51
|
+
"""
|
52
|
+
Process a search request using the Exa API.
|
53
|
+
|
54
|
+
Args:
|
55
|
+
input_data: Search input parameters
|
56
|
+
|
57
|
+
Returns:
|
58
|
+
Search results from Exa
|
59
|
+
|
60
|
+
Raises:
|
61
|
+
ProcessingError: If there's an issue with the API request
|
62
|
+
"""
|
63
|
+
try:
|
64
|
+
# Prepare request payload
|
65
|
+
payload = input_data.model_dump(exclude_none=True)
|
66
|
+
|
67
|
+
# Build request headers
|
68
|
+
headers = {
|
69
|
+
"content-type": "application/json",
|
70
|
+
"Authorization": f"Bearer {self.credentials.api_key.get_secret_value()}",
|
71
|
+
}
|
72
|
+
|
73
|
+
# Make the API request
|
74
|
+
async with httpx.AsyncClient() as client:
|
75
|
+
response = await client.post(
|
76
|
+
self.EXA_API_ENDPOINT,
|
77
|
+
headers=headers,
|
78
|
+
json=payload,
|
79
|
+
timeout=self.timeout,
|
80
|
+
)
|
81
|
+
|
82
|
+
# Check for successful response
|
83
|
+
if response.status_code == 200:
|
84
|
+
result_data = response.json()
|
85
|
+
|
86
|
+
# Construct the output schema
|
87
|
+
output = ExaSearchOutputSchema(
|
88
|
+
results=result_data.get("results", []),
|
89
|
+
query=input_data.query,
|
90
|
+
autopromptString=result_data.get("autopromptString"),
|
91
|
+
costDollars=result_data.get("costDollars"),
|
92
|
+
)
|
93
|
+
|
94
|
+
return output
|
95
|
+
else:
|
96
|
+
# Handle error responses
|
97
|
+
error_message = f"Exa API returned status code {response.status_code}: {response.text}"
|
98
|
+
logger.error(error_message)
|
99
|
+
raise ProcessingError(error_message)
|
100
|
+
|
101
|
+
except httpx.TimeoutException:
|
102
|
+
error_message = f"Timeout while querying Exa API (timeout={self.timeout}s)"
|
103
|
+
logger.error(error_message)
|
104
|
+
raise ProcessingError(error_message)
|
105
|
+
|
106
|
+
except ValidationError as e:
|
107
|
+
error_message = f"Schema validation error: {str(e)}"
|
108
|
+
logger.error(error_message)
|
109
|
+
raise ProcessingError(error_message)
|
110
|
+
|
111
|
+
except Exception as e:
|
112
|
+
error_message = f"Unexpected error while querying Exa API: {str(e)}"
|
113
|
+
logger.error(error_message)
|
114
|
+
raise ProcessingError(error_message)
|
airtrain/tools/__init__.py
CHANGED
@@ -12,30 +12,34 @@ from .registry import (
|
|
12
12
|
ToolFactory,
|
13
13
|
ToolValidationError,
|
14
14
|
register_tool,
|
15
|
-
execute_tool_call
|
15
|
+
execute_tool_call,
|
16
16
|
)
|
17
17
|
|
18
18
|
# Import standard tools
|
19
19
|
from .filesystem import ListDirectoryTool, DirectoryTreeTool
|
20
20
|
from .network import ApiCallTool
|
21
|
-
from .command import ExecuteCommandTool, FindFilesTool
|
21
|
+
from .command import ExecuteCommandTool, FindFilesTool, TerminalNavigationTool
|
22
|
+
from .search import SearchTermTool, WebSearchTool
|
23
|
+
from .testing import RunPytestTool
|
22
24
|
|
23
25
|
__all__ = [
|
24
26
|
# Base classes
|
25
27
|
"BaseTool",
|
26
28
|
"StatelessTool",
|
27
29
|
"StatefulTool",
|
28
|
-
|
29
30
|
# Registry components
|
30
31
|
"ToolFactory",
|
31
32
|
"ToolValidationError",
|
32
33
|
"register_tool",
|
33
34
|
"execute_tool_call",
|
34
|
-
|
35
35
|
# Standard tools
|
36
36
|
"ListDirectoryTool",
|
37
37
|
"DirectoryTreeTool",
|
38
38
|
"ApiCallTool",
|
39
39
|
"ExecuteCommandTool",
|
40
40
|
"FindFilesTool",
|
41
|
-
|
41
|
+
"TerminalNavigationTool",
|
42
|
+
"SearchTermTool",
|
43
|
+
"WebSearchTool",
|
44
|
+
"RunPytestTool",
|
45
|
+
]
|