airtrain 0.1.2__py3-none-any.whl → 0.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airtrain/__init__.py +148 -2
- airtrain/__main__.py +4 -0
- airtrain/__pycache__/__init__.cpython-313.pyc +0 -0
- airtrain/agents/__init__.py +45 -0
- airtrain/agents/example_agent.py +348 -0
- airtrain/agents/groq_agent.py +289 -0
- airtrain/agents/memory.py +663 -0
- airtrain/agents/registry.py +465 -0
- airtrain/builder/__init__.py +3 -0
- airtrain/builder/agent_builder.py +122 -0
- airtrain/cli/__init__.py +0 -0
- airtrain/cli/builder.py +23 -0
- airtrain/cli/main.py +120 -0
- airtrain/contrib/__init__.py +29 -0
- airtrain/contrib/travel/__init__.py +35 -0
- airtrain/contrib/travel/agents.py +243 -0
- airtrain/contrib/travel/models.py +59 -0
- airtrain/core/__init__.py +7 -0
- airtrain/core/__pycache__/__init__.cpython-313.pyc +0 -0
- airtrain/core/__pycache__/schemas.cpython-313.pyc +0 -0
- airtrain/core/__pycache__/skills.cpython-313.pyc +0 -0
- airtrain/core/credentials.py +171 -0
- airtrain/core/schemas.py +237 -0
- airtrain/core/skills.py +269 -0
- airtrain/integrations/__init__.py +74 -0
- airtrain/integrations/anthropic/__init__.py +33 -0
- airtrain/integrations/anthropic/credentials.py +32 -0
- airtrain/integrations/anthropic/list_models.py +110 -0
- airtrain/integrations/anthropic/models_config.py +100 -0
- airtrain/integrations/anthropic/skills.py +155 -0
- airtrain/integrations/aws/__init__.py +6 -0
- airtrain/integrations/aws/credentials.py +36 -0
- airtrain/integrations/aws/skills.py +98 -0
- airtrain/integrations/cerebras/__init__.py +6 -0
- airtrain/integrations/cerebras/credentials.py +19 -0
- airtrain/integrations/cerebras/skills.py +127 -0
- airtrain/integrations/combined/__init__.py +21 -0
- airtrain/integrations/combined/groq_fireworks_skills.py +126 -0
- airtrain/integrations/combined/list_models_factory.py +210 -0
- airtrain/integrations/fireworks/__init__.py +21 -0
- airtrain/integrations/fireworks/completion_skills.py +147 -0
- airtrain/integrations/fireworks/conversation_manager.py +109 -0
- airtrain/integrations/fireworks/credentials.py +26 -0
- airtrain/integrations/fireworks/list_models.py +128 -0
- airtrain/integrations/fireworks/models.py +139 -0
- airtrain/integrations/fireworks/requests_skills.py +207 -0
- airtrain/integrations/fireworks/skills.py +181 -0
- airtrain/integrations/fireworks/structured_completion_skills.py +175 -0
- airtrain/integrations/fireworks/structured_requests_skills.py +291 -0
- airtrain/integrations/fireworks/structured_skills.py +102 -0
- airtrain/integrations/google/__init__.py +7 -0
- airtrain/integrations/google/credentials.py +58 -0
- airtrain/integrations/google/skills.py +122 -0
- airtrain/integrations/groq/__init__.py +23 -0
- airtrain/integrations/groq/credentials.py +24 -0
- airtrain/integrations/groq/models_config.py +162 -0
- airtrain/integrations/groq/skills.py +201 -0
- airtrain/integrations/ollama/__init__.py +6 -0
- airtrain/integrations/ollama/credentials.py +26 -0
- airtrain/integrations/ollama/skills.py +41 -0
- airtrain/integrations/openai/__init__.py +37 -0
- airtrain/integrations/openai/chinese_assistant.py +42 -0
- airtrain/integrations/openai/credentials.py +39 -0
- airtrain/integrations/openai/list_models.py +112 -0
- airtrain/integrations/openai/models_config.py +224 -0
- airtrain/integrations/openai/skills.py +342 -0
- airtrain/integrations/perplexity/__init__.py +49 -0
- airtrain/integrations/perplexity/credentials.py +43 -0
- airtrain/integrations/perplexity/list_models.py +112 -0
- airtrain/integrations/perplexity/models_config.py +128 -0
- airtrain/integrations/perplexity/skills.py +279 -0
- airtrain/integrations/sambanova/__init__.py +6 -0
- airtrain/integrations/sambanova/credentials.py +20 -0
- airtrain/integrations/sambanova/skills.py +129 -0
- airtrain/integrations/search/__init__.py +21 -0
- airtrain/integrations/search/exa/__init__.py +23 -0
- airtrain/integrations/search/exa/credentials.py +30 -0
- airtrain/integrations/search/exa/schemas.py +114 -0
- airtrain/integrations/search/exa/skills.py +115 -0
- airtrain/integrations/together/__init__.py +33 -0
- airtrain/integrations/together/audio_models_config.py +34 -0
- airtrain/integrations/together/credentials.py +22 -0
- airtrain/integrations/together/embedding_models_config.py +92 -0
- airtrain/integrations/together/image_models_config.py +69 -0
- airtrain/integrations/together/image_skill.py +143 -0
- airtrain/integrations/together/list_models.py +76 -0
- airtrain/integrations/together/models.py +95 -0
- airtrain/integrations/together/models_config.py +399 -0
- airtrain/integrations/together/rerank_models_config.py +43 -0
- airtrain/integrations/together/rerank_skill.py +49 -0
- airtrain/integrations/together/schemas.py +33 -0
- airtrain/integrations/together/skills.py +305 -0
- airtrain/integrations/together/vision_models_config.py +49 -0
- airtrain/telemetry/__init__.py +38 -0
- airtrain/telemetry/service.py +167 -0
- airtrain/telemetry/views.py +237 -0
- airtrain/tools/__init__.py +45 -0
- airtrain/tools/command.py +398 -0
- airtrain/tools/filesystem.py +166 -0
- airtrain/tools/network.py +111 -0
- airtrain/tools/registry.py +320 -0
- airtrain/tools/search.py +450 -0
- airtrain/tools/testing.py +135 -0
- airtrain-0.1.4.dist-info/METADATA +222 -0
- airtrain-0.1.4.dist-info/RECORD +108 -0
- {airtrain-0.1.2.dist-info → airtrain-0.1.4.dist-info}/WHEEL +1 -1
- airtrain-0.1.4.dist-info/entry_points.txt +2 -0
- airtrain-0.1.2.dist-info/METADATA +0 -106
- airtrain-0.1.2.dist-info/RECORD +0 -5
- {airtrain-0.1.2.dist-info → airtrain-0.1.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,342 @@
|
|
1
|
+
from typing import AsyncGenerator, List, Optional, Dict, TypeVar, Type, Generator, Union
|
2
|
+
from pydantic import Field, BaseModel
|
3
|
+
from openai import OpenAI, AsyncOpenAI
|
4
|
+
from openai.types.chat import ChatCompletionChunk
|
5
|
+
import numpy as np
|
6
|
+
|
7
|
+
from airtrain.core.skills import Skill, ProcessingError
|
8
|
+
from airtrain.core.schemas import InputSchema, OutputSchema
|
9
|
+
from .credentials import OpenAICredentials
|
10
|
+
|
11
|
+
|
12
|
+
class OpenAIInput(InputSchema):
|
13
|
+
"""Schema for OpenAI chat input"""
|
14
|
+
|
15
|
+
user_input: str = Field(..., description="User's input text")
|
16
|
+
system_prompt: str = Field(
|
17
|
+
default="You are a helpful assistant.",
|
18
|
+
description="System prompt to guide the model's behavior",
|
19
|
+
)
|
20
|
+
conversation_history: List[Dict[str, str]] = Field(
|
21
|
+
default_factory=list,
|
22
|
+
description="List of previous conversation messages in [{'role': 'user|assistant', 'content': 'message'}] format",
|
23
|
+
)
|
24
|
+
model: str = Field(
|
25
|
+
default="gpt-4o",
|
26
|
+
description="OpenAI model to use",
|
27
|
+
)
|
28
|
+
temperature: float = Field(
|
29
|
+
default=0.7, description="Temperature for response generation", ge=0, le=1
|
30
|
+
)
|
31
|
+
max_tokens: Optional[int] = Field(
|
32
|
+
default=131072, description="Maximum tokens in response"
|
33
|
+
)
|
34
|
+
stream: bool = Field(
|
35
|
+
default=False,
|
36
|
+
description="Whether to stream the response token by token",
|
37
|
+
)
|
38
|
+
|
39
|
+
|
40
|
+
class OpenAIOutput(OutputSchema):
|
41
|
+
"""Schema for OpenAI chat output"""
|
42
|
+
|
43
|
+
response: str
|
44
|
+
used_model: str
|
45
|
+
usage: Dict[str, int]
|
46
|
+
|
47
|
+
|
48
|
+
class OpenAIChatSkill(Skill[OpenAIInput, OpenAIOutput]):
|
49
|
+
"""Skill for interacting with OpenAI models with async support"""
|
50
|
+
|
51
|
+
input_schema = OpenAIInput
|
52
|
+
output_schema = OpenAIOutput
|
53
|
+
|
54
|
+
def __init__(self, credentials: Optional[OpenAICredentials] = None):
|
55
|
+
"""Initialize the skill with optional credentials"""
|
56
|
+
super().__init__()
|
57
|
+
self.credentials = credentials or OpenAICredentials.from_env()
|
58
|
+
self.client = OpenAI(
|
59
|
+
api_key=self.credentials.openai_api_key.get_secret_value(),
|
60
|
+
organization=self.credentials.openai_organization_id,
|
61
|
+
)
|
62
|
+
self.async_client = AsyncOpenAI(
|
63
|
+
api_key=self.credentials.openai_api_key.get_secret_value(),
|
64
|
+
organization=self.credentials.openai_organization_id,
|
65
|
+
)
|
66
|
+
|
67
|
+
def _build_messages(self, input_data: OpenAIInput) -> List[Dict[str, str]]:
|
68
|
+
"""Build messages list from input data including conversation history."""
|
69
|
+
messages = [{"role": "system", "content": input_data.system_prompt}]
|
70
|
+
|
71
|
+
if input_data.conversation_history:
|
72
|
+
messages.extend(input_data.conversation_history)
|
73
|
+
|
74
|
+
messages.append({"role": "user", "content": input_data.user_input})
|
75
|
+
return messages
|
76
|
+
|
77
|
+
def process_stream(self, input_data: OpenAIInput) -> Generator[str, None, None]:
|
78
|
+
"""Process the input and stream the response token by token."""
|
79
|
+
try:
|
80
|
+
messages = self._build_messages(input_data)
|
81
|
+
|
82
|
+
stream = self.client.chat.completions.create(
|
83
|
+
model=input_data.model,
|
84
|
+
messages=messages,
|
85
|
+
temperature=input_data.temperature,
|
86
|
+
max_tokens=input_data.max_tokens,
|
87
|
+
stream=True,
|
88
|
+
)
|
89
|
+
|
90
|
+
for chunk in stream:
|
91
|
+
if chunk.choices[0].delta.content is not None:
|
92
|
+
yield chunk.choices[0].delta.content
|
93
|
+
|
94
|
+
except Exception as e:
|
95
|
+
raise ProcessingError(f"OpenAI streaming failed: {str(e)}")
|
96
|
+
|
97
|
+
def process(self, input_data: OpenAIInput) -> OpenAIOutput:
|
98
|
+
"""Process the input and return the complete response."""
|
99
|
+
try:
|
100
|
+
if input_data.stream:
|
101
|
+
# For streaming, collect the entire response
|
102
|
+
response_chunks = []
|
103
|
+
for chunk in self.process_stream(input_data):
|
104
|
+
response_chunks.append(chunk)
|
105
|
+
response = "".join(response_chunks)
|
106
|
+
else:
|
107
|
+
# For non-streaming, use regular completion
|
108
|
+
messages = self._build_messages(input_data)
|
109
|
+
completion = self.client.chat.completions.create(
|
110
|
+
model=input_data.model,
|
111
|
+
messages=messages,
|
112
|
+
temperature=input_data.temperature,
|
113
|
+
max_tokens=input_data.max_tokens,
|
114
|
+
stream=False,
|
115
|
+
)
|
116
|
+
response = completion.choices[0].message.content
|
117
|
+
|
118
|
+
return OpenAIOutput(
|
119
|
+
response=response,
|
120
|
+
used_model=input_data.model,
|
121
|
+
usage={
|
122
|
+
"total_tokens": completion.usage.total_tokens,
|
123
|
+
"prompt_tokens": completion.usage.prompt_tokens,
|
124
|
+
"completion_tokens": completion.usage.completion_tokens,
|
125
|
+
},
|
126
|
+
)
|
127
|
+
|
128
|
+
except Exception as e:
|
129
|
+
raise ProcessingError(f"OpenAI chat failed: {str(e)}")
|
130
|
+
|
131
|
+
async def process_async(self, input_data: OpenAIInput) -> OpenAIOutput:
|
132
|
+
"""Async version of process method"""
|
133
|
+
try:
|
134
|
+
messages = self._build_messages(input_data)
|
135
|
+
completion = await self.async_client.chat.completions.create(
|
136
|
+
model=input_data.model,
|
137
|
+
messages=messages,
|
138
|
+
temperature=input_data.temperature,
|
139
|
+
max_tokens=input_data.max_tokens,
|
140
|
+
)
|
141
|
+
return OpenAIOutput(
|
142
|
+
response=completion.choices[0].message.content,
|
143
|
+
used_model=completion.model,
|
144
|
+
usage={
|
145
|
+
"total_tokens": completion.usage.total_tokens,
|
146
|
+
"prompt_tokens": completion.usage.prompt_tokens,
|
147
|
+
"completion_tokens": completion.usage.completion_tokens,
|
148
|
+
},
|
149
|
+
)
|
150
|
+
except Exception as e:
|
151
|
+
raise ProcessingError(f"OpenAI async chat failed: {str(e)}")
|
152
|
+
|
153
|
+
async def process_stream_async(
|
154
|
+
self, input_data: OpenAIInput
|
155
|
+
) -> AsyncGenerator[str, None]:
|
156
|
+
"""Async version of stream processor"""
|
157
|
+
try:
|
158
|
+
messages = self._build_messages(input_data)
|
159
|
+
stream = await self.async_client.chat.completions.create(
|
160
|
+
model=input_data.model,
|
161
|
+
messages=messages,
|
162
|
+
temperature=input_data.temperature,
|
163
|
+
max_tokens=input_data.max_tokens,
|
164
|
+
stream=True,
|
165
|
+
)
|
166
|
+
async for chunk in stream:
|
167
|
+
if chunk.choices[0].delta.content is not None:
|
168
|
+
yield chunk.choices[0].delta.content
|
169
|
+
except Exception as e:
|
170
|
+
raise ProcessingError(f"OpenAI async streaming failed: {str(e)}")
|
171
|
+
|
172
|
+
|
173
|
+
ResponseT = TypeVar("ResponseT", bound=BaseModel)
|
174
|
+
|
175
|
+
|
176
|
+
class OpenAIParserInput(InputSchema):
|
177
|
+
"""Schema for OpenAI structured output input"""
|
178
|
+
|
179
|
+
user_input: str
|
180
|
+
system_prompt: str = "You are a helpful assistant that provides structured data."
|
181
|
+
model: str = "gpt-4o"
|
182
|
+
temperature: float = 0.7
|
183
|
+
max_tokens: Optional[int] = None
|
184
|
+
response_model: Type[ResponseT]
|
185
|
+
|
186
|
+
class Config:
|
187
|
+
arbitrary_types_allowed = True
|
188
|
+
|
189
|
+
|
190
|
+
class OpenAIParserOutput(OutputSchema):
|
191
|
+
"""Schema for OpenAI structured output"""
|
192
|
+
|
193
|
+
parsed_response: BaseModel
|
194
|
+
used_model: str
|
195
|
+
tokens_used: int
|
196
|
+
|
197
|
+
|
198
|
+
class OpenAIParserSkill(Skill[OpenAIParserInput, OpenAIParserOutput]):
|
199
|
+
"""Skill for getting structured responses from OpenAI"""
|
200
|
+
|
201
|
+
input_schema = OpenAIParserInput
|
202
|
+
output_schema = OpenAIParserOutput
|
203
|
+
|
204
|
+
def __init__(self, credentials: Optional[OpenAICredentials] = None):
|
205
|
+
"""Initialize the skill with optional credentials"""
|
206
|
+
super().__init__()
|
207
|
+
self.credentials = credentials or OpenAICredentials.from_env()
|
208
|
+
self.client = OpenAI(
|
209
|
+
api_key=self.credentials.openai_api_key.get_secret_value(),
|
210
|
+
organization=self.credentials.openai_organization_id,
|
211
|
+
)
|
212
|
+
|
213
|
+
def process(self, input_data: OpenAIParserInput) -> OpenAIParserOutput:
|
214
|
+
try:
|
215
|
+
# Use parse method instead of create
|
216
|
+
completion = self.client.beta.chat.completions.parse(
|
217
|
+
model=input_data.model,
|
218
|
+
messages=[
|
219
|
+
{"role": "system", "content": input_data.system_prompt},
|
220
|
+
{"role": "user", "content": input_data.user_input},
|
221
|
+
],
|
222
|
+
response_format=input_data.response_model,
|
223
|
+
)
|
224
|
+
|
225
|
+
if completion.choices[0].message.parsed is None:
|
226
|
+
raise ProcessingError("Failed to parse response")
|
227
|
+
|
228
|
+
return OpenAIParserOutput(
|
229
|
+
parsed_response=completion.choices[0].message.parsed,
|
230
|
+
used_model=completion.model,
|
231
|
+
tokens_used=completion.usage.total_tokens,
|
232
|
+
)
|
233
|
+
|
234
|
+
except Exception as e:
|
235
|
+
raise ProcessingError(f"OpenAI parsing failed: {str(e)}")
|
236
|
+
|
237
|
+
|
238
|
+
class OpenAIEmbeddingsInput(InputSchema):
|
239
|
+
"""Schema for OpenAI embeddings input"""
|
240
|
+
|
241
|
+
texts: Union[str, List[str]] = Field(
|
242
|
+
..., description="Text or list of texts to generate embeddings for"
|
243
|
+
)
|
244
|
+
model: str = Field(
|
245
|
+
default="text-embedding-3-large", description="OpenAI embeddings model to use"
|
246
|
+
)
|
247
|
+
encoding_format: str = Field(
|
248
|
+
default="float", description="The format of the embeddings: 'float' or 'base64'"
|
249
|
+
)
|
250
|
+
dimensions: Optional[int] = Field(
|
251
|
+
default=None, description="Optional number of dimensions for the embeddings"
|
252
|
+
)
|
253
|
+
|
254
|
+
|
255
|
+
class OpenAIEmbeddingsOutput(OutputSchema):
|
256
|
+
"""Schema for OpenAI embeddings output"""
|
257
|
+
|
258
|
+
embeddings: List[List[float]] = Field(..., description="List of embeddings vectors")
|
259
|
+
used_model: str = Field(..., description="Model used for generating embeddings")
|
260
|
+
tokens_used: int = Field(..., description="Number of tokens used")
|
261
|
+
|
262
|
+
|
263
|
+
class OpenAIEmbeddingsSkill(Skill[OpenAIEmbeddingsInput, OpenAIEmbeddingsOutput]):
|
264
|
+
"""Skill for generating embeddings using OpenAI models"""
|
265
|
+
|
266
|
+
input_schema = OpenAIEmbeddingsInput
|
267
|
+
output_schema = OpenAIEmbeddingsOutput
|
268
|
+
|
269
|
+
def __init__(self, credentials: Optional[OpenAICredentials] = None):
|
270
|
+
"""Initialize the skill with optional credentials"""
|
271
|
+
super().__init__()
|
272
|
+
self.credentials = credentials or OpenAICredentials.from_env()
|
273
|
+
self.client = OpenAI(
|
274
|
+
api_key=self.credentials.openai_api_key.get_secret_value(),
|
275
|
+
organization=self.credentials.openai_organization_id,
|
276
|
+
)
|
277
|
+
self.async_client = AsyncOpenAI(
|
278
|
+
api_key=self.credentials.openai_api_key.get_secret_value(),
|
279
|
+
organization=self.credentials.openai_organization_id,
|
280
|
+
)
|
281
|
+
|
282
|
+
def process(self, input_data: OpenAIEmbeddingsInput) -> OpenAIEmbeddingsOutput:
|
283
|
+
"""Generate embeddings for the input text(s)"""
|
284
|
+
try:
|
285
|
+
# Handle single text input
|
286
|
+
texts = (
|
287
|
+
[input_data.texts]
|
288
|
+
if isinstance(input_data.texts, str)
|
289
|
+
else input_data.texts
|
290
|
+
)
|
291
|
+
|
292
|
+
# Create embeddings
|
293
|
+
response = self.client.embeddings.create(
|
294
|
+
model=input_data.model,
|
295
|
+
input=texts,
|
296
|
+
encoding_format=input_data.encoding_format,
|
297
|
+
dimensions=input_data.dimensions,
|
298
|
+
)
|
299
|
+
|
300
|
+
# Extract embeddings
|
301
|
+
embeddings = [data.embedding for data in response.data]
|
302
|
+
|
303
|
+
return OpenAIEmbeddingsOutput(
|
304
|
+
embeddings=embeddings,
|
305
|
+
used_model=response.model,
|
306
|
+
tokens_used=response.usage.total_tokens,
|
307
|
+
)
|
308
|
+
except Exception as e:
|
309
|
+
raise ProcessingError(f"OpenAI embeddings generation failed: {str(e)}")
|
310
|
+
|
311
|
+
async def process_async(
|
312
|
+
self, input_data: OpenAIEmbeddingsInput
|
313
|
+
) -> OpenAIEmbeddingsOutput:
|
314
|
+
"""Async version of the embeddings generation"""
|
315
|
+
try:
|
316
|
+
# Handle single text input
|
317
|
+
texts = (
|
318
|
+
[input_data.texts]
|
319
|
+
if isinstance(input_data.texts, str)
|
320
|
+
else input_data.texts
|
321
|
+
)
|
322
|
+
|
323
|
+
# Create embeddings
|
324
|
+
response = await self.async_client.embeddings.create(
|
325
|
+
model=input_data.model,
|
326
|
+
input=texts,
|
327
|
+
encoding_format=input_data.encoding_format,
|
328
|
+
dimensions=input_data.dimensions,
|
329
|
+
)
|
330
|
+
|
331
|
+
# Extract embeddings
|
332
|
+
embeddings = [data.embedding for data in response.data]
|
333
|
+
|
334
|
+
return OpenAIEmbeddingsOutput(
|
335
|
+
embeddings=embeddings,
|
336
|
+
used_model=response.model,
|
337
|
+
tokens_used=response.usage.total_tokens,
|
338
|
+
)
|
339
|
+
except Exception as e:
|
340
|
+
raise ProcessingError(
|
341
|
+
f"OpenAI async embeddings generation failed: {str(e)}"
|
342
|
+
)
|
@@ -0,0 +1,49 @@
|
|
1
|
+
"""Perplexity AI integration module"""
|
2
|
+
|
3
|
+
from .credentials import PerplexityCredentials
|
4
|
+
from .skills import (
|
5
|
+
PerplexityInput,
|
6
|
+
PerplexityOutput,
|
7
|
+
PerplexityChatSkill,
|
8
|
+
PerplexityCitation,
|
9
|
+
PerplexityStreamingChatSkill,
|
10
|
+
PerplexityStreamOutput,
|
11
|
+
)
|
12
|
+
from .list_models import (
|
13
|
+
PerplexityListModelsSkill,
|
14
|
+
StandalonePerplexityListModelsSkill,
|
15
|
+
PerplexityListModelsInput,
|
16
|
+
PerplexityListModelsOutput,
|
17
|
+
)
|
18
|
+
from .models_config import (
|
19
|
+
get_model_config,
|
20
|
+
get_default_model,
|
21
|
+
supports_citations,
|
22
|
+
supports_search,
|
23
|
+
get_models_by_category,
|
24
|
+
PERPLEXITY_MODELS_CONFIG,
|
25
|
+
)
|
26
|
+
|
27
|
+
__all__ = [
|
28
|
+
# Credentials
|
29
|
+
"PerplexityCredentials",
|
30
|
+
# Skills
|
31
|
+
"PerplexityInput",
|
32
|
+
"PerplexityOutput",
|
33
|
+
"PerplexityChatSkill",
|
34
|
+
"PerplexityCitation",
|
35
|
+
"PerplexityStreamingChatSkill",
|
36
|
+
"PerplexityStreamOutput",
|
37
|
+
# List Models
|
38
|
+
"PerplexityListModelsSkill",
|
39
|
+
"StandalonePerplexityListModelsSkill",
|
40
|
+
"PerplexityListModelsInput",
|
41
|
+
"PerplexityListModelsOutput",
|
42
|
+
# Model Config
|
43
|
+
"get_model_config",
|
44
|
+
"get_default_model",
|
45
|
+
"supports_citations",
|
46
|
+
"supports_search",
|
47
|
+
"get_models_by_category",
|
48
|
+
"PERPLEXITY_MODELS_CONFIG",
|
49
|
+
]
|
@@ -0,0 +1,43 @@
|
|
1
|
+
from pydantic import Field, SecretStr
|
2
|
+
from airtrain.core.credentials import BaseCredentials, CredentialValidationError
|
3
|
+
import requests
|
4
|
+
|
5
|
+
|
6
|
+
class PerplexityCredentials(BaseCredentials):
|
7
|
+
"""Perplexity AI API credentials"""
|
8
|
+
|
9
|
+
perplexity_api_key: SecretStr = Field(..., description="Perplexity AI API key")
|
10
|
+
|
11
|
+
_required_credentials = {"perplexity_api_key"}
|
12
|
+
|
13
|
+
async def validate_credentials(self) -> bool:
|
14
|
+
"""Validate Perplexity AI credentials by making a test API call"""
|
15
|
+
try:
|
16
|
+
headers = {
|
17
|
+
"Authorization": f"Bearer {self.perplexity_api_key.get_secret_value()}",
|
18
|
+
"Content-Type": "application/json",
|
19
|
+
}
|
20
|
+
|
21
|
+
# Small API call to check if credentials are valid
|
22
|
+
data = {
|
23
|
+
"model": "sonar-pro",
|
24
|
+
"messages": [{"role": "user", "content": "Test"}],
|
25
|
+
"max_tokens": 1,
|
26
|
+
}
|
27
|
+
|
28
|
+
# Make a synchronous request for validation
|
29
|
+
response = requests.post(
|
30
|
+
"https://api.perplexity.ai/chat/completions", headers=headers, json=data
|
31
|
+
)
|
32
|
+
|
33
|
+
if response.status_code == 200:
|
34
|
+
return True
|
35
|
+
else:
|
36
|
+
raise CredentialValidationError(
|
37
|
+
f"Invalid Perplexity AI credentials: {response.status_code} - {response.text}"
|
38
|
+
)
|
39
|
+
|
40
|
+
except Exception as e:
|
41
|
+
raise CredentialValidationError(
|
42
|
+
f"Invalid Perplexity AI credentials: {str(e)}"
|
43
|
+
)
|
@@ -0,0 +1,112 @@
|
|
1
|
+
from typing import Dict, Any, List, Optional
|
2
|
+
import requests
|
3
|
+
|
4
|
+
from airtrain.core.skills import Skill, ProcessingError
|
5
|
+
from airtrain.integrations.combined.list_models_factory import (
|
6
|
+
BaseListModelsSkill,
|
7
|
+
GenericListModelsInput,
|
8
|
+
GenericListModelsOutput,
|
9
|
+
)
|
10
|
+
from airtrain.core.schemas import InputSchema, OutputSchema
|
11
|
+
from .credentials import PerplexityCredentials
|
12
|
+
from .models_config import PERPLEXITY_MODELS_CONFIG
|
13
|
+
|
14
|
+
|
15
|
+
class PerplexityListModelsInput(InputSchema):
|
16
|
+
"""Schema for listing Perplexity AI models"""
|
17
|
+
|
18
|
+
api_models_only: bool = False
|
19
|
+
|
20
|
+
|
21
|
+
class PerplexityListModelsOutput(OutputSchema):
|
22
|
+
"""Schema for Perplexity AI models listing output"""
|
23
|
+
|
24
|
+
models: List[Dict[str, Any]]
|
25
|
+
provider: str = "perplexity"
|
26
|
+
|
27
|
+
|
28
|
+
class PerplexityListModelsSkill(BaseListModelsSkill):
|
29
|
+
"""Skill for listing Perplexity AI models"""
|
30
|
+
|
31
|
+
def __init__(self, credentials: Optional[PerplexityCredentials] = None):
|
32
|
+
"""Initialize the skill with optional credentials"""
|
33
|
+
super().__init__(provider="perplexity", credentials=credentials)
|
34
|
+
self.credentials = credentials
|
35
|
+
|
36
|
+
def get_models(self) -> List[Dict[str, Any]]:
|
37
|
+
"""Return list of Perplexity AI models."""
|
38
|
+
models = []
|
39
|
+
|
40
|
+
# Add models from the configuration
|
41
|
+
for model_id, config in PERPLEXITY_MODELS_CONFIG.items():
|
42
|
+
models.append(
|
43
|
+
{
|
44
|
+
"id": model_id,
|
45
|
+
"display_name": config["name"],
|
46
|
+
"description": config.get("description", ""),
|
47
|
+
"category": config.get("category", "unknown"),
|
48
|
+
"capabilities": {
|
49
|
+
"citations": config.get("citations", False),
|
50
|
+
"search": config.get("search", False),
|
51
|
+
"context_window": config.get("context_window", 8192),
|
52
|
+
"max_completion_tokens": config.get(
|
53
|
+
"max_completion_tokens", 4096
|
54
|
+
),
|
55
|
+
},
|
56
|
+
}
|
57
|
+
)
|
58
|
+
|
59
|
+
return models
|
60
|
+
|
61
|
+
def process(self, input_data: GenericListModelsInput) -> GenericListModelsOutput:
|
62
|
+
"""Process the input and return a list of models."""
|
63
|
+
try:
|
64
|
+
models = self.get_models()
|
65
|
+
return GenericListModelsOutput(models=models, provider="perplexity")
|
66
|
+
except Exception as e:
|
67
|
+
raise ProcessingError(f"Failed to list Perplexity AI models: {str(e)}")
|
68
|
+
|
69
|
+
|
70
|
+
# Standalone version directly using the Perplexity-specific schemas
|
71
|
+
class StandalonePerplexityListModelsSkill(
|
72
|
+
Skill[PerplexityListModelsInput, PerplexityListModelsOutput]
|
73
|
+
):
|
74
|
+
"""Standalone skill for listing Perplexity AI models"""
|
75
|
+
|
76
|
+
input_schema = PerplexityListModelsInput
|
77
|
+
output_schema = PerplexityListModelsOutput
|
78
|
+
|
79
|
+
def __init__(self, credentials: Optional[PerplexityCredentials] = None):
|
80
|
+
"""Initialize the skill with optional credentials"""
|
81
|
+
super().__init__()
|
82
|
+
self.credentials = credentials
|
83
|
+
|
84
|
+
def process(
|
85
|
+
self, input_data: PerplexityListModelsInput
|
86
|
+
) -> PerplexityListModelsOutput:
|
87
|
+
"""Process the input and return a list of models."""
|
88
|
+
try:
|
89
|
+
models = []
|
90
|
+
|
91
|
+
# Add models from the configuration
|
92
|
+
for model_id, config in PERPLEXITY_MODELS_CONFIG.items():
|
93
|
+
models.append(
|
94
|
+
{
|
95
|
+
"id": model_id,
|
96
|
+
"display_name": config["name"],
|
97
|
+
"description": config.get("description", ""),
|
98
|
+
"category": config.get("category", "unknown"),
|
99
|
+
"capabilities": {
|
100
|
+
"citations": config.get("citations", False),
|
101
|
+
"search": config.get("search", False),
|
102
|
+
"context_window": config.get("context_window", 8192),
|
103
|
+
"max_completion_tokens": config.get(
|
104
|
+
"max_completion_tokens", 4096
|
105
|
+
),
|
106
|
+
},
|
107
|
+
}
|
108
|
+
)
|
109
|
+
|
110
|
+
return PerplexityListModelsOutput(models=models, provider="perplexity")
|
111
|
+
except Exception as e:
|
112
|
+
raise ProcessingError(f"Failed to list Perplexity AI models: {str(e)}")
|
@@ -0,0 +1,128 @@
|
|
1
|
+
"""Configuration of Perplexity AI model capabilities."""
|
2
|
+
|
3
|
+
from typing import Dict, Any
|
4
|
+
|
5
|
+
|
6
|
+
# Model configuration with capabilities for each Perplexity AI model
|
7
|
+
PERPLEXITY_MODELS_CONFIG = {
|
8
|
+
# Search Models
|
9
|
+
"sonar-pro": {
|
10
|
+
"name": "Sonar Pro",
|
11
|
+
"description": "Advanced search offering with grounding, supporting complex queries and follow-ups.",
|
12
|
+
"category": "search",
|
13
|
+
"context_window": 8192,
|
14
|
+
"max_completion_tokens": 4096,
|
15
|
+
"citations": True,
|
16
|
+
"search": True,
|
17
|
+
},
|
18
|
+
"sonar": {
|
19
|
+
"name": "Sonar",
|
20
|
+
"description": "Lightweight, cost-effective search model with grounding.",
|
21
|
+
"category": "search",
|
22
|
+
"context_window": 8192,
|
23
|
+
"max_completion_tokens": 4096,
|
24
|
+
"citations": True,
|
25
|
+
"search": True,
|
26
|
+
},
|
27
|
+
# Research Models
|
28
|
+
"sonar-deep-research": {
|
29
|
+
"name": "Sonar Deep Research",
|
30
|
+
"description": "Expert-level research model conducting exhaustive searches and generating comprehensive reports.",
|
31
|
+
"category": "research",
|
32
|
+
"context_window": 8192,
|
33
|
+
"max_completion_tokens": 4096,
|
34
|
+
"citations": True,
|
35
|
+
"search": True,
|
36
|
+
},
|
37
|
+
# Reasoning Models
|
38
|
+
"sonar-reasoning-pro": {
|
39
|
+
"name": "Sonar Reasoning Pro",
|
40
|
+
"description": "Premier reasoning offering powered by DeepSeek R1 with Chain of Thought (CoT).",
|
41
|
+
"category": "reasoning",
|
42
|
+
"context_window": 8192,
|
43
|
+
"max_completion_tokens": 4096,
|
44
|
+
"citations": True,
|
45
|
+
"search": True,
|
46
|
+
"chain_of_thought": True,
|
47
|
+
},
|
48
|
+
"sonar-reasoning": {
|
49
|
+
"name": "Sonar Reasoning",
|
50
|
+
"description": "Fast, real-time reasoning model designed for quick problem-solving with search.",
|
51
|
+
"category": "reasoning",
|
52
|
+
"context_window": 8192,
|
53
|
+
"max_completion_tokens": 4096,
|
54
|
+
"citations": True,
|
55
|
+
"search": True,
|
56
|
+
"chain_of_thought": True,
|
57
|
+
},
|
58
|
+
# Offline Models
|
59
|
+
"r1-1776": {
|
60
|
+
"name": "R1-1776",
|
61
|
+
"description": "A version of DeepSeek R1 post-trained for uncensored, unbiased, and factual information.",
|
62
|
+
"category": "offline",
|
63
|
+
"context_window": 8192,
|
64
|
+
"max_completion_tokens": 4096,
|
65
|
+
"citations": False,
|
66
|
+
"search": False,
|
67
|
+
},
|
68
|
+
}
|
69
|
+
|
70
|
+
|
71
|
+
def get_model_config(model_id: str) -> Dict[str, Any]:
|
72
|
+
"""
|
73
|
+
Get the configuration for a specific Perplexity AI model.
|
74
|
+
|
75
|
+
Args:
|
76
|
+
model_id: The model ID to get configuration for
|
77
|
+
|
78
|
+
Returns:
|
79
|
+
Dict with model configuration
|
80
|
+
|
81
|
+
Raises:
|
82
|
+
ValueError: If model_id is not found in configuration
|
83
|
+
"""
|
84
|
+
if model_id in PERPLEXITY_MODELS_CONFIG:
|
85
|
+
return PERPLEXITY_MODELS_CONFIG[model_id]
|
86
|
+
|
87
|
+
# Try to find a match with different format or case
|
88
|
+
normalized_id = model_id.lower().replace("-", "").replace("_", "")
|
89
|
+
for config_id, config in PERPLEXITY_MODELS_CONFIG.items():
|
90
|
+
if normalized_id == config_id.lower().replace("-", "").replace("_", ""):
|
91
|
+
return config
|
92
|
+
|
93
|
+
# If model not found, raise an error
|
94
|
+
raise ValueError(
|
95
|
+
f"Model '{model_id}' not found in Perplexity AI models configuration"
|
96
|
+
)
|
97
|
+
|
98
|
+
|
99
|
+
def get_default_model() -> str:
|
100
|
+
"""Get the default model ID for Perplexity AI."""
|
101
|
+
return "sonar-pro"
|
102
|
+
|
103
|
+
|
104
|
+
def supports_citations(model_id: str) -> bool:
|
105
|
+
"""Check if a model supports citations."""
|
106
|
+
return get_model_config(model_id).get("citations", False)
|
107
|
+
|
108
|
+
|
109
|
+
def supports_search(model_id: str) -> bool:
|
110
|
+
"""Check if a model uses search capabilities."""
|
111
|
+
return get_model_config(model_id).get("search", False)
|
112
|
+
|
113
|
+
|
114
|
+
def get_models_by_category(category: str) -> Dict[str, Dict[str, Any]]:
|
115
|
+
"""
|
116
|
+
Get all models belonging to a specific category.
|
117
|
+
|
118
|
+
Args:
|
119
|
+
category: Category to filter by ('search', 'research', 'reasoning', 'offline')
|
120
|
+
|
121
|
+
Returns:
|
122
|
+
Dict of model IDs and their configurations that match the category
|
123
|
+
"""
|
124
|
+
return {
|
125
|
+
model_id: config
|
126
|
+
for model_id, config in PERPLEXITY_MODELS_CONFIG.items()
|
127
|
+
if config.get("category") == category
|
128
|
+
}
|