airtrain 0.1.39__py3-none-any.whl → 0.1.41__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airtrain/__init__.py +1 -1
- airtrain/integrations/__init__.py +10 -1
- airtrain/integrations/anthropic/__init__.py +12 -0
- airtrain/integrations/anthropic/models_config.py +100 -0
- airtrain/integrations/fireworks/__init__.py +10 -0
- airtrain/integrations/fireworks/credentials.py +10 -2
- airtrain/integrations/fireworks/list_models.py +128 -0
- airtrain/integrations/fireworks/models.py +112 -0
- airtrain/integrations/fireworks/skills.py +62 -11
- airtrain/integrations/fireworks/structured_completion_skills.py +10 -4
- airtrain/integrations/fireworks/structured_requests_skills.py +108 -31
- airtrain/integrations/openai/__init__.py +6 -0
- airtrain/integrations/openai/models_config.py +118 -13
- airtrain/integrations/openai/skills.py +109 -1
- airtrain/integrations/together/__init__.py +14 -1
- airtrain/integrations/together/list_models.py +77 -0
- airtrain/integrations/together/models.py +42 -3
- {airtrain-0.1.39.dist-info → airtrain-0.1.41.dist-info}/METADATA +1 -1
- {airtrain-0.1.39.dist-info → airtrain-0.1.41.dist-info}/RECORD +22 -19
- {airtrain-0.1.39.dist-info → airtrain-0.1.41.dist-info}/WHEEL +1 -1
- {airtrain-0.1.39.dist-info → airtrain-0.1.41.dist-info}/entry_points.txt +0 -0
- {airtrain-0.1.39.dist-info → airtrain-0.1.41.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,5 @@
|
|
1
|
-
from typing import Type, TypeVar, Optional, List, Dict, Any, Generator
|
2
|
-
from pydantic import BaseModel, Field
|
1
|
+
from typing import Type, TypeVar, Optional, List, Dict, Any, Generator, Union
|
2
|
+
from pydantic import BaseModel, Field, create_model
|
3
3
|
import requests
|
4
4
|
import json
|
5
5
|
from loguru import logger
|
@@ -20,7 +20,7 @@ class FireworksStructuredRequestInput(InputSchema):
|
|
20
20
|
default="You are a helpful assistant that provides structured data.",
|
21
21
|
description="System prompt to guide the model's behavior",
|
22
22
|
)
|
23
|
-
conversation_history: List[Dict[str,
|
23
|
+
conversation_history: List[Dict[str, Any]] = Field(
|
24
24
|
default_factory=list,
|
25
25
|
description="List of previous conversation messages",
|
26
26
|
)
|
@@ -34,8 +34,21 @@ class FireworksStructuredRequestInput(InputSchema):
|
|
34
34
|
max_tokens: int = Field(default=4096, description="Maximum tokens in response")
|
35
35
|
response_model: Type[ResponseT]
|
36
36
|
stream: bool = Field(
|
37
|
-
default=False,
|
38
|
-
|
37
|
+
default=False, description="Whether to stream the response token by token"
|
38
|
+
)
|
39
|
+
tools: Optional[List[Dict[str, Any]]] = Field(
|
40
|
+
default=None,
|
41
|
+
description=(
|
42
|
+
"A list of tools the model may use. "
|
43
|
+
"Currently only functions supported."
|
44
|
+
),
|
45
|
+
)
|
46
|
+
tool_choice: Optional[Union[str, Dict[str, Any]]] = Field(
|
47
|
+
default=None,
|
48
|
+
description=(
|
49
|
+
"Controls which tool is called by the model. "
|
50
|
+
"'none', 'auto', or specific tool."
|
51
|
+
),
|
39
52
|
)
|
40
53
|
|
41
54
|
class Config:
|
@@ -49,6 +62,9 @@ class FireworksStructuredRequestOutput(OutputSchema):
|
|
49
62
|
used_model: str
|
50
63
|
usage: Dict[str, int]
|
51
64
|
reasoning: Optional[str] = None
|
65
|
+
tool_calls: Optional[List[Dict[str, Any]]] = Field(
|
66
|
+
default=None, description="Tool calls generated by the model"
|
67
|
+
)
|
52
68
|
|
53
69
|
|
54
70
|
class FireworksStructuredRequestSkill(
|
@@ -72,7 +88,7 @@ class FireworksStructuredRequestSkill(
|
|
72
88
|
|
73
89
|
def _build_messages(
|
74
90
|
self, input_data: FireworksStructuredRequestInput
|
75
|
-
) -> List[Dict[str,
|
91
|
+
) -> List[Dict[str, Any]]:
|
76
92
|
"""Build messages list from input data including conversation history."""
|
77
93
|
messages = [{"role": "system", "content": input_data.system_prompt}]
|
78
94
|
|
@@ -86,24 +102,24 @@ class FireworksStructuredRequestSkill(
|
|
86
102
|
self, input_data: FireworksStructuredRequestInput
|
87
103
|
) -> Dict[str, Any]:
|
88
104
|
"""Build the request payload."""
|
89
|
-
|
105
|
+
payload = {
|
90
106
|
"model": input_data.model,
|
91
107
|
"messages": self._build_messages(input_data),
|
92
108
|
"temperature": input_data.temperature,
|
93
109
|
"max_tokens": input_data.max_tokens,
|
94
110
|
"stream": input_data.stream,
|
95
|
-
"response_format": {
|
96
|
-
"type": "json_object",
|
97
|
-
"schema": {
|
98
|
-
**input_data.response_model.model_json_schema(),
|
99
|
-
"required": [
|
100
|
-
field
|
101
|
-
for field, _ in input_data.response_model.model_fields.items()
|
102
|
-
],
|
103
|
-
},
|
104
|
-
},
|
111
|
+
"response_format": {"type": "json_object"},
|
105
112
|
}
|
106
113
|
|
114
|
+
# Add tool-related parameters if provided
|
115
|
+
if input_data.tools:
|
116
|
+
payload["tools"] = input_data.tools
|
117
|
+
|
118
|
+
if input_data.tool_choice:
|
119
|
+
payload["tool_choice"] = input_data.tool_choice
|
120
|
+
|
121
|
+
return payload
|
122
|
+
|
107
123
|
def process_stream(
|
108
124
|
self, input_data: FireworksStructuredRequestInput
|
109
125
|
) -> Generator[Dict[str, Any], None, None]:
|
@@ -131,6 +147,10 @@ class FireworksStructuredRequestSkill(
|
|
131
147
|
continue
|
132
148
|
|
133
149
|
# Once complete, parse the full response with think tags
|
150
|
+
if not json_buffer:
|
151
|
+
# If no data was collected, raise error
|
152
|
+
raise ProcessingError("No data received from Fireworks API")
|
153
|
+
|
134
154
|
complete_response = "".join(json_buffer)
|
135
155
|
reasoning, json_str = self._parse_response_content(complete_response)
|
136
156
|
|
@@ -177,37 +197,94 @@ class FireworksStructuredRequestSkill(
|
|
177
197
|
|
178
198
|
if parsed_response is None:
|
179
199
|
raise ProcessingError("Failed to parse streamed response")
|
200
|
+
|
201
|
+
# Make a non-streaming call to get tool calls if tools were provided
|
202
|
+
tool_calls = None
|
203
|
+
if input_data.tools:
|
204
|
+
# Create a non-streaming request to get tool calls
|
205
|
+
non_stream_payload = self._build_payload(input_data)
|
206
|
+
non_stream_payload["stream"] = False
|
207
|
+
|
208
|
+
response = requests.post(
|
209
|
+
self.BASE_URL,
|
210
|
+
headers=self.headers,
|
211
|
+
data=json.dumps(non_stream_payload),
|
212
|
+
)
|
213
|
+
response.raise_for_status()
|
214
|
+
result = response.json()
|
215
|
+
|
216
|
+
# Check for tool calls
|
217
|
+
if (result["choices"][0]["message"].get("tool_calls")):
|
218
|
+
tool_calls = [
|
219
|
+
{
|
220
|
+
"id": tool_call["id"],
|
221
|
+
"type": tool_call["type"],
|
222
|
+
"function": {
|
223
|
+
"name": tool_call["function"]["name"],
|
224
|
+
"arguments": tool_call["function"]["arguments"]
|
225
|
+
}
|
226
|
+
}
|
227
|
+
for tool_call in result["choices"][0]["message"]["tool_calls"]
|
228
|
+
]
|
180
229
|
|
181
230
|
return FireworksStructuredRequestOutput(
|
182
231
|
parsed_response=parsed_response,
|
183
232
|
used_model=input_data.model,
|
184
|
-
usage={}, #
|
233
|
+
usage={"total_tokens": 0}, # Can't get usage stats from streaming
|
185
234
|
reasoning=reasoning,
|
235
|
+
tool_calls=tool_calls,
|
186
236
|
)
|
187
237
|
else:
|
188
238
|
# For non-streaming, use regular request
|
189
239
|
payload = self._build_payload(input_data)
|
240
|
+
payload["stream"] = False # Ensure it's not streaming
|
241
|
+
|
190
242
|
response = requests.post(
|
191
243
|
self.BASE_URL, headers=self.headers, data=json.dumps(payload)
|
192
244
|
)
|
193
245
|
response.raise_for_status()
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
)
|
246
|
+
result = response.json()
|
247
|
+
|
248
|
+
# Get the content from the response
|
249
|
+
if "choices" not in result or not result["choices"]:
|
250
|
+
raise ProcessingError("Invalid response format from Fireworks API")
|
251
|
+
|
252
|
+
content = result["choices"][0]["message"].get("content", "")
|
253
|
+
|
254
|
+
# Check for tool calls
|
255
|
+
tool_calls = None
|
256
|
+
if (result["choices"][0]["message"].get("tool_calls")):
|
257
|
+
tool_calls = [
|
258
|
+
{
|
259
|
+
"id": tool_call["id"],
|
260
|
+
"type": tool_call["type"],
|
261
|
+
"function": {
|
262
|
+
"name": tool_call["function"]["name"],
|
263
|
+
"arguments": tool_call["function"]["arguments"]
|
264
|
+
}
|
265
|
+
}
|
266
|
+
for tool_call in result["choices"][0]["message"]["tool_calls"]
|
267
|
+
]
|
268
|
+
|
269
|
+
# Parse the response content
|
270
|
+
reasoning, json_str = self._parse_response_content(content)
|
271
|
+
try:
|
272
|
+
parsed_response = input_data.response_model.model_validate_json(
|
273
|
+
json_str
|
274
|
+
)
|
275
|
+
except Exception as e:
|
276
|
+
raise ProcessingError(f"Failed to parse JSON response: {str(e)}")
|
205
277
|
|
206
278
|
return FireworksStructuredRequestOutput(
|
207
279
|
parsed_response=parsed_response,
|
208
280
|
used_model=input_data.model,
|
209
|
-
usage=
|
210
|
-
|
281
|
+
usage={
|
282
|
+
"total_tokens": result["usage"]["total_tokens"],
|
283
|
+
"prompt_tokens": result["usage"]["prompt_tokens"],
|
284
|
+
"completion_tokens": result["usage"]["completion_tokens"],
|
285
|
+
},
|
286
|
+
reasoning=reasoning,
|
287
|
+
tool_calls=tool_calls,
|
211
288
|
)
|
212
289
|
|
213
290
|
except Exception as e:
|
@@ -5,6 +5,9 @@ from .skills import (
|
|
5
5
|
OpenAIOutput,
|
6
6
|
OpenAIParserInput,
|
7
7
|
OpenAIParserOutput,
|
8
|
+
OpenAIEmbeddingsSkill,
|
9
|
+
OpenAIEmbeddingsInput,
|
10
|
+
OpenAIEmbeddingsOutput,
|
8
11
|
)
|
9
12
|
from .credentials import OpenAICredentials
|
10
13
|
|
@@ -16,4 +19,7 @@ __all__ = [
|
|
16
19
|
"OpenAIParserOutput",
|
17
20
|
"OpenAICredentials",
|
18
21
|
"OpenAIOutput",
|
22
|
+
"OpenAIEmbeddingsSkill",
|
23
|
+
"OpenAIEmbeddingsInput",
|
24
|
+
"OpenAIEmbeddingsOutput",
|
19
25
|
]
|
@@ -11,6 +11,20 @@ class OpenAIModelConfig(NamedTuple):
|
|
11
11
|
|
12
12
|
|
13
13
|
OPENAI_MODELS: Dict[str, OpenAIModelConfig] = {
|
14
|
+
"gpt-4.5-preview": OpenAIModelConfig(
|
15
|
+
display_name="GPT-4.5 Preview",
|
16
|
+
base_model="gpt-4.5-preview",
|
17
|
+
input_price=Decimal("75.00"),
|
18
|
+
cached_input_price=Decimal("37.50"),
|
19
|
+
output_price=Decimal("150.00"),
|
20
|
+
),
|
21
|
+
"gpt-4.5-preview-2025-02-27": OpenAIModelConfig(
|
22
|
+
display_name="GPT-4.5 Preview (2025-02-27)",
|
23
|
+
base_model="gpt-4.5-preview",
|
24
|
+
input_price=Decimal("75.00"),
|
25
|
+
cached_input_price=Decimal("37.50"),
|
26
|
+
output_price=Decimal("150.00"),
|
27
|
+
),
|
14
28
|
"gpt-4o": OpenAIModelConfig(
|
15
29
|
display_name="GPT-4 Optimized",
|
16
30
|
base_model="gpt-4o",
|
@@ -25,69 +39,160 @@ OPENAI_MODELS: Dict[str, OpenAIModelConfig] = {
|
|
25
39
|
cached_input_price=Decimal("1.25"),
|
26
40
|
output_price=Decimal("10.00"),
|
27
41
|
),
|
28
|
-
"gpt-4o-
|
29
|
-
display_name="GPT-4 Optimized
|
30
|
-
base_model="gpt-4o",
|
31
|
-
input_price=Decimal("
|
42
|
+
"gpt-4o-audio-preview": OpenAIModelConfig(
|
43
|
+
display_name="GPT-4 Optimized Audio Preview",
|
44
|
+
base_model="gpt-4o-audio-preview",
|
45
|
+
input_price=Decimal("2.50"),
|
32
46
|
cached_input_price=None,
|
33
|
-
output_price=Decimal("
|
47
|
+
output_price=Decimal("10.00"),
|
34
48
|
),
|
35
49
|
"gpt-4o-audio-preview-2024-12-17": OpenAIModelConfig(
|
36
|
-
display_name="GPT-4 Optimized Audio Preview",
|
50
|
+
display_name="GPT-4 Optimized Audio Preview (2024-12-17)",
|
37
51
|
base_model="gpt-4o-audio-preview",
|
38
52
|
input_price=Decimal("2.50"),
|
39
53
|
cached_input_price=None,
|
40
54
|
output_price=Decimal("10.00"),
|
41
55
|
),
|
42
|
-
"gpt-4o-realtime-preview
|
56
|
+
"gpt-4o-realtime-preview": OpenAIModelConfig(
|
43
57
|
display_name="GPT-4 Optimized Realtime Preview",
|
44
58
|
base_model="gpt-4o-realtime-preview",
|
45
59
|
input_price=Decimal("5.00"),
|
46
60
|
cached_input_price=Decimal("2.50"),
|
47
61
|
output_price=Decimal("20.00"),
|
48
62
|
),
|
49
|
-
"gpt-4o-
|
63
|
+
"gpt-4o-realtime-preview-2024-12-17": OpenAIModelConfig(
|
64
|
+
display_name="GPT-4 Optimized Realtime Preview (2024-12-17)",
|
65
|
+
base_model="gpt-4o-realtime-preview",
|
66
|
+
input_price=Decimal("5.00"),
|
67
|
+
cached_input_price=Decimal("2.50"),
|
68
|
+
output_price=Decimal("20.00"),
|
69
|
+
),
|
70
|
+
"gpt-4o-mini": OpenAIModelConfig(
|
50
71
|
display_name="GPT-4 Optimized Mini",
|
51
72
|
base_model="gpt-4o-mini",
|
52
73
|
input_price=Decimal("0.15"),
|
53
74
|
cached_input_price=Decimal("0.075"),
|
54
75
|
output_price=Decimal("0.60"),
|
55
76
|
),
|
56
|
-
"gpt-4o-mini-
|
77
|
+
"gpt-4o-mini-2024-07-18": OpenAIModelConfig(
|
78
|
+
display_name="GPT-4 Optimized Mini (2024-07-18)",
|
79
|
+
base_model="gpt-4o-mini",
|
80
|
+
input_price=Decimal("0.15"),
|
81
|
+
cached_input_price=Decimal("0.075"),
|
82
|
+
output_price=Decimal("0.60"),
|
83
|
+
),
|
84
|
+
"gpt-4o-mini-audio-preview": OpenAIModelConfig(
|
57
85
|
display_name="GPT-4 Optimized Mini Audio Preview",
|
58
86
|
base_model="gpt-4o-mini-audio-preview",
|
59
87
|
input_price=Decimal("0.15"),
|
60
88
|
cached_input_price=None,
|
61
89
|
output_price=Decimal("0.60"),
|
62
90
|
),
|
63
|
-
"gpt-4o-mini-
|
91
|
+
"gpt-4o-mini-audio-preview-2024-12-17": OpenAIModelConfig(
|
92
|
+
display_name="GPT-4 Optimized Mini Audio Preview (2024-12-17)",
|
93
|
+
base_model="gpt-4o-mini-audio-preview",
|
94
|
+
input_price=Decimal("0.15"),
|
95
|
+
cached_input_price=None,
|
96
|
+
output_price=Decimal("0.60"),
|
97
|
+
),
|
98
|
+
"gpt-4o-mini-realtime-preview": OpenAIModelConfig(
|
64
99
|
display_name="GPT-4 Optimized Mini Realtime Preview",
|
65
100
|
base_model="gpt-4o-mini-realtime-preview",
|
66
101
|
input_price=Decimal("0.60"),
|
67
102
|
cached_input_price=Decimal("0.30"),
|
68
103
|
output_price=Decimal("2.40"),
|
69
104
|
),
|
70
|
-
"
|
105
|
+
"gpt-4o-mini-realtime-preview-2024-12-17": OpenAIModelConfig(
|
106
|
+
display_name="GPT-4 Optimized Mini Realtime Preview (2024-12-17)",
|
107
|
+
base_model="gpt-4o-mini-realtime-preview",
|
108
|
+
input_price=Decimal("0.60"),
|
109
|
+
cached_input_price=Decimal("0.30"),
|
110
|
+
output_price=Decimal("2.40"),
|
111
|
+
),
|
112
|
+
"o1": OpenAIModelConfig(
|
71
113
|
display_name="O1",
|
72
114
|
base_model="o1",
|
73
115
|
input_price=Decimal("15.00"),
|
74
116
|
cached_input_price=Decimal("7.50"),
|
75
117
|
output_price=Decimal("60.00"),
|
76
118
|
),
|
77
|
-
"
|
119
|
+
"o1-2024-12-17": OpenAIModelConfig(
|
120
|
+
display_name="O1 (2024-12-17)",
|
121
|
+
base_model="o1",
|
122
|
+
input_price=Decimal("15.00"),
|
123
|
+
cached_input_price=Decimal("7.50"),
|
124
|
+
output_price=Decimal("60.00"),
|
125
|
+
),
|
126
|
+
"o3-mini": OpenAIModelConfig(
|
78
127
|
display_name="O3 Mini",
|
79
128
|
base_model="o3-mini",
|
80
129
|
input_price=Decimal("1.10"),
|
81
130
|
cached_input_price=Decimal("0.55"),
|
82
131
|
output_price=Decimal("4.40"),
|
83
132
|
),
|
84
|
-
"
|
133
|
+
"o3-mini-2025-01-31": OpenAIModelConfig(
|
134
|
+
display_name="O3 Mini (2025-01-31)",
|
135
|
+
base_model="o3-mini",
|
136
|
+
input_price=Decimal("1.10"),
|
137
|
+
cached_input_price=Decimal("0.55"),
|
138
|
+
output_price=Decimal("4.40"),
|
139
|
+
),
|
140
|
+
"o1-mini": OpenAIModelConfig(
|
85
141
|
display_name="O1 Mini",
|
86
142
|
base_model="o1-mini",
|
87
143
|
input_price=Decimal("1.10"),
|
88
144
|
cached_input_price=Decimal("0.55"),
|
89
145
|
output_price=Decimal("4.40"),
|
90
146
|
),
|
147
|
+
"o1-mini-2024-09-12": OpenAIModelConfig(
|
148
|
+
display_name="O1 Mini (2024-09-12)",
|
149
|
+
base_model="o1-mini",
|
150
|
+
input_price=Decimal("1.10"),
|
151
|
+
cached_input_price=Decimal("0.55"),
|
152
|
+
output_price=Decimal("4.40"),
|
153
|
+
),
|
154
|
+
"gpt-4o-mini-search-preview": OpenAIModelConfig(
|
155
|
+
display_name="GPT-4 Optimized Mini Search Preview",
|
156
|
+
base_model="gpt-4o-mini-search-preview",
|
157
|
+
input_price=Decimal("0.15"),
|
158
|
+
cached_input_price=None,
|
159
|
+
output_price=Decimal("0.60"),
|
160
|
+
),
|
161
|
+
"gpt-4o-mini-search-preview-2025-03-11": OpenAIModelConfig(
|
162
|
+
display_name="GPT-4 Optimized Mini Search Preview (2025-03-11)",
|
163
|
+
base_model="gpt-4o-mini-search-preview",
|
164
|
+
input_price=Decimal("0.15"),
|
165
|
+
cached_input_price=None,
|
166
|
+
output_price=Decimal("0.60"),
|
167
|
+
),
|
168
|
+
"gpt-4o-search-preview": OpenAIModelConfig(
|
169
|
+
display_name="GPT-4 Optimized Search Preview",
|
170
|
+
base_model="gpt-4o-search-preview",
|
171
|
+
input_price=Decimal("2.50"),
|
172
|
+
cached_input_price=None,
|
173
|
+
output_price=Decimal("10.00"),
|
174
|
+
),
|
175
|
+
"gpt-4o-search-preview-2025-03-11": OpenAIModelConfig(
|
176
|
+
display_name="GPT-4 Optimized Search Preview (2025-03-11)",
|
177
|
+
base_model="gpt-4o-search-preview",
|
178
|
+
input_price=Decimal("2.50"),
|
179
|
+
cached_input_price=None,
|
180
|
+
output_price=Decimal("10.00"),
|
181
|
+
),
|
182
|
+
"computer-use-preview": OpenAIModelConfig(
|
183
|
+
display_name="Computer Use Preview",
|
184
|
+
base_model="computer-use-preview",
|
185
|
+
input_price=Decimal("3.00"),
|
186
|
+
cached_input_price=None,
|
187
|
+
output_price=Decimal("12.00"),
|
188
|
+
),
|
189
|
+
"computer-use-preview-2025-03-11": OpenAIModelConfig(
|
190
|
+
display_name="Computer Use Preview (2025-03-11)",
|
191
|
+
base_model="computer-use-preview",
|
192
|
+
input_price=Decimal("3.00"),
|
193
|
+
cached_input_price=None,
|
194
|
+
output_price=Decimal("12.00"),
|
195
|
+
),
|
91
196
|
}
|
92
197
|
|
93
198
|
|
@@ -1,7 +1,8 @@
|
|
1
|
-
from typing import AsyncGenerator, List, Optional, Dict, TypeVar, Type, Generator
|
1
|
+
from typing import AsyncGenerator, List, Optional, Dict, TypeVar, Type, Generator, Union
|
2
2
|
from pydantic import Field, BaseModel
|
3
3
|
from openai import OpenAI, AsyncOpenAI
|
4
4
|
from openai.types.chat import ChatCompletionChunk
|
5
|
+
import numpy as np
|
5
6
|
|
6
7
|
from airtrain.core.skills import Skill, ProcessingError
|
7
8
|
from airtrain.core.schemas import InputSchema, OutputSchema
|
@@ -232,3 +233,110 @@ class OpenAIParserSkill(Skill[OpenAIParserInput, OpenAIParserOutput]):
|
|
232
233
|
|
233
234
|
except Exception as e:
|
234
235
|
raise ProcessingError(f"OpenAI parsing failed: {str(e)}")
|
236
|
+
|
237
|
+
|
238
|
+
class OpenAIEmbeddingsInput(InputSchema):
|
239
|
+
"""Schema for OpenAI embeddings input"""
|
240
|
+
|
241
|
+
texts: Union[str, List[str]] = Field(
|
242
|
+
..., description="Text or list of texts to generate embeddings for"
|
243
|
+
)
|
244
|
+
model: str = Field(
|
245
|
+
default="text-embedding-3-large", description="OpenAI embeddings model to use"
|
246
|
+
)
|
247
|
+
encoding_format: str = Field(
|
248
|
+
default="float", description="The format of the embeddings: 'float' or 'base64'"
|
249
|
+
)
|
250
|
+
dimensions: Optional[int] = Field(
|
251
|
+
default=None, description="Optional number of dimensions for the embeddings"
|
252
|
+
)
|
253
|
+
|
254
|
+
|
255
|
+
class OpenAIEmbeddingsOutput(OutputSchema):
|
256
|
+
"""Schema for OpenAI embeddings output"""
|
257
|
+
|
258
|
+
embeddings: List[List[float]] = Field(..., description="List of embeddings vectors")
|
259
|
+
used_model: str = Field(..., description="Model used for generating embeddings")
|
260
|
+
tokens_used: int = Field(..., description="Number of tokens used")
|
261
|
+
|
262
|
+
|
263
|
+
class OpenAIEmbeddingsSkill(Skill[OpenAIEmbeddingsInput, OpenAIEmbeddingsOutput]):
|
264
|
+
"""Skill for generating embeddings using OpenAI models"""
|
265
|
+
|
266
|
+
input_schema = OpenAIEmbeddingsInput
|
267
|
+
output_schema = OpenAIEmbeddingsOutput
|
268
|
+
|
269
|
+
def __init__(self, credentials: Optional[OpenAICredentials] = None):
|
270
|
+
"""Initialize the skill with optional credentials"""
|
271
|
+
super().__init__()
|
272
|
+
self.credentials = credentials or OpenAICredentials.from_env()
|
273
|
+
self.client = OpenAI(
|
274
|
+
api_key=self.credentials.openai_api_key.get_secret_value(),
|
275
|
+
organization=self.credentials.openai_organization_id,
|
276
|
+
)
|
277
|
+
self.async_client = AsyncOpenAI(
|
278
|
+
api_key=self.credentials.openai_api_key.get_secret_value(),
|
279
|
+
organization=self.credentials.openai_organization_id,
|
280
|
+
)
|
281
|
+
|
282
|
+
def process(self, input_data: OpenAIEmbeddingsInput) -> OpenAIEmbeddingsOutput:
|
283
|
+
"""Generate embeddings for the input text(s)"""
|
284
|
+
try:
|
285
|
+
# Handle single text input
|
286
|
+
texts = (
|
287
|
+
[input_data.texts]
|
288
|
+
if isinstance(input_data.texts, str)
|
289
|
+
else input_data.texts
|
290
|
+
)
|
291
|
+
|
292
|
+
# Create embeddings
|
293
|
+
response = self.client.embeddings.create(
|
294
|
+
model=input_data.model,
|
295
|
+
input=texts,
|
296
|
+
encoding_format=input_data.encoding_format,
|
297
|
+
dimensions=input_data.dimensions,
|
298
|
+
)
|
299
|
+
|
300
|
+
# Extract embeddings
|
301
|
+
embeddings = [data.embedding for data in response.data]
|
302
|
+
|
303
|
+
return OpenAIEmbeddingsOutput(
|
304
|
+
embeddings=embeddings,
|
305
|
+
used_model=response.model,
|
306
|
+
tokens_used=response.usage.total_tokens,
|
307
|
+
)
|
308
|
+
except Exception as e:
|
309
|
+
raise ProcessingError(f"OpenAI embeddings generation failed: {str(e)}")
|
310
|
+
|
311
|
+
async def process_async(
|
312
|
+
self, input_data: OpenAIEmbeddingsInput
|
313
|
+
) -> OpenAIEmbeddingsOutput:
|
314
|
+
"""Async version of the embeddings generation"""
|
315
|
+
try:
|
316
|
+
# Handle single text input
|
317
|
+
texts = (
|
318
|
+
[input_data.texts]
|
319
|
+
if isinstance(input_data.texts, str)
|
320
|
+
else input_data.texts
|
321
|
+
)
|
322
|
+
|
323
|
+
# Create embeddings
|
324
|
+
response = await self.async_client.embeddings.create(
|
325
|
+
model=input_data.model,
|
326
|
+
input=texts,
|
327
|
+
encoding_format=input_data.encoding_format,
|
328
|
+
dimensions=input_data.dimensions,
|
329
|
+
)
|
330
|
+
|
331
|
+
# Extract embeddings
|
332
|
+
embeddings = [data.embedding for data in response.data]
|
333
|
+
|
334
|
+
return OpenAIEmbeddingsOutput(
|
335
|
+
embeddings=embeddings,
|
336
|
+
used_model=response.model,
|
337
|
+
tokens_used=response.usage.total_tokens,
|
338
|
+
)
|
339
|
+
except Exception as e:
|
340
|
+
raise ProcessingError(
|
341
|
+
f"OpenAI async embeddings generation failed: {str(e)}"
|
342
|
+
)
|
@@ -2,5 +2,18 @@
|
|
2
2
|
|
3
3
|
from .credentials import TogetherAICredentials
|
4
4
|
from .skills import TogetherAIChatSkill
|
5
|
+
from .list_models import (
|
6
|
+
TogetherListModelsSkill,
|
7
|
+
TogetherListModelsInput,
|
8
|
+
TogetherListModelsOutput,
|
9
|
+
)
|
10
|
+
from .models import TogetherModel
|
5
11
|
|
6
|
-
__all__ = [
|
12
|
+
__all__ = [
|
13
|
+
"TogetherAICredentials",
|
14
|
+
"TogetherAIChatSkill",
|
15
|
+
"TogetherListModelsSkill",
|
16
|
+
"TogetherListModelsInput",
|
17
|
+
"TogetherListModelsOutput",
|
18
|
+
"TogetherModel",
|
19
|
+
]
|
@@ -0,0 +1,77 @@
|
|
1
|
+
from typing import Optional
|
2
|
+
import requests
|
3
|
+
from pydantic import Field
|
4
|
+
|
5
|
+
from airtrain.core.skills import Skill, ProcessingError
|
6
|
+
from airtrain.core.schemas import InputSchema, OutputSchema
|
7
|
+
from .credentials import TogetherAICredentials
|
8
|
+
from .models import TogetherModel
|
9
|
+
|
10
|
+
|
11
|
+
class TogetherListModelsInput(InputSchema):
|
12
|
+
"""Schema for Together AI list models input"""
|
13
|
+
pass
|
14
|
+
|
15
|
+
|
16
|
+
class TogetherListModelsOutput(OutputSchema):
|
17
|
+
"""Schema for Together AI list models output"""
|
18
|
+
|
19
|
+
data: list[TogetherModel] = Field(
|
20
|
+
default_factory=list,
|
21
|
+
description="List of Together AI models"
|
22
|
+
)
|
23
|
+
object: Optional[str] = Field(
|
24
|
+
default=None,
|
25
|
+
description="Object type"
|
26
|
+
)
|
27
|
+
|
28
|
+
|
29
|
+
class TogetherListModelsSkill(Skill[TogetherListModelsInput, TogetherListModelsOutput]):
|
30
|
+
"""Skill for listing Together AI models"""
|
31
|
+
|
32
|
+
input_schema = TogetherListModelsInput
|
33
|
+
output_schema = TogetherListModelsOutput
|
34
|
+
|
35
|
+
def __init__(self, credentials: Optional[TogetherAICredentials] = None):
|
36
|
+
"""Initialize the skill with optional credentials"""
|
37
|
+
super().__init__()
|
38
|
+
self.credentials = credentials or TogetherAICredentials.from_env()
|
39
|
+
self.base_url = "https://api.together.xyz/v1"
|
40
|
+
|
41
|
+
def process(
|
42
|
+
self, input_data: TogetherListModelsInput
|
43
|
+
) -> TogetherListModelsOutput:
|
44
|
+
"""Process the input and return a list of models."""
|
45
|
+
try:
|
46
|
+
# Build the URL
|
47
|
+
url = f"{self.base_url}/models"
|
48
|
+
|
49
|
+
# Make the request
|
50
|
+
headers = {
|
51
|
+
"Authorization": (
|
52
|
+
f"Bearer {self.credentials.together_api_key.get_secret_value()}"
|
53
|
+
),
|
54
|
+
"accept": "application/json"
|
55
|
+
}
|
56
|
+
|
57
|
+
response = requests.get(url, headers=headers)
|
58
|
+
response.raise_for_status()
|
59
|
+
|
60
|
+
# Parse the response
|
61
|
+
result = response.json()
|
62
|
+
|
63
|
+
# Convert the models to TogetherModel objects
|
64
|
+
models = []
|
65
|
+
for model_data in result.get("data", []):
|
66
|
+
models.append(TogetherModel(**model_data))
|
67
|
+
|
68
|
+
# Return the output
|
69
|
+
return TogetherListModelsOutput(
|
70
|
+
data=models,
|
71
|
+
object=result.get("object")
|
72
|
+
)
|
73
|
+
|
74
|
+
except requests.RequestException as e:
|
75
|
+
raise ProcessingError(f"Failed to list Together AI models: {str(e)}")
|
76
|
+
except Exception as e:
|
77
|
+
raise ProcessingError(f"Error listing Together AI models: {str(e)}")
|