airtrain 0.1.53__py3-none-any.whl → 0.1.58__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. airtrain/__init__.py +61 -2
  2. airtrain/__pycache__/__init__.cpython-313.pyc +0 -0
  3. airtrain/agents/__init__.py +45 -0
  4. airtrain/agents/example_agent.py +348 -0
  5. airtrain/agents/groq_agent.py +289 -0
  6. airtrain/agents/memory.py +663 -0
  7. airtrain/agents/registry.py +465 -0
  8. airtrain/core/__pycache__/__init__.cpython-313.pyc +0 -0
  9. airtrain/core/__pycache__/schemas.cpython-313.pyc +0 -0
  10. airtrain/core/__pycache__/skills.cpython-313.pyc +0 -0
  11. airtrain/core/skills.py +102 -0
  12. airtrain/integrations/combined/list_models_factory.py +9 -3
  13. airtrain/integrations/groq/__init__.py +18 -1
  14. airtrain/integrations/groq/models_config.py +162 -0
  15. airtrain/integrations/groq/skills.py +93 -17
  16. airtrain/integrations/together/__init__.py +15 -1
  17. airtrain/integrations/together/models_config.py +123 -1
  18. airtrain/integrations/together/skills.py +117 -20
  19. airtrain/telemetry/__init__.py +38 -0
  20. airtrain/telemetry/service.py +167 -0
  21. airtrain/telemetry/views.py +237 -0
  22. airtrain/tools/__init__.py +41 -0
  23. airtrain/tools/command.py +211 -0
  24. airtrain/tools/filesystem.py +166 -0
  25. airtrain/tools/network.py +111 -0
  26. airtrain/tools/registry.py +320 -0
  27. {airtrain-0.1.53.dist-info → airtrain-0.1.58.dist-info}/METADATA +37 -1
  28. {airtrain-0.1.53.dist-info → airtrain-0.1.58.dist-info}/RECORD +31 -13
  29. {airtrain-0.1.53.dist-info → airtrain-0.1.58.dist-info}/WHEEL +1 -1
  30. {airtrain-0.1.53.dist-info → airtrain-0.1.58.dist-info}/entry_points.txt +0 -0
  31. {airtrain-0.1.53.dist-info → airtrain-0.1.58.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,162 @@
1
+ """Configuration of Groq model capabilities."""
2
+
3
+ from typing import Dict, Any
4
+
5
+
6
+ # Model configuration with capabilities for each model
7
+ GROQ_MODELS_CONFIG = {
8
+ "llama-3.3-70b-versatile": {
9
+ "name": "Llama 3.3 70B Versatile",
10
+ "context_window": 128000,
11
+ "max_completion_tokens": 32768,
12
+ "tool_use": True,
13
+ "parallel_tool_use": True,
14
+ "json_mode": True,
15
+ },
16
+ "llama-3.1-8b-instant": {
17
+ "name": "Llama 3.1 8B Instant",
18
+ "context_window": 128000,
19
+ "max_completion_tokens": 8192,
20
+ "tool_use": True,
21
+ "parallel_tool_use": True,
22
+ "json_mode": True,
23
+ },
24
+ "mixtral-8x7b-32768": {
25
+ "name": "Mixtral 8x7B (32K)",
26
+ "context_window": 32768,
27
+ "max_completion_tokens": 8192,
28
+ "tool_use": True,
29
+ "parallel_tool_use": False,
30
+ "json_mode": True,
31
+ },
32
+ "gemma2-9b-it": {
33
+ "name": "Gemma 2 9B IT",
34
+ "context_window": 8192,
35
+ "max_completion_tokens": 4096,
36
+ "tool_use": True,
37
+ "parallel_tool_use": False,
38
+ "json_mode": True,
39
+ },
40
+ "qwen-qwq-32b": {
41
+ "name": "Qwen QWQ 32B",
42
+ "context_window": 128000,
43
+ "max_completion_tokens": 16384,
44
+ "tool_use": True,
45
+ "parallel_tool_use": True,
46
+ "json_mode": True,
47
+ },
48
+ "qwen-2.5-coder-32b": {
49
+ "name": "Qwen 2.5 Coder 32B",
50
+ "context_window": 128000,
51
+ "max_completion_tokens": 16384,
52
+ "tool_use": True,
53
+ "parallel_tool_use": True,
54
+ "json_mode": True,
55
+ },
56
+ "qwen-2.5-32b": {
57
+ "name": "Qwen 2.5 32B",
58
+ "context_window": 128000,
59
+ "max_completion_tokens": 16384,
60
+ "tool_use": True,
61
+ "parallel_tool_use": True,
62
+ "json_mode": True,
63
+ },
64
+ "deepseek-r1-distill-qwen-32b": {
65
+ "name": "DeepSeek R1 Distill Qwen 32B",
66
+ "context_window": 128000,
67
+ "max_completion_tokens": 16384,
68
+ "tool_use": True,
69
+ "parallel_tool_use": True,
70
+ "json_mode": True,
71
+ },
72
+ "deepseek-r1-distill-llama-70b": {
73
+ "name": "DeepSeek R1 Distill Llama 70B",
74
+ "context_window": 128000,
75
+ "max_completion_tokens": 16384,
76
+ "tool_use": True,
77
+ "parallel_tool_use": True,
78
+ "json_mode": True,
79
+ },
80
+ "deepseek-r1-distill-llama-70b-specdec": {
81
+ "name": "DeepSeek R1 Distill Llama 70B SpecDec",
82
+ "context_window": 128000,
83
+ "max_completion_tokens": 16384,
84
+ "tool_use": False,
85
+ "parallel_tool_use": False,
86
+ "json_mode": False,
87
+ },
88
+ "llama3-70b-8192": {
89
+ "name": "Llama 3 70B (8K)",
90
+ "context_window": 8192,
91
+ "max_completion_tokens": 4096,
92
+ "tool_use": False,
93
+ "parallel_tool_use": False,
94
+ "json_mode": False,
95
+ },
96
+ "llama3-8b-8192": {
97
+ "name": "Llama 3 8B (8K)",
98
+ "context_window": 8192,
99
+ "max_completion_tokens": 4096,
100
+ "tool_use": False,
101
+ "parallel_tool_use": False,
102
+ "json_mode": False,
103
+ },
104
+ }
105
+
106
+
107
+ def get_model_config(model_id: str) -> Dict[str, Any]:
108
+ """
109
+ Get the configuration for a specific model.
110
+
111
+ Args:
112
+ model_id: The model ID to get configuration for
113
+
114
+ Returns:
115
+ Dict with model configuration
116
+
117
+ Raises:
118
+ ValueError: If model_id is not found in configuration
119
+ """
120
+ if model_id in GROQ_MODELS_CONFIG:
121
+ return GROQ_MODELS_CONFIG[model_id]
122
+
123
+ # Try to find a match with different format or case
124
+ normalized_id = model_id.lower().replace("-", "").replace("_", "")
125
+ for config_id, config in GROQ_MODELS_CONFIG.items():
126
+ if normalized_id == config_id.lower().replace("-", "").replace("_", ""):
127
+ return config
128
+
129
+ # Default configuration for unknown models
130
+ return {
131
+ "name": model_id,
132
+ "context_window": 4096, # Conservative default
133
+ "max_completion_tokens": 1024, # Conservative default
134
+ "tool_use": False,
135
+ "parallel_tool_use": False,
136
+ "json_mode": False,
137
+ }
138
+
139
+
140
+ def get_default_model() -> str:
141
+ """Get the default model ID for Groq."""
142
+ return "llama-3.3-70b-versatile"
143
+
144
+
145
+ def supports_tool_use(model_id: str) -> bool:
146
+ """Check if a model supports tool use."""
147
+ return get_model_config(model_id).get("tool_use", False)
148
+
149
+
150
+ def supports_parallel_tool_use(model_id: str) -> bool:
151
+ """Check if a model supports parallel tool use."""
152
+ return get_model_config(model_id).get("parallel_tool_use", False)
153
+
154
+
155
+ def supports_json_mode(model_id: str) -> bool:
156
+ """Check if a model supports JSON mode."""
157
+ return get_model_config(model_id).get("json_mode", False)
158
+
159
+
160
+ def get_max_completion_tokens(model_id: str) -> int:
161
+ """Get the maximum number of completion tokens for a model."""
162
+ return get_model_config(model_id).get("max_completion_tokens", 1024)
@@ -1,8 +1,9 @@
1
- from typing import Generator, Optional, Dict, Any, List
2
- from pydantic import Field
1
+ from typing import Generator, Optional, Dict, Any, List, Union
2
+ from pydantic import Field, validator
3
3
  from airtrain.core.skills import Skill, ProcessingError
4
4
  from airtrain.core.schemas import InputSchema, OutputSchema
5
5
  from .credentials import GroqCredentials
6
+ from .models_config import get_max_completion_tokens, get_model_config
6
7
  from groq import Groq
7
8
 
8
9
 
@@ -12,22 +13,59 @@ class GroqInput(InputSchema):
12
13
  user_input: str = Field(..., description="User's input text")
13
14
  system_prompt: str = Field(
14
15
  default="You are a helpful assistant.",
15
- description="System prompt to guide the model's behavior",
16
+ description=(
17
+ "System prompt to guide the model's behavior"
18
+ ),
16
19
  )
17
20
  conversation_history: List[Dict[str, str]] = Field(
18
21
  default_factory=list,
19
- description="List of previous conversation messages in [{'role': 'user|assistant', 'content': 'message'}] format",
22
+ description=(
23
+ "List of previous conversation messages in "
24
+ "[{'role': 'user|assistant', 'content': 'message'}] format"
25
+ ),
20
26
  )
21
27
  model: str = Field(
22
- default="deepseek-r1-distill-llama-70b-specdec", description="Groq model to use"
28
+ default="llama-3.3-70b-versatile",
29
+ description="Groq model to use"
30
+ )
31
+ max_tokens: int = Field(
32
+ default=4096,
33
+ description="Maximum tokens in response"
23
34
  )
24
- max_tokens: int = Field(default=131072, description="Maximum tokens in response")
25
35
  temperature: float = Field(
26
- default=0.7, description="Temperature for response generation", ge=0, le=1
36
+ default=0.7,
37
+ description="Temperature for response generation",
38
+ ge=0,
39
+ le=1
27
40
  )
28
41
  stream: bool = Field(
29
- default=False, description="Whether to stream the response progressively"
42
+ default=False,
43
+ description="Whether to stream the response progressively"
44
+ )
45
+ tools: Optional[List[Dict[str, Any]]] = Field(
46
+ default=None,
47
+ description=(
48
+ "A list of tools the model may use. "
49
+ "Currently only functions supported."
50
+ ),
51
+ )
52
+ tool_choice: Optional[Union[str, Dict[str, Any]]] = Field(
53
+ default=None,
54
+ description=(
55
+ "Controls which tool is called by the model. "
56
+ "'none', 'auto', or specific tool."
57
+ ),
30
58
  )
59
+
60
+ @validator('max_tokens')
61
+ def validate_max_tokens(cls, v, values):
62
+ """Validate that max_tokens doesn't exceed the model's limit."""
63
+ if 'model' in values:
64
+ model_id = values['model']
65
+ max_limit = get_max_completion_tokens(model_id)
66
+ if v > max_limit:
67
+ return max_limit
68
+ return v
31
69
 
32
70
 
33
71
  class GroqOutput(OutputSchema):
@@ -38,6 +76,9 @@ class GroqOutput(OutputSchema):
38
76
  usage: Dict[str, Any] = Field(
39
77
  default_factory=dict, description="Usage statistics from the API"
40
78
  )
79
+ tool_calls: Optional[List[Dict[str, Any]]] = Field(
80
+ default=None, description="Tool calls generated by the model"
81
+ )
41
82
 
42
83
 
43
84
  class GroqChatSkill(Skill[GroqInput, GroqOutput]):
@@ -101,24 +142,59 @@ class GroqChatSkill(Skill[GroqInput, GroqOutput]):
101
142
  response_chunks.append(chunk)
102
143
  response = "".join(response_chunks)
103
144
  usage = {} # Usage stats not available in streaming
145
+ tool_calls = None # Tool calls not available in streaming
104
146
  else:
105
147
  messages = self._build_messages(input_data)
106
- completion = self.client.chat.completions.create(
107
- model=input_data.model,
108
- messages=messages,
109
- temperature=input_data.temperature,
110
- max_tokens=input_data.max_tokens,
111
- stream=False,
112
- )
113
- response = completion.choices[0].message.content
148
+
149
+ # Prepare API call parameters
150
+ api_params = {
151
+ "model": input_data.model,
152
+ "messages": messages,
153
+ "temperature": input_data.temperature,
154
+ "max_tokens": input_data.max_tokens,
155
+ "stream": False,
156
+ }
157
+
158
+ # Add tools and tool_choice if provided
159
+ if input_data.tools:
160
+ api_params["tools"] = input_data.tools
161
+
162
+ if input_data.tool_choice:
163
+ api_params["tool_choice"] = input_data.tool_choice
164
+
165
+ completion = self.client.chat.completions.create(**api_params)
166
+ response = completion.choices[0].message.content or ""
167
+
168
+ # Extract usage information
114
169
  usage = {
115
170
  "total_tokens": completion.usage.total_tokens,
116
171
  "prompt_tokens": completion.usage.prompt_tokens,
117
172
  "completion_tokens": completion.usage.completion_tokens,
118
173
  }
174
+
175
+ # Check for tool calls in the response
176
+ tool_calls = None
177
+ if (
178
+ hasattr(completion.choices[0].message, "tool_calls")
179
+ and completion.choices[0].message.tool_calls
180
+ ):
181
+ tool_calls = [
182
+ {
183
+ "id": tool_call.id,
184
+ "type": tool_call.type,
185
+ "function": {
186
+ "name": tool_call.function.name,
187
+ "arguments": tool_call.function.arguments
188
+ }
189
+ }
190
+ for tool_call in completion.choices[0].message.tool_calls
191
+ ]
119
192
 
120
193
  return GroqOutput(
121
- response=response, used_model=input_data.model, usage=usage
194
+ response=response,
195
+ used_model=input_data.model,
196
+ usage=usage,
197
+ tool_calls=tool_calls
122
198
  )
123
199
 
124
200
  except Exception as e:
@@ -1,7 +1,14 @@
1
1
  """Together AI integration module"""
2
2
 
3
3
  from .credentials import TogetherAICredentials
4
- from .skills import TogetherAIChatSkill
4
+ from .skills import TogetherAIChatSkill, TogetherAIInput, TogetherAIOutput
5
+ from .models_config import (
6
+ get_model_config_with_capabilities,
7
+ get_max_completion_tokens,
8
+ supports_tool_use,
9
+ supports_json_mode,
10
+ TOGETHER_MODELS_CONFIG,
11
+ )
5
12
  from .list_models import (
6
13
  TogetherListModelsSkill,
7
14
  TogetherListModelsInput,
@@ -12,8 +19,15 @@ from .models import TogetherModel
12
19
  __all__ = [
13
20
  "TogetherAICredentials",
14
21
  "TogetherAIChatSkill",
22
+ "TogetherAIInput",
23
+ "TogetherAIOutput",
15
24
  "TogetherListModelsSkill",
16
25
  "TogetherListModelsInput",
17
26
  "TogetherListModelsOutput",
18
27
  "TogetherModel",
28
+ "get_model_config_with_capabilities",
29
+ "get_max_completion_tokens",
30
+ "supports_tool_use",
31
+ "supports_json_mode",
32
+ "TOGETHER_MODELS_CONFIG",
19
33
  ]
@@ -1,4 +1,4 @@
1
- from typing import Dict, NamedTuple
1
+ from typing import Dict, NamedTuple, Any
2
2
 
3
3
 
4
4
  class ModelConfig(NamedTuple):
@@ -272,6 +272,128 @@ def get_default_model() -> str:
272
272
  return "meta-llama/Llama-3.3-70B-Instruct-Turbo"
273
273
 
274
274
 
275
+ # Model configuration with capabilities for each model
276
+ TOGETHER_MODELS_CONFIG = {
277
+ "meta-llama/Llama-3.1-8B-Instruct": {
278
+ "name": "Llama 3.1 8B Instruct",
279
+ "context_window": 128000,
280
+ "max_completion_tokens": 8192,
281
+ "tool_use": True,
282
+ "json_mode": True,
283
+ },
284
+ "meta-llama/Llama-3.1-70B-Instruct": {
285
+ "name": "Llama 3.1 70B Instruct",
286
+ "context_window": 128000,
287
+ "max_completion_tokens": 32768,
288
+ "tool_use": True,
289
+ "json_mode": True,
290
+ },
291
+ "mistralai/Mixtral-8x7B-Instruct-v0.1": {
292
+ "name": "Mixtral 8x7B Instruct v0.1",
293
+ "context_window": 32768,
294
+ "max_completion_tokens": 8192,
295
+ "tool_use": True,
296
+ "json_mode": True,
297
+ },
298
+ "meta-llama/Meta-Llama-3-8B-Instruct": {
299
+ "name": "Meta Llama 3 8B Instruct",
300
+ "context_window": 8192,
301
+ "max_completion_tokens": 4096,
302
+ "tool_use": True,
303
+ "json_mode": True,
304
+ },
305
+ "meta-llama/Meta-Llama-3-70B-Instruct": {
306
+ "name": "Meta Llama 3 70B Instruct",
307
+ "context_window": 8192,
308
+ "max_completion_tokens": 4096,
309
+ "tool_use": True,
310
+ "json_mode": True,
311
+ },
312
+ "deepseek-ai/DeepSeek-Coder-V2": {
313
+ "name": "DeepSeek Coder V2",
314
+ "context_window": 128000,
315
+ "max_completion_tokens": 16384,
316
+ "tool_use": True,
317
+ "json_mode": True,
318
+ },
319
+ "deepseek-ai/DeepSeek-V2": {
320
+ "name": "DeepSeek V2",
321
+ "context_window": 128000,
322
+ "max_completion_tokens": 16384,
323
+ "tool_use": True,
324
+ "json_mode": True,
325
+ },
326
+ "deepseek-ai/DeepSeek-R1": {
327
+ "name": "DeepSeek R1",
328
+ "context_window": 32768,
329
+ "max_completion_tokens": 8192,
330
+ "tool_use": False,
331
+ "json_mode": False,
332
+ },
333
+ # Qwen models
334
+ "Qwen/Qwen2.5-72B-Instruct-Turbo": {
335
+ "context_window": 128000,
336
+ "max_completion_tokens": 4096,
337
+ "tool_use": True,
338
+ "json_mode": True,
339
+ },
340
+ "Qwen/Qwen2.5-7B-Instruct": {
341
+ "context_window": 32768,
342
+ "max_completion_tokens": 4096,
343
+ "tool_use": True,
344
+ "json_mode": True,
345
+ },
346
+ }
347
+
348
+
349
+ def get_model_config_with_capabilities(model_id: str) -> Dict[str, Any]:
350
+ """
351
+ Get the configuration for a specific model.
352
+
353
+ Args:
354
+ model_id: The model ID to get configuration for
355
+
356
+ Returns:
357
+ Dict with model configuration
358
+
359
+ Raises:
360
+ ValueError: If model_id is not found in configuration
361
+ """
362
+ if model_id in TOGETHER_MODELS_CONFIG:
363
+ return TOGETHER_MODELS_CONFIG[model_id]
364
+
365
+ # Try to find a match with different format or case
366
+ normalized_id = model_id.lower().replace("-", "").replace("_", "").replace("/", "")
367
+ for config_id, config in TOGETHER_MODELS_CONFIG.items():
368
+ norm_config_id = config_id.lower().replace("-", "").replace("_", "").replace("/", "")
369
+ if normalized_id == norm_config_id:
370
+ return config
371
+
372
+ # Default configuration for unknown models
373
+ return {
374
+ "name": model_id,
375
+ "context_window": 4096, # Conservative default
376
+ "max_completion_tokens": 1024, # Conservative default
377
+ "tool_use": False,
378
+ "json_mode": False,
379
+ }
380
+
381
+
382
+ def supports_tool_use(model_id: str) -> bool:
383
+ """Check if a model supports tool use."""
384
+ return get_model_config_with_capabilities(model_id).get("tool_use", False)
385
+
386
+
387
+ def supports_json_mode(model_id: str) -> bool:
388
+ """Check if a model supports JSON mode."""
389
+ return get_model_config_with_capabilities(model_id).get("json_mode", False)
390
+
391
+
392
+ def get_max_completion_tokens(model_id: str) -> int:
393
+ """Get the maximum number of completion tokens for a model."""
394
+ return get_model_config_with_capabilities(model_id).get("max_completion_tokens", 1024)
395
+
396
+
275
397
  if __name__ == "__main__":
276
398
  print(len(TOGETHER_MODELS))
277
399
  print(get_model_config("meta-llama/Llama-3.3-70B-Instruct-Turbo"))
@@ -1,9 +1,10 @@
1
- from typing import Optional, Dict, Any, List, Generator
2
- from pydantic import Field
1
+ from typing import Optional, Dict, Any, List, Generator, Union
2
+ from pydantic import Field, validator
3
3
  from airtrain.core.skills import Skill, ProcessingError
4
4
  from airtrain.core.schemas import InputSchema, OutputSchema
5
5
  from .credentials import TogetherAICredentials
6
6
  from .models import TogetherAIImageInput, TogetherAIImageOutput, GeneratedImage
7
+ from .models_config import get_max_completion_tokens
7
8
  from pathlib import Path
8
9
  import base64
9
10
  import time
@@ -20,16 +21,53 @@ class TogetherAIInput(InputSchema):
20
21
  )
21
22
  conversation_history: List[Dict[str, str]] = Field(
22
23
  default_factory=list,
23
- description="List of previous conversation messages in [{'role': 'user|assistant', 'content': 'message'}] format",
24
+ description=(
25
+ "List of previous conversation messages in "
26
+ "[{'role': 'user|assistant', 'content': 'message'}] format"
27
+ ),
24
28
  )
25
29
  model: str = Field(
26
- default="deepseek-ai/DeepSeek-R1", description="Together AI model to use"
30
+ default="meta-llama/Llama-3.1-8B-Instruct",
31
+ description="Together AI model to use"
32
+ )
33
+ max_tokens: int = Field(
34
+ default=4096,
35
+ description="Maximum tokens in response"
27
36
  )
28
- max_tokens: int = Field(default=1024, description="Maximum tokens in response")
29
37
  temperature: float = Field(
30
- default=0.7, description="Temperature for response generation", ge=0, le=1
38
+ default=0.7,
39
+ description="Temperature for response generation",
40
+ ge=0,
41
+ le=1
42
+ )
43
+ stream: bool = Field(
44
+ default=False,
45
+ description="Whether to stream the response"
46
+ )
47
+ tools: Optional[List[Dict[str, Any]]] = Field(
48
+ default=None,
49
+ description=(
50
+ "A list of tools the model may use. "
51
+ "Currently only functions supported."
52
+ ),
31
53
  )
32
- stream: bool = Field(default=False, description="Whether to stream the response")
54
+ tool_choice: Optional[Union[str, Dict[str, Any]]] = Field(
55
+ default=None,
56
+ description=(
57
+ "Controls which tool is called by the model. "
58
+ "'none', 'auto', or specific tool."
59
+ ),
60
+ )
61
+
62
+ @validator('max_tokens')
63
+ def validate_max_tokens(cls, v, values):
64
+ """Validate that max_tokens doesn't exceed the model's limit."""
65
+ if 'model' in values:
66
+ model_id = values['model']
67
+ max_limit = get_max_completion_tokens(model_id)
68
+ if v > max_limit:
69
+ return max_limit
70
+ return v
33
71
 
34
72
 
35
73
  class TogetherAIOutput(OutputSchema):
@@ -38,6 +76,9 @@ class TogetherAIOutput(OutputSchema):
38
76
  response: str = Field(..., description="Model's response text")
39
77
  used_model: str = Field(..., description="Model used for generation")
40
78
  usage: Dict[str, Any] = Field(default_factory=dict, description="Usage statistics")
79
+ tool_calls: Optional[List[Dict[str, Any]]] = Field(
80
+ default=None, description="Tool calls generated by the model"
81
+ )
41
82
 
42
83
 
43
84
  class TogetherAIChatSkill(Skill[TogetherAIInput, TogetherAIOutput]):
@@ -102,27 +143,83 @@ class TogetherAIChatSkill(Skill[TogetherAIInput, TogetherAIOutput]):
102
143
  for chunk in self.process_stream(input_data):
103
144
  response_chunks.append(chunk)
104
145
  response = "".join(response_chunks)
146
+ usage = {} # Usage stats not available in streaming
147
+ tool_calls = None # Tool calls not available in streaming
105
148
  else:
106
149
  messages = self._build_messages(input_data)
107
- completion = self.client.chat.completions.create(
108
- model=input_data.model,
109
- messages=messages,
110
- temperature=input_data.temperature,
111
- max_tokens=input_data.max_tokens,
112
- stream=False,
113
- )
114
- response = completion.choices[0].message.content
150
+
151
+ # Prepare API call parameters
152
+ api_params = {
153
+ "model": input_data.model,
154
+ "messages": messages,
155
+ "temperature": input_data.temperature,
156
+ "max_tokens": input_data.max_tokens,
157
+ "stream": False,
158
+ }
159
+
160
+ # Add tools and tool_choice if provided
161
+ if input_data.tools:
162
+ api_params["tools"] = input_data.tools
163
+
164
+ if input_data.tool_choice:
165
+ api_params["tool_choice"] = input_data.tool_choice
166
+
167
+ try:
168
+ completion = self.client.chat.completions.create(**api_params)
169
+ response = completion.choices[0].message.content or ""
170
+
171
+ # Extract usage information
172
+ usage = (
173
+ completion.usage.model_dump()
174
+ if hasattr(completion, "usage")
175
+ else {}
176
+ )
177
+
178
+ # Check for tool calls in the response
179
+ tool_calls = None
180
+ if (
181
+ hasattr(completion.choices[0].message, "tool_calls")
182
+ and completion.choices[0].message.tool_calls
183
+ ):
184
+ tool_calls = [
185
+ {
186
+ "id": tool_call.id,
187
+ "type": tool_call.type,
188
+ "function": {
189
+ "name": tool_call.function.name,
190
+ "arguments": tool_call.function.arguments
191
+ }
192
+ }
193
+ for tool_call in completion.choices[0].message.tool_calls
194
+ ]
195
+ except Exception as api_error:
196
+ # Provide more specific error messages for common tool call issues
197
+ error_message = str(api_error)
198
+ if "tool_call_id" in error_message:
199
+ raise ProcessingError(
200
+ "Tool call error: Missing tool_call_id in conversation history. "
201
+ "When adding tool responses to conversation history, "
202
+ "make sure each message with role='tool' includes a 'tool_call_id' field."
203
+ )
204
+ elif "messages" in error_message and "tool" in error_message:
205
+ raise ProcessingError(
206
+ "Tool call error: Invalid message format in conversation history. "
207
+ f"Original error: {error_message}"
208
+ )
209
+ else:
210
+ # Re-raise with original error message
211
+ raise ProcessingError(f"Together AI API error: {error_message}")
115
212
 
116
213
  return TogetherAIOutput(
117
214
  response=response,
118
215
  used_model=input_data.model,
119
- usage=(
120
- completion.usage.model_dump()
121
- if hasattr(completion, "usage")
122
- else {}
123
- ),
216
+ usage=usage,
217
+ tool_calls=tool_calls
124
218
  )
125
219
 
220
+ except ProcessingError:
221
+ # Re-raise ProcessingError without modification
222
+ raise
126
223
  except Exception as e:
127
224
  raise ProcessingError(f"Together AI processing failed: {str(e)}")
128
225