airtrain 0.1.2__py3-none-any.whl → 0.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (110) hide show
  1. airtrain/__init__.py +148 -2
  2. airtrain/__main__.py +4 -0
  3. airtrain/__pycache__/__init__.cpython-313.pyc +0 -0
  4. airtrain/agents/__init__.py +45 -0
  5. airtrain/agents/example_agent.py +348 -0
  6. airtrain/agents/groq_agent.py +289 -0
  7. airtrain/agents/memory.py +663 -0
  8. airtrain/agents/registry.py +465 -0
  9. airtrain/builder/__init__.py +3 -0
  10. airtrain/builder/agent_builder.py +122 -0
  11. airtrain/cli/__init__.py +0 -0
  12. airtrain/cli/builder.py +23 -0
  13. airtrain/cli/main.py +120 -0
  14. airtrain/contrib/__init__.py +29 -0
  15. airtrain/contrib/travel/__init__.py +35 -0
  16. airtrain/contrib/travel/agents.py +243 -0
  17. airtrain/contrib/travel/models.py +59 -0
  18. airtrain/core/__init__.py +7 -0
  19. airtrain/core/__pycache__/__init__.cpython-313.pyc +0 -0
  20. airtrain/core/__pycache__/schemas.cpython-313.pyc +0 -0
  21. airtrain/core/__pycache__/skills.cpython-313.pyc +0 -0
  22. airtrain/core/credentials.py +171 -0
  23. airtrain/core/schemas.py +237 -0
  24. airtrain/core/skills.py +269 -0
  25. airtrain/integrations/__init__.py +74 -0
  26. airtrain/integrations/anthropic/__init__.py +33 -0
  27. airtrain/integrations/anthropic/credentials.py +32 -0
  28. airtrain/integrations/anthropic/list_models.py +110 -0
  29. airtrain/integrations/anthropic/models_config.py +100 -0
  30. airtrain/integrations/anthropic/skills.py +155 -0
  31. airtrain/integrations/aws/__init__.py +6 -0
  32. airtrain/integrations/aws/credentials.py +36 -0
  33. airtrain/integrations/aws/skills.py +98 -0
  34. airtrain/integrations/cerebras/__init__.py +6 -0
  35. airtrain/integrations/cerebras/credentials.py +19 -0
  36. airtrain/integrations/cerebras/skills.py +127 -0
  37. airtrain/integrations/combined/__init__.py +21 -0
  38. airtrain/integrations/combined/groq_fireworks_skills.py +126 -0
  39. airtrain/integrations/combined/list_models_factory.py +210 -0
  40. airtrain/integrations/fireworks/__init__.py +21 -0
  41. airtrain/integrations/fireworks/completion_skills.py +147 -0
  42. airtrain/integrations/fireworks/conversation_manager.py +109 -0
  43. airtrain/integrations/fireworks/credentials.py +26 -0
  44. airtrain/integrations/fireworks/list_models.py +128 -0
  45. airtrain/integrations/fireworks/models.py +139 -0
  46. airtrain/integrations/fireworks/requests_skills.py +207 -0
  47. airtrain/integrations/fireworks/skills.py +181 -0
  48. airtrain/integrations/fireworks/structured_completion_skills.py +175 -0
  49. airtrain/integrations/fireworks/structured_requests_skills.py +291 -0
  50. airtrain/integrations/fireworks/structured_skills.py +102 -0
  51. airtrain/integrations/google/__init__.py +7 -0
  52. airtrain/integrations/google/credentials.py +58 -0
  53. airtrain/integrations/google/skills.py +122 -0
  54. airtrain/integrations/groq/__init__.py +23 -0
  55. airtrain/integrations/groq/credentials.py +24 -0
  56. airtrain/integrations/groq/models_config.py +162 -0
  57. airtrain/integrations/groq/skills.py +201 -0
  58. airtrain/integrations/ollama/__init__.py +6 -0
  59. airtrain/integrations/ollama/credentials.py +26 -0
  60. airtrain/integrations/ollama/skills.py +41 -0
  61. airtrain/integrations/openai/__init__.py +37 -0
  62. airtrain/integrations/openai/chinese_assistant.py +42 -0
  63. airtrain/integrations/openai/credentials.py +39 -0
  64. airtrain/integrations/openai/list_models.py +112 -0
  65. airtrain/integrations/openai/models_config.py +224 -0
  66. airtrain/integrations/openai/skills.py +342 -0
  67. airtrain/integrations/perplexity/__init__.py +49 -0
  68. airtrain/integrations/perplexity/credentials.py +43 -0
  69. airtrain/integrations/perplexity/list_models.py +112 -0
  70. airtrain/integrations/perplexity/models_config.py +128 -0
  71. airtrain/integrations/perplexity/skills.py +279 -0
  72. airtrain/integrations/sambanova/__init__.py +6 -0
  73. airtrain/integrations/sambanova/credentials.py +20 -0
  74. airtrain/integrations/sambanova/skills.py +129 -0
  75. airtrain/integrations/search/__init__.py +21 -0
  76. airtrain/integrations/search/exa/__init__.py +23 -0
  77. airtrain/integrations/search/exa/credentials.py +30 -0
  78. airtrain/integrations/search/exa/schemas.py +114 -0
  79. airtrain/integrations/search/exa/skills.py +115 -0
  80. airtrain/integrations/together/__init__.py +33 -0
  81. airtrain/integrations/together/audio_models_config.py +34 -0
  82. airtrain/integrations/together/credentials.py +22 -0
  83. airtrain/integrations/together/embedding_models_config.py +92 -0
  84. airtrain/integrations/together/image_models_config.py +69 -0
  85. airtrain/integrations/together/image_skill.py +143 -0
  86. airtrain/integrations/together/list_models.py +76 -0
  87. airtrain/integrations/together/models.py +95 -0
  88. airtrain/integrations/together/models_config.py +399 -0
  89. airtrain/integrations/together/rerank_models_config.py +43 -0
  90. airtrain/integrations/together/rerank_skill.py +49 -0
  91. airtrain/integrations/together/schemas.py +33 -0
  92. airtrain/integrations/together/skills.py +305 -0
  93. airtrain/integrations/together/vision_models_config.py +49 -0
  94. airtrain/telemetry/__init__.py +38 -0
  95. airtrain/telemetry/service.py +167 -0
  96. airtrain/telemetry/views.py +237 -0
  97. airtrain/tools/__init__.py +45 -0
  98. airtrain/tools/command.py +398 -0
  99. airtrain/tools/filesystem.py +166 -0
  100. airtrain/tools/network.py +111 -0
  101. airtrain/tools/registry.py +320 -0
  102. airtrain/tools/search.py +450 -0
  103. airtrain/tools/testing.py +135 -0
  104. airtrain-0.1.4.dist-info/METADATA +222 -0
  105. airtrain-0.1.4.dist-info/RECORD +108 -0
  106. {airtrain-0.1.2.dist-info → airtrain-0.1.4.dist-info}/WHEEL +1 -1
  107. airtrain-0.1.4.dist-info/entry_points.txt +2 -0
  108. airtrain-0.1.2.dist-info/METADATA +0 -106
  109. airtrain-0.1.2.dist-info/RECORD +0 -5
  110. {airtrain-0.1.2.dist-info → airtrain-0.1.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,342 @@
1
+ from typing import AsyncGenerator, List, Optional, Dict, TypeVar, Type, Generator, Union
2
+ from pydantic import Field, BaseModel
3
+ from openai import OpenAI, AsyncOpenAI
4
+ from openai.types.chat import ChatCompletionChunk
5
+ import numpy as np
6
+
7
+ from airtrain.core.skills import Skill, ProcessingError
8
+ from airtrain.core.schemas import InputSchema, OutputSchema
9
+ from .credentials import OpenAICredentials
10
+
11
+
12
+ class OpenAIInput(InputSchema):
13
+ """Schema for OpenAI chat input"""
14
+
15
+ user_input: str = Field(..., description="User's input text")
16
+ system_prompt: str = Field(
17
+ default="You are a helpful assistant.",
18
+ description="System prompt to guide the model's behavior",
19
+ )
20
+ conversation_history: List[Dict[str, str]] = Field(
21
+ default_factory=list,
22
+ description="List of previous conversation messages in [{'role': 'user|assistant', 'content': 'message'}] format",
23
+ )
24
+ model: str = Field(
25
+ default="gpt-4o",
26
+ description="OpenAI model to use",
27
+ )
28
+ temperature: float = Field(
29
+ default=0.7, description="Temperature for response generation", ge=0, le=1
30
+ )
31
+ max_tokens: Optional[int] = Field(
32
+ default=131072, description="Maximum tokens in response"
33
+ )
34
+ stream: bool = Field(
35
+ default=False,
36
+ description="Whether to stream the response token by token",
37
+ )
38
+
39
+
40
+ class OpenAIOutput(OutputSchema):
41
+ """Schema for OpenAI chat output"""
42
+
43
+ response: str
44
+ used_model: str
45
+ usage: Dict[str, int]
46
+
47
+
48
+ class OpenAIChatSkill(Skill[OpenAIInput, OpenAIOutput]):
49
+ """Skill for interacting with OpenAI models with async support"""
50
+
51
+ input_schema = OpenAIInput
52
+ output_schema = OpenAIOutput
53
+
54
+ def __init__(self, credentials: Optional[OpenAICredentials] = None):
55
+ """Initialize the skill with optional credentials"""
56
+ super().__init__()
57
+ self.credentials = credentials or OpenAICredentials.from_env()
58
+ self.client = OpenAI(
59
+ api_key=self.credentials.openai_api_key.get_secret_value(),
60
+ organization=self.credentials.openai_organization_id,
61
+ )
62
+ self.async_client = AsyncOpenAI(
63
+ api_key=self.credentials.openai_api_key.get_secret_value(),
64
+ organization=self.credentials.openai_organization_id,
65
+ )
66
+
67
+ def _build_messages(self, input_data: OpenAIInput) -> List[Dict[str, str]]:
68
+ """Build messages list from input data including conversation history."""
69
+ messages = [{"role": "system", "content": input_data.system_prompt}]
70
+
71
+ if input_data.conversation_history:
72
+ messages.extend(input_data.conversation_history)
73
+
74
+ messages.append({"role": "user", "content": input_data.user_input})
75
+ return messages
76
+
77
+ def process_stream(self, input_data: OpenAIInput) -> Generator[str, None, None]:
78
+ """Process the input and stream the response token by token."""
79
+ try:
80
+ messages = self._build_messages(input_data)
81
+
82
+ stream = self.client.chat.completions.create(
83
+ model=input_data.model,
84
+ messages=messages,
85
+ temperature=input_data.temperature,
86
+ max_tokens=input_data.max_tokens,
87
+ stream=True,
88
+ )
89
+
90
+ for chunk in stream:
91
+ if chunk.choices[0].delta.content is not None:
92
+ yield chunk.choices[0].delta.content
93
+
94
+ except Exception as e:
95
+ raise ProcessingError(f"OpenAI streaming failed: {str(e)}")
96
+
97
+ def process(self, input_data: OpenAIInput) -> OpenAIOutput:
98
+ """Process the input and return the complete response."""
99
+ try:
100
+ if input_data.stream:
101
+ # For streaming, collect the entire response
102
+ response_chunks = []
103
+ for chunk in self.process_stream(input_data):
104
+ response_chunks.append(chunk)
105
+ response = "".join(response_chunks)
106
+ else:
107
+ # For non-streaming, use regular completion
108
+ messages = self._build_messages(input_data)
109
+ completion = self.client.chat.completions.create(
110
+ model=input_data.model,
111
+ messages=messages,
112
+ temperature=input_data.temperature,
113
+ max_tokens=input_data.max_tokens,
114
+ stream=False,
115
+ )
116
+ response = completion.choices[0].message.content
117
+
118
+ return OpenAIOutput(
119
+ response=response,
120
+ used_model=input_data.model,
121
+ usage={
122
+ "total_tokens": completion.usage.total_tokens,
123
+ "prompt_tokens": completion.usage.prompt_tokens,
124
+ "completion_tokens": completion.usage.completion_tokens,
125
+ },
126
+ )
127
+
128
+ except Exception as e:
129
+ raise ProcessingError(f"OpenAI chat failed: {str(e)}")
130
+
131
+ async def process_async(self, input_data: OpenAIInput) -> OpenAIOutput:
132
+ """Async version of process method"""
133
+ try:
134
+ messages = self._build_messages(input_data)
135
+ completion = await self.async_client.chat.completions.create(
136
+ model=input_data.model,
137
+ messages=messages,
138
+ temperature=input_data.temperature,
139
+ max_tokens=input_data.max_tokens,
140
+ )
141
+ return OpenAIOutput(
142
+ response=completion.choices[0].message.content,
143
+ used_model=completion.model,
144
+ usage={
145
+ "total_tokens": completion.usage.total_tokens,
146
+ "prompt_tokens": completion.usage.prompt_tokens,
147
+ "completion_tokens": completion.usage.completion_tokens,
148
+ },
149
+ )
150
+ except Exception as e:
151
+ raise ProcessingError(f"OpenAI async chat failed: {str(e)}")
152
+
153
+ async def process_stream_async(
154
+ self, input_data: OpenAIInput
155
+ ) -> AsyncGenerator[str, None]:
156
+ """Async version of stream processor"""
157
+ try:
158
+ messages = self._build_messages(input_data)
159
+ stream = await self.async_client.chat.completions.create(
160
+ model=input_data.model,
161
+ messages=messages,
162
+ temperature=input_data.temperature,
163
+ max_tokens=input_data.max_tokens,
164
+ stream=True,
165
+ )
166
+ async for chunk in stream:
167
+ if chunk.choices[0].delta.content is not None:
168
+ yield chunk.choices[0].delta.content
169
+ except Exception as e:
170
+ raise ProcessingError(f"OpenAI async streaming failed: {str(e)}")
171
+
172
+
173
+ ResponseT = TypeVar("ResponseT", bound=BaseModel)
174
+
175
+
176
+ class OpenAIParserInput(InputSchema):
177
+ """Schema for OpenAI structured output input"""
178
+
179
+ user_input: str
180
+ system_prompt: str = "You are a helpful assistant that provides structured data."
181
+ model: str = "gpt-4o"
182
+ temperature: float = 0.7
183
+ max_tokens: Optional[int] = None
184
+ response_model: Type[ResponseT]
185
+
186
+ class Config:
187
+ arbitrary_types_allowed = True
188
+
189
+
190
+ class OpenAIParserOutput(OutputSchema):
191
+ """Schema for OpenAI structured output"""
192
+
193
+ parsed_response: BaseModel
194
+ used_model: str
195
+ tokens_used: int
196
+
197
+
198
+ class OpenAIParserSkill(Skill[OpenAIParserInput, OpenAIParserOutput]):
199
+ """Skill for getting structured responses from OpenAI"""
200
+
201
+ input_schema = OpenAIParserInput
202
+ output_schema = OpenAIParserOutput
203
+
204
+ def __init__(self, credentials: Optional[OpenAICredentials] = None):
205
+ """Initialize the skill with optional credentials"""
206
+ super().__init__()
207
+ self.credentials = credentials or OpenAICredentials.from_env()
208
+ self.client = OpenAI(
209
+ api_key=self.credentials.openai_api_key.get_secret_value(),
210
+ organization=self.credentials.openai_organization_id,
211
+ )
212
+
213
+ def process(self, input_data: OpenAIParserInput) -> OpenAIParserOutput:
214
+ try:
215
+ # Use parse method instead of create
216
+ completion = self.client.beta.chat.completions.parse(
217
+ model=input_data.model,
218
+ messages=[
219
+ {"role": "system", "content": input_data.system_prompt},
220
+ {"role": "user", "content": input_data.user_input},
221
+ ],
222
+ response_format=input_data.response_model,
223
+ )
224
+
225
+ if completion.choices[0].message.parsed is None:
226
+ raise ProcessingError("Failed to parse response")
227
+
228
+ return OpenAIParserOutput(
229
+ parsed_response=completion.choices[0].message.parsed,
230
+ used_model=completion.model,
231
+ tokens_used=completion.usage.total_tokens,
232
+ )
233
+
234
+ except Exception as e:
235
+ raise ProcessingError(f"OpenAI parsing failed: {str(e)}")
236
+
237
+
238
+ class OpenAIEmbeddingsInput(InputSchema):
239
+ """Schema for OpenAI embeddings input"""
240
+
241
+ texts: Union[str, List[str]] = Field(
242
+ ..., description="Text or list of texts to generate embeddings for"
243
+ )
244
+ model: str = Field(
245
+ default="text-embedding-3-large", description="OpenAI embeddings model to use"
246
+ )
247
+ encoding_format: str = Field(
248
+ default="float", description="The format of the embeddings: 'float' or 'base64'"
249
+ )
250
+ dimensions: Optional[int] = Field(
251
+ default=None, description="Optional number of dimensions for the embeddings"
252
+ )
253
+
254
+
255
+ class OpenAIEmbeddingsOutput(OutputSchema):
256
+ """Schema for OpenAI embeddings output"""
257
+
258
+ embeddings: List[List[float]] = Field(..., description="List of embeddings vectors")
259
+ used_model: str = Field(..., description="Model used for generating embeddings")
260
+ tokens_used: int = Field(..., description="Number of tokens used")
261
+
262
+
263
+ class OpenAIEmbeddingsSkill(Skill[OpenAIEmbeddingsInput, OpenAIEmbeddingsOutput]):
264
+ """Skill for generating embeddings using OpenAI models"""
265
+
266
+ input_schema = OpenAIEmbeddingsInput
267
+ output_schema = OpenAIEmbeddingsOutput
268
+
269
+ def __init__(self, credentials: Optional[OpenAICredentials] = None):
270
+ """Initialize the skill with optional credentials"""
271
+ super().__init__()
272
+ self.credentials = credentials or OpenAICredentials.from_env()
273
+ self.client = OpenAI(
274
+ api_key=self.credentials.openai_api_key.get_secret_value(),
275
+ organization=self.credentials.openai_organization_id,
276
+ )
277
+ self.async_client = AsyncOpenAI(
278
+ api_key=self.credentials.openai_api_key.get_secret_value(),
279
+ organization=self.credentials.openai_organization_id,
280
+ )
281
+
282
+ def process(self, input_data: OpenAIEmbeddingsInput) -> OpenAIEmbeddingsOutput:
283
+ """Generate embeddings for the input text(s)"""
284
+ try:
285
+ # Handle single text input
286
+ texts = (
287
+ [input_data.texts]
288
+ if isinstance(input_data.texts, str)
289
+ else input_data.texts
290
+ )
291
+
292
+ # Create embeddings
293
+ response = self.client.embeddings.create(
294
+ model=input_data.model,
295
+ input=texts,
296
+ encoding_format=input_data.encoding_format,
297
+ dimensions=input_data.dimensions,
298
+ )
299
+
300
+ # Extract embeddings
301
+ embeddings = [data.embedding for data in response.data]
302
+
303
+ return OpenAIEmbeddingsOutput(
304
+ embeddings=embeddings,
305
+ used_model=response.model,
306
+ tokens_used=response.usage.total_tokens,
307
+ )
308
+ except Exception as e:
309
+ raise ProcessingError(f"OpenAI embeddings generation failed: {str(e)}")
310
+
311
+ async def process_async(
312
+ self, input_data: OpenAIEmbeddingsInput
313
+ ) -> OpenAIEmbeddingsOutput:
314
+ """Async version of the embeddings generation"""
315
+ try:
316
+ # Handle single text input
317
+ texts = (
318
+ [input_data.texts]
319
+ if isinstance(input_data.texts, str)
320
+ else input_data.texts
321
+ )
322
+
323
+ # Create embeddings
324
+ response = await self.async_client.embeddings.create(
325
+ model=input_data.model,
326
+ input=texts,
327
+ encoding_format=input_data.encoding_format,
328
+ dimensions=input_data.dimensions,
329
+ )
330
+
331
+ # Extract embeddings
332
+ embeddings = [data.embedding for data in response.data]
333
+
334
+ return OpenAIEmbeddingsOutput(
335
+ embeddings=embeddings,
336
+ used_model=response.model,
337
+ tokens_used=response.usage.total_tokens,
338
+ )
339
+ except Exception as e:
340
+ raise ProcessingError(
341
+ f"OpenAI async embeddings generation failed: {str(e)}"
342
+ )
@@ -0,0 +1,49 @@
1
+ """Perplexity AI integration module"""
2
+
3
+ from .credentials import PerplexityCredentials
4
+ from .skills import (
5
+ PerplexityInput,
6
+ PerplexityOutput,
7
+ PerplexityChatSkill,
8
+ PerplexityCitation,
9
+ PerplexityStreamingChatSkill,
10
+ PerplexityStreamOutput,
11
+ )
12
+ from .list_models import (
13
+ PerplexityListModelsSkill,
14
+ StandalonePerplexityListModelsSkill,
15
+ PerplexityListModelsInput,
16
+ PerplexityListModelsOutput,
17
+ )
18
+ from .models_config import (
19
+ get_model_config,
20
+ get_default_model,
21
+ supports_citations,
22
+ supports_search,
23
+ get_models_by_category,
24
+ PERPLEXITY_MODELS_CONFIG,
25
+ )
26
+
27
+ __all__ = [
28
+ # Credentials
29
+ "PerplexityCredentials",
30
+ # Skills
31
+ "PerplexityInput",
32
+ "PerplexityOutput",
33
+ "PerplexityChatSkill",
34
+ "PerplexityCitation",
35
+ "PerplexityStreamingChatSkill",
36
+ "PerplexityStreamOutput",
37
+ # List Models
38
+ "PerplexityListModelsSkill",
39
+ "StandalonePerplexityListModelsSkill",
40
+ "PerplexityListModelsInput",
41
+ "PerplexityListModelsOutput",
42
+ # Model Config
43
+ "get_model_config",
44
+ "get_default_model",
45
+ "supports_citations",
46
+ "supports_search",
47
+ "get_models_by_category",
48
+ "PERPLEXITY_MODELS_CONFIG",
49
+ ]
@@ -0,0 +1,43 @@
1
+ from pydantic import Field, SecretStr
2
+ from airtrain.core.credentials import BaseCredentials, CredentialValidationError
3
+ import requests
4
+
5
+
6
+ class PerplexityCredentials(BaseCredentials):
7
+ """Perplexity AI API credentials"""
8
+
9
+ perplexity_api_key: SecretStr = Field(..., description="Perplexity AI API key")
10
+
11
+ _required_credentials = {"perplexity_api_key"}
12
+
13
+ async def validate_credentials(self) -> bool:
14
+ """Validate Perplexity AI credentials by making a test API call"""
15
+ try:
16
+ headers = {
17
+ "Authorization": f"Bearer {self.perplexity_api_key.get_secret_value()}",
18
+ "Content-Type": "application/json",
19
+ }
20
+
21
+ # Small API call to check if credentials are valid
22
+ data = {
23
+ "model": "sonar-pro",
24
+ "messages": [{"role": "user", "content": "Test"}],
25
+ "max_tokens": 1,
26
+ }
27
+
28
+ # Make a synchronous request for validation
29
+ response = requests.post(
30
+ "https://api.perplexity.ai/chat/completions", headers=headers, json=data
31
+ )
32
+
33
+ if response.status_code == 200:
34
+ return True
35
+ else:
36
+ raise CredentialValidationError(
37
+ f"Invalid Perplexity AI credentials: {response.status_code} - {response.text}"
38
+ )
39
+
40
+ except Exception as e:
41
+ raise CredentialValidationError(
42
+ f"Invalid Perplexity AI credentials: {str(e)}"
43
+ )
@@ -0,0 +1,112 @@
1
+ from typing import Dict, Any, List, Optional
2
+ import requests
3
+
4
+ from airtrain.core.skills import Skill, ProcessingError
5
+ from airtrain.integrations.combined.list_models_factory import (
6
+ BaseListModelsSkill,
7
+ GenericListModelsInput,
8
+ GenericListModelsOutput,
9
+ )
10
+ from airtrain.core.schemas import InputSchema, OutputSchema
11
+ from .credentials import PerplexityCredentials
12
+ from .models_config import PERPLEXITY_MODELS_CONFIG
13
+
14
+
15
+ class PerplexityListModelsInput(InputSchema):
16
+ """Schema for listing Perplexity AI models"""
17
+
18
+ api_models_only: bool = False
19
+
20
+
21
+ class PerplexityListModelsOutput(OutputSchema):
22
+ """Schema for Perplexity AI models listing output"""
23
+
24
+ models: List[Dict[str, Any]]
25
+ provider: str = "perplexity"
26
+
27
+
28
+ class PerplexityListModelsSkill(BaseListModelsSkill):
29
+ """Skill for listing Perplexity AI models"""
30
+
31
+ def __init__(self, credentials: Optional[PerplexityCredentials] = None):
32
+ """Initialize the skill with optional credentials"""
33
+ super().__init__(provider="perplexity", credentials=credentials)
34
+ self.credentials = credentials
35
+
36
+ def get_models(self) -> List[Dict[str, Any]]:
37
+ """Return list of Perplexity AI models."""
38
+ models = []
39
+
40
+ # Add models from the configuration
41
+ for model_id, config in PERPLEXITY_MODELS_CONFIG.items():
42
+ models.append(
43
+ {
44
+ "id": model_id,
45
+ "display_name": config["name"],
46
+ "description": config.get("description", ""),
47
+ "category": config.get("category", "unknown"),
48
+ "capabilities": {
49
+ "citations": config.get("citations", False),
50
+ "search": config.get("search", False),
51
+ "context_window": config.get("context_window", 8192),
52
+ "max_completion_tokens": config.get(
53
+ "max_completion_tokens", 4096
54
+ ),
55
+ },
56
+ }
57
+ )
58
+
59
+ return models
60
+
61
+ def process(self, input_data: GenericListModelsInput) -> GenericListModelsOutput:
62
+ """Process the input and return a list of models."""
63
+ try:
64
+ models = self.get_models()
65
+ return GenericListModelsOutput(models=models, provider="perplexity")
66
+ except Exception as e:
67
+ raise ProcessingError(f"Failed to list Perplexity AI models: {str(e)}")
68
+
69
+
70
+ # Standalone version directly using the Perplexity-specific schemas
71
+ class StandalonePerplexityListModelsSkill(
72
+ Skill[PerplexityListModelsInput, PerplexityListModelsOutput]
73
+ ):
74
+ """Standalone skill for listing Perplexity AI models"""
75
+
76
+ input_schema = PerplexityListModelsInput
77
+ output_schema = PerplexityListModelsOutput
78
+
79
+ def __init__(self, credentials: Optional[PerplexityCredentials] = None):
80
+ """Initialize the skill with optional credentials"""
81
+ super().__init__()
82
+ self.credentials = credentials
83
+
84
+ def process(
85
+ self, input_data: PerplexityListModelsInput
86
+ ) -> PerplexityListModelsOutput:
87
+ """Process the input and return a list of models."""
88
+ try:
89
+ models = []
90
+
91
+ # Add models from the configuration
92
+ for model_id, config in PERPLEXITY_MODELS_CONFIG.items():
93
+ models.append(
94
+ {
95
+ "id": model_id,
96
+ "display_name": config["name"],
97
+ "description": config.get("description", ""),
98
+ "category": config.get("category", "unknown"),
99
+ "capabilities": {
100
+ "citations": config.get("citations", False),
101
+ "search": config.get("search", False),
102
+ "context_window": config.get("context_window", 8192),
103
+ "max_completion_tokens": config.get(
104
+ "max_completion_tokens", 4096
105
+ ),
106
+ },
107
+ }
108
+ )
109
+
110
+ return PerplexityListModelsOutput(models=models, provider="perplexity")
111
+ except Exception as e:
112
+ raise ProcessingError(f"Failed to list Perplexity AI models: {str(e)}")
@@ -0,0 +1,128 @@
1
+ """Configuration of Perplexity AI model capabilities."""
2
+
3
+ from typing import Dict, Any
4
+
5
+
6
+ # Model configuration with capabilities for each Perplexity AI model
7
+ PERPLEXITY_MODELS_CONFIG = {
8
+ # Search Models
9
+ "sonar-pro": {
10
+ "name": "Sonar Pro",
11
+ "description": "Advanced search offering with grounding, supporting complex queries and follow-ups.",
12
+ "category": "search",
13
+ "context_window": 8192,
14
+ "max_completion_tokens": 4096,
15
+ "citations": True,
16
+ "search": True,
17
+ },
18
+ "sonar": {
19
+ "name": "Sonar",
20
+ "description": "Lightweight, cost-effective search model with grounding.",
21
+ "category": "search",
22
+ "context_window": 8192,
23
+ "max_completion_tokens": 4096,
24
+ "citations": True,
25
+ "search": True,
26
+ },
27
+ # Research Models
28
+ "sonar-deep-research": {
29
+ "name": "Sonar Deep Research",
30
+ "description": "Expert-level research model conducting exhaustive searches and generating comprehensive reports.",
31
+ "category": "research",
32
+ "context_window": 8192,
33
+ "max_completion_tokens": 4096,
34
+ "citations": True,
35
+ "search": True,
36
+ },
37
+ # Reasoning Models
38
+ "sonar-reasoning-pro": {
39
+ "name": "Sonar Reasoning Pro",
40
+ "description": "Premier reasoning offering powered by DeepSeek R1 with Chain of Thought (CoT).",
41
+ "category": "reasoning",
42
+ "context_window": 8192,
43
+ "max_completion_tokens": 4096,
44
+ "citations": True,
45
+ "search": True,
46
+ "chain_of_thought": True,
47
+ },
48
+ "sonar-reasoning": {
49
+ "name": "Sonar Reasoning",
50
+ "description": "Fast, real-time reasoning model designed for quick problem-solving with search.",
51
+ "category": "reasoning",
52
+ "context_window": 8192,
53
+ "max_completion_tokens": 4096,
54
+ "citations": True,
55
+ "search": True,
56
+ "chain_of_thought": True,
57
+ },
58
+ # Offline Models
59
+ "r1-1776": {
60
+ "name": "R1-1776",
61
+ "description": "A version of DeepSeek R1 post-trained for uncensored, unbiased, and factual information.",
62
+ "category": "offline",
63
+ "context_window": 8192,
64
+ "max_completion_tokens": 4096,
65
+ "citations": False,
66
+ "search": False,
67
+ },
68
+ }
69
+
70
+
71
+ def get_model_config(model_id: str) -> Dict[str, Any]:
72
+ """
73
+ Get the configuration for a specific Perplexity AI model.
74
+
75
+ Args:
76
+ model_id: The model ID to get configuration for
77
+
78
+ Returns:
79
+ Dict with model configuration
80
+
81
+ Raises:
82
+ ValueError: If model_id is not found in configuration
83
+ """
84
+ if model_id in PERPLEXITY_MODELS_CONFIG:
85
+ return PERPLEXITY_MODELS_CONFIG[model_id]
86
+
87
+ # Try to find a match with different format or case
88
+ normalized_id = model_id.lower().replace("-", "").replace("_", "")
89
+ for config_id, config in PERPLEXITY_MODELS_CONFIG.items():
90
+ if normalized_id == config_id.lower().replace("-", "").replace("_", ""):
91
+ return config
92
+
93
+ # If model not found, raise an error
94
+ raise ValueError(
95
+ f"Model '{model_id}' not found in Perplexity AI models configuration"
96
+ )
97
+
98
+
99
+ def get_default_model() -> str:
100
+ """Get the default model ID for Perplexity AI."""
101
+ return "sonar-pro"
102
+
103
+
104
+ def supports_citations(model_id: str) -> bool:
105
+ """Check if a model supports citations."""
106
+ return get_model_config(model_id).get("citations", False)
107
+
108
+
109
+ def supports_search(model_id: str) -> bool:
110
+ """Check if a model uses search capabilities."""
111
+ return get_model_config(model_id).get("search", False)
112
+
113
+
114
+ def get_models_by_category(category: str) -> Dict[str, Dict[str, Any]]:
115
+ """
116
+ Get all models belonging to a specific category.
117
+
118
+ Args:
119
+ category: Category to filter by ('search', 'research', 'reasoning', 'offline')
120
+
121
+ Returns:
122
+ Dict of model IDs and their configurations that match the category
123
+ """
124
+ return {
125
+ model_id: config
126
+ for model_id, config in PERPLEXITY_MODELS_CONFIG.items()
127
+ if config.get("category") == category
128
+ }