airtrain 0.1.58__py3-none-any.whl → 0.1.61__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
airtrain/__init__.py CHANGED
@@ -1,43 +1,65 @@
1
- """Airtrain - A platform for building and deploying AI agents with structured skills"""
1
+ """
2
+ Airtrain: AI Agent Framework
2
3
 
3
- __version__ = "0.1.58"
4
+ This library provides a flexible framework for building AI agents
5
+ that can complete complex tasks using AI models, skills, and tools.
6
+ """
7
+
8
+ __version__ = "0.1.61"
4
9
 
5
10
  import sys
6
11
 
7
12
  # Core imports
8
- from .core.skills import Skill, ProcessingError
9
- from .core.schemas import InputSchema, OutputSchema
10
- from .core.credentials import BaseCredentials
13
+ from .core import Skill, ProcessingError, InputSchema, OutputSchema, BaseCredentials
11
14
 
12
15
  # Integration imports - Credentials
13
- from .integrations.openai.credentials import OpenAICredentials
14
- from .integrations.aws.credentials import AWSCredentials
15
- from .integrations.google.credentials import GoogleCloudCredentials
16
- from .integrations.anthropic.credentials import AnthropicCredentials
17
- from .integrations.groq.credentials import GroqCredentials
18
- from .integrations.together.credentials import TogetherAICredentials
19
- from .integrations.ollama.credentials import OllamaCredentials
20
- from .integrations.sambanova.credentials import SambanovaCredentials
21
- from .integrations.cerebras.credentials import CerebrasCredentials
16
+ from .integrations import (
17
+ # OpenAI
18
+ OpenAICredentials,
19
+ OpenAIChatSkill,
20
+ OpenAICompletionSkill,
21
+ # Anthropic
22
+ AnthropicCredentials,
23
+ AnthropicChatSkill,
24
+ # Together.ai
25
+ TogetherAICredentials,
26
+ TogetherChatSkill,
27
+ # Fireworks
28
+ FireworksCredentials,
29
+ FireworksChatSkill,
30
+ # Google
31
+ GeminiCredentials,
32
+ GeminiChatSkill,
33
+ # Search
34
+ ExaCredentials,
35
+ ExaSearchSkill,
36
+ ExaSearchInputSchema,
37
+ ExaSearchOutputSchema,
38
+ )
22
39
 
23
40
  # Integration imports - Skills
24
- from .integrations.openai.skills import OpenAIChatSkill, OpenAIParserSkill
25
- from .integrations.anthropic.skills import AnthropicChatSkill
26
41
  from .integrations.aws.skills import AWSBedrockSkill
27
42
  from .integrations.google.skills import GoogleChatSkill
28
43
  from .integrations.groq.skills import GroqChatSkill
29
- from .integrations.together.skills import TogetherAIChatSkill
30
44
  from .integrations.ollama.skills import OllamaChatSkill
31
45
  from .integrations.sambanova.skills import SambanovaChatSkill
32
46
  from .integrations.cerebras.skills import CerebrasChatSkill
33
47
 
34
48
  # Tool imports
35
49
  from .tools import (
36
- StatefulTool,
37
- StatelessTool,
38
- register_tool,
39
50
  ToolFactory,
40
- execute_tool_call
51
+ register_tool,
52
+ StatelessTool,
53
+ StatefulTool,
54
+ BaseTool,
55
+ ListDirectoryTool,
56
+ DirectoryTreeTool,
57
+ ApiCallTool,
58
+ ExecuteCommandTool,
59
+ FindFilesTool,
60
+ TerminalNavigationTool,
61
+ SearchTermTool,
62
+ RunPytestTool,
41
63
  )
42
64
 
43
65
  # Agent imports
@@ -48,7 +70,7 @@ from .agents import (
48
70
  BaseMemory,
49
71
  ShortTermMemory,
50
72
  LongTermMemory,
51
- SharedMemory
73
+ SharedMemory,
52
74
  )
53
75
 
54
76
  # Telemetry import - must be imported after version is defined
