airtrain 0.1.53__py3-none-any.whl → 0.1.58__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airtrain/__init__.py +61 -2
- airtrain/__pycache__/__init__.cpython-313.pyc +0 -0
- airtrain/agents/__init__.py +45 -0
- airtrain/agents/example_agent.py +348 -0
- airtrain/agents/groq_agent.py +289 -0
- airtrain/agents/memory.py +663 -0
- airtrain/agents/registry.py +465 -0
- airtrain/core/__pycache__/__init__.cpython-313.pyc +0 -0
- airtrain/core/__pycache__/schemas.cpython-313.pyc +0 -0
- airtrain/core/__pycache__/skills.cpython-313.pyc +0 -0
- airtrain/core/skills.py +102 -0
- airtrain/integrations/combined/list_models_factory.py +9 -3
- airtrain/integrations/groq/__init__.py +18 -1
- airtrain/integrations/groq/models_config.py +162 -0
- airtrain/integrations/groq/skills.py +93 -17
- airtrain/integrations/together/__init__.py +15 -1
- airtrain/integrations/together/models_config.py +123 -1
- airtrain/integrations/together/skills.py +117 -20
- airtrain/telemetry/__init__.py +38 -0
- airtrain/telemetry/service.py +167 -0
- airtrain/telemetry/views.py +237 -0
- airtrain/tools/__init__.py +41 -0
- airtrain/tools/command.py +211 -0
- airtrain/tools/filesystem.py +166 -0
- airtrain/tools/network.py +111 -0
- airtrain/tools/registry.py +320 -0
- {airtrain-0.1.53.dist-info → airtrain-0.1.58.dist-info}/METADATA +37 -1
- {airtrain-0.1.53.dist-info → airtrain-0.1.58.dist-info}/RECORD +31 -13
- {airtrain-0.1.53.dist-info → airtrain-0.1.58.dist-info}/WHEEL +1 -1
- {airtrain-0.1.53.dist-info → airtrain-0.1.58.dist-info}/entry_points.txt +0 -0
- {airtrain-0.1.53.dist-info → airtrain-0.1.58.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,162 @@
|
|
1
|
+
"""Configuration of Groq model capabilities."""
|
2
|
+
|
3
|
+
from typing import Dict, Any
|
4
|
+
|
5
|
+
|
6
|
+
# Model configuration with capabilities for each model
|
7
|
+
GROQ_MODELS_CONFIG = {
|
8
|
+
"llama-3.3-70b-versatile": {
|
9
|
+
"name": "Llama 3.3 70B Versatile",
|
10
|
+
"context_window": 128000,
|
11
|
+
"max_completion_tokens": 32768,
|
12
|
+
"tool_use": True,
|
13
|
+
"parallel_tool_use": True,
|
14
|
+
"json_mode": True,
|
15
|
+
},
|
16
|
+
"llama-3.1-8b-instant": {
|
17
|
+
"name": "Llama 3.1 8B Instant",
|
18
|
+
"context_window": 128000,
|
19
|
+
"max_completion_tokens": 8192,
|
20
|
+
"tool_use": True,
|
21
|
+
"parallel_tool_use": True,
|
22
|
+
"json_mode": True,
|
23
|
+
},
|
24
|
+
"mixtral-8x7b-32768": {
|
25
|
+
"name": "Mixtral 8x7B (32K)",
|
26
|
+
"context_window": 32768,
|
27
|
+
"max_completion_tokens": 8192,
|
28
|
+
"tool_use": True,
|
29
|
+
"parallel_tool_use": False,
|
30
|
+
"json_mode": True,
|
31
|
+
},
|
32
|
+
"gemma2-9b-it": {
|
33
|
+
"name": "Gemma 2 9B IT",
|
34
|
+
"context_window": 8192,
|
35
|
+
"max_completion_tokens": 4096,
|
36
|
+
"tool_use": True,
|
37
|
+
"parallel_tool_use": False,
|
38
|
+
"json_mode": True,
|
39
|
+
},
|
40
|
+
"qwen-qwq-32b": {
|
41
|
+
"name": "Qwen QWQ 32B",
|
42
|
+
"context_window": 128000,
|
43
|
+
"max_completion_tokens": 16384,
|
44
|
+
"tool_use": True,
|
45
|
+
"parallel_tool_use": True,
|
46
|
+
"json_mode": True,
|
47
|
+
},
|
48
|
+
"qwen-2.5-coder-32b": {
|
49
|
+
"name": "Qwen 2.5 Coder 32B",
|
50
|
+
"context_window": 128000,
|
51
|
+
"max_completion_tokens": 16384,
|
52
|
+
"tool_use": True,
|
53
|
+
"parallel_tool_use": True,
|
54
|
+
"json_mode": True,
|
55
|
+
},
|
56
|
+
"qwen-2.5-32b": {
|
57
|
+
"name": "Qwen 2.5 32B",
|
58
|
+
"context_window": 128000,
|
59
|
+
"max_completion_tokens": 16384,
|
60
|
+
"tool_use": True,
|
61
|
+
"parallel_tool_use": True,
|
62
|
+
"json_mode": True,
|
63
|
+
},
|
64
|
+
"deepseek-r1-distill-qwen-32b": {
|
65
|
+
"name": "DeepSeek R1 Distill Qwen 32B",
|
66
|
+
"context_window": 128000,
|
67
|
+
"max_completion_tokens": 16384,
|
68
|
+
"tool_use": True,
|
69
|
+
"parallel_tool_use": True,
|
70
|
+
"json_mode": True,
|
71
|
+
},
|
72
|
+
"deepseek-r1-distill-llama-70b": {
|
73
|
+
"name": "DeepSeek R1 Distill Llama 70B",
|
74
|
+
"context_window": 128000,
|
75
|
+
"max_completion_tokens": 16384,
|
76
|
+
"tool_use": True,
|
77
|
+
"parallel_tool_use": True,
|
78
|
+
"json_mode": True,
|
79
|
+
},
|
80
|
+
"deepseek-r1-distill-llama-70b-specdec": {
|
81
|
+
"name": "DeepSeek R1 Distill Llama 70B SpecDec",
|
82
|
+
"context_window": 128000,
|
83
|
+
"max_completion_tokens": 16384,
|
84
|
+
"tool_use": False,
|
85
|
+
"parallel_tool_use": False,
|
86
|
+
"json_mode": False,
|
87
|
+
},
|
88
|
+
"llama3-70b-8192": {
|
89
|
+
"name": "Llama 3 70B (8K)",
|
90
|
+
"context_window": 8192,
|
91
|
+
"max_completion_tokens": 4096,
|
92
|
+
"tool_use": False,
|
93
|
+
"parallel_tool_use": False,
|
94
|
+
"json_mode": False,
|
95
|
+
},
|
96
|
+
"llama3-8b-8192": {
|
97
|
+
"name": "Llama 3 8B (8K)",
|
98
|
+
"context_window": 8192,
|
99
|
+
"max_completion_tokens": 4096,
|
100
|
+
"tool_use": False,
|
101
|
+
"parallel_tool_use": False,
|
102
|
+
"json_mode": False,
|
103
|
+
},
|
104
|
+
}
|
105
|
+
|
106
|
+
|
107
|
+
def get_model_config(model_id: str) -> Dict[str, Any]:
|
108
|
+
"""
|
109
|
+
Get the configuration for a specific model.
|
110
|
+
|
111
|
+
Args:
|
112
|
+
model_id: The model ID to get configuration for
|
113
|
+
|
114
|
+
Returns:
|
115
|
+
Dict with model configuration
|
116
|
+
|
117
|
+
Raises:
|
118
|
+
ValueError: If model_id is not found in configuration
|
119
|
+
"""
|
120
|
+
if model_id in GROQ_MODELS_CONFIG:
|
121
|
+
return GROQ_MODELS_CONFIG[model_id]
|
122
|
+
|
123
|
+
# Try to find a match with different format or case
|
124
|
+
normalized_id = model_id.lower().replace("-", "").replace("_", "")
|
125
|
+
for config_id, config in GROQ_MODELS_CONFIG.items():
|
126
|
+
if normalized_id == config_id.lower().replace("-", "").replace("_", ""):
|
127
|
+
return config
|
128
|
+
|
129
|
+
# Default configuration for unknown models
|
130
|
+
return {
|
131
|
+
"name": model_id,
|
132
|
+
"context_window": 4096, # Conservative default
|
133
|
+
"max_completion_tokens": 1024, # Conservative default
|
134
|
+
"tool_use": False,
|
135
|
+
"parallel_tool_use": False,
|
136
|
+
"json_mode": False,
|
137
|
+
}
|
138
|
+
|
139
|
+
|
140
|
+
def get_default_model() -> str:
|
141
|
+
"""Get the default model ID for Groq."""
|
142
|
+
return "llama-3.3-70b-versatile"
|
143
|
+
|
144
|
+
|
145
|
+
def supports_tool_use(model_id: str) -> bool:
|
146
|
+
"""Check if a model supports tool use."""
|
147
|
+
return get_model_config(model_id).get("tool_use", False)
|
148
|
+
|
149
|
+
|
150
|
+
def supports_parallel_tool_use(model_id: str) -> bool:
|
151
|
+
"""Check if a model supports parallel tool use."""
|
152
|
+
return get_model_config(model_id).get("parallel_tool_use", False)
|
153
|
+
|
154
|
+
|
155
|
+
def supports_json_mode(model_id: str) -> bool:
|
156
|
+
"""Check if a model supports JSON mode."""
|
157
|
+
return get_model_config(model_id).get("json_mode", False)
|
158
|
+
|
159
|
+
|
160
|
+
def get_max_completion_tokens(model_id: str) -> int:
|
161
|
+
"""Get the maximum number of completion tokens for a model."""
|
162
|
+
return get_model_config(model_id).get("max_completion_tokens", 1024)
|
@@ -1,8 +1,9 @@
|
|
1
|
-
from typing import Generator, Optional, Dict, Any, List
|
2
|
-
from pydantic import Field
|
1
|
+
from typing import Generator, Optional, Dict, Any, List, Union
|
2
|
+
from pydantic import Field, validator
|
3
3
|
from airtrain.core.skills import Skill, ProcessingError
|
4
4
|
from airtrain.core.schemas import InputSchema, OutputSchema
|
5
5
|
from .credentials import GroqCredentials
|
6
|
+
from .models_config import get_max_completion_tokens, get_model_config
|
6
7
|
from groq import Groq
|
7
8
|
|
8
9
|
|
@@ -12,22 +13,59 @@ class GroqInput(InputSchema):
|
|
12
13
|
user_input: str = Field(..., description="User's input text")
|
13
14
|
system_prompt: str = Field(
|
14
15
|
default="You are a helpful assistant.",
|
15
|
-
description=
|
16
|
+
description=(
|
17
|
+
"System prompt to guide the model's behavior"
|
18
|
+
),
|
16
19
|
)
|
17
20
|
conversation_history: List[Dict[str, str]] = Field(
|
18
21
|
default_factory=list,
|
19
|
-
description=
|
22
|
+
description=(
|
23
|
+
"List of previous conversation messages in "
|
24
|
+
"[{'role': 'user|assistant', 'content': 'message'}] format"
|
25
|
+
),
|
20
26
|
)
|
21
27
|
model: str = Field(
|
22
|
-
default="
|
28
|
+
default="llama-3.3-70b-versatile",
|
29
|
+
description="Groq model to use"
|
30
|
+
)
|
31
|
+
max_tokens: int = Field(
|
32
|
+
default=4096,
|
33
|
+
description="Maximum tokens in response"
|
23
34
|
)
|
24
|
-
max_tokens: int = Field(default=131072, description="Maximum tokens in response")
|
25
35
|
temperature: float = Field(
|
26
|
-
default=0.7,
|
36
|
+
default=0.7,
|
37
|
+
description="Temperature for response generation",
|
38
|
+
ge=0,
|
39
|
+
le=1
|
27
40
|
)
|
28
41
|
stream: bool = Field(
|
29
|
-
default=False,
|
42
|
+
default=False,
|
43
|
+
description="Whether to stream the response progressively"
|
44
|
+
)
|
45
|
+
tools: Optional[List[Dict[str, Any]]] = Field(
|
46
|
+
default=None,
|
47
|
+
description=(
|
48
|
+
"A list of tools the model may use. "
|
49
|
+
"Currently only functions supported."
|
50
|
+
),
|
51
|
+
)
|
52
|
+
tool_choice: Optional[Union[str, Dict[str, Any]]] = Field(
|
53
|
+
default=None,
|
54
|
+
description=(
|
55
|
+
"Controls which tool is called by the model. "
|
56
|
+
"'none', 'auto', or specific tool."
|
57
|
+
),
|
30
58
|
)
|
59
|
+
|
60
|
+
@validator('max_tokens')
|
61
|
+
def validate_max_tokens(cls, v, values):
|
62
|
+
"""Validate that max_tokens doesn't exceed the model's limit."""
|
63
|
+
if 'model' in values:
|
64
|
+
model_id = values['model']
|
65
|
+
max_limit = get_max_completion_tokens(model_id)
|
66
|
+
if v > max_limit:
|
67
|
+
return max_limit
|
68
|
+
return v
|
31
69
|
|
32
70
|
|
33
71
|
class GroqOutput(OutputSchema):
|
@@ -38,6 +76,9 @@ class GroqOutput(OutputSchema):
|
|
38
76
|
usage: Dict[str, Any] = Field(
|
39
77
|
default_factory=dict, description="Usage statistics from the API"
|
40
78
|
)
|
79
|
+
tool_calls: Optional[List[Dict[str, Any]]] = Field(
|
80
|
+
default=None, description="Tool calls generated by the model"
|
81
|
+
)
|
41
82
|
|
42
83
|
|
43
84
|
class GroqChatSkill(Skill[GroqInput, GroqOutput]):
|
@@ -101,24 +142,59 @@ class GroqChatSkill(Skill[GroqInput, GroqOutput]):
|
|
101
142
|
response_chunks.append(chunk)
|
102
143
|
response = "".join(response_chunks)
|
103
144
|
usage = {} # Usage stats not available in streaming
|
145
|
+
tool_calls = None # Tool calls not available in streaming
|
104
146
|
else:
|
105
147
|
messages = self._build_messages(input_data)
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
148
|
+
|
149
|
+
# Prepare API call parameters
|
150
|
+
api_params = {
|
151
|
+
"model": input_data.model,
|
152
|
+
"messages": messages,
|
153
|
+
"temperature": input_data.temperature,
|
154
|
+
"max_tokens": input_data.max_tokens,
|
155
|
+
"stream": False,
|
156
|
+
}
|
157
|
+
|
158
|
+
# Add tools and tool_choice if provided
|
159
|
+
if input_data.tools:
|
160
|
+
api_params["tools"] = input_data.tools
|
161
|
+
|
162
|
+
if input_data.tool_choice:
|
163
|
+
api_params["tool_choice"] = input_data.tool_choice
|
164
|
+
|
165
|
+
completion = self.client.chat.completions.create(**api_params)
|
166
|
+
response = completion.choices[0].message.content or ""
|
167
|
+
|
168
|
+
# Extract usage information
|
114
169
|
usage = {
|
115
170
|
"total_tokens": completion.usage.total_tokens,
|
116
171
|
"prompt_tokens": completion.usage.prompt_tokens,
|
117
172
|
"completion_tokens": completion.usage.completion_tokens,
|
118
173
|
}
|
174
|
+
|
175
|
+
# Check for tool calls in the response
|
176
|
+
tool_calls = None
|
177
|
+
if (
|
178
|
+
hasattr(completion.choices[0].message, "tool_calls")
|
179
|
+
and completion.choices[0].message.tool_calls
|
180
|
+
):
|
181
|
+
tool_calls = [
|
182
|
+
{
|
183
|
+
"id": tool_call.id,
|
184
|
+
"type": tool_call.type,
|
185
|
+
"function": {
|
186
|
+
"name": tool_call.function.name,
|
187
|
+
"arguments": tool_call.function.arguments
|
188
|
+
}
|
189
|
+
}
|
190
|
+
for tool_call in completion.choices[0].message.tool_calls
|
191
|
+
]
|
119
192
|
|
120
193
|
return GroqOutput(
|
121
|
-
response=response,
|
194
|
+
response=response,
|
195
|
+
used_model=input_data.model,
|
196
|
+
usage=usage,
|
197
|
+
tool_calls=tool_calls
|
122
198
|
)
|
123
199
|
|
124
200
|
except Exception as e:
|
@@ -1,7 +1,14 @@
|
|
1
1
|
"""Together AI integration module"""
|
2
2
|
|
3
3
|
from .credentials import TogetherAICredentials
|
4
|
-
from .skills import TogetherAIChatSkill
|
4
|
+
from .skills import TogetherAIChatSkill, TogetherAIInput, TogetherAIOutput
|
5
|
+
from .models_config import (
|
6
|
+
get_model_config_with_capabilities,
|
7
|
+
get_max_completion_tokens,
|
8
|
+
supports_tool_use,
|
9
|
+
supports_json_mode,
|
10
|
+
TOGETHER_MODELS_CONFIG,
|
11
|
+
)
|
5
12
|
from .list_models import (
|
6
13
|
TogetherListModelsSkill,
|
7
14
|
TogetherListModelsInput,
|
@@ -12,8 +19,15 @@ from .models import TogetherModel
|
|
12
19
|
__all__ = [
|
13
20
|
"TogetherAICredentials",
|
14
21
|
"TogetherAIChatSkill",
|
22
|
+
"TogetherAIInput",
|
23
|
+
"TogetherAIOutput",
|
15
24
|
"TogetherListModelsSkill",
|
16
25
|
"TogetherListModelsInput",
|
17
26
|
"TogetherListModelsOutput",
|
18
27
|
"TogetherModel",
|
28
|
+
"get_model_config_with_capabilities",
|
29
|
+
"get_max_completion_tokens",
|
30
|
+
"supports_tool_use",
|
31
|
+
"supports_json_mode",
|
32
|
+
"TOGETHER_MODELS_CONFIG",
|
19
33
|
]
|
@@ -1,4 +1,4 @@
|
|
1
|
-
from typing import Dict, NamedTuple
|
1
|
+
from typing import Dict, NamedTuple, Any
|
2
2
|
|
3
3
|
|
4
4
|
class ModelConfig(NamedTuple):
|
@@ -272,6 +272,128 @@ def get_default_model() -> str:
|
|
272
272
|
return "meta-llama/Llama-3.3-70B-Instruct-Turbo"
|
273
273
|
|
274
274
|
|
275
|
+
# Model configuration with capabilities for each model
|
276
|
+
TOGETHER_MODELS_CONFIG = {
|
277
|
+
"meta-llama/Llama-3.1-8B-Instruct": {
|
278
|
+
"name": "Llama 3.1 8B Instruct",
|
279
|
+
"context_window": 128000,
|
280
|
+
"max_completion_tokens": 8192,
|
281
|
+
"tool_use": True,
|
282
|
+
"json_mode": True,
|
283
|
+
},
|
284
|
+
"meta-llama/Llama-3.1-70B-Instruct": {
|
285
|
+
"name": "Llama 3.1 70B Instruct",
|
286
|
+
"context_window": 128000,
|
287
|
+
"max_completion_tokens": 32768,
|
288
|
+
"tool_use": True,
|
289
|
+
"json_mode": True,
|
290
|
+
},
|
291
|
+
"mistralai/Mixtral-8x7B-Instruct-v0.1": {
|
292
|
+
"name": "Mixtral 8x7B Instruct v0.1",
|
293
|
+
"context_window": 32768,
|
294
|
+
"max_completion_tokens": 8192,
|
295
|
+
"tool_use": True,
|
296
|
+
"json_mode": True,
|
297
|
+
},
|
298
|
+
"meta-llama/Meta-Llama-3-8B-Instruct": {
|
299
|
+
"name": "Meta Llama 3 8B Instruct",
|
300
|
+
"context_window": 8192,
|
301
|
+
"max_completion_tokens": 4096,
|
302
|
+
"tool_use": True,
|
303
|
+
"json_mode": True,
|
304
|
+
},
|
305
|
+
"meta-llama/Meta-Llama-3-70B-Instruct": {
|
306
|
+
"name": "Meta Llama 3 70B Instruct",
|
307
|
+
"context_window": 8192,
|
308
|
+
"max_completion_tokens": 4096,
|
309
|
+
"tool_use": True,
|
310
|
+
"json_mode": True,
|
311
|
+
},
|
312
|
+
"deepseek-ai/DeepSeek-Coder-V2": {
|
313
|
+
"name": "DeepSeek Coder V2",
|
314
|
+
"context_window": 128000,
|
315
|
+
"max_completion_tokens": 16384,
|
316
|
+
"tool_use": True,
|
317
|
+
"json_mode": True,
|
318
|
+
},
|
319
|
+
"deepseek-ai/DeepSeek-V2": {
|
320
|
+
"name": "DeepSeek V2",
|
321
|
+
"context_window": 128000,
|
322
|
+
"max_completion_tokens": 16384,
|
323
|
+
"tool_use": True,
|
324
|
+
"json_mode": True,
|
325
|
+
},
|
326
|
+
"deepseek-ai/DeepSeek-R1": {
|
327
|
+
"name": "DeepSeek R1",
|
328
|
+
"context_window": 32768,
|
329
|
+
"max_completion_tokens": 8192,
|
330
|
+
"tool_use": False,
|
331
|
+
"json_mode": False,
|
332
|
+
},
|
333
|
+
# Qwen models
|
334
|
+
"Qwen/Qwen2.5-72B-Instruct-Turbo": {
|
335
|
+
"context_window": 128000,
|
336
|
+
"max_completion_tokens": 4096,
|
337
|
+
"tool_use": True,
|
338
|
+
"json_mode": True,
|
339
|
+
},
|
340
|
+
"Qwen/Qwen2.5-7B-Instruct": {
|
341
|
+
"context_window": 32768,
|
342
|
+
"max_completion_tokens": 4096,
|
343
|
+
"tool_use": True,
|
344
|
+
"json_mode": True,
|
345
|
+
},
|
346
|
+
}
|
347
|
+
|
348
|
+
|
349
|
+
def get_model_config_with_capabilities(model_id: str) -> Dict[str, Any]:
|
350
|
+
"""
|
351
|
+
Get the configuration for a specific model.
|
352
|
+
|
353
|
+
Args:
|
354
|
+
model_id: The model ID to get configuration for
|
355
|
+
|
356
|
+
Returns:
|
357
|
+
Dict with model configuration
|
358
|
+
|
359
|
+
Raises:
|
360
|
+
ValueError: If model_id is not found in configuration
|
361
|
+
"""
|
362
|
+
if model_id in TOGETHER_MODELS_CONFIG:
|
363
|
+
return TOGETHER_MODELS_CONFIG[model_id]
|
364
|
+
|
365
|
+
# Try to find a match with different format or case
|
366
|
+
normalized_id = model_id.lower().replace("-", "").replace("_", "").replace("/", "")
|
367
|
+
for config_id, config in TOGETHER_MODELS_CONFIG.items():
|
368
|
+
norm_config_id = config_id.lower().replace("-", "").replace("_", "").replace("/", "")
|
369
|
+
if normalized_id == norm_config_id:
|
370
|
+
return config
|
371
|
+
|
372
|
+
# Default configuration for unknown models
|
373
|
+
return {
|
374
|
+
"name": model_id,
|
375
|
+
"context_window": 4096, # Conservative default
|
376
|
+
"max_completion_tokens": 1024, # Conservative default
|
377
|
+
"tool_use": False,
|
378
|
+
"json_mode": False,
|
379
|
+
}
|
380
|
+
|
381
|
+
|
382
|
+
def supports_tool_use(model_id: str) -> bool:
|
383
|
+
"""Check if a model supports tool use."""
|
384
|
+
return get_model_config_with_capabilities(model_id).get("tool_use", False)
|
385
|
+
|
386
|
+
|
387
|
+
def supports_json_mode(model_id: str) -> bool:
|
388
|
+
"""Check if a model supports JSON mode."""
|
389
|
+
return get_model_config_with_capabilities(model_id).get("json_mode", False)
|
390
|
+
|
391
|
+
|
392
|
+
def get_max_completion_tokens(model_id: str) -> int:
|
393
|
+
"""Get the maximum number of completion tokens for a model."""
|
394
|
+
return get_model_config_with_capabilities(model_id).get("max_completion_tokens", 1024)
|
395
|
+
|
396
|
+
|
275
397
|
if __name__ == "__main__":
|
276
398
|
print(len(TOGETHER_MODELS))
|
277
399
|
print(get_model_config("meta-llama/Llama-3.3-70B-Instruct-Turbo"))
|
@@ -1,9 +1,10 @@
|
|
1
|
-
from typing import Optional, Dict, Any, List, Generator
|
2
|
-
from pydantic import Field
|
1
|
+
from typing import Optional, Dict, Any, List, Generator, Union
|
2
|
+
from pydantic import Field, validator
|
3
3
|
from airtrain.core.skills import Skill, ProcessingError
|
4
4
|
from airtrain.core.schemas import InputSchema, OutputSchema
|
5
5
|
from .credentials import TogetherAICredentials
|
6
6
|
from .models import TogetherAIImageInput, TogetherAIImageOutput, GeneratedImage
|
7
|
+
from .models_config import get_max_completion_tokens
|
7
8
|
from pathlib import Path
|
8
9
|
import base64
|
9
10
|
import time
|
@@ -20,16 +21,53 @@ class TogetherAIInput(InputSchema):
|
|
20
21
|
)
|
21
22
|
conversation_history: List[Dict[str, str]] = Field(
|
22
23
|
default_factory=list,
|
23
|
-
description=
|
24
|
+
description=(
|
25
|
+
"List of previous conversation messages in "
|
26
|
+
"[{'role': 'user|assistant', 'content': 'message'}] format"
|
27
|
+
),
|
24
28
|
)
|
25
29
|
model: str = Field(
|
26
|
-
default="
|
30
|
+
default="meta-llama/Llama-3.1-8B-Instruct",
|
31
|
+
description="Together AI model to use"
|
32
|
+
)
|
33
|
+
max_tokens: int = Field(
|
34
|
+
default=4096,
|
35
|
+
description="Maximum tokens in response"
|
27
36
|
)
|
28
|
-
max_tokens: int = Field(default=1024, description="Maximum tokens in response")
|
29
37
|
temperature: float = Field(
|
30
|
-
default=0.7,
|
38
|
+
default=0.7,
|
39
|
+
description="Temperature for response generation",
|
40
|
+
ge=0,
|
41
|
+
le=1
|
42
|
+
)
|
43
|
+
stream: bool = Field(
|
44
|
+
default=False,
|
45
|
+
description="Whether to stream the response"
|
46
|
+
)
|
47
|
+
tools: Optional[List[Dict[str, Any]]] = Field(
|
48
|
+
default=None,
|
49
|
+
description=(
|
50
|
+
"A list of tools the model may use. "
|
51
|
+
"Currently only functions supported."
|
52
|
+
),
|
31
53
|
)
|
32
|
-
|
54
|
+
tool_choice: Optional[Union[str, Dict[str, Any]]] = Field(
|
55
|
+
default=None,
|
56
|
+
description=(
|
57
|
+
"Controls which tool is called by the model. "
|
58
|
+
"'none', 'auto', or specific tool."
|
59
|
+
),
|
60
|
+
)
|
61
|
+
|
62
|
+
@validator('max_tokens')
|
63
|
+
def validate_max_tokens(cls, v, values):
|
64
|
+
"""Validate that max_tokens doesn't exceed the model's limit."""
|
65
|
+
if 'model' in values:
|
66
|
+
model_id = values['model']
|
67
|
+
max_limit = get_max_completion_tokens(model_id)
|
68
|
+
if v > max_limit:
|
69
|
+
return max_limit
|
70
|
+
return v
|
33
71
|
|
34
72
|
|
35
73
|
class TogetherAIOutput(OutputSchema):
|
@@ -38,6 +76,9 @@ class TogetherAIOutput(OutputSchema):
|
|
38
76
|
response: str = Field(..., description="Model's response text")
|
39
77
|
used_model: str = Field(..., description="Model used for generation")
|
40
78
|
usage: Dict[str, Any] = Field(default_factory=dict, description="Usage statistics")
|
79
|
+
tool_calls: Optional[List[Dict[str, Any]]] = Field(
|
80
|
+
default=None, description="Tool calls generated by the model"
|
81
|
+
)
|
41
82
|
|
42
83
|
|
43
84
|
class TogetherAIChatSkill(Skill[TogetherAIInput, TogetherAIOutput]):
|
@@ -102,27 +143,83 @@ class TogetherAIChatSkill(Skill[TogetherAIInput, TogetherAIOutput]):
|
|
102
143
|
for chunk in self.process_stream(input_data):
|
103
144
|
response_chunks.append(chunk)
|
104
145
|
response = "".join(response_chunks)
|
146
|
+
usage = {} # Usage stats not available in streaming
|
147
|
+
tool_calls = None # Tool calls not available in streaming
|
105
148
|
else:
|
106
149
|
messages = self._build_messages(input_data)
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
150
|
+
|
151
|
+
# Prepare API call parameters
|
152
|
+
api_params = {
|
153
|
+
"model": input_data.model,
|
154
|
+
"messages": messages,
|
155
|
+
"temperature": input_data.temperature,
|
156
|
+
"max_tokens": input_data.max_tokens,
|
157
|
+
"stream": False,
|
158
|
+
}
|
159
|
+
|
160
|
+
# Add tools and tool_choice if provided
|
161
|
+
if input_data.tools:
|
162
|
+
api_params["tools"] = input_data.tools
|
163
|
+
|
164
|
+
if input_data.tool_choice:
|
165
|
+
api_params["tool_choice"] = input_data.tool_choice
|
166
|
+
|
167
|
+
try:
|
168
|
+
completion = self.client.chat.completions.create(**api_params)
|
169
|
+
response = completion.choices[0].message.content or ""
|
170
|
+
|
171
|
+
# Extract usage information
|
172
|
+
usage = (
|
173
|
+
completion.usage.model_dump()
|
174
|
+
if hasattr(completion, "usage")
|
175
|
+
else {}
|
176
|
+
)
|
177
|
+
|
178
|
+
# Check for tool calls in the response
|
179
|
+
tool_calls = None
|
180
|
+
if (
|
181
|
+
hasattr(completion.choices[0].message, "tool_calls")
|
182
|
+
and completion.choices[0].message.tool_calls
|
183
|
+
):
|
184
|
+
tool_calls = [
|
185
|
+
{
|
186
|
+
"id": tool_call.id,
|
187
|
+
"type": tool_call.type,
|
188
|
+
"function": {
|
189
|
+
"name": tool_call.function.name,
|
190
|
+
"arguments": tool_call.function.arguments
|
191
|
+
}
|
192
|
+
}
|
193
|
+
for tool_call in completion.choices[0].message.tool_calls
|
194
|
+
]
|
195
|
+
except Exception as api_error:
|
196
|
+
# Provide more specific error messages for common tool call issues
|
197
|
+
error_message = str(api_error)
|
198
|
+
if "tool_call_id" in error_message:
|
199
|
+
raise ProcessingError(
|
200
|
+
"Tool call error: Missing tool_call_id in conversation history. "
|
201
|
+
"When adding tool responses to conversation history, "
|
202
|
+
"make sure each message with role='tool' includes a 'tool_call_id' field."
|
203
|
+
)
|
204
|
+
elif "messages" in error_message and "tool" in error_message:
|
205
|
+
raise ProcessingError(
|
206
|
+
"Tool call error: Invalid message format in conversation history. "
|
207
|
+
f"Original error: {error_message}"
|
208
|
+
)
|
209
|
+
else:
|
210
|
+
# Re-raise with original error message
|
211
|
+
raise ProcessingError(f"Together AI API error: {error_message}")
|
115
212
|
|
116
213
|
return TogetherAIOutput(
|
117
214
|
response=response,
|
118
215
|
used_model=input_data.model,
|
119
|
-
usage=
|
120
|
-
|
121
|
-
if hasattr(completion, "usage")
|
122
|
-
else {}
|
123
|
-
),
|
216
|
+
usage=usage,
|
217
|
+
tool_calls=tool_calls
|
124
218
|
)
|
125
219
|
|
220
|
+
except ProcessingError:
|
221
|
+
# Re-raise ProcessingError without modification
|
222
|
+
raise
|
126
223
|
except Exception as e:
|
127
224
|
raise ProcessingError(f"Together AI processing failed: {str(e)}")
|
128
225
|
|