airtrain 0.1.2__py3-none-any.whl → 0.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (110) hide show
  1. airtrain/__init__.py +148 -2
  2. airtrain/__main__.py +4 -0
  3. airtrain/__pycache__/__init__.cpython-313.pyc +0 -0
  4. airtrain/agents/__init__.py +45 -0
  5. airtrain/agents/example_agent.py +348 -0
  6. airtrain/agents/groq_agent.py +289 -0
  7. airtrain/agents/memory.py +663 -0
  8. airtrain/agents/registry.py +465 -0
  9. airtrain/builder/__init__.py +3 -0
  10. airtrain/builder/agent_builder.py +122 -0
  11. airtrain/cli/__init__.py +0 -0
  12. airtrain/cli/builder.py +23 -0
  13. airtrain/cli/main.py +120 -0
  14. airtrain/contrib/__init__.py +29 -0
  15. airtrain/contrib/travel/__init__.py +35 -0
  16. airtrain/contrib/travel/agents.py +243 -0
  17. airtrain/contrib/travel/models.py +59 -0
  18. airtrain/core/__init__.py +7 -0
  19. airtrain/core/__pycache__/__init__.cpython-313.pyc +0 -0
  20. airtrain/core/__pycache__/schemas.cpython-313.pyc +0 -0
  21. airtrain/core/__pycache__/skills.cpython-313.pyc +0 -0
  22. airtrain/core/credentials.py +171 -0
  23. airtrain/core/schemas.py +237 -0
  24. airtrain/core/skills.py +269 -0
  25. airtrain/integrations/__init__.py +74 -0
  26. airtrain/integrations/anthropic/__init__.py +33 -0
  27. airtrain/integrations/anthropic/credentials.py +32 -0
  28. airtrain/integrations/anthropic/list_models.py +110 -0
  29. airtrain/integrations/anthropic/models_config.py +100 -0
  30. airtrain/integrations/anthropic/skills.py +155 -0
  31. airtrain/integrations/aws/__init__.py +6 -0
  32. airtrain/integrations/aws/credentials.py +36 -0
  33. airtrain/integrations/aws/skills.py +98 -0
  34. airtrain/integrations/cerebras/__init__.py +6 -0
  35. airtrain/integrations/cerebras/credentials.py +19 -0
  36. airtrain/integrations/cerebras/skills.py +127 -0
  37. airtrain/integrations/combined/__init__.py +21 -0
  38. airtrain/integrations/combined/groq_fireworks_skills.py +126 -0
  39. airtrain/integrations/combined/list_models_factory.py +210 -0
  40. airtrain/integrations/fireworks/__init__.py +21 -0
  41. airtrain/integrations/fireworks/completion_skills.py +147 -0
  42. airtrain/integrations/fireworks/conversation_manager.py +109 -0
  43. airtrain/integrations/fireworks/credentials.py +26 -0
  44. airtrain/integrations/fireworks/list_models.py +128 -0
  45. airtrain/integrations/fireworks/models.py +139 -0
  46. airtrain/integrations/fireworks/requests_skills.py +207 -0
  47. airtrain/integrations/fireworks/skills.py +181 -0
  48. airtrain/integrations/fireworks/structured_completion_skills.py +175 -0
  49. airtrain/integrations/fireworks/structured_requests_skills.py +291 -0
  50. airtrain/integrations/fireworks/structured_skills.py +102 -0
  51. airtrain/integrations/google/__init__.py +7 -0
  52. airtrain/integrations/google/credentials.py +58 -0
  53. airtrain/integrations/google/skills.py +122 -0
  54. airtrain/integrations/groq/__init__.py +23 -0
  55. airtrain/integrations/groq/credentials.py +24 -0
  56. airtrain/integrations/groq/models_config.py +162 -0
  57. airtrain/integrations/groq/skills.py +201 -0
  58. airtrain/integrations/ollama/__init__.py +6 -0
  59. airtrain/integrations/ollama/credentials.py +26 -0
  60. airtrain/integrations/ollama/skills.py +41 -0
  61. airtrain/integrations/openai/__init__.py +37 -0
  62. airtrain/integrations/openai/chinese_assistant.py +42 -0
  63. airtrain/integrations/openai/credentials.py +39 -0
  64. airtrain/integrations/openai/list_models.py +112 -0
  65. airtrain/integrations/openai/models_config.py +224 -0
  66. airtrain/integrations/openai/skills.py +342 -0
  67. airtrain/integrations/perplexity/__init__.py +49 -0
  68. airtrain/integrations/perplexity/credentials.py +43 -0
  69. airtrain/integrations/perplexity/list_models.py +112 -0
  70. airtrain/integrations/perplexity/models_config.py +128 -0
  71. airtrain/integrations/perplexity/skills.py +279 -0
  72. airtrain/integrations/sambanova/__init__.py +6 -0
  73. airtrain/integrations/sambanova/credentials.py +20 -0
  74. airtrain/integrations/sambanova/skills.py +129 -0
  75. airtrain/integrations/search/__init__.py +21 -0
  76. airtrain/integrations/search/exa/__init__.py +23 -0
  77. airtrain/integrations/search/exa/credentials.py +30 -0
  78. airtrain/integrations/search/exa/schemas.py +114 -0
  79. airtrain/integrations/search/exa/skills.py +115 -0
  80. airtrain/integrations/together/__init__.py +33 -0
  81. airtrain/integrations/together/audio_models_config.py +34 -0
  82. airtrain/integrations/together/credentials.py +22 -0
  83. airtrain/integrations/together/embedding_models_config.py +92 -0
  84. airtrain/integrations/together/image_models_config.py +69 -0
  85. airtrain/integrations/together/image_skill.py +143 -0
  86. airtrain/integrations/together/list_models.py +76 -0
  87. airtrain/integrations/together/models.py +95 -0
  88. airtrain/integrations/together/models_config.py +399 -0
  89. airtrain/integrations/together/rerank_models_config.py +43 -0
  90. airtrain/integrations/together/rerank_skill.py +49 -0
  91. airtrain/integrations/together/schemas.py +33 -0
  92. airtrain/integrations/together/skills.py +305 -0
  93. airtrain/integrations/together/vision_models_config.py +49 -0
  94. airtrain/telemetry/__init__.py +38 -0
  95. airtrain/telemetry/service.py +167 -0
  96. airtrain/telemetry/views.py +237 -0
  97. airtrain/tools/__init__.py +45 -0
  98. airtrain/tools/command.py +398 -0
  99. airtrain/tools/filesystem.py +166 -0
  100. airtrain/tools/network.py +111 -0
  101. airtrain/tools/registry.py +320 -0
  102. airtrain/tools/search.py +450 -0
  103. airtrain/tools/testing.py +135 -0
  104. airtrain-0.1.4.dist-info/METADATA +222 -0
  105. airtrain-0.1.4.dist-info/RECORD +108 -0
  106. {airtrain-0.1.2.dist-info → airtrain-0.1.4.dist-info}/WHEEL +1 -1
  107. airtrain-0.1.4.dist-info/entry_points.txt +2 -0
  108. airtrain-0.1.2.dist-info/METADATA +0 -106
  109. airtrain-0.1.2.dist-info/RECORD +0 -5
  110. {airtrain-0.1.2.dist-info → airtrain-0.1.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,279 @@
1
+ from typing import Dict, Any, List, Optional, Generator, Union
2
+ from pydantic import Field, validator
3
+ import requests
4
+
5
+ from airtrain.core.skills import Skill, ProcessingError
6
+ from airtrain.core.schemas import InputSchema, OutputSchema
7
+ from .credentials import PerplexityCredentials
8
+ from .models_config import get_model_config, get_default_model
9
+
10
+
11
+ class PerplexityInput(InputSchema):
12
+ """Schema for Perplexity AI chat input"""
13
+
14
+ user_input: str = Field(..., description="User's input text")
15
+ system_prompt: Optional[str] = Field(
16
+ default=None,
17
+ description="System prompt to guide the model's behavior",
18
+ )
19
+ conversation_history: List[Dict[str, str]] = Field(
20
+ default_factory=list,
21
+ description="List of previous conversation messages in [{'role': 'user|assistant', 'content': 'message'}] format",
22
+ )
23
+ model: str = Field(
24
+ default="sonar-pro",
25
+ description="Perplexity AI model to use",
26
+ )
27
+ temperature: Optional[float] = Field(
28
+ default=0.7, description="Temperature for response generation", ge=0, le=1
29
+ )
30
+ max_tokens: Optional[int] = Field(
31
+ default=500, description="Maximum tokens in response"
32
+ )
33
+ top_p: Optional[float] = Field(
34
+ default=1.0, description="Top-p (nucleus) sampling parameter", ge=0, le=1
35
+ )
36
+ top_k: Optional[int] = Field(
37
+ default=None,
38
+ description="Top-k sampling parameter",
39
+ )
40
+ presence_penalty: Optional[float] = Field(
41
+ default=None,
42
+ description="Presence penalty parameter",
43
+ )
44
+ frequency_penalty: Optional[float] = Field(
45
+ default=None,
46
+ description="Frequency penalty parameter",
47
+ )
48
+
49
+ @validator("model")
50
+ def validate_model(cls, v):
51
+ """Validate that the model is supported by Perplexity AI."""
52
+ try:
53
+ get_model_config(v)
54
+ return v
55
+ except ValueError as e:
56
+ raise ValueError(f"Invalid Perplexity AI model: {v}. {str(e)}")
57
+
58
+
59
+ class PerplexityCitation(OutputSchema):
60
+ """Schema for Perplexity AI citation information"""
61
+
62
+ url: str = Field(..., description="URL of the citation source")
63
+ title: Optional[str] = Field(None, description="Title of the cited source")
64
+ snippet: Optional[str] = Field(None, description="Text snippet from the citation")
65
+
66
+
67
+ class PerplexityOutput(OutputSchema):
68
+ """Schema for Perplexity AI chat output"""
69
+
70
+ response: str = Field(..., description="Model's response text")
71
+ used_model: str = Field(..., description="Model used for generation")
72
+ usage: Dict[str, int] = Field(..., description="Usage statistics from the API")
73
+ citations: Optional[List[PerplexityCitation]] = Field(
74
+ default=None, description="Citations used in the response, if available"
75
+ )
76
+ search_queries: Optional[List[str]] = Field(
77
+ default=None, description="Search queries used, if available"
78
+ )
79
+
80
+
81
+ class PerplexityChatSkill(Skill[PerplexityInput, PerplexityOutput]):
82
+ """Skill for interacting with Perplexity AI models"""
83
+
84
+ input_schema = PerplexityInput
85
+ output_schema = PerplexityOutput
86
+
87
+ def __init__(self, credentials: Optional[PerplexityCredentials] = None):
88
+ """Initialize the skill with optional credentials"""
89
+ super().__init__()
90
+ self.credentials = credentials or PerplexityCredentials.from_env()
91
+ self.api_url = "https://api.perplexity.ai/chat/completions"
92
+
93
+ def _build_messages(self, input_data: PerplexityInput) -> List[Dict[str, str]]:
94
+ """Build messages list from input data including conversation history."""
95
+ messages = []
96
+
97
+ # Add system prompt if provided
98
+ if input_data.system_prompt:
99
+ messages.append({"role": "system", "content": input_data.system_prompt})
100
+
101
+ # Add conversation history
102
+ if input_data.conversation_history:
103
+ messages.extend(input_data.conversation_history)
104
+
105
+ # Add current user input
106
+ messages.append({"role": "user", "content": input_data.user_input})
107
+
108
+ return messages
109
+
110
+ def _prepare_api_parameters(self, input_data: PerplexityInput) -> Dict[str, Any]:
111
+ """Prepare parameters for the API request."""
112
+ parameters = {
113
+ "model": input_data.model,
114
+ "messages": self._build_messages(input_data),
115
+ "max_tokens": input_data.max_tokens,
116
+ }
117
+
118
+ # Add optional parameters if provided
119
+ if input_data.temperature is not None:
120
+ parameters["temperature"] = input_data.temperature
121
+
122
+ if input_data.top_p is not None:
123
+ parameters["top_p"] = input_data.top_p
124
+
125
+ if input_data.top_k is not None:
126
+ parameters["top_k"] = input_data.top_k
127
+
128
+ if input_data.presence_penalty is not None:
129
+ parameters["presence_penalty"] = input_data.presence_penalty
130
+
131
+ if input_data.frequency_penalty is not None:
132
+ parameters["frequency_penalty"] = input_data.frequency_penalty
133
+
134
+ return parameters
135
+
136
+ def process(self, input_data: PerplexityInput) -> PerplexityOutput:
137
+ """Process the input and return the complete response."""
138
+ try:
139
+ # Prepare headers with API key
140
+ headers = {
141
+ "Authorization": f"Bearer {self.credentials.perplexity_api_key.get_secret_value()}",
142
+ "Content-Type": "application/json",
143
+ }
144
+
145
+ # Prepare parameters for the API request
146
+ data = self._prepare_api_parameters(input_data)
147
+
148
+ # Make the API request
149
+ response = requests.post(self.api_url, headers=headers, json=data)
150
+
151
+ # Check if request was successful
152
+ if response.status_code != 200:
153
+ raise ProcessingError(
154
+ f"Perplexity AI API error: {response.status_code} - {response.text}"
155
+ )
156
+
157
+ # Parse the response
158
+ result = response.json()
159
+
160
+ # Extract content from the completion
161
+ content = result["choices"][0]["message"]["content"]
162
+
163
+ # Extract and process citations if available
164
+ citations = None
165
+ if "citations" in result:
166
+ citations = [
167
+ PerplexityCitation(
168
+ url=citation.get("url", ""),
169
+ title=citation.get("title"),
170
+ snippet=citation.get("snippet"),
171
+ )
172
+ for citation in result.get("citations", [])
173
+ ]
174
+
175
+ # Extract search queries if available
176
+ search_queries = None
177
+ if "usage" in result and "num_search_queries" in result["usage"]:
178
+ search_queries = result.get("search_queries", [])
179
+
180
+ # Create and return output
181
+ return PerplexityOutput(
182
+ response=content,
183
+ used_model=input_data.model,
184
+ usage=result.get("usage", {}),
185
+ citations=citations,
186
+ search_queries=search_queries,
187
+ )
188
+
189
+ except Exception as e:
190
+ if isinstance(e, ProcessingError):
191
+ raise e
192
+ raise ProcessingError(f"Perplexity AI processing failed: {str(e)}")
193
+
194
+
195
+ class PerplexityProcessStreamError(Exception):
196
+ """Error raised during stream processing"""
197
+
198
+ pass
199
+
200
+
201
+ class PerplexityStreamOutput(OutputSchema):
202
+ """Schema for streaming output tokens"""
203
+
204
+ token: str = Field(..., description="Text token")
205
+ finish_reason: Optional[str] = Field(
206
+ None, description="Why the completion finished"
207
+ )
208
+
209
+
210
+ class PerplexityStreamingChatSkill(PerplexityChatSkill):
211
+ """Extension of PerplexityChatSkill that supports streaming responses"""
212
+
213
+ def process_stream(
214
+ self, input_data: PerplexityInput
215
+ ) -> Generator[PerplexityStreamOutput, None, None]:
216
+ """
217
+ Process the input and stream the response tokens.
218
+
219
+ Note: Perplexity AI API may not support true streaming. In that case, this
220
+ method will make a regular API call and yield the entire response at once.
221
+ """
222
+ try:
223
+ # Prepare headers with API key
224
+ headers = {
225
+ "Authorization": f"Bearer {self.credentials.perplexity_api_key.get_secret_value()}",
226
+ "Content-Type": "application/json",
227
+ }
228
+
229
+ # Prepare parameters for the API request, including stream=true if possible
230
+ data = self._prepare_api_parameters(input_data)
231
+ data["stream"] = True
232
+
233
+ # Make the API request
234
+ response = requests.post(
235
+ self.api_url, headers=headers, json=data, stream=True
236
+ )
237
+
238
+ # Check if request was successful
239
+ if response.status_code != 200:
240
+ raise PerplexityProcessStreamError(
241
+ f"Perplexity AI API error: {response.status_code} - {response.text}"
242
+ )
243
+
244
+ # Process the streaming response if supported
245
+ for line in response.iter_lines():
246
+ if line:
247
+ # Parse the response line
248
+ try:
249
+ # Remove 'data: ' prefix if present
250
+ if line.startswith(b"data: "):
251
+ line = line[6:]
252
+
253
+ # Parse JSON
254
+ import json
255
+
256
+ chunk = json.loads(line)
257
+
258
+ # Extract content
259
+ if "choices" in chunk and len(chunk["choices"]) > 0:
260
+ choice = chunk["choices"][0]
261
+ if "delta" in choice and "content" in choice["delta"]:
262
+ content = choice["delta"]["content"]
263
+ if content:
264
+ yield PerplexityStreamOutput(
265
+ token=content,
266
+ finish_reason=choice.get("finish_reason"),
267
+ )
268
+ except json.JSONDecodeError:
269
+ # Skip non-JSON lines
270
+ continue
271
+ except Exception as e:
272
+ raise PerplexityProcessStreamError(
273
+ f"Error processing stream chunk: {str(e)}"
274
+ )
275
+
276
+ except Exception as e:
277
+ if isinstance(e, PerplexityProcessStreamError):
278
+ raise ProcessingError(str(e))
279
+ raise ProcessingError(f"Perplexity AI streaming failed: {str(e)}")
@@ -0,0 +1,6 @@
1
+ """Sambanova integration module"""
2
+
3
+ from .credentials import SambanovaCredentials
4
+ from .skills import SambanovaChatSkill
5
+
6
+ __all__ = ["SambanovaCredentials", "SambanovaChatSkill"]
@@ -0,0 +1,20 @@
1
+ from pydantic import Field, SecretStr, HttpUrl
2
+ from airtrain.core.credentials import BaseCredentials, CredentialValidationError
3
+
4
+
5
+ class SambanovaCredentials(BaseCredentials):
6
+ """SambaNova credentials"""
7
+
8
+ sambanova_api_key: SecretStr = Field(..., description="SambaNova API key")
9
+ sambanova_endpoint_url: HttpUrl = Field(..., description="SambaNova API endpoint")
10
+
11
+ _required_credentials = {"sambanova_api_key", "sambanova_endpoint_url"}
12
+
13
+ async def validate_credentials(self) -> bool:
14
+ """Validate SambaNova credentials"""
15
+ try:
16
+ # Implement SambaNova-specific validation
17
+ # This would depend on their API client implementation
18
+ return True
19
+ except Exception as e:
20
+ raise CredentialValidationError(f"Invalid SambaNova credentials: {str(e)}")
@@ -0,0 +1,129 @@
1
+ from typing import Optional, Dict, Any, List, Generator
2
+ from pydantic import Field
3
+ from airtrain.core.skills import Skill, ProcessingError
4
+ from airtrain.core.schemas import InputSchema, OutputSchema
5
+ from .credentials import SambanovaCredentials
6
+ import openai
7
+
8
+
9
+ class SambanovaInput(InputSchema):
10
+ """Schema for Sambanova input"""
11
+
12
+ user_input: str = Field(..., description="User's input text")
13
+ system_prompt: str = Field(
14
+ default="You are a helpful assistant.",
15
+ description="System prompt to guide the model's behavior",
16
+ )
17
+ conversation_history: List[Dict[str, str]] = Field(
18
+ default_factory=list,
19
+ description="List of previous conversation messages in [{'role': 'user|assistant', 'content': 'message'}] format",
20
+ )
21
+ model: str = Field(
22
+ default="DeepSeek-R1-Distill-Llama-70B", description="Sambanova model to use"
23
+ )
24
+ max_tokens: int = Field(default=1024, description="Maximum tokens in response")
25
+ temperature: float = Field(
26
+ default=0.7, description="Temperature for response generation", ge=0, le=1
27
+ )
28
+ top_p: float = Field(
29
+ default=0.1, description="Top p sampling parameter", ge=0, le=1
30
+ )
31
+ stream: bool = Field(
32
+ default=False, description="Whether to stream the response progressively"
33
+ )
34
+
35
+
36
+ class SambanovaOutput(OutputSchema):
37
+ """Schema for Sambanova output"""
38
+
39
+ response: str = Field(..., description="Model's response text")
40
+ used_model: str = Field(..., description="Model used for generation")
41
+ usage: Dict[str, Any] = Field(default_factory=dict, description="Usage statistics")
42
+
43
+
44
+ class SambanovaChatSkill(Skill[SambanovaInput, SambanovaOutput]):
45
+ """Skill for Sambanova chat"""
46
+
47
+ input_schema = SambanovaInput
48
+ output_schema = SambanovaOutput
49
+
50
+ def __init__(self, credentials: Optional[SambanovaCredentials] = None):
51
+ super().__init__()
52
+ self.credentials = credentials or SambanovaCredentials.from_env()
53
+ self.client = openai.OpenAI(
54
+ api_key=self.credentials.sambanova_api_key.get_secret_value(),
55
+ base_url="https://api.sambanova.ai/v1",
56
+ )
57
+
58
+ def _build_messages(self, input_data: SambanovaInput) -> List[Dict[str, str]]:
59
+ """
60
+ Build messages list from input data including conversation history.
61
+
62
+ Args:
63
+ input_data: The input data containing system prompt, conversation history, and user input
64
+
65
+ Returns:
66
+ List[Dict[str, str]]: List of messages in the format required by Sambanova
67
+ """
68
+ messages = [{"role": "system", "content": input_data.system_prompt}]
69
+
70
+ # Add conversation history if present
71
+ if input_data.conversation_history:
72
+ messages.extend(input_data.conversation_history)
73
+
74
+ # Add current user input
75
+ messages.append({"role": "user", "content": input_data.user_input})
76
+
77
+ return messages
78
+
79
+ def process_stream(self, input_data: SambanovaInput) -> Generator[str, None, None]:
80
+ """Process the input and stream the response token by token."""
81
+ try:
82
+ messages = self._build_messages(input_data)
83
+
84
+ stream = self.client.chat.completions.create(
85
+ model=input_data.model,
86
+ messages=messages,
87
+ temperature=input_data.temperature,
88
+ max_tokens=input_data.max_tokens,
89
+ top_p=input_data.top_p,
90
+ stream=True,
91
+ )
92
+
93
+ for chunk in stream:
94
+ if chunk.choices[0].delta.content is not None:
95
+ yield chunk.choices[0].delta.content
96
+
97
+ except Exception as e:
98
+ raise ProcessingError(f"Sambanova streaming failed: {str(e)}")
99
+
100
+ def process(self, input_data: SambanovaInput) -> SambanovaOutput:
101
+ """Process the input and return the complete response."""
102
+ try:
103
+ if input_data.stream:
104
+ response_chunks = []
105
+ for chunk in self.process_stream(input_data):
106
+ response_chunks.append(chunk)
107
+ response = "".join(response_chunks)
108
+ usage = {} # Usage stats not available in streaming
109
+ else:
110
+ messages = self._build_messages(input_data)
111
+ response = self.client.chat.completions.create(
112
+ model=input_data.model,
113
+ messages=messages,
114
+ temperature=input_data.temperature,
115
+ max_tokens=input_data.max_tokens,
116
+ top_p=input_data.top_p,
117
+ )
118
+ usage = (
119
+ response.usage.model_dump() if hasattr(response, "usage") else {}
120
+ )
121
+
122
+ return SambanovaOutput(
123
+ response=response.choices[0].message.content,
124
+ used_model=input_data.model,
125
+ usage=usage,
126
+ )
127
+
128
+ except Exception as e:
129
+ raise ProcessingError(f"Sambanova processing failed: {str(e)}")
@@ -0,0 +1,21 @@
1
+ """
2
+ Search integrations for AirTrain.
3
+
4
+ This package provides integrations with various search providers.
5
+ """
6
+
7
+ # Import specific search integrations as needed
8
+ from .exa import (
9
+ ExaCredentials,
10
+ ExaSearchInputSchema,
11
+ ExaSearchOutputSchema,
12
+ ExaSearchSkill,
13
+ )
14
+
15
+ __all__ = [
16
+ # Exa Search
17
+ "ExaCredentials",
18
+ "ExaSearchInputSchema",
19
+ "ExaSearchOutputSchema",
20
+ "ExaSearchSkill",
21
+ ]
@@ -0,0 +1,23 @@
1
+ """
2
+ Exa Search API integration.
3
+
4
+ This module provides integration with the Exa search API for web searching capabilities.
5
+ """
6
+
7
+ from .credentials import ExaCredentials
8
+ from .schemas import (
9
+ ExaSearchInputSchema,
10
+ ExaSearchOutputSchema,
11
+ ExaContentConfig,
12
+ ExaSearchResult,
13
+ )
14
+ from .skills import ExaSearchSkill
15
+
16
+ __all__ = [
17
+ "ExaCredentials",
18
+ "ExaSearchInputSchema",
19
+ "ExaSearchOutputSchema",
20
+ "ExaContentConfig",
21
+ "ExaSearchResult",
22
+ "ExaSearchSkill",
23
+ ]
@@ -0,0 +1,30 @@
1
+ """
2
+ Credentials for Exa Search API.
3
+
4
+ This module provides credential management for the Exa search API.
5
+ """
6
+
7
+ from typing import Optional
8
+ from pydantic import Field, SecretStr
9
+
10
+ from airtrain.core.credentials import BaseCredentials
11
+
12
+
13
+ class ExaCredentials(BaseCredentials):
14
+ """Credentials for accessing the Exa search API."""
15
+
16
+ exa_api_key: SecretStr = Field(
17
+ description="Exa search API key",
18
+ )
19
+
20
+ _required_credentials = {"exa_api_key"}
21
+
22
+ async def validate_credentials(self) -> bool:
23
+ """Validate that the required credentials are present and valid."""
24
+ # First check that required credentials are present
25
+ await super().validate_credentials()
26
+
27
+ # In a production environment, we might want to make a test API call here
28
+ # to verify the API key is actually valid, but for now we'll just check
29
+ # that it's present
30
+ return True
@@ -0,0 +1,114 @@
1
+ """
2
+ Schemas for Exa Search API.
3
+
4
+ This module defines the input and output schemas for the Exa search API.
5
+ """
6
+
7
+ from typing import Dict, List, Optional, Any, Union, bool
8
+ from pydantic import BaseModel, Field, HttpUrl
9
+
10
+ from airtrain.core.schemas import InputSchema, OutputSchema
11
+
12
+
13
+ class ExaContentConfig(BaseModel):
14
+ """Configuration for the content to be returned by Exa search."""
15
+
16
+ text: bool = Field(default=True, description="Whether to return text content.")
17
+ extractedText: Optional[bool] = Field(
18
+ default=None, description="Whether to return extracted text content."
19
+ )
20
+ embedded: Optional[bool] = Field(
21
+ default=None, description="Whether to return embedded content."
22
+ )
23
+ links: Optional[bool] = Field(
24
+ default=None, description="Whether to return links from the content."
25
+ )
26
+ screenshot: Optional[bool] = Field(
27
+ default=None, description="Whether to return screenshots of the content."
28
+ )
29
+ highlighted: Optional[bool] = Field(
30
+ default=None, description="Whether to return highlighted text."
31
+ )
32
+
33
+
34
+ class ExaSearchInputSchema(InputSchema):
35
+ """Input schema for Exa search API."""
36
+
37
+ query: str = Field(description="The search query to execute.")
38
+ numResults: Optional[int] = Field(
39
+ default=None, description="Number of results to return."
40
+ )
41
+ contents: Optional[ExaContentConfig] = Field(
42
+ default_factory=ExaContentConfig,
43
+ description="Configuration for the content to be returned.",
44
+ )
45
+ highlights: Optional[dict] = Field(
46
+ default=None, description="Highlighting configuration for search results."
47
+ )
48
+ useAutoprompt: Optional[bool] = Field(
49
+ default=None, description="Whether to use autoprompt for the search."
50
+ )
51
+ type: Optional[str] = Field(default=None, description="Type of search to perform.")
52
+ includeDomains: Optional[List[str]] = Field(
53
+ default=None, description="List of domains to include in the search."
54
+ )
55
+ excludeDomains: Optional[List[str]] = Field(
56
+ default=None, description="List of domains to exclude from the search."
57
+ )
58
+
59
+
60
+ class ExaModerationConfig(BaseModel):
61
+ """Moderation configuration returned in search results."""
62
+
63
+ llamaguardS1: Optional[bool] = None
64
+ llamaguardS3: Optional[bool] = None
65
+ llamaguardS4: Optional[bool] = None
66
+ llamaguardS12: Optional[bool] = None
67
+ domainBlacklisted: Optional[bool] = None
68
+ domainBlacklistedMedia: Optional[bool] = None
69
+
70
+
71
+ class ExaHighlight(BaseModel):
72
+ """Highlight information for a search result."""
73
+
74
+ text: str
75
+ score: float
76
+
77
+
78
+ class ExaSearchResult(BaseModel):
79
+ """Individual search result from Exa."""
80
+
81
+ id: str
82
+ url: str
83
+ title: Optional[str] = None
84
+ text: Optional[str] = None
85
+ extractedText: Optional[str] = None
86
+ embedded: Optional[Dict[str, Any]] = None
87
+ score: float
88
+ published: Optional[str] = None
89
+ author: Optional[str] = None
90
+ highlights: Optional[List[ExaHighlight]] = None
91
+ robotsAllowed: Optional[bool] = None
92
+ moderationConfig: Optional[ExaModerationConfig] = None
93
+ urls: Optional[List[str]] = None
94
+
95
+
96
+ class ExaCostDetails(BaseModel):
97
+ """Cost details for an Exa search request."""
98
+
99
+ total: float
100
+ search: Dict[str, float]
101
+ contents: Dict[str, float]
102
+
103
+
104
+ class ExaSearchOutputSchema(OutputSchema):
105
+ """Output schema for Exa search API."""
106
+
107
+ results: List[ExaSearchResult] = Field(description="List of search results.")
108
+ query: str = Field(description="The original search query.")
109
+ autopromptString: Optional[str] = Field(
110
+ default=None, description="Autoprompt string used for the search if enabled."
111
+ )
112
+ costDollars: Optional[ExaCostDetails] = Field(
113
+ default=None, description="Cost details for the search request."
114
+ )