@@ -63,33 +85,41 @@ __all__ = [
63
85
  "InputSchema",
64
86
  "OutputSchema",
65
87
  "BaseCredentials",
66
- # Credentials
88
+ # OpenAI Integration
67
89
  "OpenAICredentials",
68
- "AWSCredentials",
69
- "GoogleCloudCredentials",
70
- "AnthropicCredentials",
71
- "GroqCredentials",
72
- "TogetherAICredentials",
73
- "OllamaCredentials",
74
- "SambanovaCredentials",
75
- "CerebrasCredentials",
76
- # Skills
77
90
  "OpenAIChatSkill",
78
- "OpenAIParserSkill",
91
+ "OpenAICompletionSkill",
92
+ # Anthropic Integration
93
+ "AnthropicCredentials",
79
94
  "AnthropicChatSkill",
80
- "AWSBedrockSkill",
81
- "GoogleChatSkill",
82
- "GroqChatSkill",
83
- "TogetherAIChatSkill",
84
- "OllamaChatSkill",
85
- "SambanovaChatSkill",
86
- "CerebrasChatSkill",
95
+ # Together Integration
96
+ "TogetherAICredentials",
97
+ "TogetherChatSkill",
98
+ # Fireworks Integration
99
+ "FireworksCredentials",
100
+ "FireworksChatSkill",
101
+ # Google Integration
102
+ "GeminiCredentials",
103
+ "GeminiChatSkill",
104
+ # Search Integration
105
+ "ExaCredentials",
106
+ "ExaSearchSkill",
107
+ "ExaSearchInputSchema",
108
+ "ExaSearchOutputSchema",
87
109
  # Tools
88
- "StatefulTool",
89
- "StatelessTool",
90
- "register_tool",
91
110
  "ToolFactory",
92
- "execute_tool_call",
111
+ "register_tool",
112
+ "StatelessTool",
113
+ "StatefulTool",
114
+ "BaseTool",
115
+ "ListDirectoryTool",
116
+ "DirectoryTreeTool",
117
+ "ApiCallTool",
118
+ "ExecuteCommandTool",
119
+ "FindFilesTool",
120
+ "TerminalNavigationTool",
121
+ "SearchTermTool",
122
+ "RunPytestTool",
93
123
  # Agents
94
124
  "BaseAgent",
95
125
  "AgentFactory",
@@ -111,7 +141,7 @@ try:
111
141
  f"{sys.version_info.major}."
112
142
  f"{sys.version_info.minor}."
113
143
  f"{sys.version_info.micro}"
114
- )
144
+ ),
115
145
  )
116
146
  )
117
147
  except Exception:
@@ -1,4 +1,4 @@
1
- from typing import Dict, List, Optional, Set
1
+ from typing import Dict, List, Optional, Set, Union
2
2
  import os
3
3
  import json
4
4
  from pathlib import Path
@@ -53,30 +53,70 @@ class BaseCredentials(BaseModel):
53
53
  return cls(**field_values)
54
54
 
55
55
  @classmethod
56
- def from_file(cls, file_path: Path) -> "BaseCredentials":
57
- """Load credentials from a file (supports .env, .json, .yaml)"""
56
+ def from_file(cls, file_path: str | Path) -> "BaseCredentials":
57
+ """Load credentials from a file (supports .env, .json, .yaml).
58
+
59
+ Args:
60
+ file_path: Path to load credentials from. Can be a string or Path object.
61
+ Supported formats: .env, .json, .yaml/.yml
62
+
63
+ Returns:
64
+ BaseCredentials: Initialized credentials object
65
+
66
+ Raises:
67
+ FileNotFoundError: If the credentials file does not exist
68
+ ValueError: If the file format is not supported
69
+ """
70
+ # Convert to Path object if string
71
+ if isinstance(file_path, str):
72
+ file_path = Path(file_path)
73
+
58
74
  if not file_path.exists():
59
75
  raise FileNotFoundError(f"Credentials file not found: {file_path}")
60
76
 
