airtrain 0.1.38__py3-none-any.whl → 0.1.40__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airtrain/__init__.py +1 -1
- airtrain/cli/main.py +9 -0
- airtrain/integrations/__init__.py +10 -1
- airtrain/integrations/anthropic/__init__.py +12 -0
- airtrain/integrations/anthropic/models_config.py +100 -0
- airtrain/integrations/fireworks/__init__.py +10 -0
- airtrain/integrations/fireworks/credentials.py +10 -2
- airtrain/integrations/fireworks/list_models.py +128 -0
- airtrain/integrations/fireworks/models.py +112 -0
- airtrain/integrations/fireworks/skills.py +62 -11
- airtrain/integrations/fireworks/structured_completion_skills.py +10 -4
- airtrain/integrations/fireworks/structured_requests_skills.py +108 -31
- airtrain/integrations/openai/__init__.py +6 -0
- airtrain/integrations/openai/models_config.py +118 -13
- airtrain/integrations/openai/skills.py +109 -1
- {airtrain-0.1.38.dist-info → airtrain-0.1.40.dist-info}/METADATA +1 -1
- {airtrain-0.1.38.dist-info → airtrain-0.1.40.dist-info}/RECORD +20 -18
- {airtrain-0.1.38.dist-info → airtrain-0.1.40.dist-info}/WHEEL +1 -1
- {airtrain-0.1.38.dist-info → airtrain-0.1.40.dist-info}/entry_points.txt +0 -0
- {airtrain-0.1.38.dist-info → airtrain-0.1.40.dist-info}/top_level.txt +0 -0
airtrain/__init__.py
CHANGED
airtrain/cli/main.py
CHANGED
@@ -109,3 +109,12 @@ def chat(provider: str, temperature: float, system_prompt: str):
|
|
109
109
|
|
110
110
|
# Add to existing cli group
|
111
111
|
cli.add_command(build)
|
112
|
+
|
113
|
+
|
114
|
+
def main():
|
115
|
+
"""Main entry point for the CLI"""
|
116
|
+
cli()
|
117
|
+
|
118
|
+
|
119
|
+
if __name__ == "__main__":
|
120
|
+
main()
|
@@ -22,6 +22,10 @@ from .ollama.skills import OllamaChatSkill
|
|
22
22
|
from .sambanova.skills import SambanovaChatSkill
|
23
23
|
from .cerebras.skills import CerebrasChatSkill
|
24
24
|
|
25
|
+
# Model configurations
|
26
|
+
from .openai.models_config import OPENAI_MODELS, OpenAIModelConfig
|
27
|
+
from .anthropic.models_config import ANTHROPIC_MODELS, AnthropicModelConfig
|
28
|
+
|
25
29
|
__all__ = [
|
26
30
|
# Credentials
|
27
31
|
"OpenAICredentials",
|
@@ -38,10 +42,15 @@ __all__ = [
|
|
38
42
|
"OpenAIParserSkill",
|
39
43
|
"AnthropicChatSkill",
|
40
44
|
"AWSBedrockSkill",
|
41
|
-
"
|
45
|
+
"GoogleChatSkill",
|
42
46
|
"GroqChatSkill",
|
43
47
|
"TogetherAIChatSkill",
|
44
48
|
"OllamaChatSkill",
|
45
49
|
"SambanovaChatSkill",
|
46
50
|
"CerebrasChatSkill",
|
51
|
+
# Model configurations
|
52
|
+
"OPENAI_MODELS",
|
53
|
+
"OpenAIModelConfig",
|
54
|
+
"ANTHROPIC_MODELS",
|
55
|
+
"AnthropicModelConfig",
|
47
56
|
]
|
@@ -2,10 +2,22 @@
|
|
2
2
|
|
3
3
|
from .credentials import AnthropicCredentials
|
4
4
|
from .skills import AnthropicChatSkill, AnthropicInput, AnthropicOutput
|
5
|
+
from .models_config import (
|
6
|
+
ANTHROPIC_MODELS,
|
7
|
+
AnthropicModelConfig,
|
8
|
+
get_model_config,
|
9
|
+
get_default_model,
|
10
|
+
calculate_cost,
|
11
|
+
)
|
5
12
|
|
6
13
|
__all__ = [
|
7
14
|
"AnthropicCredentials",
|
8
15
|
"AnthropicChatSkill",
|
9
16
|
"AnthropicInput",
|
10
17
|
"AnthropicOutput",
|
18
|
+
"ANTHROPIC_MODELS",
|
19
|
+
"AnthropicModelConfig",
|
20
|
+
"get_model_config",
|
21
|
+
"get_default_model",
|
22
|
+
"calculate_cost",
|
11
23
|
]
|
@@ -0,0 +1,100 @@
|
|
1
|
+
from typing import Dict, NamedTuple, Optional
|
2
|
+
from decimal import Decimal
|
3
|
+
|
4
|
+
|
5
|
+
class AnthropicModelConfig(NamedTuple):
|
6
|
+
display_name: str
|
7
|
+
base_model: str
|
8
|
+
input_price: Decimal
|
9
|
+
cached_write_price: Optional[Decimal]
|
10
|
+
cached_read_price: Optional[Decimal]
|
11
|
+
output_price: Decimal
|
12
|
+
|
13
|
+
|
14
|
+
ANTHROPIC_MODELS: Dict[str, AnthropicModelConfig] = {
|
15
|
+
"claude-3-7-sonnet": AnthropicModelConfig(
|
16
|
+
display_name="Claude 3.7 Sonnet",
|
17
|
+
base_model="claude-3-7-sonnet",
|
18
|
+
input_price=Decimal("3.00"),
|
19
|
+
cached_write_price=Decimal("3.75"),
|
20
|
+
cached_read_price=Decimal("0.30"),
|
21
|
+
output_price=Decimal("15.00"),
|
22
|
+
),
|
23
|
+
"claude-3-5-haiku": AnthropicModelConfig(
|
24
|
+
display_name="Claude 3.5 Haiku",
|
25
|
+
base_model="claude-3-5-haiku",
|
26
|
+
input_price=Decimal("0.80"),
|
27
|
+
cached_write_price=Decimal("1.00"),
|
28
|
+
cached_read_price=Decimal("0.08"),
|
29
|
+
output_price=Decimal("4.00"),
|
30
|
+
),
|
31
|
+
"claude-3-opus": AnthropicModelConfig(
|
32
|
+
display_name="Claude 3 Opus",
|
33
|
+
base_model="claude-3-opus",
|
34
|
+
input_price=Decimal("15.00"),
|
35
|
+
cached_write_price=Decimal("18.75"),
|
36
|
+
cached_read_price=Decimal("1.50"),
|
37
|
+
output_price=Decimal("75.00"),
|
38
|
+
),
|
39
|
+
"claude-3-sonnet": AnthropicModelConfig(
|
40
|
+
display_name="Claude 3 Sonnet",
|
41
|
+
base_model="claude-3-sonnet",
|
42
|
+
input_price=Decimal("3.00"),
|
43
|
+
cached_write_price=Decimal("3.75"),
|
44
|
+
cached_read_price=Decimal("0.30"),
|
45
|
+
output_price=Decimal("15.00"),
|
46
|
+
),
|
47
|
+
"claude-3-haiku": AnthropicModelConfig(
|
48
|
+
display_name="Claude 3 Haiku",
|
49
|
+
base_model="claude-3-haiku",
|
50
|
+
input_price=Decimal("0.25"),
|
51
|
+
cached_write_price=Decimal("0.31"),
|
52
|
+
cached_read_price=Decimal("0.025"),
|
53
|
+
output_price=Decimal("1.25"),
|
54
|
+
),
|
55
|
+
}
|
56
|
+
|
57
|
+
|
58
|
+
def get_model_config(model_id: str) -> AnthropicModelConfig:
|
59
|
+
"""Get model configuration by model ID"""
|
60
|
+
if model_id not in ANTHROPIC_MODELS:
|
61
|
+
raise ValueError(f"Model {model_id} not found in Anthropic models")
|
62
|
+
return ANTHROPIC_MODELS[model_id]
|
63
|
+
|
64
|
+
|
65
|
+
def get_default_model() -> str:
|
66
|
+
"""Get the default model ID"""
|
67
|
+
return "claude-3-sonnet"
|
68
|
+
|
69
|
+
|
70
|
+
def calculate_cost(
|
71
|
+
model_id: str,
|
72
|
+
input_tokens: int,
|
73
|
+
output_tokens: int,
|
74
|
+
use_cached: bool = False,
|
75
|
+
cache_type: str = "read"
|
76
|
+
) -> Decimal:
|
77
|
+
"""Calculate cost for token usage
|
78
|
+
|
79
|
+
Args:
|
80
|
+
model_id: The model ID to calculate costs for
|
81
|
+
input_tokens: Number of input tokens
|
82
|
+
output_tokens: Number of output tokens
|
83
|
+
use_cached: Whether to use cached pricing
|
84
|
+
cache_type: Either "read" or "write" for cached pricing type
|
85
|
+
"""
|
86
|
+
config = get_model_config(model_id)
|
87
|
+
|
88
|
+
if not use_cached:
|
89
|
+
input_cost = config.input_price * Decimal(str(input_tokens))
|
90
|
+
else:
|
91
|
+
if cache_type == "read" and config.cached_read_price is not None:
|
92
|
+
input_cost = config.cached_read_price * Decimal(str(input_tokens))
|
93
|
+
elif cache_type == "write" and config.cached_write_price is not None:
|
94
|
+
input_cost = config.cached_write_price * Decimal(str(input_tokens))
|
95
|
+
else:
|
96
|
+
input_cost = config.input_price * Decimal(str(input_tokens))
|
97
|
+
|
98
|
+
output_cost = config.output_price * Decimal(str(output_tokens))
|
99
|
+
|
100
|
+
return (input_cost + output_cost) / Decimal("1000")
|
@@ -2,10 +2,20 @@
|
|
2
2
|
|
3
3
|
from .credentials import FireworksCredentials
|
4
4
|
from .skills import FireworksChatSkill, FireworksInput, FireworksOutput
|
5
|
+
from .list_models import (
|
6
|
+
FireworksListModelsSkill,
|
7
|
+
FireworksListModelsInput,
|
8
|
+
FireworksListModelsOutput,
|
9
|
+
)
|
10
|
+
from .models import FireworksModel
|
5
11
|
|
6
12
|
__all__ = [
|
7
13
|
"FireworksCredentials",
|
8
14
|
"FireworksChatSkill",
|
9
15
|
"FireworksInput",
|
10
16
|
"FireworksOutput",
|
17
|
+
"FireworksListModelsSkill",
|
18
|
+
"FireworksListModelsInput",
|
19
|
+
"FireworksListModelsOutput",
|
20
|
+
"FireworksModel",
|
11
21
|
]
|
@@ -1,4 +1,4 @@
|
|
1
|
-
from pydantic import SecretStr, BaseModel
|
1
|
+
from pydantic import SecretStr, BaseModel, Field
|
2
2
|
from typing import Optional
|
3
3
|
import os
|
4
4
|
|
@@ -6,7 +6,15 @@ import os
|
|
6
6
|
class FireworksCredentials(BaseModel):
|
7
7
|
"""Credentials for Fireworks AI API"""
|
8
8
|
|
9
|
-
fireworks_api_key: SecretStr
|
9
|
+
fireworks_api_key: SecretStr = Field(..., min_length=1)
|
10
|
+
|
11
|
+
def __repr__(self) -> str:
|
12
|
+
"""Return a string representation of the credentials."""
|
13
|
+
return f"FireworksCredentials(fireworks_api_key=SecretStr('**********'))"
|
14
|
+
|
15
|
+
def __str__(self) -> str:
|
16
|
+
"""Return a string representation of the credentials."""
|
17
|
+
return self.__repr__()
|
10
18
|
|
11
19
|
@classmethod
|
12
20
|
def from_env(cls) -> "FireworksCredentials":
|
@@ -0,0 +1,128 @@
|
|
1
|
+
from typing import Optional, List
|
2
|
+
import requests
|
3
|
+
from pydantic import Field
|
4
|
+
|
5
|
+
from airtrain.core.skills import Skill, ProcessingError
|
6
|
+
from airtrain.core.schemas import InputSchema, OutputSchema
|
7
|
+
from .credentials import FireworksCredentials
|
8
|
+
from .models import FireworksModel
|
9
|
+
|
10
|
+
|
11
|
+
class FireworksListModelsInput(InputSchema):
|
12
|
+
"""Schema for Fireworks AI list models input"""
|
13
|
+
|
14
|
+
account_id: str = Field(..., description="The Account Id")
|
15
|
+
page_size: Optional[int] = Field(
|
16
|
+
default=50,
|
17
|
+
description=(
|
18
|
+
"The maximum number of models to return. The maximum page_size is 200, "
|
19
|
+
"values above 200 will be coerced to 200."
|
20
|
+
),
|
21
|
+
le=200
|
22
|
+
)
|
23
|
+
page_token: Optional[str] = Field(
|
24
|
+
default=None,
|
25
|
+
description=(
|
26
|
+
"A page token, received from a previous ListModels call. Provide this "
|
27
|
+
"to retrieve the subsequent page. When paginating, all other parameters "
|
28
|
+
"provided to ListModels must match the call that provided the page token."
|
29
|
+
)
|
30
|
+
)
|
31
|
+
filter: Optional[str] = Field(
|
32
|
+
default=None,
|
33
|
+
description=(
|
34
|
+
"Only model satisfying the provided filter (if specified) will be "
|
35
|
+
"returned. See https://google.aip.dev/160 for the filter grammar."
|
36
|
+
)
|
37
|
+
)
|
38
|
+
order_by: Optional[str] = Field(
|
39
|
+
default=None,
|
40
|
+
description=(
|
41
|
+
"A comma-separated list of fields to order by. e.g. \"foo,bar\" "
|
42
|
+
"The default sort order is ascending. To specify a descending order for a "
|
43
|
+
"field, append a \" desc\" suffix. e.g. \"foo desc,bar\" "
|
44
|
+
"Subfields are specified with a \".\" character. e.g. \"foo.bar\" "
|
45
|
+
"If not specified, the default order is by \"name\"."
|
46
|
+
)
|
47
|
+
)
|
48
|
+
|
49
|
+
|
50
|
+
class FireworksListModelsOutput(OutputSchema):
|
51
|
+
"""Schema for Fireworks AI list models output"""
|
52
|
+
|
53
|
+
models: List[FireworksModel] = Field(
|
54
|
+
default_factory=list,
|
55
|
+
description="List of Fireworks models"
|
56
|
+
)
|
57
|
+
next_page_token: Optional[str] = Field(
|
58
|
+
default=None,
|
59
|
+
description="Token for retrieving the next page of results"
|
60
|
+
)
|
61
|
+
total_size: Optional[int] = Field(
|
62
|
+
default=None,
|
63
|
+
description="Total number of models available"
|
64
|
+
)
|
65
|
+
|
66
|
+
|
67
|
+
class FireworksListModelsSkill(
|
68
|
+
Skill[FireworksListModelsInput, FireworksListModelsOutput]
|
69
|
+
):
|
70
|
+
"""Skill for listing Fireworks AI models"""
|
71
|
+
|
72
|
+
input_schema = FireworksListModelsInput
|
73
|
+
output_schema = FireworksListModelsOutput
|
74
|
+
|
75
|
+
def __init__(self, credentials: Optional[FireworksCredentials] = None):
|
76
|
+
"""Initialize the skill with optional credentials"""
|
77
|
+
super().__init__()
|
78
|
+
self.credentials = credentials or FireworksCredentials.from_env()
|
79
|
+
self.base_url = "https://api.fireworks.ai/v1"
|
80
|
+
|
81
|
+
def process(
|
82
|
+
self, input_data: FireworksListModelsInput
|
83
|
+
) -> FireworksListModelsOutput:
|
84
|
+
"""Process the input and return a list of models."""
|
85
|
+
try:
|
86
|
+
# Build the URL
|
87
|
+
url = f"{self.base_url}/accounts/{input_data.account_id}/models"
|
88
|
+
|
89
|
+
# Prepare query parameters
|
90
|
+
params = {}
|
91
|
+
if input_data.page_size:
|
92
|
+
params["pageSize"] = input_data.page_size
|
93
|
+
if input_data.page_token:
|
94
|
+
params["pageToken"] = input_data.page_token
|
95
|
+
if input_data.filter:
|
96
|
+
params["filter"] = input_data.filter
|
97
|
+
if input_data.order_by:
|
98
|
+
params["orderBy"] = input_data.order_by
|
99
|
+
|
100
|
+
# Make the request
|
101
|
+
headers = {
|
102
|
+
"Authorization": (
|
103
|
+
f"Bearer {self.credentials.fireworks_api_key.get_secret_value()}"
|
104
|
+
)
|
105
|
+
}
|
106
|
+
|
107
|
+
response = requests.get(url, headers=headers, params=params)
|
108
|
+
response.raise_for_status()
|
109
|
+
|
110
|
+
# Parse the response
|
111
|
+
result = response.json()
|
112
|
+
|
113
|
+
# Convert the models to FireworksModel objects
|
114
|
+
models = []
|
115
|
+
for model_data in result.get("models", []):
|
116
|
+
models.append(FireworksModel(**model_data))
|
117
|
+
|
118
|
+
# Return the output
|
119
|
+
return FireworksListModelsOutput(
|
120
|
+
models=models,
|
121
|
+
next_page_token=result.get("nextPageToken"),
|
122
|
+
total_size=result.get("totalSize")
|
123
|
+
)
|
124
|
+
|
125
|
+
except requests.RequestException as e:
|
126
|
+
raise ProcessingError(f"Failed to list Fireworks models: {str(e)}")
|
127
|
+
except Exception as e:
|
128
|
+
raise ProcessingError(f"Error listing Fireworks models: {str(e)}")
|
@@ -25,3 +25,115 @@ class FireworksResponse(BaseModel):
|
|
25
25
|
created: int
|
26
26
|
model: str
|
27
27
|
usage: FireworksUsage
|
28
|
+
|
29
|
+
|
30
|
+
class FireworksModelStatus(BaseModel):
|
31
|
+
"""Schema for Fireworks model status"""
|
32
|
+
# This would be filled with actual fields from the API response
|
33
|
+
|
34
|
+
|
35
|
+
class FireworksModelBaseDetails(BaseModel):
|
36
|
+
"""Schema for Fireworks base model details"""
|
37
|
+
# This would be filled with actual fields from the API response
|
38
|
+
|
39
|
+
|
40
|
+
class FireworksPeftDetails(BaseModel):
|
41
|
+
"""Schema for Fireworks PEFT details"""
|
42
|
+
# This would be filled with actual fields from the API response
|
43
|
+
|
44
|
+
|
45
|
+
class FireworksConversationConfig(BaseModel):
|
46
|
+
"""Schema for Fireworks conversation configuration"""
|
47
|
+
# This would be filled with actual fields from the API response
|
48
|
+
|
49
|
+
|
50
|
+
class FireworksModelDeployedRef(BaseModel):
|
51
|
+
"""Schema for Fireworks deployed model reference"""
|
52
|
+
# This would be filled with actual fields from the API response
|
53
|
+
|
54
|
+
|
55
|
+
class FireworksDeprecationDate(BaseModel):
|
56
|
+
"""Schema for Fireworks deprecation date"""
|
57
|
+
# This would be filled with actual fields from the API response
|
58
|
+
|
59
|
+
|
60
|
+
class FireworksModel(BaseModel):
|
61
|
+
"""Schema for a Fireworks model"""
|
62
|
+
|
63
|
+
name: str
|
64
|
+
display_name: Optional[str] = None
|
65
|
+
description: Optional[str] = None
|
66
|
+
create_time: Optional[str] = None
|
67
|
+
created_by: Optional[str] = None
|
68
|
+
state: Optional[str] = None
|
69
|
+
status: Optional[Dict[str, Any]] = None
|
70
|
+
kind: Optional[str] = None
|
71
|
+
github_url: Optional[str] = None
|
72
|
+
hugging_face_url: Optional[str] = None
|
73
|
+
base_model_details: Optional[Dict[str, Any]] = None
|
74
|
+
peft_details: Optional[Dict[str, Any]] = None
|
75
|
+
teft_details: Optional[Dict[str, Any]] = None
|
76
|
+
public: Optional[bool] = None
|
77
|
+
conversation_config: Optional[Dict[str, Any]] = None
|
78
|
+
context_length: Optional[int] = None
|
79
|
+
supports_image_input: Optional[bool] = None
|
80
|
+
supports_tools: Optional[bool] = None
|
81
|
+
imported_from: Optional[str] = None
|
82
|
+
fine_tuning_job: Optional[str] = None
|
83
|
+
default_draft_model: Optional[str] = None
|
84
|
+
default_draft_token_count: Optional[int] = None
|
85
|
+
precisions: Optional[List[str]] = None
|
86
|
+
deployed_model_refs: Optional[List[Dict[str, Any]]] = None
|
87
|
+
cluster: Optional[str] = None
|
88
|
+
deprecation_date: Optional[Dict[str, Any]] = None
|
89
|
+
calibrated: Optional[bool] = None
|
90
|
+
tunable: Optional[bool] = None
|
91
|
+
supports_lora: Optional[bool] = None
|
92
|
+
use_hf_apply_chat_template: Optional[bool] = None
|
93
|
+
|
94
|
+
|
95
|
+
class ListModelsInput(BaseModel):
|
96
|
+
"""Schema for listing Fireworks models input"""
|
97
|
+
|
98
|
+
account_id: str = Field(..., description="The Account Id")
|
99
|
+
page_size: Optional[int] = Field(
|
100
|
+
default=50,
|
101
|
+
description=(
|
102
|
+
"The maximum number of models to return. The maximum page_size is 200, "
|
103
|
+
"values above 200 will be coerced to 200."
|
104
|
+
),
|
105
|
+
le=200
|
106
|
+
)
|
107
|
+
page_token: Optional[str] = Field(
|
108
|
+
default=None,
|
109
|
+
description=(
|
110
|
+
"A page token, received from a previous ListModels call. Provide this "
|
111
|
+
"to retrieve the subsequent page. When paginating, all other parameters "
|
112
|
+
"provided to ListModels must match the call that provided the page token."
|
113
|
+
)
|
114
|
+
)
|
115
|
+
filter: Optional[str] = Field(
|
116
|
+
default=None,
|
117
|
+
description=(
|
118
|
+
"Only model satisfying the provided filter (if specified) will be "
|
119
|
+
"returned. See https://google.aip.dev/160 for the filter grammar."
|
120
|
+
)
|
121
|
+
)
|
122
|
+
order_by: Optional[str] = Field(
|
123
|
+
default=None,
|
124
|
+
description=(
|
125
|
+
"A comma-separated list of fields to order by. e.g. \"foo,bar\" "
|
126
|
+
"The default sort order is ascending. To specify a descending order for a "
|
127
|
+
"field, append a \" desc\" suffix. e.g. \"foo desc,bar\" "
|
128
|
+
"Subfields are specified with a \".\" character. e.g. \"foo.bar\" "
|
129
|
+
"If not specified, the default order is by \"name\"."
|
130
|
+
)
|
131
|
+
)
|
132
|
+
|
133
|
+
|
134
|
+
class ListModelsOutput(BaseModel):
|
135
|
+
"""Schema for listing Fireworks models output"""
|
136
|
+
|
137
|
+
models: List[FireworksModel]
|
138
|
+
next_page_token: Optional[str] = None
|
139
|
+
total_size: Optional[int] = None
|
@@ -1,14 +1,10 @@
|
|
1
|
-
from typing import List, Optional, Dict, Any, Generator
|
1
|
+
from typing import List, Optional, Dict, Any, Generator, Union
|
2
2
|
from pydantic import Field
|
3
|
-
import requests
|
4
|
-
from loguru import logger
|
5
3
|
from openai import OpenAI
|
6
|
-
from openai.types.chat import ChatCompletionChunk
|
7
4
|
|
8
5
|
from airtrain.core.skills import Skill, ProcessingError
|
9
6
|
from airtrain.core.schemas import InputSchema, OutputSchema
|
10
7
|
from .credentials import FireworksCredentials
|
11
|
-
from .models import FireworksMessage, FireworksResponse
|
12
8
|
|
13
9
|
|
14
10
|
class FireworksInput(InputSchema):
|
@@ -19,9 +15,9 @@ class FireworksInput(InputSchema):
|
|
19
15
|
default="You are a helpful assistant.",
|
20
16
|
description="System prompt to guide the model's behavior",
|
21
17
|
)
|
22
|
-
conversation_history: List[Dict[str,
|
18
|
+
conversation_history: List[Dict[str, Any]] = Field(
|
23
19
|
default_factory=list,
|
24
|
-
description="List of previous conversation messages
|
20
|
+
description="List of previous conversation messages",
|
25
21
|
)
|
26
22
|
model: str = Field(
|
27
23
|
default="accounts/fireworks/models/deepseek-r1",
|
@@ -40,6 +36,20 @@ class FireworksInput(InputSchema):
|
|
40
36
|
default=False,
|
41
37
|
description="Whether to stream the response token by token",
|
42
38
|
)
|
39
|
+
tools: Optional[List[Dict[str, Any]]] = Field(
|
40
|
+
default=None,
|
41
|
+
description=(
|
42
|
+
"A list of tools the model may use. "
|
43
|
+
"Currently only functions supported."
|
44
|
+
),
|
45
|
+
)
|
46
|
+
tool_choice: Optional[Union[str, Dict[str, Any]]] = Field(
|
47
|
+
default=None,
|
48
|
+
description=(
|
49
|
+
"Controls which tool is called by the model. "
|
50
|
+
"'none', 'auto', or specific tool."
|
51
|
+
),
|
52
|
+
)
|
43
53
|
|
44
54
|
|
45
55
|
class FireworksOutput(OutputSchema):
|
@@ -48,6 +58,9 @@ class FireworksOutput(OutputSchema):
|
|
48
58
|
response: str = Field(..., description="Model's response text")
|
49
59
|
used_model: str = Field(..., description="Model used for generation")
|
50
60
|
usage: Dict[str, int] = Field(default_factory=dict, description="Usage statistics")
|
61
|
+
tool_calls: Optional[List[Dict[str, Any]]] = Field(
|
62
|
+
default=None, description="Tool calls generated by the model"
|
63
|
+
)
|
51
64
|
|
52
65
|
|
53
66
|
class FireworksChatSkill(Skill[FireworksInput, FireworksOutput]):
|
@@ -65,7 +78,7 @@ class FireworksChatSkill(Skill[FireworksInput, FireworksOutput]):
|
|
65
78
|
api_key=self.credentials.fireworks_api_key.get_secret_value(),
|
66
79
|
)
|
67
80
|
|
68
|
-
def _build_messages(self, input_data: FireworksInput) -> List[Dict[str,
|
81
|
+
def _build_messages(self, input_data: FireworksInput) -> List[Dict[str, Any]]:
|
69
82
|
"""Build messages list from input data including conversation history."""
|
70
83
|
messages = [{"role": "system", "content": input_data.system_prompt}]
|
71
84
|
|
@@ -104,8 +117,8 @@ class FireworksChatSkill(Skill[FireworksInput, FireworksOutput]):
|
|
104
117
|
for chunk in self.process_stream(input_data):
|
105
118
|
response_chunks.append(chunk)
|
106
119
|
response = "".join(response_chunks)
|
107
|
-
|
108
|
-
#
|
120
|
+
|
121
|
+
# Create completion object for usage stats
|
109
122
|
messages = self._build_messages(input_data)
|
110
123
|
completion = self.client.chat.completions.create(
|
111
124
|
model=input_data.model,
|
@@ -114,7 +127,44 @@ class FireworksChatSkill(Skill[FireworksInput, FireworksOutput]):
|
|
114
127
|
max_tokens=input_data.max_tokens,
|
115
128
|
stream=False,
|
116
129
|
)
|
117
|
-
|
130
|
+
else:
|
131
|
+
# For non-streaming, use regular completion
|
132
|
+
messages = self._build_messages(input_data)
|
133
|
+
|
134
|
+
# Prepare API call parameters
|
135
|
+
api_params = {
|
136
|
+
"model": input_data.model,
|
137
|
+
"messages": messages,
|
138
|
+
"temperature": input_data.temperature,
|
139
|
+
"max_tokens": input_data.max_tokens,
|
140
|
+
"stream": False,
|
141
|
+
}
|
142
|
+
|
143
|
+
# Add tools and tool_choice if provided
|
144
|
+
if input_data.tools:
|
145
|
+
api_params["tools"] = input_data.tools
|
146
|
+
|
147
|
+
if input_data.tool_choice:
|
148
|
+
api_params["tool_choice"] = input_data.tool_choice
|
149
|
+
|
150
|
+
completion = self.client.chat.completions.create(**api_params)
|
151
|
+
response = completion.choices[0].message.content or ""
|
152
|
+
|
153
|
+
# Check for tool calls in the response
|
154
|
+
tool_calls = None
|
155
|
+
if (hasattr(completion.choices[0].message, "tool_calls") and
|
156
|
+
completion.choices[0].message.tool_calls):
|
157
|
+
tool_calls = [
|
158
|
+
{
|
159
|
+
"id": tool_call.id,
|
160
|
+
"type": tool_call.type,
|
161
|
+
"function": {
|
162
|
+
"name": tool_call.function.name,
|
163
|
+
"arguments": tool_call.function.arguments
|
164
|
+
}
|
165
|
+
}
|
166
|
+
for tool_call in completion.choices[0].message.tool_calls
|
167
|
+
]
|
118
168
|
|
119
169
|
return FireworksOutput(
|
120
170
|
response=response,
|
@@ -124,6 +174,7 @@ class FireworksChatSkill(Skill[FireworksInput, FireworksOutput]):
|
|
124
174
|
"prompt_tokens": completion.usage.prompt_tokens,
|
125
175
|
"completion_tokens": completion.usage.completion_tokens,
|
126
176
|
},
|
177
|
+
tool_calls=tool_calls
|
127
178
|
)
|
128
179
|
|
129
180
|
except Exception as e:
|
@@ -1,14 +1,13 @@
|
|
1
|
-
from typing import
|
1
|
+
from typing import Any, Dict, Generator, List, Optional, Type, TypeVar
|
2
2
|
from pydantic import BaseModel, Field
|
3
3
|
import requests
|
4
4
|
import json
|
5
|
-
from loguru import logger
|
6
5
|
|
7
6
|
from airtrain.core.skills import Skill, ProcessingError
|
8
7
|
from airtrain.core.schemas import InputSchema, OutputSchema
|
9
8
|
from .credentials import FireworksCredentials
|
10
9
|
|
11
|
-
ResponseT = TypeVar("ResponseT"
|
10
|
+
ResponseT = TypeVar("ResponseT")
|
12
11
|
|
13
12
|
|
14
13
|
class FireworksStructuredCompletionInput(InputSchema):
|
@@ -26,7 +25,7 @@ class FireworksStructuredCompletionInput(InputSchema):
|
|
26
25
|
response_model: Type[ResponseT]
|
27
26
|
stream: bool = Field(
|
28
27
|
default=False,
|
29
|
-
description="Whether to stream the response",
|
28
|
+
description="Whether to stream the response token by token",
|
30
29
|
)
|
31
30
|
|
32
31
|
class Config:
|
@@ -39,6 +38,13 @@ class FireworksStructuredCompletionOutput(OutputSchema):
|
|
39
38
|
parsed_response: Any
|
40
39
|
used_model: str
|
41
40
|
usage: Dict[str, int]
|
41
|
+
tool_calls: Optional[List[Dict[str, Any]]] = Field(
|
42
|
+
default=None,
|
43
|
+
description=(
|
44
|
+
"Tool calls are not applicable for completions, "
|
45
|
+
"included for compatibility"
|
46
|
+
)
|
47
|
+
)
|
42
48
|
|
43
49
|
|
44
50
|
class FireworksStructuredCompletionSkill(
|
@@ -1,5 +1,5 @@
|
|
1
|
-
from typing import Type, TypeVar, Optional, List, Dict, Any, Generator
|
2
|
-
from pydantic import BaseModel, Field
|
1
|
+
from typing import Type, TypeVar, Optional, List, Dict, Any, Generator, Union
|
2
|
+
from pydantic import BaseModel, Field, create_model
|
3
3
|
import requests
|
4
4
|
import json
|
5
5
|
from loguru import logger
|
@@ -20,7 +20,7 @@ class FireworksStructuredRequestInput(InputSchema):
|
|
20
20
|
default="You are a helpful assistant that provides structured data.",
|
21
21
|
description="System prompt to guide the model's behavior",
|
22
22
|
)
|
23
|
-
conversation_history: List[Dict[str,
|
23
|
+
conversation_history: List[Dict[str, Any]] = Field(
|
24
24
|
default_factory=list,
|
25
25
|
description="List of previous conversation messages",
|
26
26
|
)
|
@@ -34,8 +34,21 @@ class FireworksStructuredRequestInput(InputSchema):
|
|
34
34
|
max_tokens: int = Field(default=4096, description="Maximum tokens in response")
|
35
35
|
response_model: Type[ResponseT]
|
36
36
|
stream: bool = Field(
|
37
|
-
default=False,
|
38
|
-
|
37
|
+
default=False, description="Whether to stream the response token by token"
|
38
|
+
)
|
39
|
+
tools: Optional[List[Dict[str, Any]]] = Field(
|
40
|
+
default=None,
|
41
|
+
description=(
|
42
|
+
"A list of tools the model may use. "
|
43
|
+
"Currently only functions supported."
|
44
|
+
),
|
45
|
+
)
|
46
|
+
tool_choice: Optional[Union[str, Dict[str, Any]]] = Field(
|
47
|
+
default=None,
|
48
|
+
description=(
|
49
|
+
"Controls which tool is called by the model. "
|
50
|
+
"'none', 'auto', or specific tool."
|
51
|
+
),
|
39
52
|
)
|
40
53
|
|
41
54
|
class Config:
|
@@ -49,6 +62,9 @@ class FireworksStructuredRequestOutput(OutputSchema):
|
|
49
62
|
used_model: str
|
50
63
|
usage: Dict[str, int]
|
51
64
|
reasoning: Optional[str] = None
|
65
|
+
tool_calls: Optional[List[Dict[str, Any]]] = Field(
|
66
|
+
default=None, description="Tool calls generated by the model"
|
67
|
+
)
|
52
68
|
|
53
69
|
|
54
70
|
class FireworksStructuredRequestSkill(
|
@@ -72,7 +88,7 @@ class FireworksStructuredRequestSkill(
|
|
72
88
|
|
73
89
|
def _build_messages(
|
74
90
|
self, input_data: FireworksStructuredRequestInput
|
75
|
-
) -> List[Dict[str,
|
91
|
+
) -> List[Dict[str, Any]]:
|
76
92
|
"""Build messages list from input data including conversation history."""
|
77
93
|
messages = [{"role": "system", "content": input_data.system_prompt}]
|
78
94
|
|
@@ -86,24 +102,24 @@ class FireworksStructuredRequestSkill(
|
|
86
102
|
self, input_data: FireworksStructuredRequestInput
|
87
103
|
) -> Dict[str, Any]:
|
88
104
|
"""Build the request payload."""
|
89
|
-
|
105
|
+
payload = {
|
90
106
|
"model": input_data.model,
|
91
107
|
"messages": self._build_messages(input_data),
|
92
108
|
"temperature": input_data.temperature,
|
93
109
|
"max_tokens": input_data.max_tokens,
|
94
110
|
"stream": input_data.stream,
|
95
|
-
"response_format": {
|
96
|
-
"type": "json_object",
|
97
|
-
"schema": {
|
98
|
-
**input_data.response_model.model_json_schema(),
|
99
|
-
"required": [
|
100
|
-
field
|
101
|
-
for field, _ in input_data.response_model.model_fields.items()
|
102
|
-
],
|
103
|
-
},
|
104
|
-
},
|
111
|
+
"response_format": {"type": "json_object"},
|
105
112
|
}
|
106
113
|
|
114
|
+
# Add tool-related parameters if provided
|
115
|
+
if input_data.tools:
|
116
|
+
payload["tools"] = input_data.tools
|
117
|
+
|
118
|
+
if input_data.tool_choice:
|
119
|
+
payload["tool_choice"] = input_data.tool_choice
|
120
|
+
|
121
|
+
return payload
|
122
|
+
|
107
123
|
def process_stream(
|
108
124
|
self, input_data: FireworksStructuredRequestInput
|
109
125
|
) -> Generator[Dict[str, Any], None, None]:
|
@@ -131,6 +147,10 @@ class FireworksStructuredRequestSkill(
|
|
131
147
|
continue
|
132
148
|
|
133
149
|
# Once complete, parse the full response with think tags
|
150
|
+
if not json_buffer:
|
151
|
+
# If no data was collected, raise error
|
152
|
+
raise ProcessingError("No data received from Fireworks API")
|
153
|
+
|
134
154
|
complete_response = "".join(json_buffer)
|
135
155
|
reasoning, json_str = self._parse_response_content(complete_response)
|
136
156
|
|
@@ -177,37 +197,94 @@ class FireworksStructuredRequestSkill(
|
|
177
197
|
|
178
198
|
if parsed_response is None:
|
179
199
|
raise ProcessingError("Failed to parse streamed response")
|
200
|
+
|
201
|
+
# Make a non-streaming call to get tool calls if tools were provided
|
202
|
+
tool_calls = None
|
203
|
+
if input_data.tools:
|
204
|
+
# Create a non-streaming request to get tool calls
|
205
|
+
non_stream_payload = self._build_payload(input_data)
|
206
|
+
non_stream_payload["stream"] = False
|
207
|
+
|
208
|
+
response = requests.post(
|
209
|
+
self.BASE_URL,
|
210
|
+
headers=self.headers,
|
211
|
+
data=json.dumps(non_stream_payload),
|
212
|
+
)
|
213
|
+
response.raise_for_status()
|
214
|
+
result = response.json()
|
215
|
+
|
216
|
+
# Check for tool calls
|
217
|
+
if (result["choices"][0]["message"].get("tool_calls")):
|
218
|
+
tool_calls = [
|
219
|
+
{
|
220
|
+
"id": tool_call["id"],
|
221
|
+
"type": tool_call["type"],
|
222
|
+
"function": {
|
223
|
+
"name": tool_call["function"]["name"],
|
224
|
+
"arguments": tool_call["function"]["arguments"]
|
225
|
+
}
|
226
|
+
}
|
227
|
+
for tool_call in result["choices"][0]["message"]["tool_calls"]
|
228
|
+
]
|
180
229
|
|
181
230
|
return FireworksStructuredRequestOutput(
|
182
231
|
parsed_response=parsed_response,
|
183
232
|
used_model=input_data.model,
|
184
|
-
usage={}, #
|
233
|
+
usage={"total_tokens": 0}, # Can't get usage stats from streaming
|
185
234
|
reasoning=reasoning,
|
235
|
+
tool_calls=tool_calls,
|
186
236
|
)
|
187
237
|
else:
|
188
238
|
# For non-streaming, use regular request
|
189
239
|
payload = self._build_payload(input_data)
|
240
|
+
payload["stream"] = False # Ensure it's not streaming
|
241
|
+
|
190
242
|
response = requests.post(
|
191
243
|
self.BASE_URL, headers=self.headers, data=json.dumps(payload)
|
192
244
|
)
|
193
245
|
response.raise_for_status()
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
)
|
246
|
+
result = response.json()
|
247
|
+
|
248
|
+
# Get the content from the response
|
249
|
+
if "choices" not in result or not result["choices"]:
|
250
|
+
raise ProcessingError("Invalid response format from Fireworks API")
|
251
|
+
|
252
|
+
content = result["choices"][0]["message"].get("content", "")
|
253
|
+
|
254
|
+
# Check for tool calls
|
255
|
+
tool_calls = None
|
256
|
+
if (result["choices"][0]["message"].get("tool_calls")):
|
257
|
+
tool_calls = [
|
258
|
+
{
|
259
|
+
"id": tool_call["id"],
|
260
|
+
"type": tool_call["type"],
|
261
|
+
"function": {
|
262
|
+
"name": tool_call["function"]["name"],
|
263
|
+
"arguments": tool_call["function"]["arguments"]
|
264
|
+
}
|
265
|
+
}
|
266
|
+
for tool_call in result["choices"][0]["message"]["tool_calls"]
|
267
|
+
]
|
268
|
+
|
269
|
+
# Parse the response content
|
270
|
+
reasoning, json_str = self._parse_response_content(content)
|
271
|
+
try:
|
272
|
+
parsed_response = input_data.response_model.model_validate_json(
|
273
|
+
json_str
|
274
|
+
)
|
275
|
+
except Exception as e:
|
276
|
+
raise ProcessingError(f"Failed to parse JSON response: {str(e)}")
|
205
277
|
|
206
278
|
return FireworksStructuredRequestOutput(
|
207
279
|
parsed_response=parsed_response,
|
208
280
|
used_model=input_data.model,
|
209
|
-
usage=
|
210
|
-
|
281
|
+
usage={
|
282
|
+
"total_tokens": result["usage"]["total_tokens"],
|
283
|
+
"prompt_tokens": result["usage"]["prompt_tokens"],
|
284
|
+
"completion_tokens": result["usage"]["completion_tokens"],
|
285
|
+
},
|
286
|
+
reasoning=reasoning,
|
287
|
+
tool_calls=tool_calls,
|
211
288
|
)
|
212
289
|
|
213
290
|
except Exception as e:
|
@@ -5,6 +5,9 @@ from .skills import (
|
|
5
5
|
OpenAIOutput,
|
6
6
|
OpenAIParserInput,
|
7
7
|
OpenAIParserOutput,
|
8
|
+
OpenAIEmbeddingsSkill,
|
9
|
+
OpenAIEmbeddingsInput,
|
10
|
+
OpenAIEmbeddingsOutput,
|
8
11
|
)
|
9
12
|
from .credentials import OpenAICredentials
|
10
13
|
|
@@ -16,4 +19,7 @@ __all__ = [
|
|
16
19
|
"OpenAIParserOutput",
|
17
20
|
"OpenAICredentials",
|
18
21
|
"OpenAIOutput",
|
22
|
+
"OpenAIEmbeddingsSkill",
|
23
|
+
"OpenAIEmbeddingsInput",
|
24
|
+
"OpenAIEmbeddingsOutput",
|
19
25
|
]
|
@@ -11,6 +11,20 @@ class OpenAIModelConfig(NamedTuple):
|
|
11
11
|
|
12
12
|
|
13
13
|
OPENAI_MODELS: Dict[str, OpenAIModelConfig] = {
|
14
|
+
"gpt-4.5-preview": OpenAIModelConfig(
|
15
|
+
display_name="GPT-4.5 Preview",
|
16
|
+
base_model="gpt-4.5-preview",
|
17
|
+
input_price=Decimal("75.00"),
|
18
|
+
cached_input_price=Decimal("37.50"),
|
19
|
+
output_price=Decimal("150.00"),
|
20
|
+
),
|
21
|
+
"gpt-4.5-preview-2025-02-27": OpenAIModelConfig(
|
22
|
+
display_name="GPT-4.5 Preview (2025-02-27)",
|
23
|
+
base_model="gpt-4.5-preview",
|
24
|
+
input_price=Decimal("75.00"),
|
25
|
+
cached_input_price=Decimal("37.50"),
|
26
|
+
output_price=Decimal("150.00"),
|
27
|
+
),
|
14
28
|
"gpt-4o": OpenAIModelConfig(
|
15
29
|
display_name="GPT-4 Optimized",
|
16
30
|
base_model="gpt-4o",
|
@@ -25,69 +39,160 @@ OPENAI_MODELS: Dict[str, OpenAIModelConfig] = {
|
|
25
39
|
cached_input_price=Decimal("1.25"),
|
26
40
|
output_price=Decimal("10.00"),
|
27
41
|
),
|
28
|
-
"gpt-4o-
|
29
|
-
display_name="GPT-4 Optimized
|
30
|
-
base_model="gpt-4o",
|
31
|
-
input_price=Decimal("
|
42
|
+
"gpt-4o-audio-preview": OpenAIModelConfig(
|
43
|
+
display_name="GPT-4 Optimized Audio Preview",
|
44
|
+
base_model="gpt-4o-audio-preview",
|
45
|
+
input_price=Decimal("2.50"),
|
32
46
|
cached_input_price=None,
|
33
|
-
output_price=Decimal("
|
47
|
+
output_price=Decimal("10.00"),
|
34
48
|
),
|
35
49
|
"gpt-4o-audio-preview-2024-12-17": OpenAIModelConfig(
|
36
|
-
display_name="GPT-4 Optimized Audio Preview",
|
50
|
+
display_name="GPT-4 Optimized Audio Preview (2024-12-17)",
|
37
51
|
base_model="gpt-4o-audio-preview",
|
38
52
|
input_price=Decimal("2.50"),
|
39
53
|
cached_input_price=None,
|
40
54
|
output_price=Decimal("10.00"),
|
41
55
|
),
|
42
|
-
"gpt-4o-realtime-preview
|
56
|
+
"gpt-4o-realtime-preview": OpenAIModelConfig(
|
43
57
|
display_name="GPT-4 Optimized Realtime Preview",
|
44
58
|
base_model="gpt-4o-realtime-preview",
|
45
59
|
input_price=Decimal("5.00"),
|
46
60
|
cached_input_price=Decimal("2.50"),
|
47
61
|
output_price=Decimal("20.00"),
|
48
62
|
),
|
49
|
-
"gpt-4o-
|
63
|
+
"gpt-4o-realtime-preview-2024-12-17": OpenAIModelConfig(
|
64
|
+
display_name="GPT-4 Optimized Realtime Preview (2024-12-17)",
|
65
|
+
base_model="gpt-4o-realtime-preview",
|
66
|
+
input_price=Decimal("5.00"),
|
67
|
+
cached_input_price=Decimal("2.50"),
|
68
|
+
output_price=Decimal("20.00"),
|
69
|
+
),
|
70
|
+
"gpt-4o-mini": OpenAIModelConfig(
|
50
71
|
display_name="GPT-4 Optimized Mini",
|
51
72
|
base_model="gpt-4o-mini",
|
52
73
|
input_price=Decimal("0.15"),
|
53
74
|
cached_input_price=Decimal("0.075"),
|
54
75
|
output_price=Decimal("0.60"),
|
55
76
|
),
|
56
|
-
"gpt-4o-mini-
|
77
|
+
"gpt-4o-mini-2024-07-18": OpenAIModelConfig(
|
78
|
+
display_name="GPT-4 Optimized Mini (2024-07-18)",
|
79
|
+
base_model="gpt-4o-mini",
|
80
|
+
input_price=Decimal("0.15"),
|
81
|
+
cached_input_price=Decimal("0.075"),
|
82
|
+
output_price=Decimal("0.60"),
|
83
|
+
),
|
84
|
+
"gpt-4o-mini-audio-preview": OpenAIModelConfig(
|
57
85
|
display_name="GPT-4 Optimized Mini Audio Preview",
|
58
86
|
base_model="gpt-4o-mini-audio-preview",
|
59
87
|
input_price=Decimal("0.15"),
|
60
88
|
cached_input_price=None,
|
61
89
|
output_price=Decimal("0.60"),
|
62
90
|
),
|
63
|
-
"gpt-4o-mini-
|
91
|
+
"gpt-4o-mini-audio-preview-2024-12-17": OpenAIModelConfig(
|
92
|
+
display_name="GPT-4 Optimized Mini Audio Preview (2024-12-17)",
|
93
|
+
base_model="gpt-4o-mini-audio-preview",
|
94
|
+
input_price=Decimal("0.15"),
|
95
|
+
cached_input_price=None,
|
96
|
+
output_price=Decimal("0.60"),
|
97
|
+
),
|
98
|
+
"gpt-4o-mini-realtime-preview": OpenAIModelConfig(
|
64
99
|
display_name="GPT-4 Optimized Mini Realtime Preview",
|
65
100
|
base_model="gpt-4o-mini-realtime-preview",
|
66
101
|
input_price=Decimal("0.60"),
|
67
102
|
cached_input_price=Decimal("0.30"),
|
68
103
|
output_price=Decimal("2.40"),
|
69
104
|
),
|
70
|
-
"
|
105
|
+
"gpt-4o-mini-realtime-preview-2024-12-17": OpenAIModelConfig(
|
106
|
+
display_name="GPT-4 Optimized Mini Realtime Preview (2024-12-17)",
|
107
|
+
base_model="gpt-4o-mini-realtime-preview",
|
108
|
+
input_price=Decimal("0.60"),
|
109
|
+
cached_input_price=Decimal("0.30"),
|
110
|
+
output_price=Decimal("2.40"),
|
111
|
+
),
|
112
|
+
"o1": OpenAIModelConfig(
|
71
113
|
display_name="O1",
|
72
114
|
base_model="o1",
|
73
115
|
input_price=Decimal("15.00"),
|
74
116
|
cached_input_price=Decimal("7.50"),
|
75
117
|
output_price=Decimal("60.00"),
|
76
118
|
),
|
77
|
-
"
|
119
|
+
"o1-2024-12-17": OpenAIModelConfig(
|
120
|
+
display_name="O1 (2024-12-17)",
|
121
|
+
base_model="o1",
|
122
|
+
input_price=Decimal("15.00"),
|
123
|
+
cached_input_price=Decimal("7.50"),
|
124
|
+
output_price=Decimal("60.00"),
|
125
|
+
),
|
126
|
+
"o3-mini": OpenAIModelConfig(
|
78
127
|
display_name="O3 Mini",
|
79
128
|
base_model="o3-mini",
|
80
129
|
input_price=Decimal("1.10"),
|
81
130
|
cached_input_price=Decimal("0.55"),
|
82
131
|
output_price=Decimal("4.40"),
|
83
132
|
),
|
84
|
-
"
|
133
|
+
"o3-mini-2025-01-31": OpenAIModelConfig(
|
134
|
+
display_name="O3 Mini (2025-01-31)",
|
135
|
+
base_model="o3-mini",
|
136
|
+
input_price=Decimal("1.10"),
|
137
|
+
cached_input_price=Decimal("0.55"),
|
138
|
+
output_price=Decimal("4.40"),
|
139
|
+
),
|
140
|
+
"o1-mini": OpenAIModelConfig(
|
85
141
|
display_name="O1 Mini",
|
86
142
|
base_model="o1-mini",
|
87
143
|
input_price=Decimal("1.10"),
|
88
144
|
cached_input_price=Decimal("0.55"),
|
89
145
|
output_price=Decimal("4.40"),
|
90
146
|
),
|
147
|
+
"o1-mini-2024-09-12": OpenAIModelConfig(
|
148
|
+
display_name="O1 Mini (2024-09-12)",
|
149
|
+
base_model="o1-mini",
|
150
|
+
input_price=Decimal("1.10"),
|
151
|
+
cached_input_price=Decimal("0.55"),
|
152
|
+
output_price=Decimal("4.40"),
|
153
|
+
),
|
154
|
+
"gpt-4o-mini-search-preview": OpenAIModelConfig(
|
155
|
+
display_name="GPT-4 Optimized Mini Search Preview",
|
156
|
+
base_model="gpt-4o-mini-search-preview",
|
157
|
+
input_price=Decimal("0.15"),
|
158
|
+
cached_input_price=None,
|
159
|
+
output_price=Decimal("0.60"),
|
160
|
+
),
|
161
|
+
"gpt-4o-mini-search-preview-2025-03-11": OpenAIModelConfig(
|
162
|
+
display_name="GPT-4 Optimized Mini Search Preview (2025-03-11)",
|
163
|
+
base_model="gpt-4o-mini-search-preview",
|
164
|
+
input_price=Decimal("0.15"),
|
165
|
+
cached_input_price=None,
|
166
|
+
output_price=Decimal("0.60"),
|
167
|
+
),
|
168
|
+
"gpt-4o-search-preview": OpenAIModelConfig(
|
169
|
+
display_name="GPT-4 Optimized Search Preview",
|
170
|
+
base_model="gpt-4o-search-preview",
|
171
|
+
input_price=Decimal("2.50"),
|
172
|
+
cached_input_price=None,
|
173
|
+
output_price=Decimal("10.00"),
|
174
|
+
),
|
175
|
+
"gpt-4o-search-preview-2025-03-11": OpenAIModelConfig(
|
176
|
+
display_name="GPT-4 Optimized Search Preview (2025-03-11)",
|
177
|
+
base_model="gpt-4o-search-preview",
|
178
|
+
input_price=Decimal("2.50"),
|
179
|
+
cached_input_price=None,
|
180
|
+
output_price=Decimal("10.00"),
|
181
|
+
),
|
182
|
+
"computer-use-preview": OpenAIModelConfig(
|
183
|
+
display_name="Computer Use Preview",
|
184
|
+
base_model="computer-use-preview",
|
185
|
+
input_price=Decimal("3.00"),
|
186
|
+
cached_input_price=None,
|
187
|
+
output_price=Decimal("12.00"),
|
188
|
+
),
|
189
|
+
"computer-use-preview-2025-03-11": OpenAIModelConfig(
|
190
|
+
display_name="Computer Use Preview (2025-03-11)",
|
191
|
+
base_model="computer-use-preview",
|
192
|
+
input_price=Decimal("3.00"),
|
193
|
+
cached_input_price=None,
|
194
|
+
output_price=Decimal("12.00"),
|
195
|
+
),
|
91
196
|
}
|
92
197
|
|
93
198
|
|
@@ -1,7 +1,8 @@
|
|
1
|
-
from typing import AsyncGenerator, List, Optional, Dict, TypeVar, Type, Generator
|
1
|
+
from typing import AsyncGenerator, List, Optional, Dict, TypeVar, Type, Generator, Union
|
2
2
|
from pydantic import Field, BaseModel
|
3
3
|
from openai import OpenAI, AsyncOpenAI
|
4
4
|
from openai.types.chat import ChatCompletionChunk
|
5
|
+
import numpy as np
|
5
6
|
|
6
7
|
from airtrain.core.skills import Skill, ProcessingError
|
7
8
|
from airtrain.core.schemas import InputSchema, OutputSchema
|
@@ -232,3 +233,110 @@ class OpenAIParserSkill(Skill[OpenAIParserInput, OpenAIParserOutput]):
|
|
232
233
|
|
233
234
|
except Exception as e:
|
234
235
|
raise ProcessingError(f"OpenAI parsing failed: {str(e)}")
|
236
|
+
|
237
|
+
|
238
|
+
class OpenAIEmbeddingsInput(InputSchema):
|
239
|
+
"""Schema for OpenAI embeddings input"""
|
240
|
+
|
241
|
+
texts: Union[str, List[str]] = Field(
|
242
|
+
..., description="Text or list of texts to generate embeddings for"
|
243
|
+
)
|
244
|
+
model: str = Field(
|
245
|
+
default="text-embedding-3-large", description="OpenAI embeddings model to use"
|
246
|
+
)
|
247
|
+
encoding_format: str = Field(
|
248
|
+
default="float", description="The format of the embeddings: 'float' or 'base64'"
|
249
|
+
)
|
250
|
+
dimensions: Optional[int] = Field(
|
251
|
+
default=None, description="Optional number of dimensions for the embeddings"
|
252
|
+
)
|
253
|
+
|
254
|
+
|
255
|
+
class OpenAIEmbeddingsOutput(OutputSchema):
|
256
|
+
"""Schema for OpenAI embeddings output"""
|
257
|
+
|
258
|
+
embeddings: List[List[float]] = Field(..., description="List of embeddings vectors")
|
259
|
+
used_model: str = Field(..., description="Model used for generating embeddings")
|
260
|
+
tokens_used: int = Field(..., description="Number of tokens used")
|
261
|
+
|
262
|
+
|
263
|
+
class OpenAIEmbeddingsSkill(Skill[OpenAIEmbeddingsInput, OpenAIEmbeddingsOutput]):
|
264
|
+
"""Skill for generating embeddings using OpenAI models"""
|
265
|
+
|
266
|
+
input_schema = OpenAIEmbeddingsInput
|
267
|
+
output_schema = OpenAIEmbeddingsOutput
|
268
|
+
|
269
|
+
def __init__(self, credentials: Optional[OpenAICredentials] = None):
|
270
|
+
"""Initialize the skill with optional credentials"""
|
271
|
+
super().__init__()
|
272
|
+
self.credentials = credentials or OpenAICredentials.from_env()
|
273
|
+
self.client = OpenAI(
|
274
|
+
api_key=self.credentials.openai_api_key.get_secret_value(),
|
275
|
+
organization=self.credentials.openai_organization_id,
|
276
|
+
)
|
277
|
+
self.async_client = AsyncOpenAI(
|
278
|
+
api_key=self.credentials.openai_api_key.get_secret_value(),
|
279
|
+
organization=self.credentials.openai_organization_id,
|
280
|
+
)
|
281
|
+
|
282
|
+
def process(self, input_data: OpenAIEmbeddingsInput) -> OpenAIEmbeddingsOutput:
|
283
|
+
"""Generate embeddings for the input text(s)"""
|
284
|
+
try:
|
285
|
+
# Handle single text input
|
286
|
+
texts = (
|
287
|
+
[input_data.texts]
|
288
|
+
if isinstance(input_data.texts, str)
|
289
|
+
else input_data.texts
|
290
|
+
)
|
291
|
+
|
292
|
+
# Create embeddings
|
293
|
+
response = self.client.embeddings.create(
|
294
|
+
model=input_data.model,
|
295
|
+
input=texts,
|
296
|
+
encoding_format=input_data.encoding_format,
|
297
|
+
dimensions=input_data.dimensions,
|
298
|
+
)
|
299
|
+
|
300
|
+
# Extract embeddings
|
301
|
+
embeddings = [data.embedding for data in response.data]
|
302
|
+
|
303
|
+
return OpenAIEmbeddingsOutput(
|
304
|
+
embeddings=embeddings,
|
305
|
+
used_model=response.model,
|
306
|
+
tokens_used=response.usage.total_tokens,
|
307
|
+
)
|
308
|
+
except Exception as e:
|
309
|
+
raise ProcessingError(f"OpenAI embeddings generation failed: {str(e)}")
|
310
|
+
|
311
|
+
async def process_async(
|
312
|
+
self, input_data: OpenAIEmbeddingsInput
|
313
|
+
) -> OpenAIEmbeddingsOutput:
|
314
|
+
"""Async version of the embeddings generation"""
|
315
|
+
try:
|
316
|
+
# Handle single text input
|
317
|
+
texts = (
|
318
|
+
[input_data.texts]
|
319
|
+
if isinstance(input_data.texts, str)
|
320
|
+
else input_data.texts
|
321
|
+
)
|
322
|
+
|
323
|
+
# Create embeddings
|
324
|
+
response = await self.async_client.embeddings.create(
|
325
|
+
model=input_data.model,
|
326
|
+
input=texts,
|
327
|
+
encoding_format=input_data.encoding_format,
|
328
|
+
dimensions=input_data.dimensions,
|
329
|
+
)
|
330
|
+
|
331
|
+
# Extract embeddings
|
332
|
+
embeddings = [data.embedding for data in response.data]
|
333
|
+
|
334
|
+
return OpenAIEmbeddingsOutput(
|
335
|
+
embeddings=embeddings,
|
336
|
+
used_model=response.model,
|
337
|
+
tokens_used=response.usage.total_tokens,
|
338
|
+
)
|
339
|
+
except Exception as e:
|
340
|
+
raise ProcessingError(
|
341
|
+
f"OpenAI async embeddings generation failed: {str(e)}"
|
342
|
+
)
|
@@ -1,10 +1,10 @@
|
|
1
|
-
airtrain/__init__.py,sha256=
|
1
|
+
airtrain/__init__.py,sha256=fRBP9B8-SN5didcE8h1LLAPTuQTxiODu_-bbKGWLDR4,2099
|
2
2
|
airtrain/__main__.py,sha256=EU8ffFmCdC1G-UcHHt0Oo3lB1PGqfC6kwzH39CnYSwU,72
|
3
3
|
airtrain/builder/__init__.py,sha256=D33sr0k_WAe6FAJkk8rUaivEzFaeVqLXkQgyFWEhfPU,110
|
4
4
|
airtrain/builder/agent_builder.py,sha256=3XnGUAcK_6lWoUDtL0TanliQZuh7u0unhNbnrz1z2-I,5018
|
5
5
|
airtrain/cli/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
6
6
|
airtrain/cli/builder.py,sha256=cI0FZCfRrgXPmjt8lOnHZwCWKrOB2doaOn49kmxVxHs,669
|
7
|
-
airtrain/cli/main.py,sha256=
|
7
|
+
airtrain/cli/main.py,sha256=WGt0WXhfRl7D_UGNtCMRDWiBTBwbXcRbkEZOh9StXOo,3559
|
8
8
|
airtrain/contrib/__init__.py,sha256=pG-7mJ0pBMqp3Q86mIF9bo1PqoBOVSGlnEK1yY1U1ok,641
|
9
9
|
airtrain/contrib/travel/__init__.py,sha256=clmBodw4nkTA-DsgjVGcXfJGPaWxIpCZDtdO-8RzL0M,811
|
10
10
|
airtrain/contrib/travel/agents.py,sha256=tpQtZ0WUiXBuhvZtc2JlEam5TuR5l-Tndi14YyImDBM,8975
|
@@ -13,9 +13,10 @@ airtrain/core/__init__.py,sha256=9h7iKwTzZocCPc9bU6j8bA02BokteWIOcO1uaqGMcrk,254
|
|
13
13
|
airtrain/core/credentials.py,sha256=PgQotrQc46J5djidKnkK1znUv3fyNkUFDO-m2Kn_Gzo,4006
|
14
14
|
airtrain/core/schemas.py,sha256=MMXrDviC4gRea_QaPpbjgO--B_UKxnD7YrxqZOLJZZU,7003
|
15
15
|
airtrain/core/skills.py,sha256=LljalzeSHK5eQPTAOEAYc5D8Qn1kVSfiz9WgziTD5UM,4688
|
16
|
-
airtrain/integrations/__init__.py,sha256
|
17
|
-
airtrain/integrations/anthropic/__init__.py,sha256=
|
16
|
+
airtrain/integrations/__init__.py,sha256=rk9QFl0Dd7Qp4rULhi_u4smwsJwk69Kg_-fv0GQ43iw,1782
|
17
|
+
airtrain/integrations/anthropic/__init__.py,sha256=F4kB5fuj7nYgTVcgzeHGc91LT96FZfsCJVBVCnTRh-k,541
|
18
18
|
airtrain/integrations/anthropic/credentials.py,sha256=hlTSw9HX66kYNaeQUtn0JjdZQBMNkzzFOJOoLOOzvcY,1246
|
19
|
+
airtrain/integrations/anthropic/models_config.py,sha256=TZt31hLcT-9YK-NxqiarMyOwvUWMgXAzAcPfSwzDSiQ,3347
|
19
20
|
airtrain/integrations/anthropic/skills.py,sha256=WV-9254H2VqUAq_7Zr1xG5IhejeC_gQSqyH0hwW1_tY,5870
|
20
21
|
airtrain/integrations/aws/__init__.py,sha256=3x7v2NxpAfI-U-YgwQeH5PtsmUrNLPMfLyUGFLiBjbs,155
|
21
22
|
airtrain/integrations/aws/credentials.py,sha256=nN-daKAl7qOb_VdRpsThG8gN5GeSUkx-ji5E_gF_vYw,1444
|
@@ -23,15 +24,16 @@ airtrain/integrations/aws/skills.py,sha256=TQiMXeXRRcJ14fe8Xi7Uk20iS6_INbcznuLGt
|
|
23
24
|
airtrain/integrations/cerebras/__init__.py,sha256=zAD-qV38OzHhMCz1z-NvjjqcYEhURbm8RWTOKHNqbew,174
|
24
25
|
airtrain/integrations/cerebras/credentials.py,sha256=KDEH4r8FGT68L9p34MLZWK65wq_a703pqIF3ODaSbts,694
|
25
26
|
airtrain/integrations/cerebras/skills.py,sha256=hmqcnF-nkFk5YJVf8f-TiKBfb8kYCfnC30W67VZ7CKU,4922
|
26
|
-
airtrain/integrations/fireworks/__init__.py,sha256=
|
27
|
+
airtrain/integrations/fireworks/__init__.py,sha256=GstUg0rYC-7Pg0DVbDXwL5eO1hp3WCSfroWazbGpfi0,545
|
27
28
|
airtrain/integrations/fireworks/completion_skills.py,sha256=G657xWd7izLOxXq7RmqdupBF4DHqXQgXuhQ-MW7mtqc,5613
|
28
29
|
airtrain/integrations/fireworks/conversation_manager.py,sha256=m6VEHijqpYEYawkKhuHtb8RQxw4kxGWFWdbSK6zGuro,3704
|
29
|
-
airtrain/integrations/fireworks/credentials.py,sha256=
|
30
|
-
airtrain/integrations/fireworks/
|
30
|
+
airtrain/integrations/fireworks/credentials.py,sha256=eeV9y_4pTe8LZX02I7kfA_YNY2D7MSbFl7JEZVn22zQ,864
|
31
|
+
airtrain/integrations/fireworks/list_models.py,sha256=o4fP0K3qstBopO7va2LysLp4_KUf5Iz_YROrYkaNtVs,4686
|
32
|
+
airtrain/integrations/fireworks/models.py,sha256=yo4xtweSi4qQftg04r4naRddx3KjU9Jluzqf5C7V9f4,4626
|
31
33
|
airtrain/integrations/fireworks/requests_skills.py,sha256=c84Vy_4EcBrwJfp3jqizzlcja_LsEtvWh59qiaIjukg,8233
|
32
|
-
airtrain/integrations/fireworks/skills.py,sha256=
|
33
|
-
airtrain/integrations/fireworks/structured_completion_skills.py,sha256
|
34
|
-
airtrain/integrations/fireworks/structured_requests_skills.py,sha256=
|
34
|
+
airtrain/integrations/fireworks/skills.py,sha256=o9OY69cC10P8BtBBYRYLCyR_GwxmNlF6YhnrXiNS53o,7154
|
35
|
+
airtrain/integrations/fireworks/structured_completion_skills.py,sha256=-AJTaOFC8vkFiEjHW24VL8ymcNSVbhZp6xb4enkL95U,6620
|
36
|
+
airtrain/integrations/fireworks/structured_requests_skills.py,sha256=FgUdWb6_GI2ZBWhK2wp-WqKZUkwCkKNBBjYcRkHtjog,11850
|
35
37
|
airtrain/integrations/fireworks/structured_skills.py,sha256=BZaLqSOTC11QdZ4kDORS4JnwF_YXBAa-IiwQ5dJiHXw,3895
|
36
38
|
airtrain/integrations/google/__init__.py,sha256=ElwgcXfbg_gGMm6zbkMXCQPFKZUb-yTJk986o19A7Cs,214
|
37
39
|
airtrain/integrations/google/credentials.py,sha256=KSvWNqW8Mjr4MkysRvUqlrOSGdShNIe5u2OPO6vRrWY,2047
|
@@ -42,11 +44,11 @@ airtrain/integrations/groq/skills.py,sha256=qFyxC_2xZYnByAPo5p2aHbrqhdHYCoIdvDRA
|
|
42
44
|
airtrain/integrations/ollama/__init__.py,sha256=zMHBsGzViVrvxAeJmfq6r-ZfSE6Dy5QcKLhe4d5fEcM,164
|
43
45
|
airtrain/integrations/ollama/credentials.py,sha256=D7O4kUAb_VHs5s1ncUN9Ezhu5PvLfgj3RifAkB9sEZk,940
|
44
46
|
airtrain/integrations/ollama/skills.py,sha256=M_Un8D5VJ5XtPEq9IClzqV3jCPBoFTSm2ve6EO8W2JU,1556
|
45
|
-
airtrain/integrations/openai/__init__.py,sha256=
|
47
|
+
airtrain/integrations/openai/__init__.py,sha256=w5V7lxvrKtrrjyqGoppEKg9ORKKQ2cxaLOpgCZdm_H8,541
|
46
48
|
airtrain/integrations/openai/chinese_assistant.py,sha256=MMhv4NBOoEQ0O22ZZtP255rd5ajHC9l6FPWIjpqxBOA,1581
|
47
49
|
airtrain/integrations/openai/credentials.py,sha256=NfRyp1QgEtgm8cxt2-BOLq-6d0X-Pcm80NnfHM8p0FY,1470
|
48
|
-
airtrain/integrations/openai/models_config.py,sha256=
|
49
|
-
airtrain/integrations/openai/skills.py,sha256=
|
50
|
+
airtrain/integrations/openai/models_config.py,sha256=W9mu_z9tCC4ZUKHSJ6Hk4X09TRZLqEhT7TtRY5JEk5g,8007
|
51
|
+
airtrain/integrations/openai/skills.py,sha256=1dvRJYrnU2hOmGRlkHBtyR6P8D7aIwHZfUKmjlReWrQ,12821
|
50
52
|
airtrain/integrations/sambanova/__init__.py,sha256=dp_263iOckM_J9pOEvyqpf3FrejD6-_x33r0edMCTe0,179
|
51
53
|
airtrain/integrations/sambanova/credentials.py,sha256=JyN8sbMCoXuXAjim46aI3LTicBijoemS7Ao0rn4yBJU,824
|
52
54
|
airtrain/integrations/sambanova/skills.py,sha256=SZ_GAimMiOCILiNkzyhNflyRR6bdC5r0Tnog19K8geU,4997
|
@@ -63,8 +65,8 @@ airtrain/integrations/together/rerank_skill.py,sha256=gjH24hLWCweWKPyyfKZMG3K_g9
|
|
63
65
|
airtrain/integrations/together/schemas.py,sha256=pBMrbX67oxPCr-sg4K8_Xqu1DWbaC4uLCloVSascROg,1210
|
64
66
|
airtrain/integrations/together/skills.py,sha256=8DwkexMJu1Gm6QmNDfNasYStQ31QsXBbFP99zR-YCf0,7598
|
65
67
|
airtrain/integrations/together/vision_models_config.py,sha256=m28HwYDk2Kup_J-a1FtynIa2ZVcbl37kltfoHnK8zxs,1544
|
66
|
-
airtrain-0.1.
|
67
|
-
airtrain-0.1.
|
68
|
-
airtrain-0.1.
|
69
|
-
airtrain-0.1.
|
70
|
-
airtrain-0.1.
|
68
|
+
airtrain-0.1.40.dist-info/METADATA,sha256=3-KnsKsMdriztxTjdEWoJF599N3oicxpw2x17-mrQaw,5375
|
69
|
+
airtrain-0.1.40.dist-info/WHEEL,sha256=beeZ86-EfXScwlR_HKu4SllMC9wUEj_8Z_4FJ3egI2w,91
|
70
|
+
airtrain-0.1.40.dist-info/entry_points.txt,sha256=rrJ36IUsyq6n1dSfTWXqVAgpQLPRWDfCqwd6_3B-G0U,52
|
71
|
+
airtrain-0.1.40.dist-info/top_level.txt,sha256=cFWW1vY6VMCb3AGVdz6jBDpZ65xxBRSqlsPyySxTkxY,9
|
72
|
+
airtrain-0.1.40.dist-info/RECORD,,
|
File without changes
|
File without changes
|