airtrain 0.1.58__py3-none-any.whl → 0.1.62__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,114 @@
1
+ """
2
+ Schemas for Exa Search API.
3
+
4
+ This module defines the input and output schemas for the Exa search API.
5
+ """
6
+
7
+ from typing import Dict, List, Optional, Any, Union
8
+ from pydantic import BaseModel, Field, HttpUrl
9
+
10
+ from airtrain.core.schemas import InputSchema, OutputSchema
11
+
12
+
13
+ class ExaContentConfig(BaseModel):
14
+ """Configuration for the content to be returned by Exa search."""
15
+
16
+ text: bool = Field(default=True, description="Whether to return text content.")
17
+ extractedText: Optional[bool] = Field(
18
+ default=None, description="Whether to return extracted text content."
19
+ )
20
+ embedded: Optional[bool] = Field(
21
+ default=None, description="Whether to return embedded content."
22
+ )
23
+ links: Optional[bool] = Field(
24
+ default=None, description="Whether to return links from the content."
25
+ )
26
+ screenshot: Optional[bool] = Field(
27
+ default=None, description="Whether to return screenshots of the content."
28
+ )
29
+ highlighted: Optional[bool] = Field(
30
+ default=None, description="Whether to return highlighted text."
31
+ )
32
+
33
+
34
+ class ExaSearchInputSchema(InputSchema):
35
+ """Input schema for Exa search API."""
36
+
37
+ query: str = Field(description="The search query to execute.")
38
+ numResults: Optional[int] = Field(
39
+ default=None, description="Number of results to return."
40
+ )
41
+ contents: Optional[ExaContentConfig] = Field(
42
+ default_factory=ExaContentConfig,
43
+ description="Configuration for the content to be returned.",
44
+ )
45
+ highlights: Optional[dict] = Field(
46
+ default=None, description="Highlighting configuration for search results."
47
+ )
48
+ useAutoprompt: Optional[bool] = Field(
49
+ default=None, description="Whether to use autoprompt for the search."
50
+ )
51
+ type: Optional[str] = Field(default=None, description="Type of search to perform.")
52
+ includeDomains: Optional[List[str]] = Field(
53
+ default=None, description="List of domains to include in the search."
54
+ )
55
+ excludeDomains: Optional[List[str]] = Field(
56
+ default=None, description="List of domains to exclude from the search."
57
+ )
58
+
59
+
60
+ class ExaModerationConfig(BaseModel):
61
+ """Moderation configuration returned in search results."""
62
+
63
+ llamaguardS1: Optional[bool] = None
64
+ llamaguardS3: Optional[bool] = None
65
+ llamaguardS4: Optional[bool] = None
66
+ llamaguardS12: Optional[bool] = None
67
+ domainBlacklisted: Optional[bool] = None
68
+ domainBlacklistedMedia: Optional[bool] = None
69
+
70
+
71
+ class ExaHighlight(BaseModel):
72
+ """Highlight information for a search result."""
73
+
74
+ text: str
75
+ score: float
76
+
77
+
78
+ class ExaSearchResult(BaseModel):
79
+ """Individual search result from Exa."""
80
+
81
+ id: str
82
+ url: str
83
+ title: Optional[str] = None
84
+ text: Optional[str] = None
85
+ extractedText: Optional[str] = None
86
+ embedded: Optional[Dict[str, Any]] = None
87
+ score: float
88
+ published: Optional[str] = None
89
+ author: Optional[str] = None
90
+ highlights: Optional[List[ExaHighlight]] = None
91
+ robotsAllowed: Optional[bool] = None
92
+ moderationConfig: Optional[ExaModerationConfig] = None
93
+ urls: Optional[List[str]] = None
94
+
95
+
96
+ class ExaCostDetails(BaseModel):
97
+ """Cost details for an Exa search request."""
98
+
99
+ total: float
100
+ search: Dict[str, float]
101
+ contents: Dict[str, float]
102
+
103
+
104
+ class ExaSearchOutputSchema(OutputSchema):
105
+ """Output schema for Exa search API."""
106
+
107
+ results: List[ExaSearchResult] = Field(description="List of search results.")
108
+ query: str = Field(description="The original search query.")
109
+ autopromptString: Optional[str] = Field(
110
+ default=None, description="Autoprompt string used for the search if enabled."
111
+ )
112
+ costDollars: Optional[ExaCostDetails] = Field(
113
+ default=None, description="Cost details for the search request."
114
+ )
@@ -0,0 +1,114 @@
1
+ """
2
+ Skills for Exa Search API.
3
+
4
+ This module provides skills for using the Exa search API.
5
+ """
6
+
7
+ import json
8
+ import logging
9
+ import httpx
10
+ from typing import Optional, Dict, Any, List, cast
11
+
12
+ from pydantic import ValidationError
13
+
14
+ from airtrain.core.skills import Skill, ProcessingError
15
+ from .credentials import ExaCredentials
16
+ from .schemas import ExaSearchInputSchema, ExaSearchOutputSchema, ExaSearchResult
17
+
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ class ExaSearchSkill(Skill[ExaSearchInputSchema, ExaSearchOutputSchema]):
23
+ """Skill for searching the web using the Exa search API."""
24
+
25
+ input_schema = ExaSearchInputSchema
26
+ output_schema = ExaSearchOutputSchema
27
+
28
+ EXA_API_ENDPOINT = "https://api.exa.ai/search"
29
+
30
+ def __init__(
31
+ self,
32
+ credentials: ExaCredentials,
33
+ timeout: float = 60.0,
34
+ max_retries: int = 3,
35
+ **kwargs,
36
+ ):
37
+ """
38
+ Initialize the Exa search skill.
39
+
40
+ Args:
41
+ credentials: Credentials for accessing the Exa API
42
+ timeout: Timeout for API requests in seconds
43
+ max_retries: Maximum number of retries for failed requests
44
+ """
45
+ super().__init__(**kwargs)
46
+ self.credentials = credentials
47
+ self.timeout = timeout
48
+ self.max_retries = max_retries
49
+
50
+ async def process(self, input_data: ExaSearchInputSchema) -> ExaSearchOutputSchema:
51
+ """
52
+ Process a search request using the Exa API.
53
+
54
+ Args:
55
+ input_data: Search input parameters
56
+
57
+ Returns:
58
+ Search results from Exa
59
+
60
+ Raises:
61
+ ProcessingError: If there's an issue with the API request
62
+ """
63
+ try:
64
+ # Prepare request payload
65
+ payload = input_data.model_dump(exclude_none=True)
66
+
67
+ # Build request headers
68
+ headers = {
69
+ "content-type": "application/json",
70
+ "Authorization": f"Bearer {self.credentials.api_key.get_secret_value()}",
71
+ }
72
+
73
+ # Make the API request
74
+ async with httpx.AsyncClient() as client:
75
+ response = await client.post(
76
+ self.EXA_API_ENDPOINT,
77
+ headers=headers,
78
+ json=payload,
79
+ timeout=self.timeout,
80
+ )
81
+
82
+ # Check for successful response
83
+ if response.status_code == 200:
84
+ result_data = response.json()
85
+
86
+ # Construct the output schema
87
+ output = ExaSearchOutputSchema(
88
+ results=result_data.get("results", []),
89
+ query=input_data.query,
90
+ autopromptString=result_data.get("autopromptString"),
91
+ costDollars=result_data.get("costDollars"),
92
+ )
93
+
94
+ return output
95
+ else:
96
+ # Handle error responses
97
+ error_message = f"Exa API returned status code {response.status_code}: {response.text}"
98
+ logger.error(error_message)
99
+ raise ProcessingError(error_message)
100
+
101
+ except httpx.TimeoutException:
102
+ error_message = f"Timeout while querying Exa API (timeout={self.timeout}s)"
103
+ logger.error(error_message)
104
+ raise ProcessingError(error_message)
105
+
106
+ except ValidationError as e:
107
+ error_message = f"Schema validation error: {str(e)}"
108
+ logger.error(error_message)
109
+ raise ProcessingError(error_message)
110
+
111
+ except Exception as e:
112
+ error_message = f"Unexpected error while querying Exa API: {str(e)}"
113
+ logger.error(error_message)
114
+ raise ProcessingError(error_message)
@@ -12,30 +12,34 @@ from .registry import (
12
12
  ToolFactory,
13
13
  ToolValidationError,
14
14
  register_tool,
15
- execute_tool_call
15
+ execute_tool_call,
16
16
  )
17
17
 
18
18
  # Import standard tools
19
19
  from .filesystem import ListDirectoryTool, DirectoryTreeTool
20
20
  from .network import ApiCallTool
21
- from .command import ExecuteCommandTool, FindFilesTool
21
+ from .command import ExecuteCommandTool, FindFilesTool, TerminalNavigationTool
22
+ from .search import SearchTermTool, WebSearchTool
23
+ from .testing import RunPytestTool
22
24
 
23
25
  __all__ = [
24
26
  # Base classes
25
27
  "BaseTool",
26
28
  "StatelessTool",
27
29
  "StatefulTool",
28
-
29
30
  # Registry components
30
31
  "ToolFactory",
31
32
  "ToolValidationError",
32
33
  "register_tool",
33
34
  "execute_tool_call",
34
-
35
35
  # Standard tools
36
36
  "ListDirectoryTool",
37
37
  "DirectoryTreeTool",
38
38
  "ApiCallTool",
39
39
  "ExecuteCommandTool",
40
40
  "FindFilesTool",
41
- ]
41
+ "TerminalNavigationTool",
42
+ "SearchTermTool",
43
+ "WebSearchTool",
44
+ "RunPytestTool",
45
+ ]