airtrain 0.1.2__py3-none-any.whl → 0.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airtrain/__init__.py +148 -2
- airtrain/__main__.py +4 -0
- airtrain/__pycache__/__init__.cpython-313.pyc +0 -0
- airtrain/agents/__init__.py +45 -0
- airtrain/agents/example_agent.py +348 -0
- airtrain/agents/groq_agent.py +289 -0
- airtrain/agents/memory.py +663 -0
- airtrain/agents/registry.py +465 -0
- airtrain/builder/__init__.py +3 -0
- airtrain/builder/agent_builder.py +122 -0
- airtrain/cli/__init__.py +0 -0
- airtrain/cli/builder.py +23 -0
- airtrain/cli/main.py +120 -0
- airtrain/contrib/__init__.py +29 -0
- airtrain/contrib/travel/__init__.py +35 -0
- airtrain/contrib/travel/agents.py +243 -0
- airtrain/contrib/travel/models.py +59 -0
- airtrain/core/__init__.py +7 -0
- airtrain/core/__pycache__/__init__.cpython-313.pyc +0 -0
- airtrain/core/__pycache__/schemas.cpython-313.pyc +0 -0
- airtrain/core/__pycache__/skills.cpython-313.pyc +0 -0
- airtrain/core/credentials.py +171 -0
- airtrain/core/schemas.py +237 -0
- airtrain/core/skills.py +269 -0
- airtrain/integrations/__init__.py +74 -0
- airtrain/integrations/anthropic/__init__.py +33 -0
- airtrain/integrations/anthropic/credentials.py +32 -0
- airtrain/integrations/anthropic/list_models.py +110 -0
- airtrain/integrations/anthropic/models_config.py +100 -0
- airtrain/integrations/anthropic/skills.py +155 -0
- airtrain/integrations/aws/__init__.py +6 -0
- airtrain/integrations/aws/credentials.py +36 -0
- airtrain/integrations/aws/skills.py +98 -0
- airtrain/integrations/cerebras/__init__.py +6 -0
- airtrain/integrations/cerebras/credentials.py +19 -0
- airtrain/integrations/cerebras/skills.py +127 -0
- airtrain/integrations/combined/__init__.py +21 -0
- airtrain/integrations/combined/groq_fireworks_skills.py +126 -0
- airtrain/integrations/combined/list_models_factory.py +210 -0
- airtrain/integrations/fireworks/__init__.py +21 -0
- airtrain/integrations/fireworks/completion_skills.py +147 -0
- airtrain/integrations/fireworks/conversation_manager.py +109 -0
- airtrain/integrations/fireworks/credentials.py +26 -0
- airtrain/integrations/fireworks/list_models.py +128 -0
- airtrain/integrations/fireworks/models.py +139 -0
- airtrain/integrations/fireworks/requests_skills.py +207 -0
- airtrain/integrations/fireworks/skills.py +181 -0
- airtrain/integrations/fireworks/structured_completion_skills.py +175 -0
- airtrain/integrations/fireworks/structured_requests_skills.py +291 -0
- airtrain/integrations/fireworks/structured_skills.py +102 -0
- airtrain/integrations/google/__init__.py +7 -0
- airtrain/integrations/google/credentials.py +58 -0
- airtrain/integrations/google/skills.py +122 -0
- airtrain/integrations/groq/__init__.py +23 -0
- airtrain/integrations/groq/credentials.py +24 -0
- airtrain/integrations/groq/models_config.py +162 -0
- airtrain/integrations/groq/skills.py +201 -0
- airtrain/integrations/ollama/__init__.py +6 -0
- airtrain/integrations/ollama/credentials.py +26 -0
- airtrain/integrations/ollama/skills.py +41 -0
- airtrain/integrations/openai/__init__.py +37 -0
- airtrain/integrations/openai/chinese_assistant.py +42 -0
- airtrain/integrations/openai/credentials.py +39 -0
- airtrain/integrations/openai/list_models.py +112 -0
- airtrain/integrations/openai/models_config.py +224 -0
- airtrain/integrations/openai/skills.py +342 -0
- airtrain/integrations/perplexity/__init__.py +49 -0
- airtrain/integrations/perplexity/credentials.py +43 -0
- airtrain/integrations/perplexity/list_models.py +112 -0
- airtrain/integrations/perplexity/models_config.py +128 -0
- airtrain/integrations/perplexity/skills.py +279 -0
- airtrain/integrations/sambanova/__init__.py +6 -0
- airtrain/integrations/sambanova/credentials.py +20 -0
- airtrain/integrations/sambanova/skills.py +129 -0
- airtrain/integrations/search/__init__.py +21 -0
- airtrain/integrations/search/exa/__init__.py +23 -0
- airtrain/integrations/search/exa/credentials.py +30 -0
- airtrain/integrations/search/exa/schemas.py +114 -0
- airtrain/integrations/search/exa/skills.py +115 -0
- airtrain/integrations/together/__init__.py +33 -0
- airtrain/integrations/together/audio_models_config.py +34 -0
- airtrain/integrations/together/credentials.py +22 -0
- airtrain/integrations/together/embedding_models_config.py +92 -0
- airtrain/integrations/together/image_models_config.py +69 -0
- airtrain/integrations/together/image_skill.py +143 -0
- airtrain/integrations/together/list_models.py +76 -0
- airtrain/integrations/together/models.py +95 -0
- airtrain/integrations/together/models_config.py +399 -0
- airtrain/integrations/together/rerank_models_config.py +43 -0
- airtrain/integrations/together/rerank_skill.py +49 -0
- airtrain/integrations/together/schemas.py +33 -0
- airtrain/integrations/together/skills.py +305 -0
- airtrain/integrations/together/vision_models_config.py +49 -0
- airtrain/telemetry/__init__.py +38 -0
- airtrain/telemetry/service.py +167 -0
- airtrain/telemetry/views.py +237 -0
- airtrain/tools/__init__.py +45 -0
- airtrain/tools/command.py +398 -0
- airtrain/tools/filesystem.py +166 -0
- airtrain/tools/network.py +111 -0
- airtrain/tools/registry.py +320 -0
- airtrain/tools/search.py +450 -0
- airtrain/tools/testing.py +135 -0
- airtrain-0.1.4.dist-info/METADATA +222 -0
- airtrain-0.1.4.dist-info/RECORD +108 -0
- {airtrain-0.1.2.dist-info → airtrain-0.1.4.dist-info}/WHEEL +1 -1
- airtrain-0.1.4.dist-info/entry_points.txt +2 -0
- airtrain-0.1.2.dist-info/METADATA +0 -106
- airtrain-0.1.2.dist-info/RECORD +0 -5
- {airtrain-0.1.2.dist-info → airtrain-0.1.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,279 @@
|
|
1
|
+
from typing import Dict, Any, List, Optional, Generator, Union
|
2
|
+
from pydantic import Field, validator
|
3
|
+
import requests
|
4
|
+
|
5
|
+
from airtrain.core.skills import Skill, ProcessingError
|
6
|
+
from airtrain.core.schemas import InputSchema, OutputSchema
|
7
|
+
from .credentials import PerplexityCredentials
|
8
|
+
from .models_config import get_model_config, get_default_model
|
9
|
+
|
10
|
+
|
11
|
+
class PerplexityInput(InputSchema):
|
12
|
+
"""Schema for Perplexity AI chat input"""
|
13
|
+
|
14
|
+
user_input: str = Field(..., description="User's input text")
|
15
|
+
system_prompt: Optional[str] = Field(
|
16
|
+
default=None,
|
17
|
+
description="System prompt to guide the model's behavior",
|
18
|
+
)
|
19
|
+
conversation_history: List[Dict[str, str]] = Field(
|
20
|
+
default_factory=list,
|
21
|
+
description="List of previous conversation messages in [{'role': 'user|assistant', 'content': 'message'}] format",
|
22
|
+
)
|
23
|
+
model: str = Field(
|
24
|
+
default="sonar-pro",
|
25
|
+
description="Perplexity AI model to use",
|
26
|
+
)
|
27
|
+
temperature: Optional[float] = Field(
|
28
|
+
default=0.7, description="Temperature for response generation", ge=0, le=1
|
29
|
+
)
|
30
|
+
max_tokens: Optional[int] = Field(
|
31
|
+
default=500, description="Maximum tokens in response"
|
32
|
+
)
|
33
|
+
top_p: Optional[float] = Field(
|
34
|
+
default=1.0, description="Top-p (nucleus) sampling parameter", ge=0, le=1
|
35
|
+
)
|
36
|
+
top_k: Optional[int] = Field(
|
37
|
+
default=None,
|
38
|
+
description="Top-k sampling parameter",
|
39
|
+
)
|
40
|
+
presence_penalty: Optional[float] = Field(
|
41
|
+
default=None,
|
42
|
+
description="Presence penalty parameter",
|
43
|
+
)
|
44
|
+
frequency_penalty: Optional[float] = Field(
|
45
|
+
default=None,
|
46
|
+
description="Frequency penalty parameter",
|
47
|
+
)
|
48
|
+
|
49
|
+
@validator("model")
|
50
|
+
def validate_model(cls, v):
|
51
|
+
"""Validate that the model is supported by Perplexity AI."""
|
52
|
+
try:
|
53
|
+
get_model_config(v)
|
54
|
+
return v
|
55
|
+
except ValueError as e:
|
56
|
+
raise ValueError(f"Invalid Perplexity AI model: {v}. {str(e)}")
|
57
|
+
|
58
|
+
|
59
|
+
class PerplexityCitation(OutputSchema):
|
60
|
+
"""Schema for Perplexity AI citation information"""
|
61
|
+
|
62
|
+
url: str = Field(..., description="URL of the citation source")
|
63
|
+
title: Optional[str] = Field(None, description="Title of the cited source")
|
64
|
+
snippet: Optional[str] = Field(None, description="Text snippet from the citation")
|
65
|
+
|
66
|
+
|
67
|
+
class PerplexityOutput(OutputSchema):
|
68
|
+
"""Schema for Perplexity AI chat output"""
|
69
|
+
|
70
|
+
response: str = Field(..., description="Model's response text")
|
71
|
+
used_model: str = Field(..., description="Model used for generation")
|
72
|
+
usage: Dict[str, int] = Field(..., description="Usage statistics from the API")
|
73
|
+
citations: Optional[List[PerplexityCitation]] = Field(
|
74
|
+
default=None, description="Citations used in the response, if available"
|
75
|
+
)
|
76
|
+
search_queries: Optional[List[str]] = Field(
|
77
|
+
default=None, description="Search queries used, if available"
|
78
|
+
)
|
79
|
+
|
80
|
+
|
81
|
+
class PerplexityChatSkill(Skill[PerplexityInput, PerplexityOutput]):
|
82
|
+
"""Skill for interacting with Perplexity AI models"""
|
83
|
+
|
84
|
+
input_schema = PerplexityInput
|
85
|
+
output_schema = PerplexityOutput
|
86
|
+
|
87
|
+
def __init__(self, credentials: Optional[PerplexityCredentials] = None):
|
88
|
+
"""Initialize the skill with optional credentials"""
|
89
|
+
super().__init__()
|
90
|
+
self.credentials = credentials or PerplexityCredentials.from_env()
|
91
|
+
self.api_url = "https://api.perplexity.ai/chat/completions"
|
92
|
+
|
93
|
+
def _build_messages(self, input_data: PerplexityInput) -> List[Dict[str, str]]:
|
94
|
+
"""Build messages list from input data including conversation history."""
|
95
|
+
messages = []
|
96
|
+
|
97
|
+
# Add system prompt if provided
|
98
|
+
if input_data.system_prompt:
|
99
|
+
messages.append({"role": "system", "content": input_data.system_prompt})
|
100
|
+
|
101
|
+
# Add conversation history
|
102
|
+
if input_data.conversation_history:
|
103
|
+
messages.extend(input_data.conversation_history)
|
104
|
+
|
105
|
+
# Add current user input
|
106
|
+
messages.append({"role": "user", "content": input_data.user_input})
|
107
|
+
|
108
|
+
return messages
|
109
|
+
|
110
|
+
def _prepare_api_parameters(self, input_data: PerplexityInput) -> Dict[str, Any]:
|
111
|
+
"""Prepare parameters for the API request."""
|
112
|
+
parameters = {
|
113
|
+
"model": input_data.model,
|
114
|
+
"messages": self._build_messages(input_data),
|
115
|
+
"max_tokens": input_data.max_tokens,
|
116
|
+
}
|
117
|
+
|
118
|
+
# Add optional parameters if provided
|
119
|
+
if input_data.temperature is not None:
|
120
|
+
parameters["temperature"] = input_data.temperature
|
121
|
+
|
122
|
+
if input_data.top_p is not None:
|
123
|
+
parameters["top_p"] = input_data.top_p
|
124
|
+
|
125
|
+
if input_data.top_k is not None:
|
126
|
+
parameters["top_k"] = input_data.top_k
|
127
|
+
|
128
|
+
if input_data.presence_penalty is not None:
|
129
|
+
parameters["presence_penalty"] = input_data.presence_penalty
|
130
|
+
|
131
|
+
if input_data.frequency_penalty is not None:
|
132
|
+
parameters["frequency_penalty"] = input_data.frequency_penalty
|
133
|
+
|
134
|
+
return parameters
|
135
|
+
|
136
|
+
def process(self, input_data: PerplexityInput) -> PerplexityOutput:
|
137
|
+
"""Process the input and return the complete response."""
|
138
|
+
try:
|
139
|
+
# Prepare headers with API key
|
140
|
+
headers = {
|
141
|
+
"Authorization": f"Bearer {self.credentials.perplexity_api_key.get_secret_value()}",
|
142
|
+
"Content-Type": "application/json",
|
143
|
+
}
|
144
|
+
|
145
|
+
# Prepare parameters for the API request
|
146
|
+
data = self._prepare_api_parameters(input_data)
|
147
|
+
|
148
|
+
# Make the API request
|
149
|
+
response = requests.post(self.api_url, headers=headers, json=data)
|
150
|
+
|
151
|
+
# Check if request was successful
|
152
|
+
if response.status_code != 200:
|
153
|
+
raise ProcessingError(
|
154
|
+
f"Perplexity AI API error: {response.status_code} - {response.text}"
|
155
|
+
)
|
156
|
+
|
157
|
+
# Parse the response
|
158
|
+
result = response.json()
|
159
|
+
|
160
|
+
# Extract content from the completion
|
161
|
+
content = result["choices"][0]["message"]["content"]
|
162
|
+
|
163
|
+
# Extract and process citations if available
|
164
|
+
citations = None
|
165
|
+
if "citations" in result:
|
166
|
+
citations = [
|
167
|
+
PerplexityCitation(
|
168
|
+
url=citation.get("url", ""),
|
169
|
+
title=citation.get("title"),
|
170
|
+
snippet=citation.get("snippet"),
|
171
|
+
)
|
172
|
+
for citation in result.get("citations", [])
|
173
|
+
]
|
174
|
+
|
175
|
+
# Extract search queries if available
|
176
|
+
search_queries = None
|
177
|
+
if "usage" in result and "num_search_queries" in result["usage"]:
|
178
|
+
search_queries = result.get("search_queries", [])
|
179
|
+
|
180
|
+
# Create and return output
|
181
|
+
return PerplexityOutput(
|
182
|
+
response=content,
|
183
|
+
used_model=input_data.model,
|
184
|
+
usage=result.get("usage", {}),
|
185
|
+
citations=citations,
|
186
|
+
search_queries=search_queries,
|
187
|
+
)
|
188
|
+
|
189
|
+
except Exception as e:
|
190
|
+
if isinstance(e, ProcessingError):
|
191
|
+
raise e
|
192
|
+
raise ProcessingError(f"Perplexity AI processing failed: {str(e)}")
|
193
|
+
|
194
|
+
|
195
|
+
class PerplexityProcessStreamError(Exception):
|
196
|
+
"""Error raised during stream processing"""
|
197
|
+
|
198
|
+
pass
|
199
|
+
|
200
|
+
|
201
|
+
class PerplexityStreamOutput(OutputSchema):
|
202
|
+
"""Schema for streaming output tokens"""
|
203
|
+
|
204
|
+
token: str = Field(..., description="Text token")
|
205
|
+
finish_reason: Optional[str] = Field(
|
206
|
+
None, description="Why the completion finished"
|
207
|
+
)
|
208
|
+
|
209
|
+
|
210
|
+
class PerplexityStreamingChatSkill(PerplexityChatSkill):
|
211
|
+
"""Extension of PerplexityChatSkill that supports streaming responses"""
|
212
|
+
|
213
|
+
def process_stream(
|
214
|
+
self, input_data: PerplexityInput
|
215
|
+
) -> Generator[PerplexityStreamOutput, None, None]:
|
216
|
+
"""
|
217
|
+
Process the input and stream the response tokens.
|
218
|
+
|
219
|
+
Note: Perplexity AI API may not support true streaming. In that case, this
|
220
|
+
method will make a regular API call and yield the entire response at once.
|
221
|
+
"""
|
222
|
+
try:
|
223
|
+
# Prepare headers with API key
|
224
|
+
headers = {
|
225
|
+
"Authorization": f"Bearer {self.credentials.perplexity_api_key.get_secret_value()}",
|
226
|
+
"Content-Type": "application/json",
|
227
|
+
}
|
228
|
+
|
229
|
+
# Prepare parameters for the API request, including stream=true if possible
|
230
|
+
data = self._prepare_api_parameters(input_data)
|
231
|
+
data["stream"] = True
|
232
|
+
|
233
|
+
# Make the API request
|
234
|
+
response = requests.post(
|
235
|
+
self.api_url, headers=headers, json=data, stream=True
|
236
|
+
)
|
237
|
+
|
238
|
+
# Check if request was successful
|
239
|
+
if response.status_code != 200:
|
240
|
+
raise PerplexityProcessStreamError(
|
241
|
+
f"Perplexity AI API error: {response.status_code} - {response.text}"
|
242
|
+
)
|
243
|
+
|
244
|
+
# Process the streaming response if supported
|
245
|
+
for line in response.iter_lines():
|
246
|
+
if line:
|
247
|
+
# Parse the response line
|
248
|
+
try:
|
249
|
+
# Remove 'data: ' prefix if present
|
250
|
+
if line.startswith(b"data: "):
|
251
|
+
line = line[6:]
|
252
|
+
|
253
|
+
# Parse JSON
|
254
|
+
import json
|
255
|
+
|
256
|
+
chunk = json.loads(line)
|
257
|
+
|
258
|
+
# Extract content
|
259
|
+
if "choices" in chunk and len(chunk["choices"]) > 0:
|
260
|
+
choice = chunk["choices"][0]
|
261
|
+
if "delta" in choice and "content" in choice["delta"]:
|
262
|
+
content = choice["delta"]["content"]
|
263
|
+
if content:
|
264
|
+
yield PerplexityStreamOutput(
|
265
|
+
token=content,
|
266
|
+
finish_reason=choice.get("finish_reason"),
|
267
|
+
)
|
268
|
+
except json.JSONDecodeError:
|
269
|
+
# Skip non-JSON lines
|
270
|
+
continue
|
271
|
+
except Exception as e:
|
272
|
+
raise PerplexityProcessStreamError(
|
273
|
+
f"Error processing stream chunk: {str(e)}"
|
274
|
+
)
|
275
|
+
|
276
|
+
except Exception as e:
|
277
|
+
if isinstance(e, PerplexityProcessStreamError):
|
278
|
+
raise ProcessingError(str(e))
|
279
|
+
raise ProcessingError(f"Perplexity AI streaming failed: {str(e)}")
|
@@ -0,0 +1,20 @@
|
|
1
|
+
from pydantic import Field, SecretStr, HttpUrl
|
2
|
+
from airtrain.core.credentials import BaseCredentials, CredentialValidationError
|
3
|
+
|
4
|
+
|
5
|
+
class SambanovaCredentials(BaseCredentials):
|
6
|
+
"""SambaNova credentials"""
|
7
|
+
|
8
|
+
sambanova_api_key: SecretStr = Field(..., description="SambaNova API key")
|
9
|
+
sambanova_endpoint_url: HttpUrl = Field(..., description="SambaNova API endpoint")
|
10
|
+
|
11
|
+
_required_credentials = {"sambanova_api_key", "sambanova_endpoint_url"}
|
12
|
+
|
13
|
+
async def validate_credentials(self) -> bool:
|
14
|
+
"""Validate SambaNova credentials"""
|
15
|
+
try:
|
16
|
+
# Implement SambaNova-specific validation
|
17
|
+
# This would depend on their API client implementation
|
18
|
+
return True
|
19
|
+
except Exception as e:
|
20
|
+
raise CredentialValidationError(f"Invalid SambaNova credentials: {str(e)}")
|
@@ -0,0 +1,129 @@
|
|
1
|
+
from typing import Optional, Dict, Any, List, Generator
|
2
|
+
from pydantic import Field
|
3
|
+
from airtrain.core.skills import Skill, ProcessingError
|
4
|
+
from airtrain.core.schemas import InputSchema, OutputSchema
|
5
|
+
from .credentials import SambanovaCredentials
|
6
|
+
import openai
|
7
|
+
|
8
|
+
|
9
|
+
class SambanovaInput(InputSchema):
|
10
|
+
"""Schema for Sambanova input"""
|
11
|
+
|
12
|
+
user_input: str = Field(..., description="User's input text")
|
13
|
+
system_prompt: str = Field(
|
14
|
+
default="You are a helpful assistant.",
|
15
|
+
description="System prompt to guide the model's behavior",
|
16
|
+
)
|
17
|
+
conversation_history: List[Dict[str, str]] = Field(
|
18
|
+
default_factory=list,
|
19
|
+
description="List of previous conversation messages in [{'role': 'user|assistant', 'content': 'message'}] format",
|
20
|
+
)
|
21
|
+
model: str = Field(
|
22
|
+
default="DeepSeek-R1-Distill-Llama-70B", description="Sambanova model to use"
|
23
|
+
)
|
24
|
+
max_tokens: int = Field(default=1024, description="Maximum tokens in response")
|
25
|
+
temperature: float = Field(
|
26
|
+
default=0.7, description="Temperature for response generation", ge=0, le=1
|
27
|
+
)
|
28
|
+
top_p: float = Field(
|
29
|
+
default=0.1, description="Top p sampling parameter", ge=0, le=1
|
30
|
+
)
|
31
|
+
stream: bool = Field(
|
32
|
+
default=False, description="Whether to stream the response progressively"
|
33
|
+
)
|
34
|
+
|
35
|
+
|
36
|
+
class SambanovaOutput(OutputSchema):
|
37
|
+
"""Schema for Sambanova output"""
|
38
|
+
|
39
|
+
response: str = Field(..., description="Model's response text")
|
40
|
+
used_model: str = Field(..., description="Model used for generation")
|
41
|
+
usage: Dict[str, Any] = Field(default_factory=dict, description="Usage statistics")
|
42
|
+
|
43
|
+
|
44
|
+
class SambanovaChatSkill(Skill[SambanovaInput, SambanovaOutput]):
|
45
|
+
"""Skill for Sambanova chat"""
|
46
|
+
|
47
|
+
input_schema = SambanovaInput
|
48
|
+
output_schema = SambanovaOutput
|
49
|
+
|
50
|
+
def __init__(self, credentials: Optional[SambanovaCredentials] = None):
|
51
|
+
super().__init__()
|
52
|
+
self.credentials = credentials or SambanovaCredentials.from_env()
|
53
|
+
self.client = openai.OpenAI(
|
54
|
+
api_key=self.credentials.sambanova_api_key.get_secret_value(),
|
55
|
+
base_url="https://api.sambanova.ai/v1",
|
56
|
+
)
|
57
|
+
|
58
|
+
def _build_messages(self, input_data: SambanovaInput) -> List[Dict[str, str]]:
|
59
|
+
"""
|
60
|
+
Build messages list from input data including conversation history.
|
61
|
+
|
62
|
+
Args:
|
63
|
+
input_data: The input data containing system prompt, conversation history, and user input
|
64
|
+
|
65
|
+
Returns:
|
66
|
+
List[Dict[str, str]]: List of messages in the format required by Sambanova
|
67
|
+
"""
|
68
|
+
messages = [{"role": "system", "content": input_data.system_prompt}]
|
69
|
+
|
70
|
+
# Add conversation history if present
|
71
|
+
if input_data.conversation_history:
|
72
|
+
messages.extend(input_data.conversation_history)
|
73
|
+
|
74
|
+
# Add current user input
|
75
|
+
messages.append({"role": "user", "content": input_data.user_input})
|
76
|
+
|
77
|
+
return messages
|
78
|
+
|
79
|
+
def process_stream(self, input_data: SambanovaInput) -> Generator[str, None, None]:
|
80
|
+
"""Process the input and stream the response token by token."""
|
81
|
+
try:
|
82
|
+
messages = self._build_messages(input_data)
|
83
|
+
|
84
|
+
stream = self.client.chat.completions.create(
|
85
|
+
model=input_data.model,
|
86
|
+
messages=messages,
|
87
|
+
temperature=input_data.temperature,
|
88
|
+
max_tokens=input_data.max_tokens,
|
89
|
+
top_p=input_data.top_p,
|
90
|
+
stream=True,
|
91
|
+
)
|
92
|
+
|
93
|
+
for chunk in stream:
|
94
|
+
if chunk.choices[0].delta.content is not None:
|
95
|
+
yield chunk.choices[0].delta.content
|
96
|
+
|
97
|
+
except Exception as e:
|
98
|
+
raise ProcessingError(f"Sambanova streaming failed: {str(e)}")
|
99
|
+
|
100
|
+
def process(self, input_data: SambanovaInput) -> SambanovaOutput:
|
101
|
+
"""Process the input and return the complete response."""
|
102
|
+
try:
|
103
|
+
if input_data.stream:
|
104
|
+
response_chunks = []
|
105
|
+
for chunk in self.process_stream(input_data):
|
106
|
+
response_chunks.append(chunk)
|
107
|
+
response = "".join(response_chunks)
|
108
|
+
usage = {} # Usage stats not available in streaming
|
109
|
+
else:
|
110
|
+
messages = self._build_messages(input_data)
|
111
|
+
response = self.client.chat.completions.create(
|
112
|
+
model=input_data.model,
|
113
|
+
messages=messages,
|
114
|
+
temperature=input_data.temperature,
|
115
|
+
max_tokens=input_data.max_tokens,
|
116
|
+
top_p=input_data.top_p,
|
117
|
+
)
|
118
|
+
usage = (
|
119
|
+
response.usage.model_dump() if hasattr(response, "usage") else {}
|
120
|
+
)
|
121
|
+
|
122
|
+
return SambanovaOutput(
|
123
|
+
response=response.choices[0].message.content,
|
124
|
+
used_model=input_data.model,
|
125
|
+
usage=usage,
|
126
|
+
)
|
127
|
+
|
128
|
+
except Exception as e:
|
129
|
+
raise ProcessingError(f"Sambanova processing failed: {str(e)}")
|
@@ -0,0 +1,21 @@
|
|
1
|
+
"""
|
2
|
+
Search integrations for AirTrain.
|
3
|
+
|
4
|
+
This package provides integrations with various search providers.
|
5
|
+
"""
|
6
|
+
|
7
|
+
# Import specific search integrations as needed
|
8
|
+
from .exa import (
|
9
|
+
ExaCredentials,
|
10
|
+
ExaSearchInputSchema,
|
11
|
+
ExaSearchOutputSchema,
|
12
|
+
ExaSearchSkill,
|
13
|
+
)
|
14
|
+
|
15
|
+
__all__ = [
|
16
|
+
# Exa Search
|
17
|
+
"ExaCredentials",
|
18
|
+
"ExaSearchInputSchema",
|
19
|
+
"ExaSearchOutputSchema",
|
20
|
+
"ExaSearchSkill",
|
21
|
+
]
|
@@ -0,0 +1,23 @@
|
|
1
|
+
"""
|
2
|
+
Exa Search API integration.
|
3
|
+
|
4
|
+
This module provides integration with the Exa search API for web searching capabilities.
|
5
|
+
"""
|
6
|
+
|
7
|
+
from .credentials import ExaCredentials
|
8
|
+
from .schemas import (
|
9
|
+
ExaSearchInputSchema,
|
10
|
+
ExaSearchOutputSchema,
|
11
|
+
ExaContentConfig,
|
12
|
+
ExaSearchResult,
|
13
|
+
)
|
14
|
+
from .skills import ExaSearchSkill
|
15
|
+
|
16
|
+
__all__ = [
|
17
|
+
"ExaCredentials",
|
18
|
+
"ExaSearchInputSchema",
|
19
|
+
"ExaSearchOutputSchema",
|
20
|
+
"ExaContentConfig",
|
21
|
+
"ExaSearchResult",
|
22
|
+
"ExaSearchSkill",
|
23
|
+
]
|
@@ -0,0 +1,30 @@
|
|
1
|
+
"""
|
2
|
+
Credentials for Exa Search API.
|
3
|
+
|
4
|
+
This module provides credential management for the Exa search API.
|
5
|
+
"""
|
6
|
+
|
7
|
+
from typing import Optional
|
8
|
+
from pydantic import Field, SecretStr
|
9
|
+
|
10
|
+
from airtrain.core.credentials import BaseCredentials
|
11
|
+
|
12
|
+
|
13
|
+
class ExaCredentials(BaseCredentials):
|
14
|
+
"""Credentials for accessing the Exa search API."""
|
15
|
+
|
16
|
+
exa_api_key: SecretStr = Field(
|
17
|
+
description="Exa search API key",
|
18
|
+
)
|
19
|
+
|
20
|
+
_required_credentials = {"exa_api_key"}
|
21
|
+
|
22
|
+
async def validate_credentials(self) -> bool:
|
23
|
+
"""Validate that the required credentials are present and valid."""
|
24
|
+
# First check that required credentials are present
|
25
|
+
await super().validate_credentials()
|
26
|
+
|
27
|
+
# In a production environment, we might want to make a test API call here
|
28
|
+
# to verify the API key is actually valid, but for now we'll just check
|
29
|
+
# that it's present
|
30
|
+
return True
|
@@ -0,0 +1,114 @@
|
|
1
|
+
"""
|
2
|
+
Schemas for Exa Search API.
|
3
|
+
|
4
|
+
This module defines the input and output schemas for the Exa search API.
|
5
|
+
"""
|
6
|
+
|
7
|
+
from typing import Dict, List, Optional, Any, Union, bool
|
8
|
+
from pydantic import BaseModel, Field, HttpUrl
|
9
|
+
|
10
|
+
from airtrain.core.schemas import InputSchema, OutputSchema
|
11
|
+
|
12
|
+
|
13
|
+
class ExaContentConfig(BaseModel):
|
14
|
+
"""Configuration for the content to be returned by Exa search."""
|
15
|
+
|
16
|
+
text: bool = Field(default=True, description="Whether to return text content.")
|
17
|
+
extractedText: Optional[bool] = Field(
|
18
|
+
default=None, description="Whether to return extracted text content."
|
19
|
+
)
|
20
|
+
embedded: Optional[bool] = Field(
|
21
|
+
default=None, description="Whether to return embedded content."
|
22
|
+
)
|
23
|
+
links: Optional[bool] = Field(
|
24
|
+
default=None, description="Whether to return links from the content."
|
25
|
+
)
|
26
|
+
screenshot: Optional[bool] = Field(
|
27
|
+
default=None, description="Whether to return screenshots of the content."
|
28
|
+
)
|
29
|
+
highlighted: Optional[bool] = Field(
|
30
|
+
default=None, description="Whether to return highlighted text."
|
31
|
+
)
|
32
|
+
|
33
|
+
|
34
|
+
class ExaSearchInputSchema(InputSchema):
|
35
|
+
"""Input schema for Exa search API."""
|
36
|
+
|
37
|
+
query: str = Field(description="The search query to execute.")
|
38
|
+
numResults: Optional[int] = Field(
|
39
|
+
default=None, description="Number of results to return."
|
40
|
+
)
|
41
|
+
contents: Optional[ExaContentConfig] = Field(
|
42
|
+
default_factory=ExaContentConfig,
|
43
|
+
description="Configuration for the content to be returned.",
|
44
|
+
)
|
45
|
+
highlights: Optional[dict] = Field(
|
46
|
+
default=None, description="Highlighting configuration for search results."
|
47
|
+
)
|
48
|
+
useAutoprompt: Optional[bool] = Field(
|
49
|
+
default=None, description="Whether to use autoprompt for the search."
|
50
|
+
)
|
51
|
+
type: Optional[str] = Field(default=None, description="Type of search to perform.")
|
52
|
+
includeDomains: Optional[List[str]] = Field(
|
53
|
+
default=None, description="List of domains to include in the search."
|
54
|
+
)
|
55
|
+
excludeDomains: Optional[List[str]] = Field(
|
56
|
+
default=None, description="List of domains to exclude from the search."
|
57
|
+
)
|
58
|
+
|
59
|
+
|
60
|
+
class ExaModerationConfig(BaseModel):
|
61
|
+
"""Moderation configuration returned in search results."""
|
62
|
+
|
63
|
+
llamaguardS1: Optional[bool] = None
|
64
|
+
llamaguardS3: Optional[bool] = None
|
65
|
+
llamaguardS4: Optional[bool] = None
|
66
|
+
llamaguardS12: Optional[bool] = None
|
67
|
+
domainBlacklisted: Optional[bool] = None
|
68
|
+
domainBlacklistedMedia: Optional[bool] = None
|
69
|
+
|
70
|
+
|
71
|
+
class ExaHighlight(BaseModel):
|
72
|
+
"""Highlight information for a search result."""
|
73
|
+
|
74
|
+
text: str
|
75
|
+
score: float
|
76
|
+
|
77
|
+
|
78
|
+
class ExaSearchResult(BaseModel):
|
79
|
+
"""Individual search result from Exa."""
|
80
|
+
|
81
|
+
id: str
|
82
|
+
url: str
|
83
|
+
title: Optional[str] = None
|
84
|
+
text: Optional[str] = None
|
85
|
+
extractedText: Optional[str] = None
|
86
|
+
embedded: Optional[Dict[str, Any]] = None
|
87
|
+
score: float
|
88
|
+
published: Optional[str] = None
|
89
|
+
author: Optional[str] = None
|
90
|
+
highlights: Optional[List[ExaHighlight]] = None
|
91
|
+
robotsAllowed: Optional[bool] = None
|
92
|
+
moderationConfig: Optional[ExaModerationConfig] = None
|
93
|
+
urls: Optional[List[str]] = None
|
94
|
+
|
95
|
+
|
96
|
+
class ExaCostDetails(BaseModel):
|
97
|
+
"""Cost details for an Exa search request."""
|
98
|
+
|
99
|
+
total: float
|
100
|
+
search: Dict[str, float]
|
101
|
+
contents: Dict[str, float]
|
102
|
+
|
103
|
+
|
104
|
+
class ExaSearchOutputSchema(OutputSchema):
|
105
|
+
"""Output schema for Exa search API."""
|
106
|
+
|
107
|
+
results: List[ExaSearchResult] = Field(description="List of search results.")
|
108
|
+
query: str = Field(description="The original search query.")
|
109
|
+
autopromptString: Optional[str] = Field(
|
110
|
+
default=None, description="Autoprompt string used for the search if enabled."
|
111
|
+
)
|
112
|
+
costDollars: Optional[ExaCostDetails] = Field(
|
113
|
+
default=None, description="Cost details for the search request."
|
114
|
+
)
|