61
- if file_path.suffix == ".env":
77
+ # Get file extension, default to .env if none provided
78
+ suffix = file_path.suffix
79
+ if not suffix:
80
+ # Try to find a file with the same name but different extension
81
+ for ext in [".env", ".json", ".yaml", ".yml"]:
82
+ potential_path = file_path.with_suffix(ext)
83
+ if potential_path.exists():
84
+ file_path = potential_path
85
+ suffix = ext
86
+ break
87
+ # If no file was found, default to .env
88
+ if not suffix:
89
+ file_path = file_path.with_suffix(".env")
90
+ suffix = ".env"
91
+
92
+ if suffix == ".env":
62
93
  dotenv.load_dotenv(file_path)
63
94
  return cls.from_env()
64
95
 
65
- elif file_path.suffix == ".json":
96
+ elif suffix == ".json":
66
97
  with open(file_path) as f:
67
98
  data = json.load(f)
68
99
  return cls(**data)
69
100
 
70
- elif file_path.suffix in {".yaml", ".yml"}:
101
+ elif suffix in {".yaml", ".yml"}:
71
102
  with open(file_path) as f:
72
103
  data = yaml.safe_load(f)
73
104
  return cls(**data)
74
105
 
75
106
  else:
76
- raise ValueError(f"Unsupported file format: {file_path.suffix}")
107
+ raise ValueError(f"Unsupported file format: {suffix}")
108
+
109
+ def save_to_file(self, file_path: str | Path) -> None:
110
+ """Save credentials to a file.
111
+
112
+ Args:
113
+ file_path: Path to save credentials to. Can be a string or Path object.
114
+ Supported formats: .env, .json, .yaml/.yml
115
+ """
116
+ # Convert to Path object if string
117
+ if isinstance(file_path, str):
118
+ file_path = Path(file_path)
77
119
 
78
- def save_to_file(self, file_path: Path) -> None:
79
- """Save credentials to a file"""
80
120
  data = self.model_dump(exclude={"_loaded"})
81
121
 
82
122
  # Convert SecretStr to plain strings for saving
@@ -84,21 +124,27 @@ class BaseCredentials(BaseModel):
84
124
  if isinstance(value, SecretStr):
85
125
  data[key] = value.get_secret_value()
86
126
 
87
- if file_path.suffix == ".env":
127
+ # Get file extension, default to .env if none provided
128
+ suffix = file_path.suffix
129
+ if not suffix:
130
+ file_path = file_path.with_suffix(".env")
131
+ suffix = ".env"
132
+
133
+ if suffix == ".env":
88
134
  with open(file_path, "w") as f:
89
135
  for key, value in data.items():
90
136
  f.write(f"{key.upper()}={value}\n")
91
137
 
92
- elif file_path.suffix == ".json":
138
+ elif suffix == ".json":
93
139
  with open(file_path, "w") as f:
94
140
  json.dump(data, f, indent=2)
95
141
 
96
- elif file_path.suffix in {".yaml", ".yml"}:
142
+ elif suffix in {".yaml", ".yml"}:
97
143
  with open(file_path, "w") as f:
98
144
  yaml.dump(data, f)
99
145
 
100
146
  else:
101
- raise ValueError(f"Unsupported file format: {file_path.suffix}")
147
+ raise ValueError(f"Unsupported file format: {suffix}")
102
148
 
103
149
  async def validate_credentials(self) -> bool:
104
150
  """Validate that all required credentials are present"""
@@ -10,6 +10,7 @@ from .together.credentials import TogetherAICredentials
10
10
  from .ollama.credentials import OllamaCredentials
11
11
  from .sambanova.credentials import SambanovaCredentials
12
12
  from .cerebras.credentials import CerebrasCredentials
13
+ from .perplexity.credentials import PerplexityCredentials
13
14
 
14
15
  # Skills imports
15
16
  from .openai.skills import OpenAIChatSkill, OpenAIParserSkill
@@ -21,16 +22,18 @@ from .together.skills import TogetherAIChatSkill
21
22
  from .ollama.skills import OllamaChatSkill
