airtrain 0.1.39__py3-none-any.whl → 0.1.41__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,5 +1,5 @@
1
- from typing import Type, TypeVar, Optional, List, Dict, Any, Generator
2
- from pydantic import BaseModel, Field
1
+ from typing import Type, TypeVar, Optional, List, Dict, Any, Generator, Union
2
+ from pydantic import BaseModel, Field, create_model
3
3
  import requests
4
4
  import json
5
5
  from loguru import logger
@@ -20,7 +20,7 @@ class FireworksStructuredRequestInput(InputSchema):
20
20
  default="You are a helpful assistant that provides structured data.",
21
21
  description="System prompt to guide the model's behavior",
22
22
  )
23
- conversation_history: List[Dict[str, str]] = Field(
23
+ conversation_history: List[Dict[str, Any]] = Field(
24
24
  default_factory=list,
25
25
  description="List of previous conversation messages",
26
26
  )
@@ -34,8 +34,21 @@ class FireworksStructuredRequestInput(InputSchema):
34
34
  max_tokens: int = Field(default=4096, description="Maximum tokens in response")
35
35
  response_model: Type[ResponseT]
36
36
  stream: bool = Field(
37
- default=False,
38
- description="Whether to stream the response",
37
+ default=False, description="Whether to stream the response token by token"
38
+ )
39
+ tools: Optional[List[Dict[str, Any]]] = Field(
40
+ default=None,
41
+ description=(
42
+ "A list of tools the model may use. "
43
+ "Currently only functions supported."
44
+ ),
45
+ )
46
+ tool_choice: Optional[Union[str, Dict[str, Any]]] = Field(
47
+ default=None,
48
+ description=(
49
+ "Controls which tool is called by the model. "
50
+ "'none', 'auto', or specific tool."
51
+ ),
39
52
  )
40
53
 
41
54
  class Config:
@@ -49,6 +62,9 @@ class FireworksStructuredRequestOutput(OutputSchema):
49
62
  used_model: str
50
63
  usage: Dict[str, int]
51
64
  reasoning: Optional[str] = None
65
+ tool_calls: Optional[List[Dict[str, Any]]] = Field(
66
+ default=None, description="Tool calls generated by the model"
67
+ )
52
68
 
53
69
 
54
70
  class FireworksStructuredRequestSkill(
@@ -72,7 +88,7 @@ class FireworksStructuredRequestSkill(
72
88
 
73
89
  def _build_messages(
74
90
  self, input_data: FireworksStructuredRequestInput
75
- ) -> List[Dict[str, str]]:
91
+ ) -> List[Dict[str, Any]]:
76
92
  """Build messages list from input data including conversation history."""
77
93
  messages = [{"role": "system", "content": input_data.system_prompt}]
78
94
 
@@ -86,24 +102,24 @@ class FireworksStructuredRequestSkill(
86
102
  self, input_data: FireworksStructuredRequestInput
87
103
  ) -> Dict[str, Any]:
88
104
  """Build the request payload."""
89
- return {
105
+ payload = {
90
106
  "model": input_data.model,
91
107
  "messages": self._build_messages(input_data),
92
108
  "temperature": input_data.temperature,
93
109
  "max_tokens": input_data.max_tokens,
94
110
  "stream": input_data.stream,
95
- "response_format": {
96
- "type": "json_object",
97
- "schema": {
98
- **input_data.response_model.model_json_schema(),
99
- "required": [
100
- field
101
- for field, _ in input_data.response_model.model_fields.items()
102
- ],
103
- },
104
- },
111
+ "response_format": {"type": "json_object"},
105
112
  }
106
113
 
114
+ # Add tool-related parameters if provided
115
+ if input_data.tools:
116
+ payload["tools"] = input_data.tools
117
+
118
+ if input_data.tool_choice:
119
+ payload["tool_choice"] = input_data.tool_choice
120
+
121
+ return payload
122
+
107
123
  def process_stream(
108
124
  self, input_data: FireworksStructuredRequestInput
109
125
  ) -> Generator[Dict[str, Any], None, None]:
@@ -131,6 +147,10 @@ class FireworksStructuredRequestSkill(
131
147
  continue
132
148
 
133
149
  # Once complete, parse the full response with think tags
150
+ if not json_buffer:
151
+ # If no data was collected, raise error
152
+ raise ProcessingError("No data received from Fireworks API")
153
+
134
154
  complete_response = "".join(json_buffer)
135
155
  reasoning, json_str = self._parse_response_content(complete_response)
136
156
 
@@ -177,37 +197,94 @@ class FireworksStructuredRequestSkill(
177
197
 
178
198
  if parsed_response is None:
179
199
  raise ProcessingError("Failed to parse streamed response")
200
+
201
+ # Make a non-streaming call to get tool calls if tools were provided
202
+ tool_calls = None
203
+ if input_data.tools:
204
+ # Create a non-streaming request to get tool calls
205
+ non_stream_payload = self._build_payload(input_data)
206
+ non_stream_payload["stream"] = False
207
+
208
+ response = requests.post(
209
+ self.BASE_URL,
210
+ headers=self.headers,
211
+ data=json.dumps(non_stream_payload),
212
+ )
213
+ response.raise_for_status()
214
+ result = response.json()
215
+
216
+ # Check for tool calls
217
+ if (result["choices"][0]["message"].get("tool_calls")):
218
+ tool_calls = [
219
+ {
220
+ "id": tool_call["id"],
221
+ "type": tool_call["type"],
222
+ "function": {
223
+ "name": tool_call["function"]["name"],
224
+ "arguments": tool_call["function"]["arguments"]
225
+ }
226
+ }
227
+ for tool_call in result["choices"][0]["message"]["tool_calls"]
228
+ ]
180
229
 
181
230
  return FireworksStructuredRequestOutput(
182
231
  parsed_response=parsed_response,
183
232
  used_model=input_data.model,
184
- usage={}, # Usage stats not available in streaming mode
233
+ usage={"total_tokens": 0}, # Can't get usage stats from streaming
185
234
  reasoning=reasoning,
235
+ tool_calls=tool_calls,
186
236
  )
187
237
  else:
188
238
  # For non-streaming, use regular request
189
239
  payload = self._build_payload(input_data)
240
+ payload["stream"] = False # Ensure it's not streaming
241
+
190
242
  response = requests.post(
191
243
  self.BASE_URL, headers=self.headers, data=json.dumps(payload)
192
244
  )
193
245
  response.raise_for_status()
194
- data = response.json()
195
-
196
- response_content = data["choices"][0]["message"]["content"]
197
-
198
- # Parse the response content to extract reasoning and JSON
199
- reasoning, json_str = self._parse_response_content(response_content)
200
-
201
- # Parse the JSON string into the specified model
202
- parsed_response = input_data.response_model.model_validate_json(
203
- json_str
204
- )
246
+ result = response.json()
247
+
248
+ # Get the content from the response
249
+ if "choices" not in result or not result["choices"]:
250
+ raise ProcessingError("Invalid response format from Fireworks API")
251
+
252
+ content = result["choices"][0]["message"].get("content", "")
253
+
254
+ # Check for tool calls
255
+ tool_calls = None
256
+ if (result["choices"][0]["message"].get("tool_calls")):
257
+ tool_calls = [
258
+ {
259
+ "id": tool_call["id"],
260
+ "type": tool_call["type"],
261
+ "function": {
262
+ "name": tool_call["function"]["name"],
263
+ "arguments": tool_call["function"]["arguments"]
264
+ }
265
+ }
266
+ for tool_call in result["choices"][0]["message"]["tool_calls"]
267
+ ]
268
+
269
+ # Parse the response content
270
+ reasoning, json_str = self._parse_response_content(content)
271
+ try:
272
+ parsed_response = input_data.response_model.model_validate_json(
273
+ json_str
274
+ )
275
+ except Exception as e:
276
+ raise ProcessingError(f"Failed to parse JSON response: {str(e)}")
205
277
 
206
278
  return FireworksStructuredRequestOutput(
207
279
  parsed_response=parsed_response,
208
280
  used_model=input_data.model,
209
- usage=data["usage"],
210
- reasoning=reasoning, # Add reasoning to output if present
281
+ usage={
282
+ "total_tokens": result["usage"]["total_tokens"],
283
+ "prompt_tokens": result["usage"]["prompt_tokens"],
284
+ "completion_tokens": result["usage"]["completion_tokens"],
285
+ },
286
+ reasoning=reasoning,
287
+ tool_calls=tool_calls,
211
288
  )
212
289
 
213
290
  except Exception as e:
@@ -5,6 +5,9 @@ from .skills import (
5
5
  OpenAIOutput,
6
6
  OpenAIParserInput,
7
7
  OpenAIParserOutput,
8
+ OpenAIEmbeddingsSkill,
9
+ OpenAIEmbeddingsInput,
10
+ OpenAIEmbeddingsOutput,
8
11
  )
9
12
  from .credentials import OpenAICredentials
10
13
 
@@ -16,4 +19,7 @@ __all__ = [
16
19
  "OpenAIParserOutput",
17
20
  "OpenAICredentials",
18
21
  "OpenAIOutput",
22
+ "OpenAIEmbeddingsSkill",
23
+ "OpenAIEmbeddingsInput",
24
+ "OpenAIEmbeddingsOutput",
19
25
  ]
@@ -11,6 +11,20 @@ class OpenAIModelConfig(NamedTuple):
11
11
 
12
12
 
13
13
  OPENAI_MODELS: Dict[str, OpenAIModelConfig] = {
14
+ "gpt-4.5-preview": OpenAIModelConfig(
15
+ display_name="GPT-4.5 Preview",
16
+ base_model="gpt-4.5-preview",
17
+ input_price=Decimal("75.00"),
18
+ cached_input_price=Decimal("37.50"),
19
+ output_price=Decimal("150.00"),
20
+ ),
21
+ "gpt-4.5-preview-2025-02-27": OpenAIModelConfig(
22
+ display_name="GPT-4.5 Preview (2025-02-27)",
23
+ base_model="gpt-4.5-preview",
24
+ input_price=Decimal("75.00"),
25
+ cached_input_price=Decimal("37.50"),
26
+ output_price=Decimal("150.00"),
27
+ ),
14
28
  "gpt-4o": OpenAIModelConfig(
15
29
  display_name="GPT-4 Optimized",
16
30
  base_model="gpt-4o",
@@ -25,69 +39,160 @@ OPENAI_MODELS: Dict[str, OpenAIModelConfig] = {
25
39
  cached_input_price=Decimal("1.25"),
26
40
  output_price=Decimal("10.00"),
27
41
  ),
28
- "gpt-4o-2024-05-13": OpenAIModelConfig(
29
- display_name="GPT-4 Optimized (2024-05-13)",
30
- base_model="gpt-4o",
31
- input_price=Decimal("5.00"),
42
+ "gpt-4o-audio-preview": OpenAIModelConfig(
43
+ display_name="GPT-4 Optimized Audio Preview",
44
+ base_model="gpt-4o-audio-preview",
45
+ input_price=Decimal("2.50"),
32
46
  cached_input_price=None,
33
- output_price=Decimal("15.00"),
47
+ output_price=Decimal("10.00"),
34
48
  ),
35
49
  "gpt-4o-audio-preview-2024-12-17": OpenAIModelConfig(
36
- display_name="GPT-4 Optimized Audio Preview",
50
+ display_name="GPT-4 Optimized Audio Preview (2024-12-17)",
37
51
  base_model="gpt-4o-audio-preview",
38
52
  input_price=Decimal("2.50"),
39
53
  cached_input_price=None,
40
54
  output_price=Decimal("10.00"),
41
55
  ),
42
- "gpt-4o-realtime-preview-2024-12-17": OpenAIModelConfig(
56
+ "gpt-4o-realtime-preview": OpenAIModelConfig(
43
57
  display_name="GPT-4 Optimized Realtime Preview",
44
58
  base_model="gpt-4o-realtime-preview",
45
59
  input_price=Decimal("5.00"),
46
60
  cached_input_price=Decimal("2.50"),
47
61
  output_price=Decimal("20.00"),
48
62
  ),
49
- "gpt-4o-mini-2024-07-18": OpenAIModelConfig(
63
+ "gpt-4o-realtime-preview-2024-12-17": OpenAIModelConfig(
64
+ display_name="GPT-4 Optimized Realtime Preview (2024-12-17)",
65
+ base_model="gpt-4o-realtime-preview",
66
+ input_price=Decimal("5.00"),
67
+ cached_input_price=Decimal("2.50"),
68
+ output_price=Decimal("20.00"),
69
+ ),
70
+ "gpt-4o-mini": OpenAIModelConfig(
50
71
  display_name="GPT-4 Optimized Mini",
51
72
  base_model="gpt-4o-mini",
52
73
  input_price=Decimal("0.15"),
53
74
  cached_input_price=Decimal("0.075"),
54
75
  output_price=Decimal("0.60"),
55
76
  ),
56
- "gpt-4o-mini-audio-preview-2024-12-17": OpenAIModelConfig(
77
+ "gpt-4o-mini-2024-07-18": OpenAIModelConfig(
78
+ display_name="GPT-4 Optimized Mini (2024-07-18)",
79
+ base_model="gpt-4o-mini",
80
+ input_price=Decimal("0.15"),
81
+ cached_input_price=Decimal("0.075"),
82
+ output_price=Decimal("0.60"),
83
+ ),
84
+ "gpt-4o-mini-audio-preview": OpenAIModelConfig(
57
85
  display_name="GPT-4 Optimized Mini Audio Preview",
58
86
  base_model="gpt-4o-mini-audio-preview",
59
87
  input_price=Decimal("0.15"),
60
88
  cached_input_price=None,
61
89
  output_price=Decimal("0.60"),
62
90
  ),
63
- "gpt-4o-mini-realtime-preview-2024-12-17": OpenAIModelConfig(
91
+ "gpt-4o-mini-audio-preview-2024-12-17": OpenAIModelConfig(
92
+ display_name="GPT-4 Optimized Mini Audio Preview (2024-12-17)",
93
+ base_model="gpt-4o-mini-audio-preview",
94
+ input_price=Decimal("0.15"),
95
+ cached_input_price=None,
96
+ output_price=Decimal("0.60"),
97
+ ),
98
+ "gpt-4o-mini-realtime-preview": OpenAIModelConfig(
64
99
  display_name="GPT-4 Optimized Mini Realtime Preview",
65
100
  base_model="gpt-4o-mini-realtime-preview",
66
101
  input_price=Decimal("0.60"),
67
102
  cached_input_price=Decimal("0.30"),
68
103
  output_price=Decimal("2.40"),
69
104
  ),
70
- "o1-2024-12-17": OpenAIModelConfig(
105
+ "gpt-4o-mini-realtime-preview-2024-12-17": OpenAIModelConfig(
106
+ display_name="GPT-4 Optimized Mini Realtime Preview (2024-12-17)",
107
+ base_model="gpt-4o-mini-realtime-preview",
108
+ input_price=Decimal("0.60"),
109
+ cached_input_price=Decimal("0.30"),
110
+ output_price=Decimal("2.40"),
111
+ ),
112
+ "o1": OpenAIModelConfig(
71
113
  display_name="O1",
72
114
  base_model="o1",
73
115
  input_price=Decimal("15.00"),
74
116
  cached_input_price=Decimal("7.50"),
75
117
  output_price=Decimal("60.00"),
76
118
  ),
77
- "o3-mini-2025-01-31": OpenAIModelConfig(
119
+ "o1-2024-12-17": OpenAIModelConfig(
120
+ display_name="O1 (2024-12-17)",
121
+ base_model="o1",
122
+ input_price=Decimal("15.00"),
123
+ cached_input_price=Decimal("7.50"),
124
+ output_price=Decimal("60.00"),
125
+ ),
126
+ "o3-mini": OpenAIModelConfig(
78
127
  display_name="O3 Mini",
79
128
  base_model="o3-mini",
80
129
  input_price=Decimal("1.10"),
81
130
  cached_input_price=Decimal("0.55"),
82
131
  output_price=Decimal("4.40"),
83
132
  ),
84
- "o1-mini-2024-09-12": OpenAIModelConfig(
133
+ "o3-mini-2025-01-31": OpenAIModelConfig(
134
+ display_name="O3 Mini (2025-01-31)",
135
+ base_model="o3-mini",
136
+ input_price=Decimal("1.10"),
137
+ cached_input_price=Decimal("0.55"),
138
+ output_price=Decimal("4.40"),
139
+ ),
140
+ "o1-mini": OpenAIModelConfig(
85
141
  display_name="O1 Mini",
86
142
  base_model="o1-mini",
87
143
  input_price=Decimal("1.10"),
88
144
  cached_input_price=Decimal("0.55"),
89
145
  output_price=Decimal("4.40"),
90
146
  ),
147
+ "o1-mini-2024-09-12": OpenAIModelConfig(
148
+ display_name="O1 Mini (2024-09-12)",
149
+ base_model="o1-mini",
150
+ input_price=Decimal("1.10"),
151
+ cached_input_price=Decimal("0.55"),
152
+ output_price=Decimal("4.40"),
153
+ ),
154
+ "gpt-4o-mini-search-preview": OpenAIModelConfig(
155
+ display_name="GPT-4 Optimized Mini Search Preview",
156
+ base_model="gpt-4o-mini-search-preview",
157
+ input_price=Decimal("0.15"),
158
+ cached_input_price=None,
159
+ output_price=Decimal("0.60"),
160
+ ),
161
+ "gpt-4o-mini-search-preview-2025-03-11": OpenAIModelConfig(
162
+ display_name="GPT-4 Optimized Mini Search Preview (2025-03-11)",
163
+ base_model="gpt-4o-mini-search-preview",
164
+ input_price=Decimal("0.15"),
165
+ cached_input_price=None,
166
+ output_price=Decimal("0.60"),
167
+ ),
168
+ "gpt-4o-search-preview": OpenAIModelConfig(
169
+ display_name="GPT-4 Optimized Search Preview",
170
+ base_model="gpt-4o-search-preview",
171
+ input_price=Decimal("2.50"),
172
+ cached_input_price=None,
173
+ output_price=Decimal("10.00"),
174
+ ),
175
+ "gpt-4o-search-preview-2025-03-11": OpenAIModelConfig(
176
+ display_name="GPT-4 Optimized Search Preview (2025-03-11)",
177
+ base_model="gpt-4o-search-preview",
178
+ input_price=Decimal("2.50"),
179
+ cached_input_price=None,
180
+ output_price=Decimal("10.00"),
181
+ ),
182
+ "computer-use-preview": OpenAIModelConfig(
183
+ display_name="Computer Use Preview",
184
+ base_model="computer-use-preview",
185
+ input_price=Decimal("3.00"),
186
+ cached_input_price=None,
187
+ output_price=Decimal("12.00"),
188
+ ),
189
+ "computer-use-preview-2025-03-11": OpenAIModelConfig(
190
+ display_name="Computer Use Preview (2025-03-11)",
191
+ base_model="computer-use-preview",
192
+ input_price=Decimal("3.00"),
193
+ cached_input_price=None,
194
+ output_price=Decimal("12.00"),
195
+ ),
91
196
  }
92
197
 
93
198
 
@@ -1,7 +1,8 @@
1
- from typing import AsyncGenerator, List, Optional, Dict, TypeVar, Type, Generator
1
+ from typing import AsyncGenerator, List, Optional, Dict, TypeVar, Type, Generator, Union
2
2
  from pydantic import Field, BaseModel
3
3
  from openai import OpenAI, AsyncOpenAI
4
4
  from openai.types.chat import ChatCompletionChunk
5
+ import numpy as np
5
6
 
6
7
  from airtrain.core.skills import Skill, ProcessingError
7
8
  from airtrain.core.schemas import InputSchema, OutputSchema
@@ -232,3 +233,110 @@ class OpenAIParserSkill(Skill[OpenAIParserInput, OpenAIParserOutput]):
232
233
 
233
234
  except Exception as e:
234
235
  raise ProcessingError(f"OpenAI parsing failed: {str(e)}")
236
+
237
+
238
+ class OpenAIEmbeddingsInput(InputSchema):
239
+ """Schema for OpenAI embeddings input"""
240
+
241
+ texts: Union[str, List[str]] = Field(
242
+ ..., description="Text or list of texts to generate embeddings for"
243
+ )
244
+ model: str = Field(
245
+ default="text-embedding-3-large", description="OpenAI embeddings model to use"
246
+ )
247
+ encoding_format: str = Field(
248
+ default="float", description="The format of the embeddings: 'float' or 'base64'"
249
+ )
250
+ dimensions: Optional[int] = Field(
251
+ default=None, description="Optional number of dimensions for the embeddings"
252
+ )
253
+
254
+
255
+ class OpenAIEmbeddingsOutput(OutputSchema):
256
+ """Schema for OpenAI embeddings output"""
257
+
258
+ embeddings: List[List[float]] = Field(..., description="List of embeddings vectors")
259
+ used_model: str = Field(..., description="Model used for generating embeddings")
260
+ tokens_used: int = Field(..., description="Number of tokens used")
261
+
262
+
263
+ class OpenAIEmbeddingsSkill(Skill[OpenAIEmbeddingsInput, OpenAIEmbeddingsOutput]):
264
+ """Skill for generating embeddings using OpenAI models"""
265
+
266
+ input_schema = OpenAIEmbeddingsInput
267
+ output_schema = OpenAIEmbeddingsOutput
268
+
269
+ def __init__(self, credentials: Optional[OpenAICredentials] = None):
270
+ """Initialize the skill with optional credentials"""
271
+ super().__init__()
272
+ self.credentials = credentials or OpenAICredentials.from_env()
273
+ self.client = OpenAI(
274
+ api_key=self.credentials.openai_api_key.get_secret_value(),
275
+ organization=self.credentials.openai_organization_id,
276
+ )
277
+ self.async_client = AsyncOpenAI(
278
+ api_key=self.credentials.openai_api_key.get_secret_value(),
279
+ organization=self.credentials.openai_organization_id,
280
+ )
281
+
282
+ def process(self, input_data: OpenAIEmbeddingsInput) -> OpenAIEmbeddingsOutput:
283
+ """Generate embeddings for the input text(s)"""
284
+ try:
285
+ # Handle single text input
286
+ texts = (
287
+ [input_data.texts]
288
+ if isinstance(input_data.texts, str)
289
+ else input_data.texts
290
+ )
291
+
292
+ # Create embeddings
293
+ response = self.client.embeddings.create(
294
+ model=input_data.model,
295
+ input=texts,
296
+ encoding_format=input_data.encoding_format,
297
+ dimensions=input_data.dimensions,
298
+ )
299
+
300
+ # Extract embeddings
301
+ embeddings = [data.embedding for data in response.data]
302
+
303
+ return OpenAIEmbeddingsOutput(
304
+ embeddings=embeddings,
305
+ used_model=response.model,
306
+ tokens_used=response.usage.total_tokens,
307
+ )
308
+ except Exception as e:
309
+ raise ProcessingError(f"OpenAI embeddings generation failed: {str(e)}")
310
+
311
+ async def process_async(
312
+ self, input_data: OpenAIEmbeddingsInput
313
+ ) -> OpenAIEmbeddingsOutput:
314
+ """Async version of the embeddings generation"""
315
+ try:
316
+ # Handle single text input
317
+ texts = (
318
+ [input_data.texts]
319
+ if isinstance(input_data.texts, str)
320
+ else input_data.texts
321
+ )
322
+
323
+ # Create embeddings
324
+ response = await self.async_client.embeddings.create(
325
+ model=input_data.model,
326
+ input=texts,
327
+ encoding_format=input_data.encoding_format,
328
+ dimensions=input_data.dimensions,
329
+ )
330
+
331
+ # Extract embeddings
332
+ embeddings = [data.embedding for data in response.data]
333
+
334
+ return OpenAIEmbeddingsOutput(
335
+ embeddings=embeddings,
336
+ used_model=response.model,
337
+ tokens_used=response.usage.total_tokens,
338
+ )
339
+ except Exception as e:
340
+ raise ProcessingError(
341
+ f"OpenAI async embeddings generation failed: {str(e)}"
342
+ )
@@ -2,5 +2,18 @@
2
2
 
3
3
  from .credentials import TogetherAICredentials
4
4
  from .skills import TogetherAIChatSkill
5
+ from .list_models import (
6
+ TogetherListModelsSkill,
7
+ TogetherListModelsInput,
8
+ TogetherListModelsOutput,
9
+ )
10
+ from .models import TogetherModel
5
11
 
6
- __all__ = ["TogetherAICredentials", "TogetherAIChatSkill"]
12
+ __all__ = [
13
+ "TogetherAICredentials",
14
+ "TogetherAIChatSkill",
15
+ "TogetherListModelsSkill",
16
+ "TogetherListModelsInput",
17
+ "TogetherListModelsOutput",
18
+ "TogetherModel",
19
+ ]
@@ -0,0 +1,77 @@
1
+ from typing import Optional
2
+ import requests
3
+ from pydantic import Field
4
+
5
+ from airtrain.core.skills import Skill, ProcessingError
6
+ from airtrain.core.schemas import InputSchema, OutputSchema
7
+ from .credentials import TogetherAICredentials
8
+ from .models import TogetherModel
9
+
10
+
11
+ class TogetherListModelsInput(InputSchema):
12
+ """Schema for Together AI list models input"""
13
+ pass
14
+
15
+
16
+ class TogetherListModelsOutput(OutputSchema):
17
+ """Schema for Together AI list models output"""
18
+
19
+ data: list[TogetherModel] = Field(
20
+ default_factory=list,
21
+ description="List of Together AI models"
22
+ )
23
+ object: Optional[str] = Field(
24
+ default=None,
25
+ description="Object type"
26
+ )
27
+
28
+
29
+ class TogetherListModelsSkill(Skill[TogetherListModelsInput, TogetherListModelsOutput]):
30
+ """Skill for listing Together AI models"""
31
+
32
+ input_schema = TogetherListModelsInput
33
+ output_schema = TogetherListModelsOutput
34
+
35
+ def __init__(self, credentials: Optional[TogetherAICredentials] = None):
36
+ """Initialize the skill with optional credentials"""
37
+ super().__init__()
38
+ self.credentials = credentials or TogetherAICredentials.from_env()
39
+ self.base_url = "https://api.together.xyz/v1"
40
+
41
+ def process(
42
+ self, input_data: TogetherListModelsInput
43
+ ) -> TogetherListModelsOutput:
44
+ """Process the input and return a list of models."""
45
+ try:
46
+ # Build the URL
47
+ url = f"{self.base_url}/models"
48
+
49
+ # Make the request
50
+ headers = {
51
+ "Authorization": (
52
+ f"Bearer {self.credentials.together_api_key.get_secret_value()}"
53
+ ),
54
+ "accept": "application/json"
55
+ }
56
+
57
+ response = requests.get(url, headers=headers)
58
+ response.raise_for_status()
59
+
60
+ # Parse the response
61
+ result = response.json()
62
+
63
+ # Convert the models to TogetherModel objects
64
+ models = []
65
+ for model_data in result.get("data", []):
66
+ models.append(TogetherModel(**model_data))
67
+
68
+ # Return the output
69
+ return TogetherListModelsOutput(
70
+ data=models,
71
+ object=result.get("object")
72
+ )
73
+
74
+ except requests.RequestException as e:
75
+ raise ProcessingError(f"Failed to list Together AI models: {str(e)}")
76
+ except Exception as e:
77
+ raise ProcessingError(f"Error listing Together AI models: {str(e)}")