airtrain 0.1.2__py3-none-any.whl → 0.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airtrain/__init__.py +148 -2
- airtrain/__main__.py +4 -0
- airtrain/__pycache__/__init__.cpython-313.pyc +0 -0
- airtrain/agents/__init__.py +45 -0
- airtrain/agents/example_agent.py +348 -0
- airtrain/agents/groq_agent.py +289 -0
- airtrain/agents/memory.py +663 -0
- airtrain/agents/registry.py +465 -0
- airtrain/builder/__init__.py +3 -0
- airtrain/builder/agent_builder.py +122 -0
- airtrain/cli/__init__.py +0 -0
- airtrain/cli/builder.py +23 -0
- airtrain/cli/main.py +120 -0
- airtrain/contrib/__init__.py +29 -0
- airtrain/contrib/travel/__init__.py +35 -0
- airtrain/contrib/travel/agents.py +243 -0
- airtrain/contrib/travel/models.py +59 -0
- airtrain/core/__init__.py +7 -0
- airtrain/core/__pycache__/__init__.cpython-313.pyc +0 -0
- airtrain/core/__pycache__/schemas.cpython-313.pyc +0 -0
- airtrain/core/__pycache__/skills.cpython-313.pyc +0 -0
- airtrain/core/credentials.py +171 -0
- airtrain/core/schemas.py +237 -0
- airtrain/core/skills.py +269 -0
- airtrain/integrations/__init__.py +74 -0
- airtrain/integrations/anthropic/__init__.py +33 -0
- airtrain/integrations/anthropic/credentials.py +32 -0
- airtrain/integrations/anthropic/list_models.py +110 -0
- airtrain/integrations/anthropic/models_config.py +100 -0
- airtrain/integrations/anthropic/skills.py +155 -0
- airtrain/integrations/aws/__init__.py +6 -0
- airtrain/integrations/aws/credentials.py +36 -0
- airtrain/integrations/aws/skills.py +98 -0
- airtrain/integrations/cerebras/__init__.py +6 -0
- airtrain/integrations/cerebras/credentials.py +19 -0
- airtrain/integrations/cerebras/skills.py +127 -0
- airtrain/integrations/combined/__init__.py +21 -0
- airtrain/integrations/combined/groq_fireworks_skills.py +126 -0
- airtrain/integrations/combined/list_models_factory.py +210 -0
- airtrain/integrations/fireworks/__init__.py +21 -0
- airtrain/integrations/fireworks/completion_skills.py +147 -0
- airtrain/integrations/fireworks/conversation_manager.py +109 -0
- airtrain/integrations/fireworks/credentials.py +26 -0
- airtrain/integrations/fireworks/list_models.py +128 -0
- airtrain/integrations/fireworks/models.py +139 -0
- airtrain/integrations/fireworks/requests_skills.py +207 -0
- airtrain/integrations/fireworks/skills.py +181 -0
- airtrain/integrations/fireworks/structured_completion_skills.py +175 -0
- airtrain/integrations/fireworks/structured_requests_skills.py +291 -0
- airtrain/integrations/fireworks/structured_skills.py +102 -0
- airtrain/integrations/google/__init__.py +7 -0
- airtrain/integrations/google/credentials.py +58 -0
- airtrain/integrations/google/skills.py +122 -0
- airtrain/integrations/groq/__init__.py +23 -0
- airtrain/integrations/groq/credentials.py +24 -0
- airtrain/integrations/groq/models_config.py +162 -0
- airtrain/integrations/groq/skills.py +201 -0
- airtrain/integrations/ollama/__init__.py +6 -0
- airtrain/integrations/ollama/credentials.py +26 -0
- airtrain/integrations/ollama/skills.py +41 -0
- airtrain/integrations/openai/__init__.py +37 -0
- airtrain/integrations/openai/chinese_assistant.py +42 -0
- airtrain/integrations/openai/credentials.py +39 -0
- airtrain/integrations/openai/list_models.py +112 -0
- airtrain/integrations/openai/models_config.py +224 -0
- airtrain/integrations/openai/skills.py +342 -0
- airtrain/integrations/perplexity/__init__.py +49 -0
- airtrain/integrations/perplexity/credentials.py +43 -0
- airtrain/integrations/perplexity/list_models.py +112 -0
- airtrain/integrations/perplexity/models_config.py +128 -0
- airtrain/integrations/perplexity/skills.py +279 -0
- airtrain/integrations/sambanova/__init__.py +6 -0
- airtrain/integrations/sambanova/credentials.py +20 -0
- airtrain/integrations/sambanova/skills.py +129 -0
- airtrain/integrations/search/__init__.py +21 -0
- airtrain/integrations/search/exa/__init__.py +23 -0
- airtrain/integrations/search/exa/credentials.py +30 -0
- airtrain/integrations/search/exa/schemas.py +114 -0
- airtrain/integrations/search/exa/skills.py +115 -0
- airtrain/integrations/together/__init__.py +33 -0
- airtrain/integrations/together/audio_models_config.py +34 -0
- airtrain/integrations/together/credentials.py +22 -0
- airtrain/integrations/together/embedding_models_config.py +92 -0
- airtrain/integrations/together/image_models_config.py +69 -0
- airtrain/integrations/together/image_skill.py +143 -0
- airtrain/integrations/together/list_models.py +76 -0
- airtrain/integrations/together/models.py +95 -0
- airtrain/integrations/together/models_config.py +399 -0
- airtrain/integrations/together/rerank_models_config.py +43 -0
- airtrain/integrations/together/rerank_skill.py +49 -0
- airtrain/integrations/together/schemas.py +33 -0
- airtrain/integrations/together/skills.py +305 -0
- airtrain/integrations/together/vision_models_config.py +49 -0
- airtrain/telemetry/__init__.py +38 -0
- airtrain/telemetry/service.py +167 -0
- airtrain/telemetry/views.py +237 -0
- airtrain/tools/__init__.py +45 -0
- airtrain/tools/command.py +398 -0
- airtrain/tools/filesystem.py +166 -0
- airtrain/tools/network.py +111 -0
- airtrain/tools/registry.py +320 -0
- airtrain/tools/search.py +450 -0
- airtrain/tools/testing.py +135 -0
- airtrain-0.1.4.dist-info/METADATA +222 -0
- airtrain-0.1.4.dist-info/RECORD +108 -0
- {airtrain-0.1.2.dist-info → airtrain-0.1.4.dist-info}/WHEEL +1 -1
- airtrain-0.1.4.dist-info/entry_points.txt +2 -0
- airtrain-0.1.2.dist-info/METADATA +0 -106
- airtrain-0.1.2.dist-info/RECORD +0 -5
- {airtrain-0.1.2.dist-info → airtrain-0.1.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,399 @@
|
|
1
|
+
from typing import Dict, NamedTuple, Any
|
2
|
+
|
3
|
+
|
4
|
+
class ModelConfig(NamedTuple):
|
5
|
+
organization: str
|
6
|
+
display_name: str
|
7
|
+
context_length: int
|
8
|
+
quantization: str
|
9
|
+
|
10
|
+
|
11
|
+
TOGETHER_MODELS: Dict[str, ModelConfig] = {
|
12
|
+
# DeepSeek Models
|
13
|
+
"deepseek-ai/DeepSeek-R1": ModelConfig(
|
14
|
+
organization="DeepSeek",
|
15
|
+
display_name="DeepSeek-R1",
|
16
|
+
context_length=131072,
|
17
|
+
quantization="FP8",
|
18
|
+
),
|
19
|
+
"deepseek-ai/DeepSeek-R1-Distill-Llama-70B": ModelConfig(
|
20
|
+
organization="DeepSeek",
|
21
|
+
display_name="DeepSeek R1 Distill Llama 70B",
|
22
|
+
context_length=131072,
|
23
|
+
quantization="FP16",
|
24
|
+
),
|
25
|
+
"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B": ModelConfig(
|
26
|
+
organization="DeepSeek",
|
27
|
+
display_name="DeepSeek R1 Distill Qwen 1.5B",
|
28
|
+
context_length=131072,
|
29
|
+
quantization="FP16",
|
30
|
+
),
|
31
|
+
"deepseek-ai/DeepSeek-R1-Distill-Qwen-14B": ModelConfig(
|
32
|
+
organization="DeepSeek",
|
33
|
+
display_name="DeepSeek R1 Distill Qwen 14B",
|
34
|
+
context_length=131072,
|
35
|
+
quantization="FP16",
|
36
|
+
),
|
37
|
+
"deepseek-ai/DeepSeek-V3": ModelConfig(
|
38
|
+
organization="DeepSeek",
|
39
|
+
display_name="DeepSeek-V3",
|
40
|
+
context_length=131072,
|
41
|
+
quantization="FP8",
|
42
|
+
),
|
43
|
+
# Meta Models
|
44
|
+
"meta-llama/Llama-3.3-70B-Instruct-Turbo": ModelConfig(
|
45
|
+
organization="Meta",
|
46
|
+
display_name="Llama 3.3 70B Instruct Turbo",
|
47
|
+
context_length=131072,
|
48
|
+
quantization="FP8",
|
49
|
+
),
|
50
|
+
"meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo": ModelConfig(
|
51
|
+
organization="Meta",
|
52
|
+
display_name="Llama 3.1 8B Instruct Turbo",
|
53
|
+
context_length=131072,
|
54
|
+
quantization="FP8",
|
55
|
+
),
|
56
|
+
"meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo": ModelConfig(
|
57
|
+
organization="Meta",
|
58
|
+
display_name="Llama 3.1 70B Instruct Turbo",
|
59
|
+
context_length=131072,
|
60
|
+
quantization="FP8",
|
61
|
+
),
|
62
|
+
"meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo": ModelConfig(
|
63
|
+
organization="Meta",
|
64
|
+
display_name="Llama 3.1 405B Instruct Turbo",
|
65
|
+
context_length=130815,
|
66
|
+
quantization="FP8",
|
67
|
+
),
|
68
|
+
"meta-llama/Meta-Llama-3-8B-Instruct-Turbo": ModelConfig(
|
69
|
+
organization="Meta",
|
70
|
+
display_name="Llama 3 8B Instruct Turbo",
|
71
|
+
context_length=8192,
|
72
|
+
quantization="FP8",
|
73
|
+
),
|
74
|
+
"meta-llama/Meta-Llama-3-70B-Instruct-Turbo": ModelConfig(
|
75
|
+
organization="Meta",
|
76
|
+
display_name="Llama 3 70B Instruct Turbo",
|
77
|
+
context_length=8192,
|
78
|
+
quantization="FP8",
|
79
|
+
),
|
80
|
+
"meta-llama/Llama-3.2-3B-Instruct-Turbo": ModelConfig(
|
81
|
+
organization="Meta",
|
82
|
+
display_name="Llama 3.2 3B Instruct Turbo",
|
83
|
+
context_length=131072,
|
84
|
+
quantization="FP16",
|
85
|
+
),
|
86
|
+
"meta-llama/Meta-Llama-3-8B-Instruct-Lite": ModelConfig(
|
87
|
+
organization="Meta",
|
88
|
+
display_name="Llama 3 8B Instruct Lite",
|
89
|
+
context_length=8192,
|
90
|
+
quantization="INT4",
|
91
|
+
),
|
92
|
+
"meta-llama/Meta-Llama-3-70B-Instruct-Lite": ModelConfig(
|
93
|
+
organization="Meta",
|
94
|
+
display_name="Llama 3 70B Instruct Lite",
|
95
|
+
context_length=8192,
|
96
|
+
quantization="INT4",
|
97
|
+
),
|
98
|
+
"meta-llama/Llama-3-8b-chat-hf": ModelConfig(
|
99
|
+
organization="Meta",
|
100
|
+
display_name="Llama 3 8B Instruct Reference",
|
101
|
+
context_length=8192,
|
102
|
+
quantization="FP16",
|
103
|
+
),
|
104
|
+
"meta-llama/Llama-3-70b-chat-hf": ModelConfig(
|
105
|
+
organization="Meta",
|
106
|
+
display_name="Llama 3 70B Instruct Reference",
|
107
|
+
context_length=8192,
|
108
|
+
quantization="FP16",
|
109
|
+
),
|
110
|
+
"meta-llama/Llama-2-13b-chat-hf": ModelConfig(
|
111
|
+
organization="Meta",
|
112
|
+
display_name="LLaMA-2 Chat (13B)",
|
113
|
+
context_length=4096,
|
114
|
+
quantization="FP16",
|
115
|
+
),
|
116
|
+
# Nvidia Models
|
117
|
+
"nvidia/Llama-3.1-Nemotron-70B-Instruct-HF": ModelConfig(
|
118
|
+
organization="Nvidia",
|
119
|
+
display_name="Llama 3.1 Nemotron 70B",
|
120
|
+
context_length=32768,
|
121
|
+
quantization="FP16",
|
122
|
+
),
|
123
|
+
# Qwen Models
|
124
|
+
"Qwen/Qwen2.5-Coder-32B-Instruct": ModelConfig(
|
125
|
+
organization="Qwen",
|
126
|
+
display_name="Qwen 2.5 Coder 32B Instruct",
|
127
|
+
context_length=32768,
|
128
|
+
quantization="FP16",
|
129
|
+
),
|
130
|
+
"Qwen/QwQ-32B-Preview": ModelConfig(
|
131
|
+
organization="Qwen",
|
132
|
+
display_name="QwQ-32B-Preview",
|
133
|
+
context_length=32768,
|
134
|
+
quantization="FP16",
|
135
|
+
),
|
136
|
+
"Qwen/Qwen2.5-7B-Instruct-Turbo": ModelConfig(
|
137
|
+
organization="Qwen",
|
138
|
+
display_name="Qwen 2.5 7B Instruct Turbo",
|
139
|
+
context_length=32768,
|
140
|
+
quantization="FP8",
|
141
|
+
),
|
142
|
+
"Qwen/Qwen2.5-72B-Instruct-Turbo": ModelConfig(
|
143
|
+
organization="Qwen",
|
144
|
+
display_name="Qwen 2.5 72B Instruct Turbo",
|
145
|
+
context_length=32768,
|
146
|
+
quantization="FP8",
|
147
|
+
),
|
148
|
+
"Qwen/Qwen2-72B-Instruct": ModelConfig(
|
149
|
+
organization="Qwen",
|
150
|
+
display_name="Qwen 2 Instruct (72B)",
|
151
|
+
context_length=32768,
|
152
|
+
quantization="FP16",
|
153
|
+
),
|
154
|
+
"Qwen/Qwen2-VL-72B-Instruct": ModelConfig(
|
155
|
+
organization="Qwen",
|
156
|
+
display_name="Qwen2 VL 72B Instruct",
|
157
|
+
context_length=32768,
|
158
|
+
quantization="FP16",
|
159
|
+
),
|
160
|
+
# Microsoft Models
|
161
|
+
"microsoft/WizardLM-2-8x22B": ModelConfig(
|
162
|
+
organization="Microsoft",
|
163
|
+
display_name="WizardLM-2 8x22B",
|
164
|
+
context_length=65536,
|
165
|
+
quantization="FP16",
|
166
|
+
),
|
167
|
+
# Google Models
|
168
|
+
"google/gemma-2-27b-it": ModelConfig(
|
169
|
+
organization="Google",
|
170
|
+
display_name="Gemma 2 27B",
|
171
|
+
context_length=8192,
|
172
|
+
quantization="FP16",
|
173
|
+
),
|
174
|
+
"google/gemma-2-9b-it": ModelConfig(
|
175
|
+
organization="Google",
|
176
|
+
display_name="Gemma 2 9B",
|
177
|
+
context_length=8192,
|
178
|
+
quantization="FP16",
|
179
|
+
),
|
180
|
+
"google/gemma-2b-it": ModelConfig(
|
181
|
+
organization="Google",
|
182
|
+
display_name="Gemma Instruct (2B)",
|
183
|
+
context_length=8192,
|
184
|
+
quantization="FP16",
|
185
|
+
),
|
186
|
+
# Databricks Models
|
187
|
+
"databricks/dbrx-instruct": ModelConfig(
|
188
|
+
organization="databricks",
|
189
|
+
display_name="DBRX Instruct",
|
190
|
+
context_length=32768,
|
191
|
+
quantization="FP16",
|
192
|
+
),
|
193
|
+
# Gryphe Models
|
194
|
+
"Gryphe/MythoMax-L2-13b": ModelConfig(
|
195
|
+
organization="Gryphe",
|
196
|
+
display_name="MythoMax-L2 (13B)",
|
197
|
+
context_length=4096,
|
198
|
+
quantization="FP16",
|
199
|
+
),
|
200
|
+
# Mistral AI Models
|
201
|
+
"mistralai/Mistral-Small-24B-Instruct-2501": ModelConfig(
|
202
|
+
organization="mistralai",
|
203
|
+
display_name="Mistral Small 3 Instruct (24B)",
|
204
|
+
context_length=32768,
|
205
|
+
quantization="FP16",
|
206
|
+
),
|
207
|
+
"mistralai/Mistral-7B-Instruct-v0.1": ModelConfig(
|
208
|
+
organization="mistralai",
|
209
|
+
display_name="Mistral (7B) Instruct",
|
210
|
+
context_length=8192,
|
211
|
+
quantization="FP16",
|
212
|
+
),
|
213
|
+
"mistralai/Mistral-7B-Instruct-v0.2": ModelConfig(
|
214
|
+
organization="mistralai",
|
215
|
+
display_name="Mistral (7B) Instruct v0.2",
|
216
|
+
context_length=32768,
|
217
|
+
quantization="FP16",
|
218
|
+
),
|
219
|
+
"mistralai/Mistral-7B-Instruct-v0.3": ModelConfig(
|
220
|
+
organization="mistralai",
|
221
|
+
display_name="Mistral (7B) Instruct v0.3",
|
222
|
+
context_length=32768,
|
223
|
+
quantization="FP16",
|
224
|
+
),
|
225
|
+
"mistralai/Mixtral-8x7B-Instruct-v0.1": ModelConfig(
|
226
|
+
organization="mistralai",
|
227
|
+
display_name="Mixtral-8x7B Instruct (46.7B)",
|
228
|
+
context_length=32768,
|
229
|
+
quantization="FP16",
|
230
|
+
),
|
231
|
+
"mistralai/Mixtral-8x22B-Instruct-v0.1": ModelConfig(
|
232
|
+
organization="mistralai",
|
233
|
+
display_name="Mixtral-8x22B Instruct (141B)",
|
234
|
+
context_length=65536,
|
235
|
+
quantization="FP16",
|
236
|
+
),
|
237
|
+
# NousResearch Models
|
238
|
+
"NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO": ModelConfig(
|
239
|
+
organization="NousResearch",
|
240
|
+
display_name="Nous Hermes 2 - Mixtral 8x7B-DPO (46.7B)",
|
241
|
+
context_length=32768,
|
242
|
+
quantization="FP16",
|
243
|
+
),
|
244
|
+
# Upstage Models
|
245
|
+
"upstage/SOLAR-10.7B-Instruct-v1.0": ModelConfig(
|
246
|
+
organization="upstage",
|
247
|
+
display_name="Upstage SOLAR Instruct v1 (11B)",
|
248
|
+
context_length=4096,
|
249
|
+
quantization="FP16",
|
250
|
+
),
|
251
|
+
}
|
252
|
+
|
253
|
+
|
254
|
+
def get_model_config(model_id: str) -> ModelConfig:
|
255
|
+
"""Get model configuration by model ID"""
|
256
|
+
if model_id not in TOGETHER_MODELS:
|
257
|
+
raise ValueError(f"Model {model_id} not found in Together AI models")
|
258
|
+
return TOGETHER_MODELS[model_id]
|
259
|
+
|
260
|
+
|
261
|
+
def list_models_by_organization(organization: str) -> Dict[str, ModelConfig]:
|
262
|
+
"""Get all models for a specific organization"""
|
263
|
+
return {
|
264
|
+
model_id: config
|
265
|
+
for model_id, config in TOGETHER_MODELS.items()
|
266
|
+
if config.organization.lower() == organization.lower()
|
267
|
+
}
|
268
|
+
|
269
|
+
|
270
|
+
def get_default_model() -> str:
|
271
|
+
"""Get the default model ID"""
|
272
|
+
return "meta-llama/Llama-3.3-70B-Instruct-Turbo"
|
273
|
+
|
274
|
+
|
275
|
+
# Model configuration with capabilities for each model
|
276
|
+
TOGETHER_MODELS_CONFIG = {
|
277
|
+
"meta-llama/Llama-3.1-8B-Instruct": {
|
278
|
+
"name": "Llama 3.1 8B Instruct",
|
279
|
+
"context_window": 128000,
|
280
|
+
"max_completion_tokens": 8192,
|
281
|
+
"tool_use": True,
|
282
|
+
"json_mode": True,
|
283
|
+
},
|
284
|
+
"meta-llama/Llama-3.1-70B-Instruct": {
|
285
|
+
"name": "Llama 3.1 70B Instruct",
|
286
|
+
"context_window": 128000,
|
287
|
+
"max_completion_tokens": 32768,
|
288
|
+
"tool_use": True,
|
289
|
+
"json_mode": True,
|
290
|
+
},
|
291
|
+
"mistralai/Mixtral-8x7B-Instruct-v0.1": {
|
292
|
+
"name": "Mixtral 8x7B Instruct v0.1",
|
293
|
+
"context_window": 32768,
|
294
|
+
"max_completion_tokens": 8192,
|
295
|
+
"tool_use": True,
|
296
|
+
"json_mode": True,
|
297
|
+
},
|
298
|
+
"meta-llama/Meta-Llama-3-8B-Instruct": {
|
299
|
+
"name": "Meta Llama 3 8B Instruct",
|
300
|
+
"context_window": 8192,
|
301
|
+
"max_completion_tokens": 4096,
|
302
|
+
"tool_use": True,
|
303
|
+
"json_mode": True,
|
304
|
+
},
|
305
|
+
"meta-llama/Meta-Llama-3-70B-Instruct": {
|
306
|
+
"name": "Meta Llama 3 70B Instruct",
|
307
|
+
"context_window": 8192,
|
308
|
+
"max_completion_tokens": 4096,
|
309
|
+
"tool_use": True,
|
310
|
+
"json_mode": True,
|
311
|
+
},
|
312
|
+
"deepseek-ai/DeepSeek-Coder-V2": {
|
313
|
+
"name": "DeepSeek Coder V2",
|
314
|
+
"context_window": 128000,
|
315
|
+
"max_completion_tokens": 16384,
|
316
|
+
"tool_use": True,
|
317
|
+
"json_mode": True,
|
318
|
+
},
|
319
|
+
"deepseek-ai/DeepSeek-V2": {
|
320
|
+
"name": "DeepSeek V2",
|
321
|
+
"context_window": 128000,
|
322
|
+
"max_completion_tokens": 16384,
|
323
|
+
"tool_use": True,
|
324
|
+
"json_mode": True,
|
325
|
+
},
|
326
|
+
"deepseek-ai/DeepSeek-R1": {
|
327
|
+
"name": "DeepSeek R1",
|
328
|
+
"context_window": 32768,
|
329
|
+
"max_completion_tokens": 8192,
|
330
|
+
"tool_use": False,
|
331
|
+
"json_mode": False,
|
332
|
+
},
|
333
|
+
# Qwen models
|
334
|
+
"Qwen/Qwen2.5-72B-Instruct-Turbo": {
|
335
|
+
"context_window": 128000,
|
336
|
+
"max_completion_tokens": 4096,
|
337
|
+
"tool_use": True,
|
338
|
+
"json_mode": True,
|
339
|
+
},
|
340
|
+
"Qwen/Qwen2.5-7B-Instruct": {
|
341
|
+
"context_window": 32768,
|
342
|
+
"max_completion_tokens": 4096,
|
343
|
+
"tool_use": True,
|
344
|
+
"json_mode": True,
|
345
|
+
},
|
346
|
+
}
|
347
|
+
|
348
|
+
|
349
|
+
def get_model_config_with_capabilities(model_id: str) -> Dict[str, Any]:
|
350
|
+
"""
|
351
|
+
Get the configuration for a specific model.
|
352
|
+
|
353
|
+
Args:
|
354
|
+
model_id: The model ID to get configuration for
|
355
|
+
|
356
|
+
Returns:
|
357
|
+
Dict with model configuration
|
358
|
+
|
359
|
+
Raises:
|
360
|
+
ValueError: If model_id is not found in configuration
|
361
|
+
"""
|
362
|
+
if model_id in TOGETHER_MODELS_CONFIG:
|
363
|
+
return TOGETHER_MODELS_CONFIG[model_id]
|
364
|
+
|
365
|
+
# Try to find a match with different format or case
|
366
|
+
normalized_id = model_id.lower().replace("-", "").replace("_", "").replace("/", "")
|
367
|
+
for config_id, config in TOGETHER_MODELS_CONFIG.items():
|
368
|
+
norm_config_id = config_id.lower().replace("-", "").replace("_", "").replace("/", "")
|
369
|
+
if normalized_id == norm_config_id:
|
370
|
+
return config
|
371
|
+
|
372
|
+
# Default configuration for unknown models
|
373
|
+
return {
|
374
|
+
"name": model_id,
|
375
|
+
"context_window": 4096, # Conservative default
|
376
|
+
"max_completion_tokens": 1024, # Conservative default
|
377
|
+
"tool_use": False,
|
378
|
+
"json_mode": False,
|
379
|
+
}
|
380
|
+
|
381
|
+
|
382
|
+
def supports_tool_use(model_id: str) -> bool:
|
383
|
+
"""Check if a model supports tool use."""
|
384
|
+
return get_model_config_with_capabilities(model_id).get("tool_use", False)
|
385
|
+
|
386
|
+
|
387
|
+
def supports_json_mode(model_id: str) -> bool:
|
388
|
+
"""Check if a model supports JSON mode."""
|
389
|
+
return get_model_config_with_capabilities(model_id).get("json_mode", False)
|
390
|
+
|
391
|
+
|
392
|
+
def get_max_completion_tokens(model_id: str) -> int:
|
393
|
+
"""Get the maximum number of completion tokens for a model."""
|
394
|
+
return get_model_config_with_capabilities(model_id).get("max_completion_tokens", 1024)
|
395
|
+
|
396
|
+
|
397
|
+
if __name__ == "__main__":
|
398
|
+
print(len(TOGETHER_MODELS))
|
399
|
+
print(get_model_config("meta-llama/Llama-3.3-70B-Instruct-Turbo"))
|
@@ -0,0 +1,43 @@
|
|
1
|
+
from typing import Dict, NamedTuple
|
2
|
+
|
3
|
+
|
4
|
+
class RerankModelConfig(NamedTuple):
|
5
|
+
organization: str
|
6
|
+
display_name: str
|
7
|
+
model_size: str
|
8
|
+
max_doc_size: int
|
9
|
+
max_docs: int
|
10
|
+
|
11
|
+
|
12
|
+
TOGETHER_RERANK_MODELS: Dict[str, RerankModelConfig] = {
|
13
|
+
"Salesforce/Llama-Rank-v1": RerankModelConfig(
|
14
|
+
organization="Salesforce",
|
15
|
+
display_name="LlamaRank",
|
16
|
+
model_size="8B",
|
17
|
+
max_doc_size=8192,
|
18
|
+
max_docs=1024,
|
19
|
+
)
|
20
|
+
}
|
21
|
+
|
22
|
+
|
23
|
+
def get_rerank_model_config(model_id: str) -> RerankModelConfig:
|
24
|
+
"""Get rerank model configuration by model ID"""
|
25
|
+
if model_id not in TOGETHER_RERANK_MODELS:
|
26
|
+
raise ValueError(f"Model {model_id} not found in Together AI rerank models")
|
27
|
+
return TOGETHER_RERANK_MODELS[model_id]
|
28
|
+
|
29
|
+
|
30
|
+
def list_rerank_models_by_organization(
|
31
|
+
organization: str,
|
32
|
+
) -> Dict[str, RerankModelConfig]:
|
33
|
+
"""Get all rerank models for a specific organization"""
|
34
|
+
return {
|
35
|
+
model_id: config
|
36
|
+
for model_id, config in TOGETHER_RERANK_MODELS.items()
|
37
|
+
if config.organization.lower() == organization.lower()
|
38
|
+
}
|
39
|
+
|
40
|
+
|
41
|
+
def get_default_rerank_model() -> str:
|
42
|
+
"""Get the default rerank model ID"""
|
43
|
+
return "Salesforce/Llama-Rank-v1"
|
@@ -0,0 +1,49 @@
|
|
1
|
+
from typing import Optional
|
2
|
+
from together import Together
|
3
|
+
from airtrain.core.skills import Skill, ProcessingError
|
4
|
+
from .credentials import TogetherAICredentials
|
5
|
+
from .schemas import TogetherAIRerankInput, TogetherAIRerankOutput, RerankResult
|
6
|
+
from .rerank_models_config import get_rerank_model_config
|
7
|
+
|
8
|
+
|
9
|
+
class TogetherAIRerankSkill(Skill[TogetherAIRerankInput, TogetherAIRerankOutput]):
|
10
|
+
"""Skill for reranking documents using Together AI"""
|
11
|
+
|
12
|
+
input_schema = TogetherAIRerankInput
|
13
|
+
output_schema = TogetherAIRerankOutput
|
14
|
+
|
15
|
+
def __init__(self, credentials: Optional[TogetherAICredentials] = None):
|
16
|
+
"""Initialize the skill with optional credentials"""
|
17
|
+
super().__init__()
|
18
|
+
self.credentials = credentials or TogetherAICredentials.from_env()
|
19
|
+
self.client = Together(
|
20
|
+
api_key=self.credentials.together_api_key.get_secret_value()
|
21
|
+
)
|
22
|
+
|
23
|
+
def process(self, input_data: TogetherAIRerankInput) -> TogetherAIRerankOutput:
|
24
|
+
try:
|
25
|
+
# Validate the model exists in our config
|
26
|
+
get_rerank_model_config(input_data.model)
|
27
|
+
|
28
|
+
# Call Together AI rerank API
|
29
|
+
response = self.client.rerank.create(
|
30
|
+
model=input_data.model,
|
31
|
+
query=input_data.query,
|
32
|
+
documents=input_data.documents,
|
33
|
+
top_n=input_data.top_n,
|
34
|
+
)
|
35
|
+
|
36
|
+
# Transform results
|
37
|
+
results = [
|
38
|
+
RerankResult(
|
39
|
+
index=result.index,
|
40
|
+
relevance_score=result.relevance_score,
|
41
|
+
document=input_data.documents[result.index],
|
42
|
+
)
|
43
|
+
for result in response.results
|
44
|
+
]
|
45
|
+
|
46
|
+
return TogetherAIRerankOutput(results=results, used_model=input_data.model)
|
47
|
+
|
48
|
+
except Exception as e:
|
49
|
+
raise ProcessingError(f"Together AI reranking failed: {str(e)}")
|
@@ -0,0 +1,33 @@
|
|
1
|
+
from typing import List, Optional
|
2
|
+
from pydantic import Field, BaseModel
|
3
|
+
from airtrain.core.schemas import InputSchema, OutputSchema
|
4
|
+
|
5
|
+
|
6
|
+
class RerankResult(BaseModel):
|
7
|
+
"""Schema for individual rerank result"""
|
8
|
+
|
9
|
+
index: int = Field(..., description="Index of the document in original list")
|
10
|
+
relevance_score: float = Field(..., description="Relevance score for the document")
|
11
|
+
document: str = Field(..., description="The document content")
|
12
|
+
|
13
|
+
|
14
|
+
class TogetherAIRerankInput(InputSchema):
|
15
|
+
"""Schema for Together AI rerank input"""
|
16
|
+
|
17
|
+
query: str = Field(..., description="Query to rank documents against")
|
18
|
+
documents: List[str] = Field(..., description="List of documents to rank")
|
19
|
+
model: str = Field(
|
20
|
+
default="Salesforce/Llama-Rank-v1",
|
21
|
+
description="Together AI rerank model to use",
|
22
|
+
)
|
23
|
+
top_n: Optional[int] = Field(
|
24
|
+
default=None,
|
25
|
+
description="Number of top results to return. If None, returns all results",
|
26
|
+
)
|
27
|
+
|
28
|
+
|
29
|
+
class TogetherAIRerankOutput(OutputSchema):
|
30
|
+
"""Schema for Together AI rerank output"""
|
31
|
+
|
32
|
+
results: List[RerankResult] = Field(..., description="Ranked results")
|
33
|
+
used_model: str = Field(..., description="Model used for ranking")
|