airtrain 0.1.39__py3-none-any.whl → 0.1.41__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airtrain/__init__.py +1 -1
- airtrain/integrations/__init__.py +10 -1
- airtrain/integrations/anthropic/__init__.py +12 -0
- airtrain/integrations/anthropic/models_config.py +100 -0
- airtrain/integrations/fireworks/__init__.py +10 -0
- airtrain/integrations/fireworks/credentials.py +10 -2
- airtrain/integrations/fireworks/list_models.py +128 -0
- airtrain/integrations/fireworks/models.py +112 -0
- airtrain/integrations/fireworks/skills.py +62 -11
- airtrain/integrations/fireworks/structured_completion_skills.py +10 -4
- airtrain/integrations/fireworks/structured_requests_skills.py +108 -31
- airtrain/integrations/openai/__init__.py +6 -0
- airtrain/integrations/openai/models_config.py +118 -13
- airtrain/integrations/openai/skills.py +109 -1
- airtrain/integrations/together/__init__.py +14 -1
- airtrain/integrations/together/list_models.py +77 -0
- airtrain/integrations/together/models.py +42 -3
- {airtrain-0.1.39.dist-info → airtrain-0.1.41.dist-info}/METADATA +1 -1
- {airtrain-0.1.39.dist-info → airtrain-0.1.41.dist-info}/RECORD +22 -19
- {airtrain-0.1.39.dist-info → airtrain-0.1.41.dist-info}/WHEEL +1 -1
- {airtrain-0.1.39.dist-info → airtrain-0.1.41.dist-info}/entry_points.txt +0 -0
- {airtrain-0.1.39.dist-info → airtrain-0.1.41.dist-info}/top_level.txt +0 -0
airtrain/__init__.py
CHANGED
@@ -22,6 +22,10 @@ from .ollama.skills import OllamaChatSkill
|
|
22
22
|
from .sambanova.skills import SambanovaChatSkill
|
23
23
|
from .cerebras.skills import CerebrasChatSkill
|
24
24
|
|
25
|
+
# Model configurations
|
26
|
+
from .openai.models_config import OPENAI_MODELS, OpenAIModelConfig
|
27
|
+
from .anthropic.models_config import ANTHROPIC_MODELS, AnthropicModelConfig
|
28
|
+
|
25
29
|
__all__ = [
|
26
30
|
# Credentials
|
27
31
|
"OpenAICredentials",
|
@@ -38,10 +42,15 @@ __all__ = [
|
|
38
42
|
"OpenAIParserSkill",
|
39
43
|
"AnthropicChatSkill",
|
40
44
|
"AWSBedrockSkill",
|
41
|
-
"
|
45
|
+
"GoogleChatSkill",
|
42
46
|
"GroqChatSkill",
|
43
47
|
"TogetherAIChatSkill",
|
44
48
|
"OllamaChatSkill",
|
45
49
|
"SambanovaChatSkill",
|
46
50
|
"CerebrasChatSkill",
|
51
|
+
# Model configurations
|
52
|
+
"OPENAI_MODELS",
|
53
|
+
"OpenAIModelConfig",
|
54
|
+
"ANTHROPIC_MODELS",
|
55
|
+
"AnthropicModelConfig",
|
47
56
|
]
|
@@ -2,10 +2,22 @@
|
|
2
2
|
|
3
3
|
from .credentials import AnthropicCredentials
|
4
4
|
from .skills import AnthropicChatSkill, AnthropicInput, AnthropicOutput
|
5
|
+
from .models_config import (
|
6
|
+
ANTHROPIC_MODELS,
|
7
|
+
AnthropicModelConfig,
|
8
|
+
get_model_config,
|
9
|
+
get_default_model,
|
10
|
+
calculate_cost,
|
11
|
+
)
|
5
12
|
|
6
13
|
__all__ = [
|
7
14
|
"AnthropicCredentials",
|
8
15
|
"AnthropicChatSkill",
|
9
16
|
"AnthropicInput",
|
10
17
|
"AnthropicOutput",
|
18
|
+
"ANTHROPIC_MODELS",
|
19
|
+
"AnthropicModelConfig",
|
20
|
+
"get_model_config",
|
21
|
+
"get_default_model",
|
22
|
+
"calculate_cost",
|
11
23
|
]
|
@@ -0,0 +1,100 @@
|
|
1
|
+
from typing import Dict, NamedTuple, Optional
|
2
|
+
from decimal import Decimal
|
3
|
+
|
4
|
+
|
5
|
+
class AnthropicModelConfig(NamedTuple):
|
6
|
+
display_name: str
|
7
|
+
base_model: str
|
8
|
+
input_price: Decimal
|
9
|
+
cached_write_price: Optional[Decimal]
|
10
|
+
cached_read_price: Optional[Decimal]
|
11
|
+
output_price: Decimal
|
12
|
+
|
13
|
+
|
14
|
+
ANTHROPIC_MODELS: Dict[str, AnthropicModelConfig] = {
|
15
|
+
"claude-3-7-sonnet": AnthropicModelConfig(
|
16
|
+
display_name="Claude 3.7 Sonnet",
|
17
|
+
base_model="claude-3-7-sonnet",
|
18
|
+
input_price=Decimal("3.00"),
|
19
|
+
cached_write_price=Decimal("3.75"),
|
20
|
+
cached_read_price=Decimal("0.30"),
|
21
|
+
output_price=Decimal("15.00"),
|
22
|
+
),
|
23
|
+
"claude-3-5-haiku": AnthropicModelConfig(
|
24
|
+
display_name="Claude 3.5 Haiku",
|
25
|
+
base_model="claude-3-5-haiku",
|
26
|
+
input_price=Decimal("0.80"),
|
27
|
+
cached_write_price=Decimal("1.00"),
|
28
|
+
cached_read_price=Decimal("0.08"),
|
29
|
+
output_price=Decimal("4.00"),
|
30
|
+
),
|
31
|
+
"claude-3-opus": AnthropicModelConfig(
|
32
|
+
display_name="Claude 3 Opus",
|
33
|
+
base_model="claude-3-opus",
|
34
|
+
input_price=Decimal("15.00"),
|
35
|
+
cached_write_price=Decimal("18.75"),
|
36
|
+
cached_read_price=Decimal("1.50"),
|
37
|
+
output_price=Decimal("75.00"),
|
38
|
+
),
|
39
|
+
"claude-3-sonnet": AnthropicModelConfig(
|
40
|
+
display_name="Claude 3 Sonnet",
|
41
|
+
base_model="claude-3-sonnet",
|
42
|
+
input_price=Decimal("3.00"),
|
43
|
+
cached_write_price=Decimal("3.75"),
|
44
|
+
cached_read_price=Decimal("0.30"),
|
45
|
+
output_price=Decimal("15.00"),
|
46
|
+
),
|
47
|
+
"claude-3-haiku": AnthropicModelConfig(
|
48
|
+
display_name="Claude 3 Haiku",
|
49
|
+
base_model="claude-3-haiku",
|
50
|
+
input_price=Decimal("0.25"),
|
51
|
+
cached_write_price=Decimal("0.31"),
|
52
|
+
cached_read_price=Decimal("0.025"),
|
53
|
+
output_price=Decimal("1.25"),
|
54
|
+
),
|
55
|
+
}
|
56
|
+
|
57
|
+
|
58
|
+
def get_model_config(model_id: str) -> AnthropicModelConfig:
|
59
|
+
"""Get model configuration by model ID"""
|
60
|
+
if model_id not in ANTHROPIC_MODELS:
|
61
|
+
raise ValueError(f"Model {model_id} not found in Anthropic models")
|
62
|
+
return ANTHROPIC_MODELS[model_id]
|
63
|
+
|
64
|
+
|
65
|
+
def get_default_model() -> str:
|
66
|
+
"""Get the default model ID"""
|
67
|
+
return "claude-3-sonnet"
|
68
|
+
|
69
|
+
|
70
|
+
def calculate_cost(
|
71
|
+
model_id: str,
|
72
|
+
input_tokens: int,
|
73
|
+
output_tokens: int,
|
74
|
+
use_cached: bool = False,
|
75
|
+
cache_type: str = "read"
|
76
|
+
) -> Decimal:
|
77
|
+
"""Calculate cost for token usage
|
78
|
+
|
79
|
+
Args:
|
80
|
+
model_id: The model ID to calculate costs for
|
81
|
+
input_tokens: Number of input tokens
|
82
|
+
output_tokens: Number of output tokens
|
83
|
+
use_cached: Whether to use cached pricing
|
84
|
+
cache_type: Either "read" or "write" for cached pricing type
|
85
|
+
"""
|
86
|
+
config = get_model_config(model_id)
|
87
|
+
|
88
|
+
if not use_cached:
|
89
|
+
input_cost = config.input_price * Decimal(str(input_tokens))
|
90
|
+
else:
|
91
|
+
if cache_type == "read" and config.cached_read_price is not None:
|
92
|
+
input_cost = config.cached_read_price * Decimal(str(input_tokens))
|
93
|
+
elif cache_type == "write" and config.cached_write_price is not None:
|
94
|
+
input_cost = config.cached_write_price * Decimal(str(input_tokens))
|
95
|
+
else:
|
96
|
+
input_cost = config.input_price * Decimal(str(input_tokens))
|
97
|
+
|
98
|
+
output_cost = config.output_price * Decimal(str(output_tokens))
|
99
|
+
|
100
|
+
return (input_cost + output_cost) / Decimal("1000")
|
@@ -2,10 +2,20 @@
|
|
2
2
|
|
3
3
|
from .credentials import FireworksCredentials
|
4
4
|
from .skills import FireworksChatSkill, FireworksInput, FireworksOutput
|
5
|
+
from .list_models import (
|
6
|
+
FireworksListModelsSkill,
|
7
|
+
FireworksListModelsInput,
|
8
|
+
FireworksListModelsOutput,
|
9
|
+
)
|
10
|
+
from .models import FireworksModel
|
5
11
|
|
6
12
|
__all__ = [
|
7
13
|
"FireworksCredentials",
|
8
14
|
"FireworksChatSkill",
|
9
15
|
"FireworksInput",
|
10
16
|
"FireworksOutput",
|
17
|
+
"FireworksListModelsSkill",
|
18
|
+
"FireworksListModelsInput",
|
19
|
+
"FireworksListModelsOutput",
|
20
|
+
"FireworksModel",
|
11
21
|
]
|
@@ -1,4 +1,4 @@
|
|
1
|
-
from pydantic import SecretStr, BaseModel
|
1
|
+
from pydantic import SecretStr, BaseModel, Field
|
2
2
|
from typing import Optional
|
3
3
|
import os
|
4
4
|
|
@@ -6,7 +6,15 @@ import os
|
|
6
6
|
class FireworksCredentials(BaseModel):
|
7
7
|
"""Credentials for Fireworks AI API"""
|
8
8
|
|
9
|
-
fireworks_api_key: SecretStr
|
9
|
+
fireworks_api_key: SecretStr = Field(..., min_length=1)
|
10
|
+
|
11
|
+
def __repr__(self) -> str:
|
12
|
+
"""Return a string representation of the credentials."""
|
13
|
+
return f"FireworksCredentials(fireworks_api_key=SecretStr('**********'))"
|
14
|
+
|
15
|
+
def __str__(self) -> str:
|
16
|
+
"""Return a string representation of the credentials."""
|
17
|
+
return self.__repr__()
|
10
18
|
|
11
19
|
@classmethod
|
12
20
|
def from_env(cls) -> "FireworksCredentials":
|
@@ -0,0 +1,128 @@
|
|
1
|
+
from typing import Optional, List
|
2
|
+
import requests
|
3
|
+
from pydantic import Field
|
4
|
+
|
5
|
+
from airtrain.core.skills import Skill, ProcessingError
|
6
|
+
from airtrain.core.schemas import InputSchema, OutputSchema
|
7
|
+
from .credentials import FireworksCredentials
|
8
|
+
from .models import FireworksModel
|
9
|
+
|
10
|
+
|
11
|
+
class FireworksListModelsInput(InputSchema):
|
12
|
+
"""Schema for Fireworks AI list models input"""
|
13
|
+
|
14
|
+
account_id: str = Field(..., description="The Account Id")
|
15
|
+
page_size: Optional[int] = Field(
|
16
|
+
default=50,
|
17
|
+
description=(
|
18
|
+
"The maximum number of models to return. The maximum page_size is 200, "
|
19
|
+
"values above 200 will be coerced to 200."
|
20
|
+
),
|
21
|
+
le=200
|
22
|
+
)
|
23
|
+
page_token: Optional[str] = Field(
|
24
|
+
default=None,
|
25
|
+
description=(
|
26
|
+
"A page token, received from a previous ListModels call. Provide this "
|
27
|
+
"to retrieve the subsequent page. When paginating, all other parameters "
|
28
|
+
"provided to ListModels must match the call that provided the page token."
|
29
|
+
)
|
30
|
+
)
|
31
|
+
filter: Optional[str] = Field(
|
32
|
+
default=None,
|
33
|
+
description=(
|
34
|
+
"Only model satisfying the provided filter (if specified) will be "
|
35
|
+
"returned. See https://google.aip.dev/160 for the filter grammar."
|
36
|
+
)
|
37
|
+
)
|
38
|
+
order_by: Optional[str] = Field(
|
39
|
+
default=None,
|
40
|
+
description=(
|
41
|
+
"A comma-separated list of fields to order by. e.g. \"foo,bar\" "
|
42
|
+
"The default sort order is ascending. To specify a descending order for a "
|
43
|
+
"field, append a \" desc\" suffix. e.g. \"foo desc,bar\" "
|
44
|
+
"Subfields are specified with a \".\" character. e.g. \"foo.bar\" "
|
45
|
+
"If not specified, the default order is by \"name\"."
|
46
|
+
)
|
47
|
+
)
|
48
|
+
|
49
|
+
|
50
|
+
class FireworksListModelsOutput(OutputSchema):
|
51
|
+
"""Schema for Fireworks AI list models output"""
|
52
|
+
|
53
|
+
models: List[FireworksModel] = Field(
|
54
|
+
default_factory=list,
|
55
|
+
description="List of Fireworks models"
|
56
|
+
)
|
57
|
+
next_page_token: Optional[str] = Field(
|
58
|
+
default=None,
|
59
|
+
description="Token for retrieving the next page of results"
|
60
|
+
)
|
61
|
+
total_size: Optional[int] = Field(
|
62
|
+
default=None,
|
63
|
+
description="Total number of models available"
|
64
|
+
)
|
65
|
+
|
66
|
+
|
67
|
+
class FireworksListModelsSkill(
|
68
|
+
Skill[FireworksListModelsInput, FireworksListModelsOutput]
|
69
|
+
):
|
70
|
+
"""Skill for listing Fireworks AI models"""
|
71
|
+
|
72
|
+
input_schema = FireworksListModelsInput
|
73
|
+
output_schema = FireworksListModelsOutput
|
74
|
+
|
75
|
+
def __init__(self, credentials: Optional[FireworksCredentials] = None):
|
76
|
+
"""Initialize the skill with optional credentials"""
|
77
|
+
super().__init__()
|
78
|
+
self.credentials = credentials or FireworksCredentials.from_env()
|
79
|
+
self.base_url = "https://api.fireworks.ai/v1"
|
80
|
+
|
81
|
+
def process(
|
82
|
+
self, input_data: FireworksListModelsInput
|
83
|
+
) -> FireworksListModelsOutput:
|
84
|
+
"""Process the input and return a list of models."""
|
85
|
+
try:
|
86
|
+
# Build the URL
|
87
|
+
url = f"{self.base_url}/accounts/{input_data.account_id}/models"
|
88
|
+
|
89
|
+
# Prepare query parameters
|
90
|
+
params = {}
|
91
|
+
if input_data.page_size:
|
92
|
+
params["pageSize"] = input_data.page_size
|
93
|
+
if input_data.page_token:
|
94
|
+
params["pageToken"] = input_data.page_token
|
95
|
+
if input_data.filter:
|
96
|
+
params["filter"] = input_data.filter
|
97
|
+
if input_data.order_by:
|
98
|
+
params["orderBy"] = input_data.order_by
|
99
|
+
|
100
|
+
# Make the request
|
101
|
+
headers = {
|
102
|
+
"Authorization": (
|
103
|
+
f"Bearer {self.credentials.fireworks_api_key.get_secret_value()}"
|
104
|
+
)
|
105
|
+
}
|
106
|
+
|
107
|
+
response = requests.get(url, headers=headers, params=params)
|
108
|
+
response.raise_for_status()
|
109
|
+
|
110
|
+
# Parse the response
|
111
|
+
result = response.json()
|
112
|
+
|
113
|
+
# Convert the models to FireworksModel objects
|
114
|
+
models = []
|
115
|
+
for model_data in result.get("models", []):
|
116
|
+
models.append(FireworksModel(**model_data))
|
117
|
+
|
118
|
+
# Return the output
|
119
|
+
return FireworksListModelsOutput(
|
120
|
+
models=models,
|
121
|
+
next_page_token=result.get("nextPageToken"),
|
122
|
+
total_size=result.get("totalSize")
|
123
|
+
)
|
124
|
+
|
125
|
+
except requests.RequestException as e:
|
126
|
+
raise ProcessingError(f"Failed to list Fireworks models: {str(e)}")
|
127
|
+
except Exception as e:
|
128
|
+
raise ProcessingError(f"Error listing Fireworks models: {str(e)}")
|
@@ -25,3 +25,115 @@ class FireworksResponse(BaseModel):
|
|
25
25
|
created: int
|
26
26
|
model: str
|
27
27
|
usage: FireworksUsage
|
28
|
+
|
29
|
+
|
30
|
+
class FireworksModelStatus(BaseModel):
|
31
|
+
"""Schema for Fireworks model status"""
|
32
|
+
# This would be filled with actual fields from the API response
|
33
|
+
|
34
|
+
|
35
|
+
class FireworksModelBaseDetails(BaseModel):
|
36
|
+
"""Schema for Fireworks base model details"""
|
37
|
+
# This would be filled with actual fields from the API response
|
38
|
+
|
39
|
+
|
40
|
+
class FireworksPeftDetails(BaseModel):
|
41
|
+
"""Schema for Fireworks PEFT details"""
|
42
|
+
# This would be filled with actual fields from the API response
|
43
|
+
|
44
|
+
|
45
|
+
class FireworksConversationConfig(BaseModel):
|
46
|
+
"""Schema for Fireworks conversation configuration"""
|
47
|
+
# This would be filled with actual fields from the API response
|
48
|
+
|
49
|
+
|
50
|
+
class FireworksModelDeployedRef(BaseModel):
|
51
|
+
"""Schema for Fireworks deployed model reference"""
|
52
|
+
# This would be filled with actual fields from the API response
|
53
|
+
|
54
|
+
|
55
|
+
class FireworksDeprecationDate(BaseModel):
|
56
|
+
"""Schema for Fireworks deprecation date"""
|
57
|
+
# This would be filled with actual fields from the API response
|
58
|
+
|
59
|
+
|
60
|
+
class FireworksModel(BaseModel):
|
61
|
+
"""Schema for a Fireworks model"""
|
62
|
+
|
63
|
+
name: str
|
64
|
+
display_name: Optional[str] = None
|
65
|
+
description: Optional[str] = None
|
66
|
+
create_time: Optional[str] = None
|
67
|
+
created_by: Optional[str] = None
|
68
|
+
state: Optional[str] = None
|
69
|
+
status: Optional[Dict[str, Any]] = None
|
70
|
+
kind: Optional[str] = None
|
71
|
+
github_url: Optional[str] = None
|
72
|
+
hugging_face_url: Optional[str] = None
|
73
|
+
base_model_details: Optional[Dict[str, Any]] = None
|
74
|
+
peft_details: Optional[Dict[str, Any]] = None
|
75
|
+
teft_details: Optional[Dict[str, Any]] = None
|
76
|
+
public: Optional[bool] = None
|
77
|
+
conversation_config: Optional[Dict[str, Any]] = None
|
78
|
+
context_length: Optional[int] = None
|
79
|
+
supports_image_input: Optional[bool] = None
|
80
|
+
supports_tools: Optional[bool] = None
|
81
|
+
imported_from: Optional[str] = None
|
82
|
+
fine_tuning_job: Optional[str] = None
|
83
|
+
default_draft_model: Optional[str] = None
|
84
|
+
default_draft_token_count: Optional[int] = None
|
85
|
+
precisions: Optional[List[str]] = None
|
86
|
+
deployed_model_refs: Optional[List[Dict[str, Any]]] = None
|
87
|
+
cluster: Optional[str] = None
|
88
|
+
deprecation_date: Optional[Dict[str, Any]] = None
|
89
|
+
calibrated: Optional[bool] = None
|
90
|
+
tunable: Optional[bool] = None
|
91
|
+
supports_lora: Optional[bool] = None
|
92
|
+
use_hf_apply_chat_template: Optional[bool] = None
|
93
|
+
|
94
|
+
|
95
|
+
class ListModelsInput(BaseModel):
|
96
|
+
"""Schema for listing Fireworks models input"""
|
97
|
+
|
98
|
+
account_id: str = Field(..., description="The Account Id")
|
99
|
+
page_size: Optional[int] = Field(
|
100
|
+
default=50,
|
101
|
+
description=(
|
102
|
+
"The maximum number of models to return. The maximum page_size is 200, "
|
103
|
+
"values above 200 will be coerced to 200."
|
104
|
+
),
|
105
|
+
le=200
|
106
|
+
)
|
107
|
+
page_token: Optional[str] = Field(
|
108
|
+
default=None,
|
109
|
+
description=(
|
110
|
+
"A page token, received from a previous ListModels call. Provide this "
|
111
|
+
"to retrieve the subsequent page. When paginating, all other parameters "
|
112
|
+
"provided to ListModels must match the call that provided the page token."
|
113
|
+
)
|
114
|
+
)
|
115
|
+
filter: Optional[str] = Field(
|
116
|
+
default=None,
|
117
|
+
description=(
|
118
|
+
"Only model satisfying the provided filter (if specified) will be "
|
119
|
+
"returned. See https://google.aip.dev/160 for the filter grammar."
|
120
|
+
)
|
121
|
+
)
|
122
|
+
order_by: Optional[str] = Field(
|
123
|
+
default=None,
|
124
|
+
description=(
|
125
|
+
"A comma-separated list of fields to order by. e.g. \"foo,bar\" "
|
126
|
+
"The default sort order is ascending. To specify a descending order for a "
|
127
|
+
"field, append a \" desc\" suffix. e.g. \"foo desc,bar\" "
|
128
|
+
"Subfields are specified with a \".\" character. e.g. \"foo.bar\" "
|
129
|
+
"If not specified, the default order is by \"name\"."
|
130
|
+
)
|
131
|
+
)
|
132
|
+
|
133
|
+
|
134
|
+
class ListModelsOutput(BaseModel):
|
135
|
+
"""Schema for listing Fireworks models output"""
|
136
|
+
|
137
|
+
models: List[FireworksModel]
|
138
|
+
next_page_token: Optional[str] = None
|
139
|
+
total_size: Optional[int] = None
|
@@ -1,14 +1,10 @@
|
|
1
|
-
from typing import List, Optional, Dict, Any, Generator
|
1
|
+
from typing import List, Optional, Dict, Any, Generator, Union
|
2
2
|
from pydantic import Field
|
3
|
-
import requests
|
4
|
-
from loguru import logger
|
5
3
|
from openai import OpenAI
|
6
|
-
from openai.types.chat import ChatCompletionChunk
|
7
4
|
|
8
5
|
from airtrain.core.skills import Skill, ProcessingError
|
9
6
|
from airtrain.core.schemas import InputSchema, OutputSchema
|
10
7
|
from .credentials import FireworksCredentials
|
11
|
-
from .models import FireworksMessage, FireworksResponse
|
12
8
|
|
13
9
|
|
14
10
|
class FireworksInput(InputSchema):
|
@@ -19,9 +15,9 @@ class FireworksInput(InputSchema):
|
|
19
15
|
default="You are a helpful assistant.",
|
20
16
|
description="System prompt to guide the model's behavior",
|
21
17
|
)
|
22
|
-
conversation_history: List[Dict[str,
|
18
|
+
conversation_history: List[Dict[str, Any]] = Field(
|
23
19
|
default_factory=list,
|
24
|
-
description="List of previous conversation messages
|
20
|
+
description="List of previous conversation messages",
|
25
21
|
)
|
26
22
|
model: str = Field(
|
27
23
|
default="accounts/fireworks/models/deepseek-r1",
|
@@ -40,6 +36,20 @@ class FireworksInput(InputSchema):
|
|
40
36
|
default=False,
|
41
37
|
description="Whether to stream the response token by token",
|
42
38
|
)
|
39
|
+
tools: Optional[List[Dict[str, Any]]] = Field(
|
40
|
+
default=None,
|
41
|
+
description=(
|
42
|
+
"A list of tools the model may use. "
|
43
|
+
"Currently only functions supported."
|
44
|
+
),
|
45
|
+
)
|
46
|
+
tool_choice: Optional[Union[str, Dict[str, Any]]] = Field(
|
47
|
+
default=None,
|
48
|
+
description=(
|
49
|
+
"Controls which tool is called by the model. "
|
50
|
+
"'none', 'auto', or specific tool."
|
51
|
+
),
|
52
|
+
)
|
43
53
|
|
44
54
|
|
45
55
|
class FireworksOutput(OutputSchema):
|
@@ -48,6 +58,9 @@ class FireworksOutput(OutputSchema):
|
|
48
58
|
response: str = Field(..., description="Model's response text")
|
49
59
|
used_model: str = Field(..., description="Model used for generation")
|
50
60
|
usage: Dict[str, int] = Field(default_factory=dict, description="Usage statistics")
|
61
|
+
tool_calls: Optional[List[Dict[str, Any]]] = Field(
|
62
|
+
default=None, description="Tool calls generated by the model"
|
63
|
+
)
|
51
64
|
|
52
65
|
|
53
66
|
class FireworksChatSkill(Skill[FireworksInput, FireworksOutput]):
|
@@ -65,7 +78,7 @@ class FireworksChatSkill(Skill[FireworksInput, FireworksOutput]):
|
|
65
78
|
api_key=self.credentials.fireworks_api_key.get_secret_value(),
|
66
79
|
)
|
67
80
|
|
68
|
-
def _build_messages(self, input_data: FireworksInput) -> List[Dict[str,
|
81
|
+
def _build_messages(self, input_data: FireworksInput) -> List[Dict[str, Any]]:
|
69
82
|
"""Build messages list from input data including conversation history."""
|
70
83
|
messages = [{"role": "system", "content": input_data.system_prompt}]
|
71
84
|
|
@@ -104,8 +117,8 @@ class FireworksChatSkill(Skill[FireworksInput, FireworksOutput]):
|
|
104
117
|
for chunk in self.process_stream(input_data):
|
105
118
|
response_chunks.append(chunk)
|
106
119
|
response = "".join(response_chunks)
|
107
|
-
|
108
|
-
#
|
120
|
+
|
121
|
+
# Create completion object for usage stats
|
109
122
|
messages = self._build_messages(input_data)
|
110
123
|
completion = self.client.chat.completions.create(
|
111
124
|
model=input_data.model,
|
@@ -114,7 +127,44 @@ class FireworksChatSkill(Skill[FireworksInput, FireworksOutput]):
|
|
114
127
|
max_tokens=input_data.max_tokens,
|
115
128
|
stream=False,
|
116
129
|
)
|
117
|
-
|
130
|
+
else:
|
131
|
+
# For non-streaming, use regular completion
|
132
|
+
messages = self._build_messages(input_data)
|
133
|
+
|
134
|
+
# Prepare API call parameters
|
135
|
+
api_params = {
|
136
|
+
"model": input_data.model,
|
137
|
+
"messages": messages,
|
138
|
+
"temperature": input_data.temperature,
|
139
|
+
"max_tokens": input_data.max_tokens,
|
140
|
+
"stream": False,
|
141
|
+
}
|
142
|
+
|
143
|
+
# Add tools and tool_choice if provided
|
144
|
+
if input_data.tools:
|
145
|
+
api_params["tools"] = input_data.tools
|
146
|
+
|
147
|
+
if input_data.tool_choice:
|
148
|
+
api_params["tool_choice"] = input_data.tool_choice
|
149
|
+
|
150
|
+
completion = self.client.chat.completions.create(**api_params)
|
151
|
+
response = completion.choices[0].message.content or ""
|
152
|
+
|
153
|
+
# Check for tool calls in the response
|
154
|
+
tool_calls = None
|
155
|
+
if (hasattr(completion.choices[0].message, "tool_calls") and
|
156
|
+
completion.choices[0].message.tool_calls):
|
157
|
+
tool_calls = [
|
158
|
+
{
|
159
|
+
"id": tool_call.id,
|
160
|
+
"type": tool_call.type,
|
161
|
+
"function": {
|
162
|
+
"name": tool_call.function.name,
|
163
|
+
"arguments": tool_call.function.arguments
|
164
|
+
}
|
165
|
+
}
|
166
|
+
for tool_call in completion.choices[0].message.tool_calls
|
167
|
+
]
|
118
168
|
|
119
169
|
return FireworksOutput(
|
120
170
|
response=response,
|
@@ -124,6 +174,7 @@ class FireworksChatSkill(Skill[FireworksInput, FireworksOutput]):
|
|
124
174
|
"prompt_tokens": completion.usage.prompt_tokens,
|
125
175
|
"completion_tokens": completion.usage.completion_tokens,
|
126
176
|
},
|
177
|
+
tool_calls=tool_calls
|
127
178
|
)
|
128
179
|
|
129
180
|
except Exception as e:
|
@@ -1,14 +1,13 @@
|
|
1
|
-
from typing import
|
1
|
+
from typing import Any, Dict, Generator, List, Optional, Type, TypeVar
|
2
2
|
from pydantic import BaseModel, Field
|
3
3
|
import requests
|
4
4
|
import json
|
5
|
-
from loguru import logger
|
6
5
|
|
7
6
|
from airtrain.core.skills import Skill, ProcessingError
|
8
7
|
from airtrain.core.schemas import InputSchema, OutputSchema
|
9
8
|
from .credentials import FireworksCredentials
|
10
9
|
|
11
|
-
ResponseT = TypeVar("ResponseT"
|
10
|
+
ResponseT = TypeVar("ResponseT")
|
12
11
|
|
13
12
|
|
14
13
|
class FireworksStructuredCompletionInput(InputSchema):
|
@@ -26,7 +25,7 @@ class FireworksStructuredCompletionInput(InputSchema):
|
|
26
25
|
response_model: Type[ResponseT]
|
27
26
|
stream: bool = Field(
|
28
27
|
default=False,
|
29
|
-
description="Whether to stream the response",
|
28
|
+
description="Whether to stream the response token by token",
|
30
29
|
)
|
31
30
|
|
32
31
|
class Config:
|
@@ -39,6 +38,13 @@ class FireworksStructuredCompletionOutput(OutputSchema):
|
|
39
38
|
parsed_response: Any
|
40
39
|
used_model: str
|
41
40
|
usage: Dict[str, int]
|
41
|
+
tool_calls: Optional[List[Dict[str, Any]]] = Field(
|
42
|
+
default=None,
|
43
|
+
description=(
|
44
|
+
"Tool calls are not applicable for completions, "
|
45
|
+
"included for compatibility"
|
46
|
+
)
|
47
|
+
)
|
42
48
|
|
43
49
|
|
44
50
|
class FireworksStructuredCompletionSkill(
|