airtrain 0.1.3__py3-none-any.whl → 0.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airtrain/__init__.py +146 -6
- airtrain/__main__.py +4 -0
- airtrain/__pycache__/__init__.cpython-313.pyc +0 -0
- airtrain/agents/__init__.py +45 -0
- airtrain/agents/example_agent.py +348 -0
- airtrain/agents/groq_agent.py +289 -0
- airtrain/agents/memory.py +663 -0
- airtrain/agents/registry.py +465 -0
- airtrain/builder/__init__.py +3 -0
- airtrain/builder/agent_builder.py +122 -0
- airtrain/cli/__init__.py +0 -0
- airtrain/cli/builder.py +23 -0
- airtrain/cli/main.py +120 -0
- airtrain/contrib/__init__.py +29 -0
- airtrain/contrib/travel/__init__.py +35 -0
- airtrain/contrib/travel/agents.py +243 -0
- airtrain/contrib/travel/models.py +59 -0
- airtrain/core/__pycache__/__init__.cpython-313.pyc +0 -0
- airtrain/core/__pycache__/schemas.cpython-313.pyc +0 -0
- airtrain/core/__pycache__/skills.cpython-313.pyc +0 -0
- airtrain/core/credentials.py +62 -44
- airtrain/core/skills.py +102 -0
- airtrain/integrations/__init__.py +74 -0
- airtrain/integrations/anthropic/__init__.py +33 -0
- airtrain/integrations/anthropic/credentials.py +32 -0
- airtrain/integrations/anthropic/list_models.py +110 -0
- airtrain/integrations/anthropic/models_config.py +100 -0
- airtrain/integrations/anthropic/skills.py +155 -0
- airtrain/integrations/aws/__init__.py +6 -0
- airtrain/integrations/aws/credentials.py +36 -0
- airtrain/integrations/aws/skills.py +98 -0
- airtrain/integrations/cerebras/__init__.py +6 -0
- airtrain/integrations/cerebras/credentials.py +19 -0
- airtrain/integrations/cerebras/skills.py +127 -0
- airtrain/integrations/combined/__init__.py +21 -0
- airtrain/integrations/combined/groq_fireworks_skills.py +126 -0
- airtrain/integrations/combined/list_models_factory.py +210 -0
- airtrain/integrations/fireworks/__init__.py +21 -0
- airtrain/integrations/fireworks/completion_skills.py +147 -0
- airtrain/integrations/fireworks/conversation_manager.py +109 -0
- airtrain/integrations/fireworks/credentials.py +26 -0
- airtrain/integrations/fireworks/list_models.py +128 -0
- airtrain/integrations/fireworks/models.py +139 -0
- airtrain/integrations/fireworks/requests_skills.py +207 -0
- airtrain/integrations/fireworks/skills.py +181 -0
- airtrain/integrations/fireworks/structured_completion_skills.py +175 -0
- airtrain/integrations/fireworks/structured_requests_skills.py +291 -0
- airtrain/integrations/fireworks/structured_skills.py +102 -0
- airtrain/integrations/google/__init__.py +7 -0
- airtrain/integrations/google/credentials.py +58 -0
- airtrain/integrations/google/skills.py +122 -0
- airtrain/integrations/groq/__init__.py +23 -0
- airtrain/integrations/groq/credentials.py +24 -0
- airtrain/integrations/groq/models_config.py +162 -0
- airtrain/integrations/groq/skills.py +201 -0
- airtrain/integrations/ollama/__init__.py +6 -0
- airtrain/integrations/ollama/credentials.py +26 -0
- airtrain/integrations/ollama/skills.py +41 -0
- airtrain/integrations/openai/__init__.py +37 -0
- airtrain/integrations/openai/chinese_assistant.py +42 -0
- airtrain/integrations/openai/credentials.py +39 -0
- airtrain/integrations/openai/list_models.py +112 -0
- airtrain/integrations/openai/models_config.py +224 -0
- airtrain/integrations/openai/skills.py +342 -0
- airtrain/integrations/perplexity/__init__.py +49 -0
- airtrain/integrations/perplexity/credentials.py +43 -0
- airtrain/integrations/perplexity/list_models.py +112 -0
- airtrain/integrations/perplexity/models_config.py +128 -0
- airtrain/integrations/perplexity/skills.py +279 -0
- airtrain/integrations/sambanova/__init__.py +6 -0
- airtrain/integrations/sambanova/credentials.py +20 -0
- airtrain/integrations/sambanova/skills.py +129 -0
- airtrain/integrations/search/__init__.py +21 -0
- airtrain/integrations/search/exa/__init__.py +23 -0
- airtrain/integrations/search/exa/credentials.py +30 -0
- airtrain/integrations/search/exa/schemas.py +114 -0
- airtrain/integrations/search/exa/skills.py +115 -0
- airtrain/integrations/together/__init__.py +33 -0
- airtrain/integrations/together/audio_models_config.py +34 -0
- airtrain/integrations/together/credentials.py +22 -0
- airtrain/integrations/together/embedding_models_config.py +92 -0
- airtrain/integrations/together/image_models_config.py +69 -0
- airtrain/integrations/together/image_skill.py +143 -0
- airtrain/integrations/together/list_models.py +76 -0
- airtrain/integrations/together/models.py +95 -0
- airtrain/integrations/together/models_config.py +399 -0
- airtrain/integrations/together/rerank_models_config.py +43 -0
- airtrain/integrations/together/rerank_skill.py +49 -0
- airtrain/integrations/together/schemas.py +33 -0
- airtrain/integrations/together/skills.py +305 -0
- airtrain/integrations/together/vision_models_config.py +49 -0
- airtrain/telemetry/__init__.py +38 -0
- airtrain/telemetry/service.py +167 -0
- airtrain/telemetry/views.py +237 -0
- airtrain/tools/__init__.py +45 -0
- airtrain/tools/command.py +398 -0
- airtrain/tools/filesystem.py +166 -0
- airtrain/tools/network.py +111 -0
- airtrain/tools/registry.py +320 -0
- airtrain/tools/search.py +450 -0
- airtrain/tools/testing.py +135 -0
- airtrain-0.1.4.dist-info/METADATA +222 -0
- airtrain-0.1.4.dist-info/RECORD +108 -0
- {airtrain-0.1.3.dist-info → airtrain-0.1.4.dist-info}/WHEEL +1 -1
- airtrain-0.1.4.dist-info/entry_points.txt +2 -0
- airtrain-0.1.3.dist-info/METADATA +0 -106
- airtrain-0.1.3.dist-info/RECORD +0 -9
- {airtrain-0.1.3.dist-info → airtrain-0.1.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,224 @@
|
|
1
|
+
from typing import Dict, NamedTuple, Optional
|
2
|
+
from decimal import Decimal
|
3
|
+
|
4
|
+
|
5
|
+
class OpenAIModelConfig(NamedTuple):
|
6
|
+
display_name: str
|
7
|
+
base_model: str
|
8
|
+
input_price: Decimal
|
9
|
+
cached_input_price: Optional[Decimal]
|
10
|
+
output_price: Decimal
|
11
|
+
|
12
|
+
|
13
|
+
OPENAI_MODELS: Dict[str, OpenAIModelConfig] = {
|
14
|
+
"gpt-4.5-preview": OpenAIModelConfig(
|
15
|
+
display_name="GPT-4.5 Preview",
|
16
|
+
base_model="gpt-4.5-preview",
|
17
|
+
input_price=Decimal("75.00"),
|
18
|
+
cached_input_price=Decimal("37.50"),
|
19
|
+
output_price=Decimal("150.00"),
|
20
|
+
),
|
21
|
+
"gpt-4.5-preview-2025-02-27": OpenAIModelConfig(
|
22
|
+
display_name="GPT-4.5 Preview (2025-02-27)",
|
23
|
+
base_model="gpt-4.5-preview",
|
24
|
+
input_price=Decimal("75.00"),
|
25
|
+
cached_input_price=Decimal("37.50"),
|
26
|
+
output_price=Decimal("150.00"),
|
27
|
+
),
|
28
|
+
"gpt-4o": OpenAIModelConfig(
|
29
|
+
display_name="GPT-4 Optimized",
|
30
|
+
base_model="gpt-4o",
|
31
|
+
input_price=Decimal("2.50"),
|
32
|
+
cached_input_price=Decimal("1.25"),
|
33
|
+
output_price=Decimal("10.00"),
|
34
|
+
),
|
35
|
+
"gpt-4o-2024-08-06": OpenAIModelConfig(
|
36
|
+
display_name="GPT-4 Optimized (2024-08-06)",
|
37
|
+
base_model="gpt-4o",
|
38
|
+
input_price=Decimal("2.50"),
|
39
|
+
cached_input_price=Decimal("1.25"),
|
40
|
+
output_price=Decimal("10.00"),
|
41
|
+
),
|
42
|
+
"gpt-4o-audio-preview": OpenAIModelConfig(
|
43
|
+
display_name="GPT-4 Optimized Audio Preview",
|
44
|
+
base_model="gpt-4o-audio-preview",
|
45
|
+
input_price=Decimal("2.50"),
|
46
|
+
cached_input_price=None,
|
47
|
+
output_price=Decimal("10.00"),
|
48
|
+
),
|
49
|
+
"gpt-4o-audio-preview-2024-12-17": OpenAIModelConfig(
|
50
|
+
display_name="GPT-4 Optimized Audio Preview (2024-12-17)",
|
51
|
+
base_model="gpt-4o-audio-preview",
|
52
|
+
input_price=Decimal("2.50"),
|
53
|
+
cached_input_price=None,
|
54
|
+
output_price=Decimal("10.00"),
|
55
|
+
),
|
56
|
+
"gpt-4o-realtime-preview": OpenAIModelConfig(
|
57
|
+
display_name="GPT-4 Optimized Realtime Preview",
|
58
|
+
base_model="gpt-4o-realtime-preview",
|
59
|
+
input_price=Decimal("5.00"),
|
60
|
+
cached_input_price=Decimal("2.50"),
|
61
|
+
output_price=Decimal("20.00"),
|
62
|
+
),
|
63
|
+
"gpt-4o-realtime-preview-2024-12-17": OpenAIModelConfig(
|
64
|
+
display_name="GPT-4 Optimized Realtime Preview (2024-12-17)",
|
65
|
+
base_model="gpt-4o-realtime-preview",
|
66
|
+
input_price=Decimal("5.00"),
|
67
|
+
cached_input_price=Decimal("2.50"),
|
68
|
+
output_price=Decimal("20.00"),
|
69
|
+
),
|
70
|
+
"gpt-4o-mini": OpenAIModelConfig(
|
71
|
+
display_name="GPT-4 Optimized Mini",
|
72
|
+
base_model="gpt-4o-mini",
|
73
|
+
input_price=Decimal("0.15"),
|
74
|
+
cached_input_price=Decimal("0.075"),
|
75
|
+
output_price=Decimal("0.60"),
|
76
|
+
),
|
77
|
+
"gpt-4o-mini-2024-07-18": OpenAIModelConfig(
|
78
|
+
display_name="GPT-4 Optimized Mini (2024-07-18)",
|
79
|
+
base_model="gpt-4o-mini",
|
80
|
+
input_price=Decimal("0.15"),
|
81
|
+
cached_input_price=Decimal("0.075"),
|
82
|
+
output_price=Decimal("0.60"),
|
83
|
+
),
|
84
|
+
"gpt-4o-mini-audio-preview": OpenAIModelConfig(
|
85
|
+
display_name="GPT-4 Optimized Mini Audio Preview",
|
86
|
+
base_model="gpt-4o-mini-audio-preview",
|
87
|
+
input_price=Decimal("0.15"),
|
88
|
+
cached_input_price=None,
|
89
|
+
output_price=Decimal("0.60"),
|
90
|
+
),
|
91
|
+
"gpt-4o-mini-audio-preview-2024-12-17": OpenAIModelConfig(
|
92
|
+
display_name="GPT-4 Optimized Mini Audio Preview (2024-12-17)",
|
93
|
+
base_model="gpt-4o-mini-audio-preview",
|
94
|
+
input_price=Decimal("0.15"),
|
95
|
+
cached_input_price=None,
|
96
|
+
output_price=Decimal("0.60"),
|
97
|
+
),
|
98
|
+
"gpt-4o-mini-realtime-preview": OpenAIModelConfig(
|
99
|
+
display_name="GPT-4 Optimized Mini Realtime Preview",
|
100
|
+
base_model="gpt-4o-mini-realtime-preview",
|
101
|
+
input_price=Decimal("0.60"),
|
102
|
+
cached_input_price=Decimal("0.30"),
|
103
|
+
output_price=Decimal("2.40"),
|
104
|
+
),
|
105
|
+
"gpt-4o-mini-realtime-preview-2024-12-17": OpenAIModelConfig(
|
106
|
+
display_name="GPT-4 Optimized Mini Realtime Preview (2024-12-17)",
|
107
|
+
base_model="gpt-4o-mini-realtime-preview",
|
108
|
+
input_price=Decimal("0.60"),
|
109
|
+
cached_input_price=Decimal("0.30"),
|
110
|
+
output_price=Decimal("2.40"),
|
111
|
+
),
|
112
|
+
"o1": OpenAIModelConfig(
|
113
|
+
display_name="O1",
|
114
|
+
base_model="o1",
|
115
|
+
input_price=Decimal("15.00"),
|
116
|
+
cached_input_price=Decimal("7.50"),
|
117
|
+
output_price=Decimal("60.00"),
|
118
|
+
),
|
119
|
+
"o1-2024-12-17": OpenAIModelConfig(
|
120
|
+
display_name="O1 (2024-12-17)",
|
121
|
+
base_model="o1",
|
122
|
+
input_price=Decimal("15.00"),
|
123
|
+
cached_input_price=Decimal("7.50"),
|
124
|
+
output_price=Decimal("60.00"),
|
125
|
+
),
|
126
|
+
"o3-mini": OpenAIModelConfig(
|
127
|
+
display_name="O3 Mini",
|
128
|
+
base_model="o3-mini",
|
129
|
+
input_price=Decimal("1.10"),
|
130
|
+
cached_input_price=Decimal("0.55"),
|
131
|
+
output_price=Decimal("4.40"),
|
132
|
+
),
|
133
|
+
"o3-mini-2025-01-31": OpenAIModelConfig(
|
134
|
+
display_name="O3 Mini (2025-01-31)",
|
135
|
+
base_model="o3-mini",
|
136
|
+
input_price=Decimal("1.10"),
|
137
|
+
cached_input_price=Decimal("0.55"),
|
138
|
+
output_price=Decimal("4.40"),
|
139
|
+
),
|
140
|
+
"o1-mini": OpenAIModelConfig(
|
141
|
+
display_name="O1 Mini",
|
142
|
+
base_model="o1-mini",
|
143
|
+
input_price=Decimal("1.10"),
|
144
|
+
cached_input_price=Decimal("0.55"),
|
145
|
+
output_price=Decimal("4.40"),
|
146
|
+
),
|
147
|
+
"o1-mini-2024-09-12": OpenAIModelConfig(
|
148
|
+
display_name="O1 Mini (2024-09-12)",
|
149
|
+
base_model="o1-mini",
|
150
|
+
input_price=Decimal("1.10"),
|
151
|
+
cached_input_price=Decimal("0.55"),
|
152
|
+
output_price=Decimal("4.40"),
|
153
|
+
),
|
154
|
+
"gpt-4o-mini-search-preview": OpenAIModelConfig(
|
155
|
+
display_name="GPT-4 Optimized Mini Search Preview",
|
156
|
+
base_model="gpt-4o-mini-search-preview",
|
157
|
+
input_price=Decimal("0.15"),
|
158
|
+
cached_input_price=None,
|
159
|
+
output_price=Decimal("0.60"),
|
160
|
+
),
|
161
|
+
"gpt-4o-mini-search-preview-2025-03-11": OpenAIModelConfig(
|
162
|
+
display_name="GPT-4 Optimized Mini Search Preview (2025-03-11)",
|
163
|
+
base_model="gpt-4o-mini-search-preview",
|
164
|
+
input_price=Decimal("0.15"),
|
165
|
+
cached_input_price=None,
|
166
|
+
output_price=Decimal("0.60"),
|
167
|
+
),
|
168
|
+
"gpt-4o-search-preview": OpenAIModelConfig(
|
169
|
+
display_name="GPT-4 Optimized Search Preview",
|
170
|
+
base_model="gpt-4o-search-preview",
|
171
|
+
input_price=Decimal("2.50"),
|
172
|
+
cached_input_price=None,
|
173
|
+
output_price=Decimal("10.00"),
|
174
|
+
),
|
175
|
+
"gpt-4o-search-preview-2025-03-11": OpenAIModelConfig(
|
176
|
+
display_name="GPT-4 Optimized Search Preview (2025-03-11)",
|
177
|
+
base_model="gpt-4o-search-preview",
|
178
|
+
input_price=Decimal("2.50"),
|
179
|
+
cached_input_price=None,
|
180
|
+
output_price=Decimal("10.00"),
|
181
|
+
),
|
182
|
+
"computer-use-preview": OpenAIModelConfig(
|
183
|
+
display_name="Computer Use Preview",
|
184
|
+
base_model="computer-use-preview",
|
185
|
+
input_price=Decimal("3.00"),
|
186
|
+
cached_input_price=None,
|
187
|
+
output_price=Decimal("12.00"),
|
188
|
+
),
|
189
|
+
"computer-use-preview-2025-03-11": OpenAIModelConfig(
|
190
|
+
display_name="Computer Use Preview (2025-03-11)",
|
191
|
+
base_model="computer-use-preview",
|
192
|
+
input_price=Decimal("3.00"),
|
193
|
+
cached_input_price=None,
|
194
|
+
output_price=Decimal("12.00"),
|
195
|
+
),
|
196
|
+
}
|
197
|
+
|
198
|
+
|
199
|
+
def get_model_config(model_id: str) -> OpenAIModelConfig:
|
200
|
+
"""Get model configuration by model ID"""
|
201
|
+
if model_id not in OPENAI_MODELS:
|
202
|
+
raise ValueError(f"Model {model_id} not found in OpenAI models")
|
203
|
+
return OPENAI_MODELS[model_id]
|
204
|
+
|
205
|
+
|
206
|
+
def get_default_model() -> str:
|
207
|
+
"""Get the default model ID"""
|
208
|
+
return "gpt-4o"
|
209
|
+
|
210
|
+
|
211
|
+
def calculate_cost(
|
212
|
+
model_id: str, input_tokens: int, output_tokens: int, use_cached: bool = False
|
213
|
+
) -> Decimal:
|
214
|
+
"""Calculate cost for token usage"""
|
215
|
+
config = get_model_config(model_id)
|
216
|
+
input_price = (
|
217
|
+
config.cached_input_price
|
218
|
+
if (use_cached and config.cached_input_price is not None)
|
219
|
+
else config.input_price
|
220
|
+
)
|
221
|
+
return (
|
222
|
+
input_price * Decimal(str(input_tokens))
|
223
|
+
+ config.output_price * Decimal(str(output_tokens))
|
224
|
+
) / Decimal("1000")
|
@@ -0,0 +1,342 @@
|
|
1
|
+
from typing import AsyncGenerator, List, Optional, Dict, TypeVar, Type, Generator, Union
|
2
|
+
from pydantic import Field, BaseModel
|
3
|
+
from openai import OpenAI, AsyncOpenAI
|
4
|
+
from openai.types.chat import ChatCompletionChunk
|
5
|
+
import numpy as np
|
6
|
+
|
7
|
+
from airtrain.core.skills import Skill, ProcessingError
|
8
|
+
from airtrain.core.schemas import InputSchema, OutputSchema
|
9
|
+
from .credentials import OpenAICredentials
|
10
|
+
|
11
|
+
|
12
|
+
class OpenAIInput(InputSchema):
|
13
|
+
"""Schema for OpenAI chat input"""
|
14
|
+
|
15
|
+
user_input: str = Field(..., description="User's input text")
|
16
|
+
system_prompt: str = Field(
|
17
|
+
default="You are a helpful assistant.",
|
18
|
+
description="System prompt to guide the model's behavior",
|
19
|
+
)
|
20
|
+
conversation_history: List[Dict[str, str]] = Field(
|
21
|
+
default_factory=list,
|
22
|
+
description="List of previous conversation messages in [{'role': 'user|assistant', 'content': 'message'}] format",
|
23
|
+
)
|
24
|
+
model: str = Field(
|
25
|
+
default="gpt-4o",
|
26
|
+
description="OpenAI model to use",
|
27
|
+
)
|
28
|
+
temperature: float = Field(
|
29
|
+
default=0.7, description="Temperature for response generation", ge=0, le=1
|
30
|
+
)
|
31
|
+
max_tokens: Optional[int] = Field(
|
32
|
+
default=131072, description="Maximum tokens in response"
|
33
|
+
)
|
34
|
+
stream: bool = Field(
|
35
|
+
default=False,
|
36
|
+
description="Whether to stream the response token by token",
|
37
|
+
)
|
38
|
+
|
39
|
+
|
40
|
+
class OpenAIOutput(OutputSchema):
|
41
|
+
"""Schema for OpenAI chat output"""
|
42
|
+
|
43
|
+
response: str
|
44
|
+
used_model: str
|
45
|
+
usage: Dict[str, int]
|
46
|
+
|
47
|
+
|
48
|
+
class OpenAIChatSkill(Skill[OpenAIInput, OpenAIOutput]):
|
49
|
+
"""Skill for interacting with OpenAI models with async support"""
|
50
|
+
|
51
|
+
input_schema = OpenAIInput
|
52
|
+
output_schema = OpenAIOutput
|
53
|
+
|
54
|
+
def __init__(self, credentials: Optional[OpenAICredentials] = None):
|
55
|
+
"""Initialize the skill with optional credentials"""
|
56
|
+
super().__init__()
|
57
|
+
self.credentials = credentials or OpenAICredentials.from_env()
|
58
|
+
self.client = OpenAI(
|
59
|
+
api_key=self.credentials.openai_api_key.get_secret_value(),
|
60
|
+
organization=self.credentials.openai_organization_id,
|
61
|
+
)
|
62
|
+
self.async_client = AsyncOpenAI(
|
63
|
+
api_key=self.credentials.openai_api_key.get_secret_value(),
|
64
|
+
organization=self.credentials.openai_organization_id,
|
65
|
+
)
|
66
|
+
|
67
|
+
def _build_messages(self, input_data: OpenAIInput) -> List[Dict[str, str]]:
|
68
|
+
"""Build messages list from input data including conversation history."""
|
69
|
+
messages = [{"role": "system", "content": input_data.system_prompt}]
|
70
|
+
|
71
|
+
if input_data.conversation_history:
|
72
|
+
messages.extend(input_data.conversation_history)
|
73
|
+
|
74
|
+
messages.append({"role": "user", "content": input_data.user_input})
|
75
|
+
return messages
|
76
|
+
|
77
|
+
def process_stream(self, input_data: OpenAIInput) -> Generator[str, None, None]:
|
78
|
+
"""Process the input and stream the response token by token."""
|
79
|
+
try:
|
80
|
+
messages = self._build_messages(input_data)
|
81
|
+
|
82
|
+
stream = self.client.chat.completions.create(
|
83
|
+
model=input_data.model,
|
84
|
+
messages=messages,
|
85
|
+
temperature=input_data.temperature,
|
86
|
+
max_tokens=input_data.max_tokens,
|
87
|
+
stream=True,
|
88
|
+
)
|
89
|
+
|
90
|
+
for chunk in stream:
|
91
|
+
if chunk.choices[0].delta.content is not None:
|
92
|
+
yield chunk.choices[0].delta.content
|
93
|
+
|
94
|
+
except Exception as e:
|
95
|
+
raise ProcessingError(f"OpenAI streaming failed: {str(e)}")
|
96
|
+
|
97
|
+
def process(self, input_data: OpenAIInput) -> OpenAIOutput:
|
98
|
+
"""Process the input and return the complete response."""
|
99
|
+
try:
|
100
|
+
if input_data.stream:
|
101
|
+
# For streaming, collect the entire response
|
102
|
+
response_chunks = []
|
103
|
+
for chunk in self.process_stream(input_data):
|
104
|
+
response_chunks.append(chunk)
|
105
|
+
response = "".join(response_chunks)
|
106
|
+
else:
|
107
|
+
# For non-streaming, use regular completion
|
108
|
+
messages = self._build_messages(input_data)
|
109
|
+
completion = self.client.chat.completions.create(
|
110
|
+
model=input_data.model,
|
111
|
+
messages=messages,
|
112
|
+
temperature=input_data.temperature,
|
113
|
+
max_tokens=input_data.max_tokens,
|
114
|
+
stream=False,
|
115
|
+
)
|
116
|
+
response = completion.choices[0].message.content
|
117
|
+
|
118
|
+
return OpenAIOutput(
|
119
|
+
response=response,
|
120
|
+
used_model=input_data.model,
|
121
|
+
usage={
|
122
|
+
"total_tokens": completion.usage.total_tokens,
|
123
|
+
"prompt_tokens": completion.usage.prompt_tokens,
|
124
|
+
"completion_tokens": completion.usage.completion_tokens,
|
125
|
+
},
|
126
|
+
)
|
127
|
+
|
128
|
+
except Exception as e:
|
129
|
+
raise ProcessingError(f"OpenAI chat failed: {str(e)}")
|
130
|
+
|
131
|
+
async def process_async(self, input_data: OpenAIInput) -> OpenAIOutput:
|
132
|
+
"""Async version of process method"""
|
133
|
+
try:
|
134
|
+
messages = self._build_messages(input_data)
|
135
|
+
completion = await self.async_client.chat.completions.create(
|
136
|
+
model=input_data.model,
|
137
|
+
messages=messages,
|
138
|
+
temperature=input_data.temperature,
|
139
|
+
max_tokens=input_data.max_tokens,
|
140
|
+
)
|
141
|
+
return OpenAIOutput(
|
142
|
+
response=completion.choices[0].message.content,
|
143
|
+
used_model=completion.model,
|
144
|
+
usage={
|
145
|
+
"total_tokens": completion.usage.total_tokens,
|
146
|
+
"prompt_tokens": completion.usage.prompt_tokens,
|
147
|
+
"completion_tokens": completion.usage.completion_tokens,
|
148
|
+
},
|
149
|
+
)
|
150
|
+
except Exception as e:
|
151
|
+
raise ProcessingError(f"OpenAI async chat failed: {str(e)}")
|
152
|
+
|
153
|
+
async def process_stream_async(
|
154
|
+
self, input_data: OpenAIInput
|
155
|
+
) -> AsyncGenerator[str, None]:
|
156
|
+
"""Async version of stream processor"""
|
157
|
+
try:
|
158
|
+
messages = self._build_messages(input_data)
|
159
|
+
stream = await self.async_client.chat.completions.create(
|
160
|
+
model=input_data.model,
|
161
|
+
messages=messages,
|
162
|
+
temperature=input_data.temperature,
|
163
|
+
max_tokens=input_data.max_tokens,
|
164
|
+
stream=True,
|
165
|
+
)
|
166
|
+
async for chunk in stream:
|
167
|
+
if chunk.choices[0].delta.content is not None:
|
168
|
+
yield chunk.choices[0].delta.content
|
169
|
+
except Exception as e:
|
170
|
+
raise ProcessingError(f"OpenAI async streaming failed: {str(e)}")
|
171
|
+
|
172
|
+
|
173
|
+
ResponseT = TypeVar("ResponseT", bound=BaseModel)
|
174
|
+
|
175
|
+
|
176
|
+
class OpenAIParserInput(InputSchema):
|
177
|
+
"""Schema for OpenAI structured output input"""
|
178
|
+
|
179
|
+
user_input: str
|
180
|
+
system_prompt: str = "You are a helpful assistant that provides structured data."
|
181
|
+
model: str = "gpt-4o"
|
182
|
+
temperature: float = 0.7
|
183
|
+
max_tokens: Optional[int] = None
|
184
|
+
response_model: Type[ResponseT]
|
185
|
+
|
186
|
+
class Config:
|
187
|
+
arbitrary_types_allowed = True
|
188
|
+
|
189
|
+
|
190
|
+
class OpenAIParserOutput(OutputSchema):
|
191
|
+
"""Schema for OpenAI structured output"""
|
192
|
+
|
193
|
+
parsed_response: BaseModel
|
194
|
+
used_model: str
|
195
|
+
tokens_used: int
|
196
|
+
|
197
|
+
|
198
|
+
class OpenAIParserSkill(Skill[OpenAIParserInput, OpenAIParserOutput]):
|
199
|
+
"""Skill for getting structured responses from OpenAI"""
|
200
|
+
|
201
|
+
input_schema = OpenAIParserInput
|
202
|
+
output_schema = OpenAIParserOutput
|
203
|
+
|
204
|
+
def __init__(self, credentials: Optional[OpenAICredentials] = None):
|
205
|
+
"""Initialize the skill with optional credentials"""
|
206
|
+
super().__init__()
|
207
|
+
self.credentials = credentials or OpenAICredentials.from_env()
|
208
|
+
self.client = OpenAI(
|
209
|
+
api_key=self.credentials.openai_api_key.get_secret_value(),
|
210
|
+
organization=self.credentials.openai_organization_id,
|
211
|
+
)
|
212
|
+
|
213
|
+
def process(self, input_data: OpenAIParserInput) -> OpenAIParserOutput:
|
214
|
+
try:
|
215
|
+
# Use parse method instead of create
|
216
|
+
completion = self.client.beta.chat.completions.parse(
|
217
|
+
model=input_data.model,
|
218
|
+
messages=[
|
219
|
+
{"role": "system", "content": input_data.system_prompt},
|
220
|
+
{"role": "user", "content": input_data.user_input},
|
221
|
+
],
|
222
|
+
response_format=input_data.response_model,
|
223
|
+
)
|
224
|
+
|
225
|
+
if completion.choices[0].message.parsed is None:
|
226
|
+
raise ProcessingError("Failed to parse response")
|
227
|
+
|
228
|
+
return OpenAIParserOutput(
|
229
|
+
parsed_response=completion.choices[0].message.parsed,
|
230
|
+
used_model=completion.model,
|
231
|
+
tokens_used=completion.usage.total_tokens,
|
232
|
+
)
|
233
|
+
|
234
|
+
except Exception as e:
|
235
|
+
raise ProcessingError(f"OpenAI parsing failed: {str(e)}")
|
236
|
+
|
237
|
+
|
238
|
+
class OpenAIEmbeddingsInput(InputSchema):
|
239
|
+
"""Schema for OpenAI embeddings input"""
|
240
|
+
|
241
|
+
texts: Union[str, List[str]] = Field(
|
242
|
+
..., description="Text or list of texts to generate embeddings for"
|
243
|
+
)
|
244
|
+
model: str = Field(
|
245
|
+
default="text-embedding-3-large", description="OpenAI embeddings model to use"
|
246
|
+
)
|
247
|
+
encoding_format: str = Field(
|
248
|
+
default="float", description="The format of the embeddings: 'float' or 'base64'"
|
249
|
+
)
|
250
|
+
dimensions: Optional[int] = Field(
|
251
|
+
default=None, description="Optional number of dimensions for the embeddings"
|
252
|
+
)
|
253
|
+
|
254
|
+
|
255
|
+
class OpenAIEmbeddingsOutput(OutputSchema):
|
256
|
+
"""Schema for OpenAI embeddings output"""
|
257
|
+
|
258
|
+
embeddings: List[List[float]] = Field(..., description="List of embeddings vectors")
|
259
|
+
used_model: str = Field(..., description="Model used for generating embeddings")
|
260
|
+
tokens_used: int = Field(..., description="Number of tokens used")
|
261
|
+
|
262
|
+
|
263
|
+
class OpenAIEmbeddingsSkill(Skill[OpenAIEmbeddingsInput, OpenAIEmbeddingsOutput]):
|
264
|
+
"""Skill for generating embeddings using OpenAI models"""
|
265
|
+
|
266
|
+
input_schema = OpenAIEmbeddingsInput
|
267
|
+
output_schema = OpenAIEmbeddingsOutput
|
268
|
+
|
269
|
+
def __init__(self, credentials: Optional[OpenAICredentials] = None):
|
270
|
+
"""Initialize the skill with optional credentials"""
|
271
|
+
super().__init__()
|
272
|
+
self.credentials = credentials or OpenAICredentials.from_env()
|
273
|
+
self.client = OpenAI(
|
274
|
+
api_key=self.credentials.openai_api_key.get_secret_value(),
|
275
|
+
organization=self.credentials.openai_organization_id,
|
276
|
+
)
|
277
|
+
self.async_client = AsyncOpenAI(
|
278
|
+
api_key=self.credentials.openai_api_key.get_secret_value(),
|
279
|
+
organization=self.credentials.openai_organization_id,
|
280
|
+
)
|
281
|
+
|
282
|
+
def process(self, input_data: OpenAIEmbeddingsInput) -> OpenAIEmbeddingsOutput:
|
283
|
+
"""Generate embeddings for the input text(s)"""
|
284
|
+
try:
|
285
|
+
# Handle single text input
|
286
|
+
texts = (
|
287
|
+
[input_data.texts]
|
288
|
+
if isinstance(input_data.texts, str)
|
289
|
+
else input_data.texts
|
290
|
+
)
|
291
|
+
|
292
|
+
# Create embeddings
|
293
|
+
response = self.client.embeddings.create(
|
294
|
+
model=input_data.model,
|
295
|
+
input=texts,
|
296
|
+
encoding_format=input_data.encoding_format,
|
297
|
+
dimensions=input_data.dimensions,
|
298
|
+
)
|
299
|
+
|
300
|
+
# Extract embeddings
|
301
|
+
embeddings = [data.embedding for data in response.data]
|
302
|
+
|
303
|
+
return OpenAIEmbeddingsOutput(
|
304
|
+
embeddings=embeddings,
|
305
|
+
used_model=response.model,
|
306
|
+
tokens_used=response.usage.total_tokens,
|
307
|
+
)
|
308
|
+
except Exception as e:
|
309
|
+
raise ProcessingError(f"OpenAI embeddings generation failed: {str(e)}")
|
310
|
+
|
311
|
+
async def process_async(
|
312
|
+
self, input_data: OpenAIEmbeddingsInput
|
313
|
+
) -> OpenAIEmbeddingsOutput:
|
314
|
+
"""Async version of the embeddings generation"""
|
315
|
+
try:
|
316
|
+
# Handle single text input
|
317
|
+
texts = (
|
318
|
+
[input_data.texts]
|
319
|
+
if isinstance(input_data.texts, str)
|
320
|
+
else input_data.texts
|
321
|
+
)
|
322
|
+
|
323
|
+
# Create embeddings
|
324
|
+
response = await self.async_client.embeddings.create(
|
325
|
+
model=input_data.model,
|
326
|
+
input=texts,
|
327
|
+
encoding_format=input_data.encoding_format,
|
328
|
+
dimensions=input_data.dimensions,
|
329
|
+
)
|
330
|
+
|
331
|
+
# Extract embeddings
|
332
|
+
embeddings = [data.embedding for data in response.data]
|
333
|
+
|
334
|
+
return OpenAIEmbeddingsOutput(
|
335
|
+
embeddings=embeddings,
|
336
|
+
used_model=response.model,
|
337
|
+
tokens_used=response.usage.total_tokens,
|
338
|
+
)
|
339
|
+
except Exception as e:
|
340
|
+
raise ProcessingError(
|
341
|
+
f"OpenAI async embeddings generation failed: {str(e)}"
|
342
|
+
)
|
@@ -0,0 +1,49 @@
|
|
1
|
+
"""Perplexity AI integration module"""
|
2
|
+
|
3
|
+
from .credentials import PerplexityCredentials
|
4
|
+
from .skills import (
|
5
|
+
PerplexityInput,
|
6
|
+
PerplexityOutput,
|
7
|
+
PerplexityChatSkill,
|
8
|
+
PerplexityCitation,
|
9
|
+
PerplexityStreamingChatSkill,
|
10
|
+
PerplexityStreamOutput,
|
11
|
+
)
|
12
|
+
from .list_models import (
|
13
|
+
PerplexityListModelsSkill,
|
14
|
+
StandalonePerplexityListModelsSkill,
|
15
|
+
PerplexityListModelsInput,
|
16
|
+
PerplexityListModelsOutput,
|
17
|
+
)
|
18
|
+
from .models_config import (
|
19
|
+
get_model_config,
|
20
|
+
get_default_model,
|
21
|
+
supports_citations,
|
22
|
+
supports_search,
|
23
|
+
get_models_by_category,
|
24
|
+
PERPLEXITY_MODELS_CONFIG,
|
25
|
+
)
|
26
|
+
|
27
|
+
__all__ = [
|
28
|
+
# Credentials
|
29
|
+
"PerplexityCredentials",
|
30
|
+
# Skills
|
31
|
+
"PerplexityInput",
|
32
|
+
"PerplexityOutput",
|
33
|
+
"PerplexityChatSkill",
|
34
|
+
"PerplexityCitation",
|
35
|
+
"PerplexityStreamingChatSkill",
|
36
|
+
"PerplexityStreamOutput",
|
37
|
+
# List Models
|
38
|
+
"PerplexityListModelsSkill",
|
39
|
+
"StandalonePerplexityListModelsSkill",
|
40
|
+
"PerplexityListModelsInput",
|
41
|
+
"PerplexityListModelsOutput",
|
42
|
+
# Model Config
|
43
|
+
"get_model_config",
|
44
|
+
"get_default_model",
|
45
|
+
"supports_citations",
|
46
|
+
"supports_search",
|
47
|
+
"get_models_by_category",
|
48
|
+
"PERPLEXITY_MODELS_CONFIG",
|
49
|
+
]
|
@@ -0,0 +1,43 @@
|
|
1
|
+
from pydantic import Field, SecretStr
|
2
|
+
from airtrain.core.credentials import BaseCredentials, CredentialValidationError
|
3
|
+
import requests
|
4
|
+
|
5
|
+
|
6
|
+
class PerplexityCredentials(BaseCredentials):
|
7
|
+
"""Perplexity AI API credentials"""
|
8
|
+
|
9
|
+
perplexity_api_key: SecretStr = Field(..., description="Perplexity AI API key")
|
10
|
+
|
11
|
+
_required_credentials = {"perplexity_api_key"}
|
12
|
+
|
13
|
+
async def validate_credentials(self) -> bool:
|
14
|
+
"""Validate Perplexity AI credentials by making a test API call"""
|
15
|
+
try:
|
16
|
+
headers = {
|
17
|
+
"Authorization": f"Bearer {self.perplexity_api_key.get_secret_value()}",
|
18
|
+
"Content-Type": "application/json",
|
19
|
+
}
|
20
|
+
|
21
|
+
# Small API call to check if credentials are valid
|
22
|
+
data = {
|
23
|
+
"model": "sonar-pro",
|
24
|
+
"messages": [{"role": "user", "content": "Test"}],
|
25
|
+
"max_tokens": 1,
|
26
|
+
}
|
27
|
+
|
28
|
+
# Make a synchronous request for validation
|
29
|
+
response = requests.post(
|
30
|
+
"https://api.perplexity.ai/chat/completions", headers=headers, json=data
|
31
|
+
)
|
32
|
+
|
33
|
+
if response.status_code == 200:
|
34
|
+
return True
|
35
|
+
else:
|
36
|
+
raise CredentialValidationError(
|
37
|
+
f"Invalid Perplexity AI credentials: {response.status_code} - {response.text}"
|
38
|
+
)
|
39
|
+
|
40
|
+
except Exception as e:
|
41
|
+
raise CredentialValidationError(
|
42
|
+
f"Invalid Perplexity AI credentials: {str(e)}"
|
43
|
+
)
|