airtrain 0.1.58__py3-none-any.whl → 0.1.62__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airtrain/__init__.py +72 -44
- airtrain/__pycache__/__init__.cpython-313.pyc +0 -0
- airtrain/core/__pycache__/__init__.cpython-313.pyc +0 -0
- airtrain/core/__pycache__/schemas.cpython-313.pyc +0 -0
- airtrain/core/__pycache__/skills.cpython-313.pyc +0 -0
- airtrain/core/credentials.py +59 -13
- airtrain/integrations/__init__.py +21 -2
- airtrain/integrations/combined/list_models_factory.py +80 -41
- airtrain/integrations/perplexity/__init__.py +49 -0
- airtrain/integrations/perplexity/credentials.py +43 -0
- airtrain/integrations/perplexity/list_models.py +112 -0
- airtrain/integrations/perplexity/models_config.py +128 -0
- airtrain/integrations/perplexity/skills.py +279 -0
- airtrain/integrations/search/__init__.py +21 -0
- airtrain/integrations/search/exa/__init__.py +23 -0
- airtrain/integrations/search/exa/credentials.py +30 -0
- airtrain/integrations/search/exa/schemas.py +114 -0
- airtrain/integrations/search/exa/skills.py +114 -0
- airtrain/tools/__init__.py +9 -5
- airtrain/tools/command.py +248 -61
- airtrain/tools/search.py +450 -0
- airtrain/tools/testing.py +135 -0
- {airtrain-0.1.58.dist-info → airtrain-0.1.62.dist-info}/METADATA +1 -1
- {airtrain-0.1.58.dist-info → airtrain-0.1.62.dist-info}/RECORD +27 -15
- {airtrain-0.1.58.dist-info → airtrain-0.1.62.dist-info}/WHEEL +1 -1
- {airtrain-0.1.58.dist-info → airtrain-0.1.62.dist-info}/entry_points.txt +0 -0
- {airtrain-0.1.58.dist-info → airtrain-0.1.62.dist-info}/top_level.txt +0 -0
airtrain/__init__.py
CHANGED
@@ -1,43 +1,64 @@
|
|
1
|
-
"""
|
1
|
+
"""
|
2
|
+
Airtrain: AI Agent Framework
|
2
3
|
|
3
|
-
|
4
|
+
This library provides a flexible framework for building AI agents
|
5
|
+
that can complete complex tasks using AI models, skills, and tools.
|
6
|
+
"""
|
7
|
+
|
8
|
+
__version__ = "0.1.62"
|
4
9
|
|
5
10
|
import sys
|
6
11
|
|
7
12
|
# Core imports
|
8
|
-
from .core
|
9
|
-
from .core.schemas import InputSchema, OutputSchema
|
10
|
-
from .core.credentials import BaseCredentials
|
13
|
+
from .core import Skill, ProcessingError, InputSchema, OutputSchema, BaseCredentials
|
11
14
|
|
12
15
|
# Integration imports - Credentials
|
13
|
-
from .integrations
|
14
|
-
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
16
|
+
from .integrations import (
|
17
|
+
# OpenAI
|
18
|
+
OpenAICredentials,
|
19
|
+
OpenAIChatSkill,
|
20
|
+
# Anthropic
|
21
|
+
AnthropicCredentials,
|
22
|
+
AnthropicChatSkill,
|
23
|
+
# Together.ai
|
24
|
+
TogetherAICredentials,
|
25
|
+
TogetherAIChatSkill,
|
26
|
+
# Fireworks
|
27
|
+
FireworksCredentials,
|
28
|
+
FireworksChatSkill,
|
29
|
+
# Google
|
30
|
+
GeminiCredentials,
|
31
|
+
GoogleChatSkill,
|
32
|
+
# Search
|
33
|
+
ExaCredentials,
|
34
|
+
ExaSearchSkill,
|
35
|
+
ExaSearchInputSchema,
|
36
|
+
ExaSearchOutputSchema,
|
37
|
+
)
|
22
38
|
|
23
39
|
# Integration imports - Skills
|
24
|
-
from .integrations.openai.skills import OpenAIChatSkill, OpenAIParserSkill
|
25
|
-
from .integrations.anthropic.skills import AnthropicChatSkill
|
26
40
|
from .integrations.aws.skills import AWSBedrockSkill
|
27
41
|
from .integrations.google.skills import GoogleChatSkill
|
28
42
|
from .integrations.groq.skills import GroqChatSkill
|
29
|
-
from .integrations.together.skills import TogetherAIChatSkill
|
30
43
|
from .integrations.ollama.skills import OllamaChatSkill
|
31
44
|
from .integrations.sambanova.skills import SambanovaChatSkill
|
32
45
|
from .integrations.cerebras.skills import CerebrasChatSkill
|
33
46
|
|
34
47
|
# Tool imports
|
35
48
|
from .tools import (
|
36
|
-
StatefulTool,
|
37
|
-
StatelessTool,
|
38
|
-
register_tool,
|
39
49
|
ToolFactory,
|
40
|
-
|
50
|
+
register_tool,
|
51
|
+
StatelessTool,
|
52
|
+
StatefulTool,
|
53
|
+
BaseTool,
|
54
|
+
ListDirectoryTool,
|
55
|
+
DirectoryTreeTool,
|
56
|
+
ApiCallTool,
|
57
|
+
ExecuteCommandTool,
|
58
|
+
FindFilesTool,
|
59
|
+
TerminalNavigationTool,
|
60
|
+
SearchTermTool,
|
61
|
+
RunPytestTool,
|
41
62
|
)
|
42
63
|
|
43
64
|
# Agent imports
|
@@ -48,7 +69,7 @@ from .agents import (
|
|
48
69
|
BaseMemory,
|
49
70
|
ShortTermMemory,
|
50
71
|
LongTermMemory,
|
51
|
-
SharedMemory
|
72
|
+
SharedMemory,
|
52
73
|
)
|
53
74
|
|
54
75
|
# Telemetry import - must be imported after version is defined
|
@@ -63,33 +84,40 @@ __all__ = [
|
|
63
84
|
"InputSchema",
|
64
85
|
"OutputSchema",
|
65
86
|
"BaseCredentials",
|
66
|
-
#
|
87
|
+
# OpenAI Integration
|
67
88
|
"OpenAICredentials",
|
68
|
-
"AWSCredentials",
|
69
|
-
"GoogleCloudCredentials",
|
70
|
-
"AnthropicCredentials",
|
71
|
-
"GroqCredentials",
|
72
|
-
"TogetherAICredentials",
|
73
|
-
"OllamaCredentials",
|
74
|
-
"SambanovaCredentials",
|
75
|
-
"CerebrasCredentials",
|
76
|
-
# Skills
|
77
89
|
"OpenAIChatSkill",
|
78
|
-
|
90
|
+
# Anthropic Integration
|
91
|
+
"AnthropicCredentials",
|
79
92
|
"AnthropicChatSkill",
|
80
|
-
|
81
|
-
"
|
82
|
-
"GroqChatSkill",
|
93
|
+
# Together Integration
|
94
|
+
"TogetherAICredentials",
|
83
95
|
"TogetherAIChatSkill",
|
84
|
-
|
85
|
-
"
|
86
|
-
"
|
96
|
+
# Fireworks Integration
|
97
|
+
"FireworksCredentials",
|
98
|
+
"FireworksChatSkill",
|
99
|
+
# Google Integration
|
100
|
+
"GeminiCredentials",
|
101
|
+
"GoogleChatSkill",
|
102
|
+
# Search Integration
|
103
|
+
"ExaCredentials",
|
104
|
+
"ExaSearchSkill",
|
105
|
+
"ExaSearchInputSchema",
|
106
|
+
"ExaSearchOutputSchema",
|
87
107
|
# Tools
|
88
|
-
"StatefulTool",
|
89
|
-
"StatelessTool",
|
90
|
-
"register_tool",
|
91
108
|
"ToolFactory",
|
92
|
-
"
|
109
|
+
"register_tool",
|
110
|
+
"StatelessTool",
|
111
|
+
"StatefulTool",
|
112
|
+
"BaseTool",
|
113
|
+
"ListDirectoryTool",
|
114
|
+
"DirectoryTreeTool",
|
115
|
+
"ApiCallTool",
|
116
|
+
"ExecuteCommandTool",
|
117
|
+
"FindFilesTool",
|
118
|
+
"TerminalNavigationTool",
|
119
|
+
"SearchTermTool",
|
120
|
+
"RunPytestTool",
|
93
121
|
# Agents
|
94
122
|
"BaseAgent",
|
95
123
|
"AgentFactory",
|
@@ -111,7 +139,7 @@ try:
|
|
111
139
|
f"{sys.version_info.major}."
|
112
140
|
f"{sys.version_info.minor}."
|
113
141
|
f"{sys.version_info.micro}"
|
114
|
-
)
|
142
|
+
),
|
115
143
|
)
|
116
144
|
)
|
117
145
|
except Exception:
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
airtrain/core/credentials.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
from typing import Dict, List, Optional, Set
|
1
|
+
from typing import Dict, List, Optional, Set, Union
|
2
2
|
import os
|
3
3
|
import json
|
4
4
|
from pathlib import Path
|
@@ -53,30 +53,70 @@ class BaseCredentials(BaseModel):
|
|
53
53
|
return cls(**field_values)
|
54
54
|
|
55
55
|
@classmethod
|
56
|
-
def from_file(cls, file_path: Path) -> "BaseCredentials":
|
57
|
-
"""Load credentials from a file (supports .env, .json, .yaml)
|
56
|
+
def from_file(cls, file_path: str | Path) -> "BaseCredentials":
|
57
|
+
"""Load credentials from a file (supports .env, .json, .yaml).
|
58
|
+
|
59
|
+
Args:
|
60
|
+
file_path: Path to load credentials from. Can be a string or Path object.
|
61
|
+
Supported formats: .env, .json, .yaml/.yml
|
62
|
+
|
63
|
+
Returns:
|
64
|
+
BaseCredentials: Initialized credentials object
|
65
|
+
|
66
|
+
Raises:
|
67
|
+
FileNotFoundError: If the credentials file does not exist
|
68
|
+
ValueError: If the file format is not supported
|
69
|
+
"""
|
70
|
+
# Convert to Path object if string
|
71
|
+
if isinstance(file_path, str):
|
72
|
+
file_path = Path(file_path)
|
73
|
+
|
58
74
|
if not file_path.exists():
|
59
75
|
raise FileNotFoundError(f"Credentials file not found: {file_path}")
|
60
76
|
|
61
|
-
|
77
|
+
# Get file extension, default to .env if none provided
|
78
|
+
suffix = file_path.suffix
|
79
|
+
if not suffix:
|
80
|
+
# Try to find a file with the same name but different extension
|
81
|
+
for ext in [".env", ".json", ".yaml", ".yml"]:
|
82
|
+
potential_path = file_path.with_suffix(ext)
|
83
|
+
if potential_path.exists():
|
84
|
+
file_path = potential_path
|
85
|
+
suffix = ext
|
86
|
+
break
|
87
|
+
# If no file was found, default to .env
|
88
|
+
if not suffix:
|
89
|
+
file_path = file_path.with_suffix(".env")
|
90
|
+
suffix = ".env"
|
91
|
+
|
92
|
+
if suffix == ".env":
|
62
93
|
dotenv.load_dotenv(file_path)
|
63
94
|
return cls.from_env()
|
64
95
|
|
65
|
-
elif
|
96
|
+
elif suffix == ".json":
|
66
97
|
with open(file_path) as f:
|
67
98
|
data = json.load(f)
|
68
99
|
return cls(**data)
|
69
100
|
|
70
|
-
elif
|
101
|
+
elif suffix in {".yaml", ".yml"}:
|
71
102
|
with open(file_path) as f:
|
72
103
|
data = yaml.safe_load(f)
|
73
104
|
return cls(**data)
|
74
105
|
|
75
106
|
else:
|
76
|
-
raise ValueError(f"Unsupported file format: {
|
107
|
+
raise ValueError(f"Unsupported file format: {suffix}")
|
108
|
+
|
109
|
+
def save_to_file(self, file_path: str | Path) -> None:
|
110
|
+
"""Save credentials to a file.
|
111
|
+
|
112
|
+
Args:
|
113
|
+
file_path: Path to save credentials to. Can be a string or Path object.
|
114
|
+
Supported formats: .env, .json, .yaml/.yml
|
115
|
+
"""
|
116
|
+
# Convert to Path object if string
|
117
|
+
if isinstance(file_path, str):
|
118
|
+
file_path = Path(file_path)
|
77
119
|
|
78
|
-
def save_to_file(self, file_path: Path) -> None:
|
79
|
-
"""Save credentials to a file"""
|
80
120
|
data = self.model_dump(exclude={"_loaded"})
|
81
121
|
|
82
122
|
# Convert SecretStr to plain strings for saving
|
@@ -84,21 +124,27 @@ class BaseCredentials(BaseModel):
|
|
84
124
|
if isinstance(value, SecretStr):
|
85
125
|
data[key] = value.get_secret_value()
|
86
126
|
|
87
|
-
|
127
|
+
# Get file extension, default to .env if none provided
|
128
|
+
suffix = file_path.suffix
|
129
|
+
if not suffix:
|
130
|
+
file_path = file_path.with_suffix(".env")
|
131
|
+
suffix = ".env"
|
132
|
+
|
133
|
+
if suffix == ".env":
|
88
134
|
with open(file_path, "w") as f:
|
89
135
|
for key, value in data.items():
|
90
136
|
f.write(f"{key.upper()}={value}\n")
|
91
137
|
|
92
|
-
elif
|
138
|
+
elif suffix == ".json":
|
93
139
|
with open(file_path, "w") as f:
|
94
140
|
json.dump(data, f, indent=2)
|
95
141
|
|
96
|
-
elif
|
142
|
+
elif suffix in {".yaml", ".yml"}:
|
97
143
|
with open(file_path, "w") as f:
|
98
144
|
yaml.dump(data, f)
|
99
145
|
|
100
146
|
else:
|
101
|
-
raise ValueError(f"Unsupported file format: {
|
147
|
+
raise ValueError(f"Unsupported file format: {suffix}")
|
102
148
|
|
103
149
|
async def validate_credentials(self) -> bool:
|
104
150
|
"""Validate that all required credentials are present"""
|
@@ -3,13 +3,16 @@
|
|
3
3
|
# Credentials imports
|
4
4
|
from .openai.credentials import OpenAICredentials
|
5
5
|
from .aws.credentials import AWSCredentials
|
6
|
-
from .google.credentials import GoogleCloudCredentials
|
6
|
+
from .google.credentials import GoogleCloudCredentials, GeminiCredentials
|
7
7
|
from .anthropic.credentials import AnthropicCredentials
|
8
8
|
from .groq.credentials import GroqCredentials
|
9
9
|
from .together.credentials import TogetherAICredentials
|
10
10
|
from .ollama.credentials import OllamaCredentials
|
11
11
|
from .sambanova.credentials import SambanovaCredentials
|
12
12
|
from .cerebras.credentials import CerebrasCredentials
|
13
|
+
from .perplexity.credentials import PerplexityCredentials
|
14
|
+
from .fireworks.credentials import FireworksCredentials
|
15
|
+
from .search.exa.credentials import ExaCredentials
|
13
16
|
|
14
17
|
# Skills imports
|
15
18
|
from .openai.skills import OpenAIChatSkill, OpenAIParserSkill
|
@@ -21,16 +24,21 @@ from .together.skills import TogetherAIChatSkill
|
|
21
24
|
from .ollama.skills import OllamaChatSkill
|
22
25
|
from .sambanova.skills import SambanovaChatSkill
|
23
26
|
from .cerebras.skills import CerebrasChatSkill
|
27
|
+
from .perplexity.skills import PerplexityChatSkill, PerplexityStreamingChatSkill
|
28
|
+
from .fireworks.skills import FireworksChatSkill
|
29
|
+
from .search.exa.skills import ExaSearchSkill
|
30
|
+
from .search.exa import ExaSearchInputSchema, ExaSearchOutputSchema
|
24
31
|
|
25
32
|
# Model configurations
|
26
33
|
from .openai.models_config import OPENAI_MODELS, OpenAIModelConfig
|
27
34
|
from .anthropic.models_config import ANTHROPIC_MODELS, AnthropicModelConfig
|
35
|
+
from .perplexity.models_config import PERPLEXITY_MODELS_CONFIG
|
28
36
|
|
29
37
|
# Combined modules
|
30
38
|
from .combined.list_models_factory import (
|
31
39
|
ListModelsSkillFactory,
|
32
40
|
GenericListModelsInput,
|
33
|
-
GenericListModelsOutput
|
41
|
+
GenericListModelsOutput,
|
34
42
|
)
|
35
43
|
|
36
44
|
__all__ = [
|
@@ -44,6 +52,10 @@ __all__ = [
|
|
44
52
|
"OllamaCredentials",
|
45
53
|
"SambanovaCredentials",
|
46
54
|
"CerebrasCredentials",
|
55
|
+
"PerplexityCredentials",
|
56
|
+
"FireworksCredentials",
|
57
|
+
"GeminiCredentials",
|
58
|
+
"ExaCredentials",
|
47
59
|
# Skills
|
48
60
|
"OpenAIChatSkill",
|
49
61
|
"OpenAIParserSkill",
|
@@ -55,11 +67,18 @@ __all__ = [
|
|
55
67
|
"OllamaChatSkill",
|
56
68
|
"SambanovaChatSkill",
|
57
69
|
"CerebrasChatSkill",
|
70
|
+
"PerplexityChatSkill",
|
71
|
+
"PerplexityStreamingChatSkill",
|
72
|
+
"FireworksChatSkill",
|
73
|
+
"ExaSearchSkill",
|
74
|
+
"ExaSearchInputSchema",
|
75
|
+
"ExaSearchOutputSchema",
|
58
76
|
# Model configurations
|
59
77
|
"OPENAI_MODELS",
|
60
78
|
"OpenAIModelConfig",
|
61
79
|
"ANTHROPIC_MODELS",
|
62
80
|
"AnthropicModelConfig",
|
81
|
+
"PERPLEXITY_MODELS_CONFIG",
|
63
82
|
# Combined modules
|
64
83
|
"ListModelsSkillFactory",
|
65
84
|
"GenericListModelsInput",
|
@@ -15,16 +15,21 @@ from airtrain.integrations.fireworks.list_models import FireworksListModelsSkill
|
|
15
15
|
from airtrain.integrations.groq.credentials import GroqCredentials
|
16
16
|
from airtrain.integrations.cerebras.credentials import CerebrasCredentials
|
17
17
|
from airtrain.integrations.sambanova.credentials import SambanovaCredentials
|
18
|
+
from airtrain.integrations.perplexity.credentials import PerplexityCredentials
|
19
|
+
|
20
|
+
# Remove this import to avoid circular dependency
|
21
|
+
# from airtrain.integrations.perplexity.list_models import PerplexityListModelsSkill
|
22
|
+
|
18
23
|
|
19
24
|
# Generic list models input schema
|
20
25
|
class GenericListModelsInput(InputSchema):
|
21
26
|
"""Generic schema for listing models from any provider"""
|
22
|
-
|
27
|
+
|
23
28
|
api_models_only: bool = Field(
|
24
29
|
default=False,
|
25
30
|
description=(
|
26
31
|
"If True, fetch models from the API only. If False, use local config."
|
27
|
-
)
|
32
|
+
),
|
28
33
|
)
|
29
34
|
|
30
35
|
class Config:
|
@@ -32,38 +37,33 @@ class GenericListModelsInput(InputSchema):
|
|
32
37
|
extra = "allow"
|
33
38
|
|
34
39
|
|
35
|
-
|
36
40
|
# Generic list models output schema
|
37
41
|
class GenericListModelsOutput(OutputSchema):
|
38
42
|
"""Generic schema for list models output from any provider"""
|
39
|
-
|
43
|
+
|
40
44
|
models: List[Dict[str, Any]] = Field(
|
41
|
-
default_factory=list,
|
42
|
-
description="List of models"
|
43
|
-
)
|
44
|
-
provider: str = Field(
|
45
|
-
...,
|
46
|
-
description="Provider name"
|
45
|
+
default_factory=list, description="List of models"
|
47
46
|
)
|
47
|
+
provider: str = Field(..., description="Provider name")
|
48
48
|
|
49
49
|
|
50
50
|
# Base class for stub implementations
|
51
51
|
class BaseListModelsSkill(Skill[GenericListModelsInput, GenericListModelsOutput]):
|
52
52
|
"""Base skill for listing models"""
|
53
|
-
|
53
|
+
|
54
54
|
input_schema = GenericListModelsInput
|
55
55
|
output_schema = GenericListModelsOutput
|
56
|
-
|
56
|
+
|
57
57
|
def __init__(self, provider: str, credentials: Optional[BaseCredentials] = None):
|
58
58
|
"""Initialize the skill with provider name and optional credentials"""
|
59
59
|
super().__init__()
|
60
60
|
self.provider = provider
|
61
61
|
self.credentials = credentials
|
62
|
-
|
62
|
+
|
63
63
|
def get_models(self) -> List[Dict[str, Any]]:
|
64
64
|
"""Return list of models. To be implemented by subclasses."""
|
65
65
|
raise NotImplementedError("Subclasses must implement get_models()")
|
66
|
-
|
66
|
+
|
67
67
|
def process(self, input_data: GenericListModelsInput) -> GenericListModelsOutput:
|
68
68
|
"""Process the input and return a list of models."""
|
69
69
|
try:
|
@@ -76,24 +76,42 @@ class BaseListModelsSkill(Skill[GenericListModelsInput, GenericListModelsOutput]
|
|
76
76
|
# Groq implementation
|
77
77
|
class GroqListModelsSkill(BaseListModelsSkill):
|
78
78
|
"""Skill for listing Groq models"""
|
79
|
-
|
79
|
+
|
80
80
|
def __init__(self, credentials: Optional[GroqCredentials] = None):
|
81
81
|
"""Initialize the skill with optional credentials"""
|
82
82
|
super().__init__(provider="groq", credentials=credentials)
|
83
|
-
|
83
|
+
|
84
84
|
def get_models(self) -> List[Dict[str, Any]]:
|
85
85
|
"""Return list of Groq models."""
|
86
86
|
# Default Groq models from trmx_agent config
|
87
87
|
models = [
|
88
|
-
{
|
89
|
-
|
90
|
-
|
88
|
+
{
|
89
|
+
"id": "llama-3.3-70b-versatile",
|
90
|
+
"display_name": "Llama 3.3 70B Versatile (Tool Use)",
|
91
|
+
},
|
92
|
+
{
|
93
|
+
"id": "llama-3.1-8b-instant",
|
94
|
+
"display_name": "Llama 3.1 8B Instant (Tool Use)",
|
95
|
+
},
|
96
|
+
{
|
97
|
+
"id": "mixtral-8x7b-32768",
|
98
|
+
"display_name": "Mixtral 8x7B (32K) (Tool Use)",
|
99
|
+
},
|
91
100
|
{"id": "gemma2-9b-it", "display_name": "Gemma 2 9B IT (Tool Use)"},
|
92
101
|
{"id": "qwen-qwq-32b", "display_name": "Qwen QWQ 32B (Tool Use)"},
|
93
|
-
{
|
102
|
+
{
|
103
|
+
"id": "qwen-2.5-coder-32b",
|
104
|
+
"display_name": "Qwen 2.5 Coder 32B (Tool Use)",
|
105
|
+
},
|
94
106
|
{"id": "qwen-2.5-32b", "display_name": "Qwen 2.5 32B (Tool Use)"},
|
95
|
-
{
|
96
|
-
|
107
|
+
{
|
108
|
+
"id": "deepseek-r1-distill-qwen-32b",
|
109
|
+
"display_name": "DeepSeek R1 Distill Qwen 32B (Tool Use)",
|
110
|
+
},
|
111
|
+
{
|
112
|
+
"id": "deepseek-r1-distill-llama-70b",
|
113
|
+
"display_name": "DeepSeek R1 Distill Llama 70B (Tool Use)",
|
114
|
+
},
|
97
115
|
]
|
98
116
|
return models
|
99
117
|
|
@@ -101,18 +119,27 @@ class GroqListModelsSkill(BaseListModelsSkill):
|
|
101
119
|
# Cerebras implementation
|
102
120
|
class CerebrasListModelsSkill(BaseListModelsSkill):
|
103
121
|
"""Skill for listing Cerebras models"""
|
104
|
-
|
122
|
+
|
105
123
|
def __init__(self, credentials: Optional[CerebrasCredentials] = None):
|
106
124
|
"""Initialize the skill with optional credentials"""
|
107
125
|
super().__init__(provider="cerebras", credentials=credentials)
|
108
|
-
|
126
|
+
|
109
127
|
def get_models(self) -> List[Dict[str, Any]]:
|
110
128
|
"""Return list of Cerebras models."""
|
111
129
|
# Default Cerebras models from trmx_agent config
|
112
130
|
models = [
|
113
|
-
{
|
114
|
-
|
115
|
-
|
131
|
+
{
|
132
|
+
"id": "cerebras/Cerebras-GPT-13B-v0.1",
|
133
|
+
"display_name": "Cerebras GPT 13B v0.1",
|
134
|
+
},
|
135
|
+
{
|
136
|
+
"id": "cerebras/Cerebras-GPT-111M-v0.9",
|
137
|
+
"display_name": "Cerebras GPT 111M v0.9",
|
138
|
+
},
|
139
|
+
{
|
140
|
+
"id": "cerebras/Cerebras-GPT-590M-v0.7",
|
141
|
+
"display_name": "Cerebras GPT 590M v0.7",
|
142
|
+
},
|
116
143
|
]
|
117
144
|
return models
|
118
145
|
|
@@ -120,17 +147,17 @@ class CerebrasListModelsSkill(BaseListModelsSkill):
|
|
120
147
|
# Sambanova implementation
|
121
148
|
class SambanovaListModelsSkill(BaseListModelsSkill):
|
122
149
|
"""Skill for listing Sambanova models"""
|
123
|
-
|
150
|
+
|
124
151
|
def __init__(self, credentials: Optional[SambanovaCredentials] = None):
|
125
152
|
"""Initialize the skill with optional credentials"""
|
126
153
|
super().__init__(provider="sambanova", credentials=credentials)
|
127
|
-
|
154
|
+
|
128
155
|
def get_models(self) -> List[Dict[str, Any]]:
|
129
156
|
"""Return list of Sambanova models."""
|
130
157
|
# Limited Sambanova model information
|
131
158
|
models = [
|
132
159
|
{"id": "sambanova/samba-1", "display_name": "Samba-1"},
|
133
|
-
{"id": "sambanova/samba-2", "display_name": "Samba-2"}
|
160
|
+
{"id": "sambanova/samba-2", "display_name": "Samba-2"},
|
134
161
|
]
|
135
162
|
return models
|
136
163
|
|
@@ -138,7 +165,7 @@ class SambanovaListModelsSkill(BaseListModelsSkill):
|
|
138
165
|
# Factory class
|
139
166
|
class ListModelsSkillFactory:
|
140
167
|
"""Factory for creating list models skills for different providers"""
|
141
|
-
|
168
|
+
|
142
169
|
# Map provider names to their corresponding list models skills
|
143
170
|
_PROVIDER_MAP = {
|
144
171
|
"openai": OpenAIListModelsSkill,
|
@@ -147,36 +174,48 @@ class ListModelsSkillFactory:
|
|
147
174
|
"fireworks": FireworksListModelsSkill,
|
148
175
|
"groq": GroqListModelsSkill,
|
149
176
|
"cerebras": CerebrasListModelsSkill,
|
150
|
-
"sambanova": SambanovaListModelsSkill
|
177
|
+
"sambanova": SambanovaListModelsSkill,
|
178
|
+
# Remove perplexity from this map as we'll handle it separately
|
179
|
+
# "perplexity": PerplexityListModelsSkill,
|
151
180
|
}
|
152
|
-
|
181
|
+
|
153
182
|
@classmethod
|
154
183
|
def get_skill(cls, provider: str, credentials=None):
|
155
184
|
"""Return a list models skill for the specified provider
|
156
|
-
|
185
|
+
|
157
186
|
Args:
|
158
187
|
provider (str): The provider name (case-insensitive)
|
159
188
|
credentials: Optional credentials for the provider
|
160
|
-
|
189
|
+
|
161
190
|
Returns:
|
162
191
|
A ListModelsSkill instance for the specified provider
|
163
|
-
|
192
|
+
|
164
193
|
Raises:
|
165
194
|
ValueError: If the provider is not supported
|
166
195
|
"""
|
167
196
|
provider = provider.lower()
|
168
|
-
|
197
|
+
|
198
|
+
# Special case for perplexity to avoid circular import
|
199
|
+
if provider == "perplexity":
|
200
|
+
# Import here to avoid circular import
|
201
|
+
from airtrain.integrations.perplexity.list_models import (
|
202
|
+
PerplexityListModelsSkill,
|
203
|
+
)
|
204
|
+
|
205
|
+
return PerplexityListModelsSkill(credentials=credentials)
|
206
|
+
|
169
207
|
if provider not in cls._PROVIDER_MAP:
|
170
|
-
supported = ", ".join(cls.get_supported_providers())
|
208
|
+
supported = ", ".join(cls.get_supported_providers() + ["perplexity"])
|
171
209
|
raise ValueError(
|
172
210
|
f"Unsupported provider: {provider}. "
|
173
211
|
f"Supported providers are: {supported}"
|
174
212
|
)
|
175
|
-
|
213
|
+
|
176
214
|
skill_class = cls._PROVIDER_MAP[provider]
|
177
215
|
return skill_class(credentials=credentials)
|
178
|
-
|
216
|
+
|
179
217
|
@classmethod
|
180
218
|
def get_supported_providers(cls):
|
181
219
|
"""Return a list of supported provider names"""
|
182
|
-
|
220
|
+
# Add perplexity to the list of supported providers
|
221
|
+
return list(cls._PROVIDER_MAP.keys()) + ["perplexity"]
|
@@ -0,0 +1,49 @@
|
|
1
|
+
"""Perplexity AI integration module"""
|
2
|
+
|
3
|
+
from .credentials import PerplexityCredentials
|
4
|
+
from .skills import (
|
5
|
+
PerplexityInput,
|
6
|
+
PerplexityOutput,
|
7
|
+
PerplexityChatSkill,
|
8
|
+
PerplexityCitation,
|
9
|
+
PerplexityStreamingChatSkill,
|
10
|
+
PerplexityStreamOutput,
|
11
|
+
)
|
12
|
+
from .list_models import (
|
13
|
+
PerplexityListModelsSkill,
|
14
|
+
StandalonePerplexityListModelsSkill,
|
15
|
+
PerplexityListModelsInput,
|
16
|
+
PerplexityListModelsOutput,
|
17
|
+
)
|
18
|
+
from .models_config import (
|
19
|
+
get_model_config,
|
20
|
+
get_default_model,
|
21
|
+
supports_citations,
|
22
|
+
supports_search,
|
23
|
+
get_models_by_category,
|
24
|
+
PERPLEXITY_MODELS_CONFIG,
|
25
|
+
)
|
26
|
+
|
27
|
+
__all__ = [
|
28
|
+
# Credentials
|
29
|
+
"PerplexityCredentials",
|
30
|
+
# Skills
|
31
|
+
"PerplexityInput",
|
32
|
+
"PerplexityOutput",
|
33
|
+
"PerplexityChatSkill",
|
34
|
+
"PerplexityCitation",
|
35
|
+
"PerplexityStreamingChatSkill",
|
36
|
+
"PerplexityStreamOutput",
|
37
|
+
# List Models
|
38
|
+
"PerplexityListModelsSkill",
|
39
|
+
"StandalonePerplexityListModelsSkill",
|
40
|
+
"PerplexityListModelsInput",
|
41
|
+
"PerplexityListModelsOutput",
|
42
|
+
# Model Config
|
43
|
+
"get_model_config",
|
44
|
+
"get_default_model",
|
45
|
+
"supports_citations",
|
46
|
+
"supports_search",
|
47
|
+
"get_models_by_category",
|
48
|
+
"PERPLEXITY_MODELS_CONFIG",
|
49
|
+
]
|