airtrain 0.1.2__py3-none-any.whl → 0.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airtrain/__init__.py +148 -2
- airtrain/__main__.py +4 -0
- airtrain/__pycache__/__init__.cpython-313.pyc +0 -0
- airtrain/agents/__init__.py +45 -0
- airtrain/agents/example_agent.py +348 -0
- airtrain/agents/groq_agent.py +289 -0
- airtrain/agents/memory.py +663 -0
- airtrain/agents/registry.py +465 -0
- airtrain/builder/__init__.py +3 -0
- airtrain/builder/agent_builder.py +122 -0
- airtrain/cli/__init__.py +0 -0
- airtrain/cli/builder.py +23 -0
- airtrain/cli/main.py +120 -0
- airtrain/contrib/__init__.py +29 -0
- airtrain/contrib/travel/__init__.py +35 -0
- airtrain/contrib/travel/agents.py +243 -0
- airtrain/contrib/travel/models.py +59 -0
- airtrain/core/__init__.py +7 -0
- airtrain/core/__pycache__/__init__.cpython-313.pyc +0 -0
- airtrain/core/__pycache__/schemas.cpython-313.pyc +0 -0
- airtrain/core/__pycache__/skills.cpython-313.pyc +0 -0
- airtrain/core/credentials.py +171 -0
- airtrain/core/schemas.py +237 -0
- airtrain/core/skills.py +269 -0
- airtrain/integrations/__init__.py +74 -0
- airtrain/integrations/anthropic/__init__.py +33 -0
- airtrain/integrations/anthropic/credentials.py +32 -0
- airtrain/integrations/anthropic/list_models.py +110 -0
- airtrain/integrations/anthropic/models_config.py +100 -0
- airtrain/integrations/anthropic/skills.py +155 -0
- airtrain/integrations/aws/__init__.py +6 -0
- airtrain/integrations/aws/credentials.py +36 -0
- airtrain/integrations/aws/skills.py +98 -0
- airtrain/integrations/cerebras/__init__.py +6 -0
- airtrain/integrations/cerebras/credentials.py +19 -0
- airtrain/integrations/cerebras/skills.py +127 -0
- airtrain/integrations/combined/__init__.py +21 -0
- airtrain/integrations/combined/groq_fireworks_skills.py +126 -0
- airtrain/integrations/combined/list_models_factory.py +210 -0
- airtrain/integrations/fireworks/__init__.py +21 -0
- airtrain/integrations/fireworks/completion_skills.py +147 -0
- airtrain/integrations/fireworks/conversation_manager.py +109 -0
- airtrain/integrations/fireworks/credentials.py +26 -0
- airtrain/integrations/fireworks/list_models.py +128 -0
- airtrain/integrations/fireworks/models.py +139 -0
- airtrain/integrations/fireworks/requests_skills.py +207 -0
- airtrain/integrations/fireworks/skills.py +181 -0
- airtrain/integrations/fireworks/structured_completion_skills.py +175 -0
- airtrain/integrations/fireworks/structured_requests_skills.py +291 -0
- airtrain/integrations/fireworks/structured_skills.py +102 -0
- airtrain/integrations/google/__init__.py +7 -0
- airtrain/integrations/google/credentials.py +58 -0
- airtrain/integrations/google/skills.py +122 -0
- airtrain/integrations/groq/__init__.py +23 -0
- airtrain/integrations/groq/credentials.py +24 -0
- airtrain/integrations/groq/models_config.py +162 -0
- airtrain/integrations/groq/skills.py +201 -0
- airtrain/integrations/ollama/__init__.py +6 -0
- airtrain/integrations/ollama/credentials.py +26 -0
- airtrain/integrations/ollama/skills.py +41 -0
- airtrain/integrations/openai/__init__.py +37 -0
- airtrain/integrations/openai/chinese_assistant.py +42 -0
- airtrain/integrations/openai/credentials.py +39 -0
- airtrain/integrations/openai/list_models.py +112 -0
- airtrain/integrations/openai/models_config.py +224 -0
- airtrain/integrations/openai/skills.py +342 -0
- airtrain/integrations/perplexity/__init__.py +49 -0
- airtrain/integrations/perplexity/credentials.py +43 -0
- airtrain/integrations/perplexity/list_models.py +112 -0
- airtrain/integrations/perplexity/models_config.py +128 -0
- airtrain/integrations/perplexity/skills.py +279 -0
- airtrain/integrations/sambanova/__init__.py +6 -0
- airtrain/integrations/sambanova/credentials.py +20 -0
- airtrain/integrations/sambanova/skills.py +129 -0
- airtrain/integrations/search/__init__.py +21 -0
- airtrain/integrations/search/exa/__init__.py +23 -0
- airtrain/integrations/search/exa/credentials.py +30 -0
- airtrain/integrations/search/exa/schemas.py +114 -0
- airtrain/integrations/search/exa/skills.py +115 -0
- airtrain/integrations/together/__init__.py +33 -0
- airtrain/integrations/together/audio_models_config.py +34 -0
- airtrain/integrations/together/credentials.py +22 -0
- airtrain/integrations/together/embedding_models_config.py +92 -0
- airtrain/integrations/together/image_models_config.py +69 -0
- airtrain/integrations/together/image_skill.py +143 -0
- airtrain/integrations/together/list_models.py +76 -0
- airtrain/integrations/together/models.py +95 -0
- airtrain/integrations/together/models_config.py +399 -0
- airtrain/integrations/together/rerank_models_config.py +43 -0
- airtrain/integrations/together/rerank_skill.py +49 -0
- airtrain/integrations/together/schemas.py +33 -0
- airtrain/integrations/together/skills.py +305 -0
- airtrain/integrations/together/vision_models_config.py +49 -0
- airtrain/telemetry/__init__.py +38 -0
- airtrain/telemetry/service.py +167 -0
- airtrain/telemetry/views.py +237 -0
- airtrain/tools/__init__.py +45 -0
- airtrain/tools/command.py +398 -0
- airtrain/tools/filesystem.py +166 -0
- airtrain/tools/network.py +111 -0
- airtrain/tools/registry.py +320 -0
- airtrain/tools/search.py +450 -0
- airtrain/tools/testing.py +135 -0
- airtrain-0.1.4.dist-info/METADATA +222 -0
- airtrain-0.1.4.dist-info/RECORD +108 -0
- {airtrain-0.1.2.dist-info → airtrain-0.1.4.dist-info}/WHEEL +1 -1
- airtrain-0.1.4.dist-info/entry_points.txt +2 -0
- airtrain-0.1.2.dist-info/METADATA +0 -106
- airtrain-0.1.2.dist-info/RECORD +0 -5
- {airtrain-0.1.2.dist-info → airtrain-0.1.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,128 @@
|
|
1
|
+
from typing import Optional, List
|
2
|
+
import requests
|
3
|
+
from pydantic import Field
|
4
|
+
|
5
|
+
from airtrain.core.skills import Skill, ProcessingError
|
6
|
+
from airtrain.core.schemas import InputSchema, OutputSchema
|
7
|
+
from .credentials import FireworksCredentials
|
8
|
+
from .models import FireworksModel
|
9
|
+
|
10
|
+
|
11
|
+
class FireworksListModelsInput(InputSchema):
|
12
|
+
"""Schema for Fireworks AI list models input"""
|
13
|
+
|
14
|
+
account_id: str = Field(..., description="The Account Id")
|
15
|
+
page_size: Optional[int] = Field(
|
16
|
+
default=50,
|
17
|
+
description=(
|
18
|
+
"The maximum number of models to return. The maximum page_size is 200, "
|
19
|
+
"values above 200 will be coerced to 200."
|
20
|
+
),
|
21
|
+
le=200
|
22
|
+
)
|
23
|
+
page_token: Optional[str] = Field(
|
24
|
+
default=None,
|
25
|
+
description=(
|
26
|
+
"A page token, received from a previous ListModels call. Provide this "
|
27
|
+
"to retrieve the subsequent page. When paginating, all other parameters "
|
28
|
+
"provided to ListModels must match the call that provided the page token."
|
29
|
+
)
|
30
|
+
)
|
31
|
+
filter: Optional[str] = Field(
|
32
|
+
default=None,
|
33
|
+
description=(
|
34
|
+
"Only model satisfying the provided filter (if specified) will be "
|
35
|
+
"returned. See https://google.aip.dev/160 for the filter grammar."
|
36
|
+
)
|
37
|
+
)
|
38
|
+
order_by: Optional[str] = Field(
|
39
|
+
default=None,
|
40
|
+
description=(
|
41
|
+
"A comma-separated list of fields to order by. e.g. \"foo,bar\" "
|
42
|
+
"The default sort order is ascending. To specify a descending order for a "
|
43
|
+
"field, append a \" desc\" suffix. e.g. \"foo desc,bar\" "
|
44
|
+
"Subfields are specified with a \".\" character. e.g. \"foo.bar\" "
|
45
|
+
"If not specified, the default order is by \"name\"."
|
46
|
+
)
|
47
|
+
)
|
48
|
+
|
49
|
+
|
50
|
+
class FireworksListModelsOutput(OutputSchema):
|
51
|
+
"""Schema for Fireworks AI list models output"""
|
52
|
+
|
53
|
+
models: List[FireworksModel] = Field(
|
54
|
+
default_factory=list,
|
55
|
+
description="List of Fireworks models"
|
56
|
+
)
|
57
|
+
next_page_token: Optional[str] = Field(
|
58
|
+
default=None,
|
59
|
+
description="Token for retrieving the next page of results"
|
60
|
+
)
|
61
|
+
total_size: Optional[int] = Field(
|
62
|
+
default=None,
|
63
|
+
description="Total number of models available"
|
64
|
+
)
|
65
|
+
|
66
|
+
|
67
|
+
class FireworksListModelsSkill(
|
68
|
+
Skill[FireworksListModelsInput, FireworksListModelsOutput]
|
69
|
+
):
|
70
|
+
"""Skill for listing Fireworks AI models"""
|
71
|
+
|
72
|
+
input_schema = FireworksListModelsInput
|
73
|
+
output_schema = FireworksListModelsOutput
|
74
|
+
|
75
|
+
def __init__(self, credentials: Optional[FireworksCredentials] = None):
|
76
|
+
"""Initialize the skill with optional credentials"""
|
77
|
+
super().__init__()
|
78
|
+
self.credentials = credentials or FireworksCredentials.from_env()
|
79
|
+
self.base_url = "https://api.fireworks.ai/v1"
|
80
|
+
|
81
|
+
def process(
|
82
|
+
self, input_data: FireworksListModelsInput
|
83
|
+
) -> FireworksListModelsOutput:
|
84
|
+
"""Process the input and return a list of models."""
|
85
|
+
try:
|
86
|
+
# Build the URL
|
87
|
+
url = f"{self.base_url}/accounts/{input_data.account_id}/models"
|
88
|
+
|
89
|
+
# Prepare query parameters
|
90
|
+
params = {}
|
91
|
+
if input_data.page_size:
|
92
|
+
params["pageSize"] = input_data.page_size
|
93
|
+
if input_data.page_token:
|
94
|
+
params["pageToken"] = input_data.page_token
|
95
|
+
if input_data.filter:
|
96
|
+
params["filter"] = input_data.filter
|
97
|
+
if input_data.order_by:
|
98
|
+
params["orderBy"] = input_data.order_by
|
99
|
+
|
100
|
+
# Make the request
|
101
|
+
headers = {
|
102
|
+
"Authorization": (
|
103
|
+
f"Bearer {self.credentials.fireworks_api_key.get_secret_value()}"
|
104
|
+
)
|
105
|
+
}
|
106
|
+
|
107
|
+
response = requests.get(url, headers=headers, params=params)
|
108
|
+
response.raise_for_status()
|
109
|
+
|
110
|
+
# Parse the response
|
111
|
+
result = response.json()
|
112
|
+
|
113
|
+
# Convert the models to FireworksModel objects
|
114
|
+
models = []
|
115
|
+
for model_data in result.get("models", []):
|
116
|
+
models.append(FireworksModel(**model_data))
|
117
|
+
|
118
|
+
# Return the output
|
119
|
+
return FireworksListModelsOutput(
|
120
|
+
models=models,
|
121
|
+
next_page_token=result.get("nextPageToken"),
|
122
|
+
total_size=result.get("totalSize")
|
123
|
+
)
|
124
|
+
|
125
|
+
except requests.RequestException as e:
|
126
|
+
raise ProcessingError(f"Failed to list Fireworks models: {str(e)}")
|
127
|
+
except Exception as e:
|
128
|
+
raise ProcessingError(f"Error listing Fireworks models: {str(e)}")
|
@@ -0,0 +1,139 @@
|
|
1
|
+
from typing import List, Optional, Dict, Any
|
2
|
+
from pydantic import Field, BaseModel
|
3
|
+
|
4
|
+
|
5
|
+
class FireworksMessage(BaseModel):
|
6
|
+
"""Schema for Fireworks chat message"""
|
7
|
+
|
8
|
+
content: str
|
9
|
+
role: str = Field(..., pattern="^(system|user|assistant)$")
|
10
|
+
|
11
|
+
|
12
|
+
class FireworksUsage(BaseModel):
|
13
|
+
"""Schema for Fireworks API usage statistics"""
|
14
|
+
|
15
|
+
prompt_tokens: int
|
16
|
+
completion_tokens: int
|
17
|
+
total_tokens: int
|
18
|
+
|
19
|
+
|
20
|
+
class FireworksResponse(BaseModel):
|
21
|
+
"""Schema for Fireworks API response"""
|
22
|
+
|
23
|
+
id: str
|
24
|
+
choices: List[Dict[str, Any]]
|
25
|
+
created: int
|
26
|
+
model: str
|
27
|
+
usage: FireworksUsage
|
28
|
+
|
29
|
+
|
30
|
+
class FireworksModelStatus(BaseModel):
|
31
|
+
"""Schema for Fireworks model status"""
|
32
|
+
# This would be filled with actual fields from the API response
|
33
|
+
|
34
|
+
|
35
|
+
class FireworksModelBaseDetails(BaseModel):
|
36
|
+
"""Schema for Fireworks base model details"""
|
37
|
+
# This would be filled with actual fields from the API response
|
38
|
+
|
39
|
+
|
40
|
+
class FireworksPeftDetails(BaseModel):
|
41
|
+
"""Schema for Fireworks PEFT details"""
|
42
|
+
# This would be filled with actual fields from the API response
|
43
|
+
|
44
|
+
|
45
|
+
class FireworksConversationConfig(BaseModel):
|
46
|
+
"""Schema for Fireworks conversation configuration"""
|
47
|
+
# This would be filled with actual fields from the API response
|
48
|
+
|
49
|
+
|
50
|
+
class FireworksModelDeployedRef(BaseModel):
|
51
|
+
"""Schema for Fireworks deployed model reference"""
|
52
|
+
# This would be filled with actual fields from the API response
|
53
|
+
|
54
|
+
|
55
|
+
class FireworksDeprecationDate(BaseModel):
|
56
|
+
"""Schema for Fireworks deprecation date"""
|
57
|
+
# This would be filled with actual fields from the API response
|
58
|
+
|
59
|
+
|
60
|
+
class FireworksModel(BaseModel):
|
61
|
+
"""Schema for a Fireworks model"""
|
62
|
+
|
63
|
+
name: str
|
64
|
+
display_name: Optional[str] = None
|
65
|
+
description: Optional[str] = None
|
66
|
+
create_time: Optional[str] = None
|
67
|
+
created_by: Optional[str] = None
|
68
|
+
state: Optional[str] = None
|
69
|
+
status: Optional[Dict[str, Any]] = None
|
70
|
+
kind: Optional[str] = None
|
71
|
+
github_url: Optional[str] = None
|
72
|
+
hugging_face_url: Optional[str] = None
|
73
|
+
base_model_details: Optional[Dict[str, Any]] = None
|
74
|
+
peft_details: Optional[Dict[str, Any]] = None
|
75
|
+
teft_details: Optional[Dict[str, Any]] = None
|
76
|
+
public: Optional[bool] = None
|
77
|
+
conversation_config: Optional[Dict[str, Any]] = None
|
78
|
+
context_length: Optional[int] = None
|
79
|
+
supports_image_input: Optional[bool] = None
|
80
|
+
supports_tools: Optional[bool] = None
|
81
|
+
imported_from: Optional[str] = None
|
82
|
+
fine_tuning_job: Optional[str] = None
|
83
|
+
default_draft_model: Optional[str] = None
|
84
|
+
default_draft_token_count: Optional[int] = None
|
85
|
+
precisions: Optional[List[str]] = None
|
86
|
+
deployed_model_refs: Optional[List[Dict[str, Any]]] = None
|
87
|
+
cluster: Optional[str] = None
|
88
|
+
deprecation_date: Optional[Dict[str, Any]] = None
|
89
|
+
calibrated: Optional[bool] = None
|
90
|
+
tunable: Optional[bool] = None
|
91
|
+
supports_lora: Optional[bool] = None
|
92
|
+
use_hf_apply_chat_template: Optional[bool] = None
|
93
|
+
|
94
|
+
|
95
|
+
class ListModelsInput(BaseModel):
|
96
|
+
"""Schema for listing Fireworks models input"""
|
97
|
+
|
98
|
+
account_id: str = Field(..., description="The Account Id")
|
99
|
+
page_size: Optional[int] = Field(
|
100
|
+
default=50,
|
101
|
+
description=(
|
102
|
+
"The maximum number of models to return. The maximum page_size is 200, "
|
103
|
+
"values above 200 will be coerced to 200."
|
104
|
+
),
|
105
|
+
le=200
|
106
|
+
)
|
107
|
+
page_token: Optional[str] = Field(
|
108
|
+
default=None,
|
109
|
+
description=(
|
110
|
+
"A page token, received from a previous ListModels call. Provide this "
|
111
|
+
"to retrieve the subsequent page. When paginating, all other parameters "
|
112
|
+
"provided to ListModels must match the call that provided the page token."
|
113
|
+
)
|
114
|
+
)
|
115
|
+
filter: Optional[str] = Field(
|
116
|
+
default=None,
|
117
|
+
description=(
|
118
|
+
"Only model satisfying the provided filter (if specified) will be "
|
119
|
+
"returned. See https://google.aip.dev/160 for the filter grammar."
|
120
|
+
)
|
121
|
+
)
|
122
|
+
order_by: Optional[str] = Field(
|
123
|
+
default=None,
|
124
|
+
description=(
|
125
|
+
"A comma-separated list of fields to order by. e.g. \"foo,bar\" "
|
126
|
+
"The default sort order is ascending. To specify a descending order for a "
|
127
|
+
"field, append a \" desc\" suffix. e.g. \"foo desc,bar\" "
|
128
|
+
"Subfields are specified with a \".\" character. e.g. \"foo.bar\" "
|
129
|
+
"If not specified, the default order is by \"name\"."
|
130
|
+
)
|
131
|
+
)
|
132
|
+
|
133
|
+
|
134
|
+
class ListModelsOutput(BaseModel):
|
135
|
+
"""Schema for listing Fireworks models output"""
|
136
|
+
|
137
|
+
models: List[FireworksModel]
|
138
|
+
next_page_token: Optional[str] = None
|
139
|
+
total_size: Optional[int] = None
|
@@ -0,0 +1,207 @@
|
|
1
|
+
from typing import List, Optional, Dict, Any, Generator, AsyncGenerator
|
2
|
+
from pydantic import Field
|
3
|
+
import requests
|
4
|
+
import json
|
5
|
+
from loguru import logger
|
6
|
+
import aiohttp
|
7
|
+
|
8
|
+
from airtrain.core.skills import Skill, ProcessingError
|
9
|
+
from airtrain.core.schemas import InputSchema, OutputSchema
|
10
|
+
from .credentials import FireworksCredentials
|
11
|
+
|
12
|
+
|
13
|
+
class FireworksRequestInput(InputSchema):
|
14
|
+
"""Schema for Fireworks AI chat input using requests"""
|
15
|
+
|
16
|
+
user_input: str = Field(..., description="User's input text")
|
17
|
+
system_prompt: str = Field(
|
18
|
+
default="You are a helpful assistant.",
|
19
|
+
description="System prompt to guide the model's behavior",
|
20
|
+
)
|
21
|
+
conversation_history: List[Dict[str, str]] = Field(
|
22
|
+
default_factory=list,
|
23
|
+
description="List of previous conversation messages",
|
24
|
+
)
|
25
|
+
model: str = Field(
|
26
|
+
default="accounts/fireworks/models/deepseek-r1",
|
27
|
+
description="Fireworks AI model to use",
|
28
|
+
)
|
29
|
+
temperature: float = Field(
|
30
|
+
default=0.7, description="Temperature for response generation", ge=0, le=1
|
31
|
+
)
|
32
|
+
max_tokens: int = Field(default=131072, description="Maximum tokens in response")
|
33
|
+
top_p: float = Field(
|
34
|
+
default=1.0, description="Top p sampling parameter", ge=0, le=1
|
35
|
+
)
|
36
|
+
top_k: int = Field(default=40, description="Top k sampling parameter", ge=0)
|
37
|
+
presence_penalty: float = Field(
|
38
|
+
default=0.0, description="Presence penalty", ge=-2.0, le=2.0
|
39
|
+
)
|
40
|
+
frequency_penalty: float = Field(
|
41
|
+
default=0.0, description="Frequency penalty", ge=-2.0, le=2.0
|
42
|
+
)
|
43
|
+
stream: bool = Field(
|
44
|
+
default=False,
|
45
|
+
description="Whether to stream the response",
|
46
|
+
)
|
47
|
+
|
48
|
+
|
49
|
+
class FireworksRequestOutput(OutputSchema):
|
50
|
+
"""Schema for Fireworks AI chat output"""
|
51
|
+
|
52
|
+
response: str
|
53
|
+
used_model: str
|
54
|
+
usage: Dict[str, int]
|
55
|
+
|
56
|
+
|
57
|
+
class FireworksRequestSkill(Skill[FireworksRequestInput, FireworksRequestOutput]):
|
58
|
+
"""Skill for interacting with Fireworks AI models using requests"""
|
59
|
+
|
60
|
+
input_schema = FireworksRequestInput
|
61
|
+
output_schema = FireworksRequestOutput
|
62
|
+
BASE_URL = "https://api.fireworks.ai/inference/v1/chat/completions"
|
63
|
+
|
64
|
+
def __init__(self, credentials: Optional[FireworksCredentials] = None):
|
65
|
+
"""Initialize the skill with optional credentials"""
|
66
|
+
super().__init__()
|
67
|
+
self.credentials = credentials or FireworksCredentials.from_env()
|
68
|
+
self.headers = {
|
69
|
+
"Accept": "application/json",
|
70
|
+
"Content-Type": "application/json",
|
71
|
+
"Authorization": f"Bearer {self.credentials.fireworks_api_key.get_secret_value()}",
|
72
|
+
}
|
73
|
+
self.stream_headers = {
|
74
|
+
"Accept": "text/event-stream",
|
75
|
+
"Content-Type": "application/json",
|
76
|
+
"Authorization": f"Bearer {self.credentials.fireworks_api_key.get_secret_value()}",
|
77
|
+
}
|
78
|
+
|
79
|
+
def _build_messages(
|
80
|
+
self, input_data: FireworksRequestInput
|
81
|
+
) -> List[Dict[str, str]]:
|
82
|
+
"""Build messages list from input data including conversation history."""
|
83
|
+
messages = [{"role": "system", "content": input_data.system_prompt}]
|
84
|
+
|
85
|
+
if input_data.conversation_history:
|
86
|
+
messages.extend(input_data.conversation_history)
|
87
|
+
|
88
|
+
messages.append({"role": "user", "content": input_data.user_input})
|
89
|
+
return messages
|
90
|
+
|
91
|
+
def _build_payload(self, input_data: FireworksRequestInput) -> Dict[str, Any]:
|
92
|
+
"""Build the request payload."""
|
93
|
+
return {
|
94
|
+
"model": input_data.model,
|
95
|
+
"messages": self._build_messages(input_data),
|
96
|
+
"temperature": input_data.temperature,
|
97
|
+
"max_tokens": input_data.max_tokens,
|
98
|
+
"top_p": input_data.top_p,
|
99
|
+
"top_k": input_data.top_k,
|
100
|
+
"presence_penalty": input_data.presence_penalty,
|
101
|
+
"frequency_penalty": input_data.frequency_penalty,
|
102
|
+
"stream": input_data.stream,
|
103
|
+
}
|
104
|
+
|
105
|
+
def process_stream(
|
106
|
+
self, input_data: FireworksRequestInput
|
107
|
+
) -> Generator[str, None, None]:
|
108
|
+
"""Process the input and stream the response."""
|
109
|
+
try:
|
110
|
+
payload = self._build_payload(input_data)
|
111
|
+
response = requests.post(
|
112
|
+
self.BASE_URL,
|
113
|
+
headers=self.headers,
|
114
|
+
data=json.dumps(payload),
|
115
|
+
stream=True,
|
116
|
+
)
|
117
|
+
response.raise_for_status()
|
118
|
+
|
119
|
+
for line in response.iter_lines():
|
120
|
+
if line:
|
121
|
+
try:
|
122
|
+
data = json.loads(line.decode("utf-8").removeprefix("data: "))
|
123
|
+
if data["choices"][0]["delta"].get("content"):
|
124
|
+
yield data["choices"][0]["delta"]["content"]
|
125
|
+
except json.JSONDecodeError:
|
126
|
+
continue
|
127
|
+
|
128
|
+
except Exception as e:
|
129
|
+
raise ProcessingError(f"Fireworks streaming request failed: {str(e)}")
|
130
|
+
|
131
|
+
def process(self, input_data: FireworksRequestInput) -> FireworksRequestOutput:
|
132
|
+
"""Process the input and return the complete response."""
|
133
|
+
try:
|
134
|
+
if input_data.stream:
|
135
|
+
# For streaming, collect the entire response
|
136
|
+
response_chunks = []
|
137
|
+
for chunk in self.process_stream(input_data):
|
138
|
+
response_chunks.append(chunk)
|
139
|
+
response_text = "".join(response_chunks)
|
140
|
+
usage = {} # Usage stats not available in streaming mode
|
141
|
+
else:
|
142
|
+
# For non-streaming, use regular request
|
143
|
+
payload = self._build_payload(input_data)
|
144
|
+
response = requests.post(
|
145
|
+
self.BASE_URL, headers=self.headers, data=json.dumps(payload)
|
146
|
+
)
|
147
|
+
response.raise_for_status()
|
148
|
+
data = response.json()
|
149
|
+
|
150
|
+
response_text = data["choices"][0]["message"]["content"]
|
151
|
+
usage = data["usage"]
|
152
|
+
|
153
|
+
return FireworksRequestOutput(
|
154
|
+
response=response_text, used_model=input_data.model, usage=usage
|
155
|
+
)
|
156
|
+
|
157
|
+
except Exception as e:
|
158
|
+
raise ProcessingError(f"Fireworks request failed: {str(e)}")
|
159
|
+
|
160
|
+
async def process_async(
|
161
|
+
self, input_data: FireworksRequestInput
|
162
|
+
) -> FireworksRequestOutput:
|
163
|
+
"""Async version of process method using aiohttp"""
|
164
|
+
try:
|
165
|
+
async with aiohttp.ClientSession() as session:
|
166
|
+
payload = self._build_payload(input_data)
|
167
|
+
async with session.post(
|
168
|
+
self.BASE_URL, headers=self.headers, json=payload
|
169
|
+
) as response:
|
170
|
+
response.raise_for_status()
|
171
|
+
data = await response.json()
|
172
|
+
|
173
|
+
return FireworksRequestOutput(
|
174
|
+
response=data["choices"][0]["message"]["content"],
|
175
|
+
used_model=input_data.model,
|
176
|
+
usage=data.get("usage", {}),
|
177
|
+
)
|
178
|
+
|
179
|
+
except Exception as e:
|
180
|
+
raise ProcessingError(f"Async Fireworks request failed: {str(e)}")
|
181
|
+
|
182
|
+
async def process_stream_async(
|
183
|
+
self, input_data: FireworksRequestInput
|
184
|
+
) -> AsyncGenerator[str, None]:
|
185
|
+
"""Async version of stream processor using aiohttp"""
|
186
|
+
try:
|
187
|
+
async with aiohttp.ClientSession() as session:
|
188
|
+
payload = self._build_payload(input_data)
|
189
|
+
async with session.post(
|
190
|
+
self.BASE_URL, headers=self.stream_headers, json=payload
|
191
|
+
) as response:
|
192
|
+
response.raise_for_status()
|
193
|
+
|
194
|
+
async for line in response.content:
|
195
|
+
if line.startswith(b"data: "):
|
196
|
+
chunk = json.loads(line[6:].strip())
|
197
|
+
if "choices" in chunk:
|
198
|
+
content = (
|
199
|
+
chunk["choices"][0]
|
200
|
+
.get("delta", {})
|
201
|
+
.get("content", "")
|
202
|
+
)
|
203
|
+
if content:
|
204
|
+
yield content
|
205
|
+
|
206
|
+
except Exception as e:
|
207
|
+
raise ProcessingError(f"Async Fireworks streaming failed: {str(e)}")
|
@@ -0,0 +1,181 @@
|
|
1
|
+
from typing import List, Optional, Dict, Any, Generator, Union
|
2
|
+
from pydantic import Field
|
3
|
+
from openai import OpenAI
|
4
|
+
|
5
|
+
from airtrain.core.skills import Skill, ProcessingError
|
6
|
+
from airtrain.core.schemas import InputSchema, OutputSchema
|
7
|
+
from .credentials import FireworksCredentials
|
8
|
+
|
9
|
+
|
10
|
+
class FireworksInput(InputSchema):
|
11
|
+
"""Schema for Fireworks AI chat input"""
|
12
|
+
|
13
|
+
user_input: str = Field(..., description="User's input text")
|
14
|
+
system_prompt: str = Field(
|
15
|
+
default="You are a helpful assistant.",
|
16
|
+
description="System prompt to guide the model's behavior",
|
17
|
+
)
|
18
|
+
conversation_history: List[Dict[str, Any]] = Field(
|
19
|
+
default_factory=list,
|
20
|
+
description="List of previous conversation messages",
|
21
|
+
)
|
22
|
+
model: str = Field(
|
23
|
+
default="accounts/fireworks/models/deepseek-r1",
|
24
|
+
description="Fireworks AI model to use",
|
25
|
+
)
|
26
|
+
temperature: float = Field(
|
27
|
+
default=0.7, description="Temperature for response generation", ge=0, le=1
|
28
|
+
)
|
29
|
+
max_tokens: Optional[int] = Field(
|
30
|
+
default=131072, description="Maximum tokens in response"
|
31
|
+
)
|
32
|
+
context_length_exceeded_behavior: str = Field(
|
33
|
+
default="truncate", description="Behavior when context length is exceeded"
|
34
|
+
)
|
35
|
+
stream: bool = Field(
|
36
|
+
default=False,
|
37
|
+
description="Whether to stream the response token by token",
|
38
|
+
)
|
39
|
+
tools: Optional[List[Dict[str, Any]]] = Field(
|
40
|
+
default=None,
|
41
|
+
description=(
|
42
|
+
"A list of tools the model may use. "
|
43
|
+
"Currently only functions supported."
|
44
|
+
),
|
45
|
+
)
|
46
|
+
tool_choice: Optional[Union[str, Dict[str, Any]]] = Field(
|
47
|
+
default=None,
|
48
|
+
description=(
|
49
|
+
"Controls which tool is called by the model. "
|
50
|
+
"'none', 'auto', or specific tool."
|
51
|
+
),
|
52
|
+
)
|
53
|
+
|
54
|
+
|
55
|
+
class FireworksOutput(OutputSchema):
|
56
|
+
"""Schema for Fireworks AI chat output"""
|
57
|
+
|
58
|
+
response: str = Field(..., description="Model's response text")
|
59
|
+
used_model: str = Field(..., description="Model used for generation")
|
60
|
+
usage: Dict[str, int] = Field(default_factory=dict, description="Usage statistics")
|
61
|
+
tool_calls: Optional[List[Dict[str, Any]]] = Field(
|
62
|
+
default=None, description="Tool calls generated by the model"
|
63
|
+
)
|
64
|
+
|
65
|
+
|
66
|
+
class FireworksChatSkill(Skill[FireworksInput, FireworksOutput]):
|
67
|
+
"""Skill for interacting with Fireworks AI models"""
|
68
|
+
|
69
|
+
input_schema = FireworksInput
|
70
|
+
output_schema = FireworksOutput
|
71
|
+
|
72
|
+
def __init__(self, credentials: Optional[FireworksCredentials] = None):
|
73
|
+
"""Initialize the skill with optional credentials"""
|
74
|
+
super().__init__()
|
75
|
+
self.credentials = credentials or FireworksCredentials.from_env()
|
76
|
+
self.client = OpenAI(
|
77
|
+
base_url="https://api.fireworks.ai/inference/v1",
|
78
|
+
api_key=self.credentials.fireworks_api_key.get_secret_value(),
|
79
|
+
)
|
80
|
+
|
81
|
+
def _build_messages(self, input_data: FireworksInput) -> List[Dict[str, Any]]:
|
82
|
+
"""Build messages list from input data including conversation history."""
|
83
|
+
messages = [{"role": "system", "content": input_data.system_prompt}]
|
84
|
+
|
85
|
+
if input_data.conversation_history:
|
86
|
+
messages.extend(input_data.conversation_history)
|
87
|
+
|
88
|
+
messages.append({"role": "user", "content": input_data.user_input})
|
89
|
+
return messages
|
90
|
+
|
91
|
+
def process_stream(self, input_data: FireworksInput) -> Generator[str, None, None]:
|
92
|
+
"""Process the input and stream the response token by token."""
|
93
|
+
try:
|
94
|
+
messages = self._build_messages(input_data)
|
95
|
+
|
96
|
+
stream = self.client.chat.completions.create(
|
97
|
+
model=input_data.model,
|
98
|
+
messages=messages,
|
99
|
+
temperature=input_data.temperature,
|
100
|
+
max_tokens=input_data.max_tokens,
|
101
|
+
stream=True,
|
102
|
+
)
|
103
|
+
|
104
|
+
for chunk in stream:
|
105
|
+
if chunk.choices[0].delta.content is not None:
|
106
|
+
yield chunk.choices[0].delta.content
|
107
|
+
|
108
|
+
except Exception as e:
|
109
|
+
raise ProcessingError(f"Fireworks streaming failed: {str(e)}")
|
110
|
+
|
111
|
+
def process(self, input_data: FireworksInput) -> FireworksOutput:
|
112
|
+
"""Process the input and return the complete response."""
|
113
|
+
try:
|
114
|
+
if input_data.stream:
|
115
|
+
# For streaming, collect the entire response
|
116
|
+
response_chunks = []
|
117
|
+
for chunk in self.process_stream(input_data):
|
118
|
+
response_chunks.append(chunk)
|
119
|
+
response = "".join(response_chunks)
|
120
|
+
|
121
|
+
# Create completion object for usage stats
|
122
|
+
messages = self._build_messages(input_data)
|
123
|
+
completion = self.client.chat.completions.create(
|
124
|
+
model=input_data.model,
|
125
|
+
messages=messages,
|
126
|
+
temperature=input_data.temperature,
|
127
|
+
max_tokens=input_data.max_tokens,
|
128
|
+
stream=False,
|
129
|
+
)
|
130
|
+
else:
|
131
|
+
# For non-streaming, use regular completion
|
132
|
+
messages = self._build_messages(input_data)
|
133
|
+
|
134
|
+
# Prepare API call parameters
|
135
|
+
api_params = {
|
136
|
+
"model": input_data.model,
|
137
|
+
"messages": messages,
|
138
|
+
"temperature": input_data.temperature,
|
139
|
+
"max_tokens": input_data.max_tokens,
|
140
|
+
"stream": False,
|
141
|
+
}
|
142
|
+
|
143
|
+
# Add tools and tool_choice if provided
|
144
|
+
if input_data.tools:
|
145
|
+
api_params["tools"] = input_data.tools
|
146
|
+
|
147
|
+
if input_data.tool_choice:
|
148
|
+
api_params["tool_choice"] = input_data.tool_choice
|
149
|
+
|
150
|
+
completion = self.client.chat.completions.create(**api_params)
|
151
|
+
response = completion.choices[0].message.content or ""
|
152
|
+
|
153
|
+
# Check for tool calls in the response
|
154
|
+
tool_calls = None
|
155
|
+
if (hasattr(completion.choices[0].message, "tool_calls") and
|
156
|
+
completion.choices[0].message.tool_calls):
|
157
|
+
tool_calls = [
|
158
|
+
{
|
159
|
+
"id": tool_call.id,
|
160
|
+
"type": tool_call.type,
|
161
|
+
"function": {
|
162
|
+
"name": tool_call.function.name,
|
163
|
+
"arguments": tool_call.function.arguments
|
164
|
+
}
|
165
|
+
}
|
166
|
+
for tool_call in completion.choices[0].message.tool_calls
|
167
|
+
]
|
168
|
+
|
169
|
+
return FireworksOutput(
|
170
|
+
response=response,
|
171
|
+
used_model=input_data.model,
|
172
|
+
usage={
|
173
|
+
"total_tokens": completion.usage.total_tokens,
|
174
|
+
"prompt_tokens": completion.usage.prompt_tokens,
|
175
|
+
"completion_tokens": completion.usage.completion_tokens,
|
176
|
+
},
|
177
|
+
tool_calls=tool_calls
|
178
|
+
)
|
179
|
+
|
180
|
+
except Exception as e:
|
181
|
+
raise ProcessingError(f"Fireworks chat failed: {str(e)}")
|