airtrain 0.1.3__py3-none-any.whl → 0.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airtrain/__init__.py +146 -6
- airtrain/__main__.py +4 -0
- airtrain/__pycache__/__init__.cpython-313.pyc +0 -0
- airtrain/agents/__init__.py +45 -0
- airtrain/agents/example_agent.py +348 -0
- airtrain/agents/groq_agent.py +289 -0
- airtrain/agents/memory.py +663 -0
- airtrain/agents/registry.py +465 -0
- airtrain/builder/__init__.py +3 -0
- airtrain/builder/agent_builder.py +122 -0
- airtrain/cli/__init__.py +0 -0
- airtrain/cli/builder.py +23 -0
- airtrain/cli/main.py +120 -0
- airtrain/contrib/__init__.py +29 -0
- airtrain/contrib/travel/__init__.py +35 -0
- airtrain/contrib/travel/agents.py +243 -0
- airtrain/contrib/travel/models.py +59 -0
- airtrain/core/__pycache__/__init__.cpython-313.pyc +0 -0
- airtrain/core/__pycache__/schemas.cpython-313.pyc +0 -0
- airtrain/core/__pycache__/skills.cpython-313.pyc +0 -0
- airtrain/core/credentials.py +62 -44
- airtrain/core/skills.py +102 -0
- airtrain/integrations/__init__.py +74 -0
- airtrain/integrations/anthropic/__init__.py +33 -0
- airtrain/integrations/anthropic/credentials.py +32 -0
- airtrain/integrations/anthropic/list_models.py +110 -0
- airtrain/integrations/anthropic/models_config.py +100 -0
- airtrain/integrations/anthropic/skills.py +155 -0
- airtrain/integrations/aws/__init__.py +6 -0
- airtrain/integrations/aws/credentials.py +36 -0
- airtrain/integrations/aws/skills.py +98 -0
- airtrain/integrations/cerebras/__init__.py +6 -0
- airtrain/integrations/cerebras/credentials.py +19 -0
- airtrain/integrations/cerebras/skills.py +127 -0
- airtrain/integrations/combined/__init__.py +21 -0
- airtrain/integrations/combined/groq_fireworks_skills.py +126 -0
- airtrain/integrations/combined/list_models_factory.py +210 -0
- airtrain/integrations/fireworks/__init__.py +21 -0
- airtrain/integrations/fireworks/completion_skills.py +147 -0
- airtrain/integrations/fireworks/conversation_manager.py +109 -0
- airtrain/integrations/fireworks/credentials.py +26 -0
- airtrain/integrations/fireworks/list_models.py +128 -0
- airtrain/integrations/fireworks/models.py +139 -0
- airtrain/integrations/fireworks/requests_skills.py +207 -0
- airtrain/integrations/fireworks/skills.py +181 -0
- airtrain/integrations/fireworks/structured_completion_skills.py +175 -0
- airtrain/integrations/fireworks/structured_requests_skills.py +291 -0
- airtrain/integrations/fireworks/structured_skills.py +102 -0
- airtrain/integrations/google/__init__.py +7 -0
- airtrain/integrations/google/credentials.py +58 -0
- airtrain/integrations/google/skills.py +122 -0
- airtrain/integrations/groq/__init__.py +23 -0
- airtrain/integrations/groq/credentials.py +24 -0
- airtrain/integrations/groq/models_config.py +162 -0
- airtrain/integrations/groq/skills.py +201 -0
- airtrain/integrations/ollama/__init__.py +6 -0
- airtrain/integrations/ollama/credentials.py +26 -0
- airtrain/integrations/ollama/skills.py +41 -0
- airtrain/integrations/openai/__init__.py +37 -0
- airtrain/integrations/openai/chinese_assistant.py +42 -0
- airtrain/integrations/openai/credentials.py +39 -0
- airtrain/integrations/openai/list_models.py +112 -0
- airtrain/integrations/openai/models_config.py +224 -0
- airtrain/integrations/openai/skills.py +342 -0
- airtrain/integrations/perplexity/__init__.py +49 -0
- airtrain/integrations/perplexity/credentials.py +43 -0
- airtrain/integrations/perplexity/list_models.py +112 -0
- airtrain/integrations/perplexity/models_config.py +128 -0
- airtrain/integrations/perplexity/skills.py +279 -0
- airtrain/integrations/sambanova/__init__.py +6 -0
- airtrain/integrations/sambanova/credentials.py +20 -0
- airtrain/integrations/sambanova/skills.py +129 -0
- airtrain/integrations/search/__init__.py +21 -0
- airtrain/integrations/search/exa/__init__.py +23 -0
- airtrain/integrations/search/exa/credentials.py +30 -0
- airtrain/integrations/search/exa/schemas.py +114 -0
- airtrain/integrations/search/exa/skills.py +115 -0
- airtrain/integrations/together/__init__.py +33 -0
- airtrain/integrations/together/audio_models_config.py +34 -0
- airtrain/integrations/together/credentials.py +22 -0
- airtrain/integrations/together/embedding_models_config.py +92 -0
- airtrain/integrations/together/image_models_config.py +69 -0
- airtrain/integrations/together/image_skill.py +143 -0
- airtrain/integrations/together/list_models.py +76 -0
- airtrain/integrations/together/models.py +95 -0
- airtrain/integrations/together/models_config.py +399 -0
- airtrain/integrations/together/rerank_models_config.py +43 -0
- airtrain/integrations/together/rerank_skill.py +49 -0
- airtrain/integrations/together/schemas.py +33 -0
- airtrain/integrations/together/skills.py +305 -0
- airtrain/integrations/together/vision_models_config.py +49 -0
- airtrain/telemetry/__init__.py +38 -0
- airtrain/telemetry/service.py +167 -0
- airtrain/telemetry/views.py +237 -0
- airtrain/tools/__init__.py +45 -0
- airtrain/tools/command.py +398 -0
- airtrain/tools/filesystem.py +166 -0
- airtrain/tools/network.py +111 -0
- airtrain/tools/registry.py +320 -0
- airtrain/tools/search.py +450 -0
- airtrain/tools/testing.py +135 -0
- airtrain-0.1.4.dist-info/METADATA +222 -0
- airtrain-0.1.4.dist-info/RECORD +108 -0
- {airtrain-0.1.3.dist-info → airtrain-0.1.4.dist-info}/WHEEL +1 -1
- airtrain-0.1.4.dist-info/entry_points.txt +2 -0
- airtrain-0.1.3.dist-info/METADATA +0 -106
- airtrain-0.1.3.dist-info/RECORD +0 -9
- {airtrain-0.1.3.dist-info → airtrain-0.1.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,207 @@
|
|
1
|
+
from typing import List, Optional, Dict, Any, Generator, AsyncGenerator
|
2
|
+
from pydantic import Field
|
3
|
+
import requests
|
4
|
+
import json
|
5
|
+
from loguru import logger
|
6
|
+
import aiohttp
|
7
|
+
|
8
|
+
from airtrain.core.skills import Skill, ProcessingError
|
9
|
+
from airtrain.core.schemas import InputSchema, OutputSchema
|
10
|
+
from .credentials import FireworksCredentials
|
11
|
+
|
12
|
+
|
13
|
+
class FireworksRequestInput(InputSchema):
|
14
|
+
"""Schema for Fireworks AI chat input using requests"""
|
15
|
+
|
16
|
+
user_input: str = Field(..., description="User's input text")
|
17
|
+
system_prompt: str = Field(
|
18
|
+
default="You are a helpful assistant.",
|
19
|
+
description="System prompt to guide the model's behavior",
|
20
|
+
)
|
21
|
+
conversation_history: List[Dict[str, str]] = Field(
|
22
|
+
default_factory=list,
|
23
|
+
description="List of previous conversation messages",
|
24
|
+
)
|
25
|
+
model: str = Field(
|
26
|
+
default="accounts/fireworks/models/deepseek-r1",
|
27
|
+
description="Fireworks AI model to use",
|
28
|
+
)
|
29
|
+
temperature: float = Field(
|
30
|
+
default=0.7, description="Temperature for response generation", ge=0, le=1
|
31
|
+
)
|
32
|
+
max_tokens: int = Field(default=131072, description="Maximum tokens in response")
|
33
|
+
top_p: float = Field(
|
34
|
+
default=1.0, description="Top p sampling parameter", ge=0, le=1
|
35
|
+
)
|
36
|
+
top_k: int = Field(default=40, description="Top k sampling parameter", ge=0)
|
37
|
+
presence_penalty: float = Field(
|
38
|
+
default=0.0, description="Presence penalty", ge=-2.0, le=2.0
|
39
|
+
)
|
40
|
+
frequency_penalty: float = Field(
|
41
|
+
default=0.0, description="Frequency penalty", ge=-2.0, le=2.0
|
42
|
+
)
|
43
|
+
stream: bool = Field(
|
44
|
+
default=False,
|
45
|
+
description="Whether to stream the response",
|
46
|
+
)
|
47
|
+
|
48
|
+
|
49
|
+
class FireworksRequestOutput(OutputSchema):
|
50
|
+
"""Schema for Fireworks AI chat output"""
|
51
|
+
|
52
|
+
response: str
|
53
|
+
used_model: str
|
54
|
+
usage: Dict[str, int]
|
55
|
+
|
56
|
+
|
57
|
+
class FireworksRequestSkill(Skill[FireworksRequestInput, FireworksRequestOutput]):
|
58
|
+
"""Skill for interacting with Fireworks AI models using requests"""
|
59
|
+
|
60
|
+
input_schema = FireworksRequestInput
|
61
|
+
output_schema = FireworksRequestOutput
|
62
|
+
BASE_URL = "https://api.fireworks.ai/inference/v1/chat/completions"
|
63
|
+
|
64
|
+
def __init__(self, credentials: Optional[FireworksCredentials] = None):
|
65
|
+
"""Initialize the skill with optional credentials"""
|
66
|
+
super().__init__()
|
67
|
+
self.credentials = credentials or FireworksCredentials.from_env()
|
68
|
+
self.headers = {
|
69
|
+
"Accept": "application/json",
|
70
|
+
"Content-Type": "application/json",
|
71
|
+
"Authorization": f"Bearer {self.credentials.fireworks_api_key.get_secret_value()}",
|
72
|
+
}
|
73
|
+
self.stream_headers = {
|
74
|
+
"Accept": "text/event-stream",
|
75
|
+
"Content-Type": "application/json",
|
76
|
+
"Authorization": f"Bearer {self.credentials.fireworks_api_key.get_secret_value()}",
|
77
|
+
}
|
78
|
+
|
79
|
+
def _build_messages(
|
80
|
+
self, input_data: FireworksRequestInput
|
81
|
+
) -> List[Dict[str, str]]:
|
82
|
+
"""Build messages list from input data including conversation history."""
|
83
|
+
messages = [{"role": "system", "content": input_data.system_prompt}]
|
84
|
+
|
85
|
+
if input_data.conversation_history:
|
86
|
+
messages.extend(input_data.conversation_history)
|
87
|
+
|
88
|
+
messages.append({"role": "user", "content": input_data.user_input})
|
89
|
+
return messages
|
90
|
+
|
91
|
+
def _build_payload(self, input_data: FireworksRequestInput) -> Dict[str, Any]:
|
92
|
+
"""Build the request payload."""
|
93
|
+
return {
|
94
|
+
"model": input_data.model,
|
95
|
+
"messages": self._build_messages(input_data),
|
96
|
+
"temperature": input_data.temperature,
|
97
|
+
"max_tokens": input_data.max_tokens,
|
98
|
+
"top_p": input_data.top_p,
|
99
|
+
"top_k": input_data.top_k,
|
100
|
+
"presence_penalty": input_data.presence_penalty,
|
101
|
+
"frequency_penalty": input_data.frequency_penalty,
|
102
|
+
"stream": input_data.stream,
|
103
|
+
}
|
104
|
+
|
105
|
+
def process_stream(
|
106
|
+
self, input_data: FireworksRequestInput
|
107
|
+
) -> Generator[str, None, None]:
|
108
|
+
"""Process the input and stream the response."""
|
109
|
+
try:
|
110
|
+
payload = self._build_payload(input_data)
|
111
|
+
response = requests.post(
|
112
|
+
self.BASE_URL,
|
113
|
+
headers=self.headers,
|
114
|
+
data=json.dumps(payload),
|
115
|
+
stream=True,
|
116
|
+
)
|
117
|
+
response.raise_for_status()
|
118
|
+
|
119
|
+
for line in response.iter_lines():
|
120
|
+
if line:
|
121
|
+
try:
|
122
|
+
data = json.loads(line.decode("utf-8").removeprefix("data: "))
|
123
|
+
if data["choices"][0]["delta"].get("content"):
|
124
|
+
yield data["choices"][0]["delta"]["content"]
|
125
|
+
except json.JSONDecodeError:
|
126
|
+
continue
|
127
|
+
|
128
|
+
except Exception as e:
|
129
|
+
raise ProcessingError(f"Fireworks streaming request failed: {str(e)}")
|
130
|
+
|
131
|
+
def process(self, input_data: FireworksRequestInput) -> FireworksRequestOutput:
|
132
|
+
"""Process the input and return the complete response."""
|
133
|
+
try:
|
134
|
+
if input_data.stream:
|
135
|
+
# For streaming, collect the entire response
|
136
|
+
response_chunks = []
|
137
|
+
for chunk in self.process_stream(input_data):
|
138
|
+
response_chunks.append(chunk)
|
139
|
+
response_text = "".join(response_chunks)
|
140
|
+
usage = {} # Usage stats not available in streaming mode
|
141
|
+
else:
|
142
|
+
# For non-streaming, use regular request
|
143
|
+
payload = self._build_payload(input_data)
|
144
|
+
response = requests.post(
|
145
|
+
self.BASE_URL, headers=self.headers, data=json.dumps(payload)
|
146
|
+
)
|
147
|
+
response.raise_for_status()
|
148
|
+
data = response.json()
|
149
|
+
|
150
|
+
response_text = data["choices"][0]["message"]["content"]
|
151
|
+
usage = data["usage"]
|
152
|
+
|
153
|
+
return FireworksRequestOutput(
|
154
|
+
response=response_text, used_model=input_data.model, usage=usage
|
155
|
+
)
|
156
|
+
|
157
|
+
except Exception as e:
|
158
|
+
raise ProcessingError(f"Fireworks request failed: {str(e)}")
|
159
|
+
|
160
|
+
async def process_async(
|
161
|
+
self, input_data: FireworksRequestInput
|
162
|
+
) -> FireworksRequestOutput:
|
163
|
+
"""Async version of process method using aiohttp"""
|
164
|
+
try:
|
165
|
+
async with aiohttp.ClientSession() as session:
|
166
|
+
payload = self._build_payload(input_data)
|
167
|
+
async with session.post(
|
168
|
+
self.BASE_URL, headers=self.headers, json=payload
|
169
|
+
) as response:
|
170
|
+
response.raise_for_status()
|
171
|
+
data = await response.json()
|
172
|
+
|
173
|
+
return FireworksRequestOutput(
|
174
|
+
response=data["choices"][0]["message"]["content"],
|
175
|
+
used_model=input_data.model,
|
176
|
+
usage=data.get("usage", {}),
|
177
|
+
)
|
178
|
+
|
179
|
+
except Exception as e:
|
180
|
+
raise ProcessingError(f"Async Fireworks request failed: {str(e)}")
|
181
|
+
|
182
|
+
async def process_stream_async(
|
183
|
+
self, input_data: FireworksRequestInput
|
184
|
+
) -> AsyncGenerator[str, None]:
|
185
|
+
"""Async version of stream processor using aiohttp"""
|
186
|
+
try:
|
187
|
+
async with aiohttp.ClientSession() as session:
|
188
|
+
payload = self._build_payload(input_data)
|
189
|
+
async with session.post(
|
190
|
+
self.BASE_URL, headers=self.stream_headers, json=payload
|
191
|
+
) as response:
|
192
|
+
response.raise_for_status()
|
193
|
+
|
194
|
+
async for line in response.content:
|
195
|
+
if line.startswith(b"data: "):
|
196
|
+
chunk = json.loads(line[6:].strip())
|
197
|
+
if "choices" in chunk:
|
198
|
+
content = (
|
199
|
+
chunk["choices"][0]
|
200
|
+
.get("delta", {})
|
201
|
+
.get("content", "")
|
202
|
+
)
|
203
|
+
if content:
|
204
|
+
yield content
|
205
|
+
|
206
|
+
except Exception as e:
|
207
|
+
raise ProcessingError(f"Async Fireworks streaming failed: {str(e)}")
|
@@ -0,0 +1,181 @@
|
|
1
|
+
from typing import List, Optional, Dict, Any, Generator, Union
|
2
|
+
from pydantic import Field
|
3
|
+
from openai import OpenAI
|
4
|
+
|
5
|
+
from airtrain.core.skills import Skill, ProcessingError
|
6
|
+
from airtrain.core.schemas import InputSchema, OutputSchema
|
7
|
+
from .credentials import FireworksCredentials
|
8
|
+
|
9
|
+
|
10
|
+
class FireworksInput(InputSchema):
|
11
|
+
"""Schema for Fireworks AI chat input"""
|
12
|
+
|
13
|
+
user_input: str = Field(..., description="User's input text")
|
14
|
+
system_prompt: str = Field(
|
15
|
+
default="You are a helpful assistant.",
|
16
|
+
description="System prompt to guide the model's behavior",
|
17
|
+
)
|
18
|
+
conversation_history: List[Dict[str, Any]] = Field(
|
19
|
+
default_factory=list,
|
20
|
+
description="List of previous conversation messages",
|
21
|
+
)
|
22
|
+
model: str = Field(
|
23
|
+
default="accounts/fireworks/models/deepseek-r1",
|
24
|
+
description="Fireworks AI model to use",
|
25
|
+
)
|
26
|
+
temperature: float = Field(
|
27
|
+
default=0.7, description="Temperature for response generation", ge=0, le=1
|
28
|
+
)
|
29
|
+
max_tokens: Optional[int] = Field(
|
30
|
+
default=131072, description="Maximum tokens in response"
|
31
|
+
)
|
32
|
+
context_length_exceeded_behavior: str = Field(
|
33
|
+
default="truncate", description="Behavior when context length is exceeded"
|
34
|
+
)
|
35
|
+
stream: bool = Field(
|
36
|
+
default=False,
|
37
|
+
description="Whether to stream the response token by token",
|
38
|
+
)
|
39
|
+
tools: Optional[List[Dict[str, Any]]] = Field(
|
40
|
+
default=None,
|
41
|
+
description=(
|
42
|
+
"A list of tools the model may use. "
|
43
|
+
"Currently only functions supported."
|
44
|
+
),
|
45
|
+
)
|
46
|
+
tool_choice: Optional[Union[str, Dict[str, Any]]] = Field(
|
47
|
+
default=None,
|
48
|
+
description=(
|
49
|
+
"Controls which tool is called by the model. "
|
50
|
+
"'none', 'auto', or specific tool."
|
51
|
+
),
|
52
|
+
)
|
53
|
+
|
54
|
+
|
55
|
+
class FireworksOutput(OutputSchema):
|
56
|
+
"""Schema for Fireworks AI chat output"""
|
57
|
+
|
58
|
+
response: str = Field(..., description="Model's response text")
|
59
|
+
used_model: str = Field(..., description="Model used for generation")
|
60
|
+
usage: Dict[str, int] = Field(default_factory=dict, description="Usage statistics")
|
61
|
+
tool_calls: Optional[List[Dict[str, Any]]] = Field(
|
62
|
+
default=None, description="Tool calls generated by the model"
|
63
|
+
)
|
64
|
+
|
65
|
+
|
66
|
+
class FireworksChatSkill(Skill[FireworksInput, FireworksOutput]):
|
67
|
+
"""Skill for interacting with Fireworks AI models"""
|
68
|
+
|
69
|
+
input_schema = FireworksInput
|
70
|
+
output_schema = FireworksOutput
|
71
|
+
|
72
|
+
def __init__(self, credentials: Optional[FireworksCredentials] = None):
|
73
|
+
"""Initialize the skill with optional credentials"""
|
74
|
+
super().__init__()
|
75
|
+
self.credentials = credentials or FireworksCredentials.from_env()
|
76
|
+
self.client = OpenAI(
|
77
|
+
base_url="https://api.fireworks.ai/inference/v1",
|
78
|
+
api_key=self.credentials.fireworks_api_key.get_secret_value(),
|
79
|
+
)
|
80
|
+
|
81
|
+
def _build_messages(self, input_data: FireworksInput) -> List[Dict[str, Any]]:
|
82
|
+
"""Build messages list from input data including conversation history."""
|
83
|
+
messages = [{"role": "system", "content": input_data.system_prompt}]
|
84
|
+
|
85
|
+
if input_data.conversation_history:
|
86
|
+
messages.extend(input_data.conversation_history)
|
87
|
+
|
88
|
+
messages.append({"role": "user", "content": input_data.user_input})
|
89
|
+
return messages
|
90
|
+
|
91
|
+
def process_stream(self, input_data: FireworksInput) -> Generator[str, None, None]:
|
92
|
+
"""Process the input and stream the response token by token."""
|
93
|
+
try:
|
94
|
+
messages = self._build_messages(input_data)
|
95
|
+
|
96
|
+
stream = self.client.chat.completions.create(
|
97
|
+
model=input_data.model,
|
98
|
+
messages=messages,
|
99
|
+
temperature=input_data.temperature,
|
100
|
+
max_tokens=input_data.max_tokens,
|
101
|
+
stream=True,
|
102
|
+
)
|
103
|
+
|
104
|
+
for chunk in stream:
|
105
|
+
if chunk.choices[0].delta.content is not None:
|
106
|
+
yield chunk.choices[0].delta.content
|
107
|
+
|
108
|
+
except Exception as e:
|
109
|
+
raise ProcessingError(f"Fireworks streaming failed: {str(e)}")
|
110
|
+
|
111
|
+
def process(self, input_data: FireworksInput) -> FireworksOutput:
|
112
|
+
"""Process the input and return the complete response."""
|
113
|
+
try:
|
114
|
+
if input_data.stream:
|
115
|
+
# For streaming, collect the entire response
|
116
|
+
response_chunks = []
|
117
|
+
for chunk in self.process_stream(input_data):
|
118
|
+
response_chunks.append(chunk)
|
119
|
+
response = "".join(response_chunks)
|
120
|
+
|
121
|
+
# Create completion object for usage stats
|
122
|
+
messages = self._build_messages(input_data)
|
123
|
+
completion = self.client.chat.completions.create(
|
124
|
+
model=input_data.model,
|
125
|
+
messages=messages,
|
126
|
+
temperature=input_data.temperature,
|
127
|
+
max_tokens=input_data.max_tokens,
|
128
|
+
stream=False,
|
129
|
+
)
|
130
|
+
else:
|
131
|
+
# For non-streaming, use regular completion
|
132
|
+
messages = self._build_messages(input_data)
|
133
|
+
|
134
|
+
# Prepare API call parameters
|
135
|
+
api_params = {
|
136
|
+
"model": input_data.model,
|
137
|
+
"messages": messages,
|
138
|
+
"temperature": input_data.temperature,
|
139
|
+
"max_tokens": input_data.max_tokens,
|
140
|
+
"stream": False,
|
141
|
+
}
|
142
|
+
|
143
|
+
# Add tools and tool_choice if provided
|
144
|
+
if input_data.tools:
|
145
|
+
api_params["tools"] = input_data.tools
|
146
|
+
|
147
|
+
if input_data.tool_choice:
|
148
|
+
api_params["tool_choice"] = input_data.tool_choice
|
149
|
+
|
150
|
+
completion = self.client.chat.completions.create(**api_params)
|
151
|
+
response = completion.choices[0].message.content or ""
|
152
|
+
|
153
|
+
# Check for tool calls in the response
|
154
|
+
tool_calls = None
|
155
|
+
if (hasattr(completion.choices[0].message, "tool_calls") and
|
156
|
+
completion.choices[0].message.tool_calls):
|
157
|
+
tool_calls = [
|
158
|
+
{
|
159
|
+
"id": tool_call.id,
|
160
|
+
"type": tool_call.type,
|
161
|
+
"function": {
|
162
|
+
"name": tool_call.function.name,
|
163
|
+
"arguments": tool_call.function.arguments
|
164
|
+
}
|
165
|
+
}
|
166
|
+
for tool_call in completion.choices[0].message.tool_calls
|
167
|
+
]
|
168
|
+
|
169
|
+
return FireworksOutput(
|
170
|
+
response=response,
|
171
|
+
used_model=input_data.model,
|
172
|
+
usage={
|
173
|
+
"total_tokens": completion.usage.total_tokens,
|
174
|
+
"prompt_tokens": completion.usage.prompt_tokens,
|
175
|
+
"completion_tokens": completion.usage.completion_tokens,
|
176
|
+
},
|
177
|
+
tool_calls=tool_calls
|
178
|
+
)
|
179
|
+
|
180
|
+
except Exception as e:
|
181
|
+
raise ProcessingError(f"Fireworks chat failed: {str(e)}")
|
@@ -0,0 +1,175 @@
|
|
1
|
+
from typing import Any, Dict, Generator, List, Optional, Type, TypeVar
|
2
|
+
from pydantic import BaseModel, Field
|
3
|
+
import requests
|
4
|
+
import json
|
5
|
+
|
6
|
+
from airtrain.core.skills import Skill, ProcessingError
|
7
|
+
from airtrain.core.schemas import InputSchema, OutputSchema
|
8
|
+
from .credentials import FireworksCredentials
|
9
|
+
|
10
|
+
ResponseT = TypeVar("ResponseT")
|
11
|
+
|
12
|
+
|
13
|
+
class FireworksStructuredCompletionInput(InputSchema):
|
14
|
+
"""Schema for Fireworks AI structured completion input"""
|
15
|
+
|
16
|
+
prompt: str = Field(..., description="Input prompt for completion")
|
17
|
+
model: str = Field(
|
18
|
+
default="accounts/fireworks/models/deepseek-r1",
|
19
|
+
description="Fireworks AI model to use",
|
20
|
+
)
|
21
|
+
temperature: float = Field(
|
22
|
+
default=0.7, description="Temperature for response generation", ge=0, le=1
|
23
|
+
)
|
24
|
+
max_tokens: int = Field(default=131072, description="Maximum tokens in response")
|
25
|
+
response_model: Type[ResponseT]
|
26
|
+
stream: bool = Field(
|
27
|
+
default=False,
|
28
|
+
description="Whether to stream the response token by token",
|
29
|
+
)
|
30
|
+
|
31
|
+
class Config:
|
32
|
+
arbitrary_types_allowed = True
|
33
|
+
|
34
|
+
|
35
|
+
class FireworksStructuredCompletionOutput(OutputSchema):
|
36
|
+
"""Schema for Fireworks AI structured completion output"""
|
37
|
+
|
38
|
+
parsed_response: Any
|
39
|
+
used_model: str
|
40
|
+
usage: Dict[str, int]
|
41
|
+
tool_calls: Optional[List[Dict[str, Any]]] = Field(
|
42
|
+
default=None,
|
43
|
+
description=(
|
44
|
+
"Tool calls are not applicable for completions, "
|
45
|
+
"included for compatibility"
|
46
|
+
)
|
47
|
+
)
|
48
|
+
|
49
|
+
|
50
|
+
class FireworksStructuredCompletionSkill(
|
51
|
+
Skill[FireworksStructuredCompletionInput, FireworksStructuredCompletionOutput]
|
52
|
+
):
|
53
|
+
"""Skill for getting structured completion responses from Fireworks AI"""
|
54
|
+
|
55
|
+
input_schema = FireworksStructuredCompletionInput
|
56
|
+
output_schema = FireworksStructuredCompletionOutput
|
57
|
+
BASE_URL = "https://api.fireworks.ai/inference/v1/completions"
|
58
|
+
|
59
|
+
def __init__(self, credentials: Optional[FireworksCredentials] = None):
|
60
|
+
"""Initialize the skill with optional credentials"""
|
61
|
+
super().__init__()
|
62
|
+
self.credentials = credentials or FireworksCredentials.from_env()
|
63
|
+
self.headers = {
|
64
|
+
"Accept": "application/json",
|
65
|
+
"Content-Type": "application/json",
|
66
|
+
"Authorization": f"Bearer {self.credentials.fireworks_api_key.get_secret_value()}",
|
67
|
+
}
|
68
|
+
|
69
|
+
def _build_payload(
|
70
|
+
self, input_data: FireworksStructuredCompletionInput
|
71
|
+
) -> Dict[str, Any]:
|
72
|
+
"""Build the request payload."""
|
73
|
+
return {
|
74
|
+
"model": input_data.model,
|
75
|
+
"prompt": input_data.prompt,
|
76
|
+
"temperature": input_data.temperature,
|
77
|
+
"max_tokens": input_data.max_tokens,
|
78
|
+
"stream": input_data.stream,
|
79
|
+
"response_format": {
|
80
|
+
"type": "json_object",
|
81
|
+
"schema": {
|
82
|
+
**input_data.response_model.model_json_schema(),
|
83
|
+
"required": [
|
84
|
+
field
|
85
|
+
for field, _ in input_data.response_model.model_fields.items()
|
86
|
+
],
|
87
|
+
},
|
88
|
+
},
|
89
|
+
}
|
90
|
+
|
91
|
+
def process_stream(
|
92
|
+
self, input_data: FireworksStructuredCompletionInput
|
93
|
+
) -> Generator[Dict[str, Any], None, None]:
|
94
|
+
"""Process the input and stream the response."""
|
95
|
+
try:
|
96
|
+
payload = self._build_payload(input_data)
|
97
|
+
response = requests.post(
|
98
|
+
self.BASE_URL,
|
99
|
+
headers=self.headers,
|
100
|
+
data=json.dumps(payload),
|
101
|
+
stream=True,
|
102
|
+
)
|
103
|
+
response.raise_for_status()
|
104
|
+
|
105
|
+
json_buffer = []
|
106
|
+
for line in response.iter_lines():
|
107
|
+
if line:
|
108
|
+
try:
|
109
|
+
data = json.loads(line.decode("utf-8").removeprefix("data: "))
|
110
|
+
if data.get("choices") and data["choices"][0].get("text"):
|
111
|
+
content = data["choices"][0]["text"]
|
112
|
+
json_buffer.append(content)
|
113
|
+
yield {"chunk": content}
|
114
|
+
except json.JSONDecodeError:
|
115
|
+
continue
|
116
|
+
|
117
|
+
# Once complete, parse the full JSON
|
118
|
+
complete_json = "".join(json_buffer)
|
119
|
+
try:
|
120
|
+
parsed_response = input_data.response_model.model_validate_json(
|
121
|
+
complete_json
|
122
|
+
)
|
123
|
+
yield {"complete": parsed_response}
|
124
|
+
except Exception as e:
|
125
|
+
raise ProcessingError(f"Failed to parse JSON response: {str(e)}")
|
126
|
+
|
127
|
+
except Exception as e:
|
128
|
+
raise ProcessingError(f"Fireworks streaming request failed: {str(e)}")
|
129
|
+
|
130
|
+
def process(
|
131
|
+
self, input_data: FireworksStructuredCompletionInput
|
132
|
+
) -> FireworksStructuredCompletionOutput:
|
133
|
+
"""Process the input and return structured response."""
|
134
|
+
try:
|
135
|
+
if input_data.stream:
|
136
|
+
# For streaming, collect and parse the entire response
|
137
|
+
json_buffer = []
|
138
|
+
parsed_response = None
|
139
|
+
|
140
|
+
for chunk in self.process_stream(input_data):
|
141
|
+
if "chunk" in chunk:
|
142
|
+
json_buffer.append(chunk["chunk"])
|
143
|
+
elif "complete" in chunk:
|
144
|
+
parsed_response = chunk["complete"]
|
145
|
+
|
146
|
+
if parsed_response is None:
|
147
|
+
raise ProcessingError("Failed to parse streamed response")
|
148
|
+
|
149
|
+
return FireworksStructuredCompletionOutput(
|
150
|
+
parsed_response=parsed_response,
|
151
|
+
used_model=input_data.model,
|
152
|
+
usage={}, # Usage stats not available in streaming mode
|
153
|
+
)
|
154
|
+
else:
|
155
|
+
# For non-streaming, use regular request
|
156
|
+
payload = self._build_payload(input_data)
|
157
|
+
response = requests.post(
|
158
|
+
self.BASE_URL, headers=self.headers, data=json.dumps(payload)
|
159
|
+
)
|
160
|
+
response.raise_for_status()
|
161
|
+
data = response.json()
|
162
|
+
|
163
|
+
response_text = data["choices"][0]["text"]
|
164
|
+
parsed_response = input_data.response_model.model_validate_json(
|
165
|
+
response_text
|
166
|
+
)
|
167
|
+
|
168
|
+
return FireworksStructuredCompletionOutput(
|
169
|
+
parsed_response=parsed_response,
|
170
|
+
used_model=input_data.model,
|
171
|
+
usage=data["usage"],
|
172
|
+
)
|
173
|
+
|
174
|
+
except Exception as e:
|
175
|
+
raise ProcessingError(f"Fireworks structured completion failed: {str(e)}")
|