22
23
  from .sambanova.skills import SambanovaChatSkill
23
24
  from .cerebras.skills import CerebrasChatSkill
25
+ from .perplexity.skills import PerplexityChatSkill, PerplexityStreamingChatSkill
24
26
 
25
27
  # Model configurations
26
28
  from .openai.models_config import OPENAI_MODELS, OpenAIModelConfig
27
29
  from .anthropic.models_config import ANTHROPIC_MODELS, AnthropicModelConfig
30
+ from .perplexity.models_config import PERPLEXITY_MODELS_CONFIG
28
31
 
29
32
  # Combined modules
30
33
  from .combined.list_models_factory import (
31
34
  ListModelsSkillFactory,
32
35
  GenericListModelsInput,
33
- GenericListModelsOutput
36
+ GenericListModelsOutput,
34
37
  )
35
38
 
36
39
  __all__ = [
@@ -44,6 +47,7 @@ __all__ = [
44
47
  "OllamaCredentials",
45
48
  "SambanovaCredentials",
46
49
  "CerebrasCredentials",
50
+ "PerplexityCredentials",
47
51
  # Skills
48
52
  "OpenAIChatSkill",
49
53
  "OpenAIParserSkill",
@@ -55,11 +59,14 @@ __all__ = [
55
59
  "OllamaChatSkill",
56
60
  "SambanovaChatSkill",
57
61
  "CerebrasChatSkill",
62
+ "PerplexityChatSkill",
63
+ "PerplexityStreamingChatSkill",
58
64
  # Model configurations
59
65
  "OPENAI_MODELS",
60
66
  "OpenAIModelConfig",
61
67
  "ANTHROPIC_MODELS",
62
68
  "AnthropicModelConfig",
69
+ "PERPLEXITY_MODELS_CONFIG",
63
70
  # Combined modules
64
71
  "ListModelsSkillFactory",
65
72
  "GenericListModelsInput",
@@ -15,16 +15,21 @@ from airtrain.integrations.fireworks.list_models import FireworksListModelsSkill
15
15
  from airtrain.integrations.groq.credentials import GroqCredentials
16
16
  from airtrain.integrations.cerebras.credentials import CerebrasCredentials
17
17
  from airtrain.integrations.sambanova.credentials import SambanovaCredentials
18
+ from airtrain.integrations.perplexity.credentials import PerplexityCredentials
19
+
20
+ # Import Perplexity list models
21
+ from airtrain.integrations.perplexity.list_models import PerplexityListModelsSkill
22
+
18
23
 
19
24
  # Generic list models input schema
20
25
  class GenericListModelsInput(InputSchema):
21
26
  """Generic schema for listing models from any provider"""
22
-
27
+
23
28
  api_models_only: bool = Field(
24
29
  default=False,
25
30
  description=(
26
31
  "If True, fetch models from the API only. If False, use local config."
27
- )
32
+ ),
28
33
  )
29
34
 
30
35
  class Config:
@@ -32,38 +37,33 @@ class GenericListModelsInput(InputSchema):
32
37
  extra = "allow"
33
38
 
34
39
 
35
-
36
40
  # Generic list models output schema
37
41
  class GenericListModelsOutput(OutputSchema):
38
42
  """Generic schema for list models output from any provider"""
39
-
43
+
40
44
  models: List[Dict[str, Any]] = Field(
41
- default_factory=list,
42
- description="List of models"
43
- )
44
- provider: str = Field(
45
- ...,
46
- description="Provider name"
45
+ default_factory=list, description="List of models"
47
46
  )
47
+ provider: str = Field(..., description="Provider name")
48
48
 
49
49
 
50
50
  # Base class for stub implementations
51
51
  class BaseListModelsSkill(Skill[GenericListModelsInput, GenericListModelsOutput]):
52
52
  """Base skill for listing models"""
53
-
53
+
54
54
  input_schema = GenericListModelsInput
55
55
  output_schema = GenericListModelsOutput
56
-
56
+
57
57
  def __init__(self, provider: str, credentials: Optional[BaseCredentials] = None):
58
58
  """Initialize the skill with provider name and optional credentials"""
59
59
  super().__init__()
60
60
  self.provider = provider
61
61
  self.credentials = credentials
62
-
62
+
63
63
  def get_models(self) -> List[Dict[str, Any]]:
64
64
  """Return list of models. To be implemented by subclasses."""
65
65
  raise NotImplementedError("Subclasses must implement get_models()")
66
-
66
+
67
67
  def process(self, input_data: GenericListModelsInput) -> GenericListModelsOutput:
68
68
  """Process the input and return a list of models."""
69
69
  try:
@@ -76,24 +76,42 @@ class BaseListModelsSkill(Skill[GenericListModelsInput, GenericListModelsOutput]
76
76
  # Groq implementation
77
77
  class GroqListModelsSkill(BaseListModelsSkill):
78
78
  """Skill for listing Groq models"""
79
-
79
+
80
80
  def __init__(self, credentials: Optional[GroqCredentials] = None):
81
81
  """Initialize the skill with optional credentials"""
82
82
  super().__init__(provider="groq", credentials=credentials)
83
-
83
+
84
84
  def get_models(self) -> List[Dict[str, Any]]:
85
85
  """Return list of Groq models."""
86
86
  # Default Groq models from trmx_agent config
87
87
  models = [
88
- {"id": "llama-3.3-70b-versatile", "display_name": "Llama 3.3 70B Versatile (Tool Use)"},
89
- {"id": "llama-3.1-8b-instant", "display_name": "Llama 3.1 8B Instant (Tool Use)"},
90
- {"id": "mixtral-8x7b-32768", "display_name": "Mixtral 8x7B (32K) (Tool Use)"},
88
+ {
89
+ "id": "llama-3.3-70b-versatile",
90
+ "display_name": "Llama 3.3 70B Versatile (Tool Use)",
91
+ },
92
+ {
93
+ "id": "llama-3.1-8b-instant",
94
+ "display_name": "Llama 3.1 8B Instant (Tool Use)",
95
+ },
96
+ {
97
+ "id": "mixtral-8x7b-32768",
98
+ "display_name": "Mixtral 8x7B (32K) (Tool Use)",
99
+ },
91
100
  {"id": "gemma2-9b-it", "display_name": "Gemma 2 9B IT (Tool Use)"},
92
101
  {"id": "qwen-qwq-32b", "display_name": "Qwen QWQ 32B (Tool Use)"},
93
- {"id": "qwen-2.5-coder-32b", "display_name": "Qwen 2.5 Coder 32B (Tool Use)"},
102
+ {
103
+ "id": "qwen-2.5-coder-32b",
104
+ "display_name": "Qwen 2.5 Coder 32B (Tool Use)",
105
+ },
94
106
  {"id": "qwen-2.5-32b", "display_name": "Qwen 2.5 32B (Tool Use)"},
95
- {"id": "deepseek-r1-distill-qwen-32b", "display_name": "DeepSeek R1 Distill Qwen 32B (Tool Use)"},
96
- {"id": "deepseek-r1-distill-llama-70b", "display_name": "DeepSeek R1 Distill Llama 70B (Tool Use)"},
107
+ {
108
+ "id": "deepseek-r1-distill-qwen-32b",
109
+ "display_name": "DeepSeek R1 Distill Qwen 32B (Tool Use)",
110
+ },
111
+ {
112
+ "id": "deepseek-r1-distill-llama-70b",
113
+ "display_name": "DeepSeek R1 Distill Llama 70B (Tool Use)",
114
+ },
97
115
  ]
98
116
  return models
99
117
 
@@ -101,18 +119,27 @@ class GroqListModelsSkill(BaseListModelsSkill):
101
119
  # Cerebras implementation
102
120
  class CerebrasListModelsSkill(BaseListModelsSkill):
103
121
  """Skill for listing Cerebras models"""
104
-
122
+
105
123
  def __init__(self, credentials: Optional[CerebrasCredentials] = None):
106
124
  """Initialize the skill with optional credentials"""
107
125
  super().__init__(provider="cerebras", credentials=credentials)
108
-
126
+
109
127
  def get_models(self) -> List[Dict[str, Any]]:
110
128
  """Return list of Cerebras models."""
111
129
  # Default Cerebras models from trmx_agent config
112
130
  models = [
113
- {"id": "cerebras/Cerebras-GPT-13B-v0.1", "display_name": "Cerebras GPT 13B v0.1"},
114
- {"id": "cerebras/Cerebras-GPT-111M-v0.9", "display_name": "Cerebras GPT 111M v0.9"},
115
- {"id": "cerebras/Cerebras-GPT-590M-v0.7", "display_name": "Cerebras GPT 590M v0.7"}
131
+ {
132
+ "id": "cerebras/Cerebras-GPT-13B-v0.1",
133
+ "display_name": "Cerebras GPT 13B v0.1",
134
+ },
135
+ {
136
+ "id": "cerebras/Cerebras-GPT-111M-v0.9",
137
+ "display_name": "Cerebras GPT 111M v0.9",
138
+ },
139
+ {
140
+ "id": "cerebras/Cerebras-GPT-590M-v0.7",
141
+ "display_name": "Cerebras GPT 590M v0.7",
142
+ },
116
143
  ]
117
144
  return models
118
145
 
@@ -120,17 +147,17 @@ class CerebrasListModelsSkill(BaseListModelsSkill):
120
147
  # Sambanova implementation
121
148
  class SambanovaListModelsSkill(BaseListModelsSkill):
122
149
  """Skill for listing Sambanova models"""
123
-
150
+
124
151
  def __init__(self, credentials: Optional[SambanovaCredentials] = None):
125
152
  """Initialize the skill with optional credentials"""
126
153
  super().__init__(provider="sambanova", credentials=credentials)
127
-
154
+
128
155
  def get_models(self) -> List[Dict[str, Any]]:
129
156
  """Return list of Sambanova models."""
130
157
  # Limited Sambanova model information
131
158
  models = [
132
159
  {"id": "sambanova/samba-1", "display_name": "Samba-1"},
133
- {"id": "sambanova/samba-2", "display_name": "Samba-2"}
160
+ {"id": "sambanova/samba-2", "display_name": "Samba-2"},
134
161
  ]
135
162
  return models
136
163
 
@@ -138,7 +165,7 @@ class SambanovaListModelsSkill(BaseListModelsSkill):
138
165
  # Factory class
139
166
  class ListModelsSkillFactory:
140
167
  """Factory for creating list models skills for different providers"""
141
-
168
+
142
169
  # Map provider names to their corresponding list models skills
143
170
  _PROVIDER_MAP = {
144
171
  "openai": OpenAIListModelsSkill,
@@ -147,36 +174,37 @@ class ListModelsSkillFactory:
147
174
  "fireworks": FireworksListModelsSkill,
148
175
  "groq": GroqListModelsSkill,
149
176
  "cerebras": CerebrasListModelsSkill,
150
- "sambanova": SambanovaListModelsSkill
177
+ "sambanova": SambanovaListModelsSkill,
178
+ "perplexity": PerplexityListModelsSkill,
151
179
  }
152
-
180
+
153
181
  @classmethod
154
182
  def get_skill(cls, provider: str, credentials=None):
155
183
  """Return a list models skill for the specified provider
156
-
184
+
157
185
  Args:
158
186
  provider (str): The provider name (case-insensitive)
159
187
  credentials: Optional credentials for the provider
160
-
188
+
161
189
  Returns:
162
190
  A ListModelsSkill instance for the specified provider
163
-
191
+
164
192
  Raises:
165
193
  ValueError: If the provider is not supported
166
194
  """
167
195
  provider = provider.lower()
168
-
196
+
169
197
  if provider not in cls._PROVIDER_MAP:
170
198
  supported = ", ".join(cls.get_supported_providers())
171
199
  raise ValueError(
172
200
  f"Unsupported provider: {provider}. "
173
201
  f"Supported providers are: {supported}"
174
202
  )
175
-
203
+
176
204
  skill_class = cls._PROVIDER_MAP[provider]
177
205
  return skill_class(credentials=credentials)
178
-
206
+
179
207
  @classmethod
180
208
  def get_supported_providers(cls):
181
209
  """Return a list of supported provider names"""
182
- return list(cls._PROVIDER_MAP.keys())
210
+ return list(cls._PROVIDER_MAP.keys())
@@ -0,0 +1,49 @@
1
+ """Perplexity AI integration module"""
2
+
3
+ from .credentials import PerplexityCredentials
4
+ from .skills import (
5
+ PerplexityInput,
6
+ PerplexityOutput,
7
+ PerplexityChatSkill,
8
+ PerplexityCitation,
9
+ PerplexityStreamingChatSkill,
10
+ PerplexityStreamOutput,
11
+ )
12
+ from .list_models import (
13
+ PerplexityListModelsSkill,
14
+ StandalonePerplexityListModelsSkill,
15
+ PerplexityListModelsInput,
16
+ PerplexityListModelsOutput,
17
+ )
18
+ from .models_config import (
19
+ get_model_config,
20
+ get_default_model,
21
+ supports_citations,
22
+ supports_search,
23
+ get_models_by_category,
24
+ PERPLEXITY_MODELS_CONFIG,
25
+ )
26
+
27
+ __all__ = [
28
+ # Credentials
29
+ "PerplexityCredentials",
30
+ # Skills
31
+ "PerplexityInput",
32
+ "PerplexityOutput",
33
+ "PerplexityChatSkill",
34
+ "PerplexityCitation",
35
+ "PerplexityStreamingChatSkill",
36
+ "PerplexityStreamOutput",
37
+ # List Models
38
+ "PerplexityListModelsSkill",
39
+ "StandalonePerplexityListModelsSkill",
40
+ "PerplexityListModelsInput",
41
+ "PerplexityListModelsOutput",
42
+ # Model Config
43
+ "get_model_config",
44
+ "get_default_model",
45
+ "supports_citations",
46
+ "supports_search",
47
+ "get_models_by_category",
48
+ "PERPLEXITY_MODELS_CONFIG",
49
+ ]
@@ -0,0 +1,43 @@
1
+ from pydantic import Field, SecretStr
2
+ from airtrain.core.credentials import BaseCredentials, CredentialValidationError
3
+ import requests
4
+
5
+
6
+ class PerplexityCredentials(BaseCredentials):
7
+ """Perplexity AI API credentials"""
8
+
9
+ perplexity_api_key: SecretStr = Field(..., description="Perplexity AI API key")
10
+
11
+ _required_credentials = {"perplexity_api_key"}
12
+
13
+ async def validate_credentials(self) -> bool:
14
+ """Validate Perplexity AI credentials by making a test API call"""
15
+ try:
16
+ headers = {
17
+ "Authorization": f"Bearer {self.perplexity_api_key.get_secret_value()}",
18
+ "Content-Type": "application/json",
19
+ }
20
+
21
+ # Small API call to check if credentials are valid
22
+ data = {
23
+ "model": "sonar-pro",
24
+ "messages": [{"role": "user", "content": "Test"}],
25
+ "max_tokens": 1,
26
+ }
27
+
28
+ # Make a synchronous request for validation
29
+ response = requests.post(
30
+ "https://api.perplexity.ai/chat/completions", headers=headers, json=data
31
+ )
32
+
33
+ if response.status_code == 200:
34
+ return True
35
+ else:
36
+ raise CredentialValidationError(
37
+ f"Invalid Perplexity AI credentials: {response.status_code} - {response.text}"
38
+ )
39
+
40
+ except Exception as e:
41
+ raise CredentialValidationError(
42
+ f"Invalid Perplexity AI credentials: {str(e)}"
43
+ )