airtrain 0.1.2__py3-none-any.whl → 0.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airtrain/__init__.py +148 -2
- airtrain/__main__.py +4 -0
- airtrain/__pycache__/__init__.cpython-313.pyc +0 -0
- airtrain/agents/__init__.py +45 -0
- airtrain/agents/example_agent.py +348 -0
- airtrain/agents/groq_agent.py +289 -0
- airtrain/agents/memory.py +663 -0
- airtrain/agents/registry.py +465 -0
- airtrain/builder/__init__.py +3 -0
- airtrain/builder/agent_builder.py +122 -0
- airtrain/cli/__init__.py +0 -0
- airtrain/cli/builder.py +23 -0
- airtrain/cli/main.py +120 -0
- airtrain/contrib/__init__.py +29 -0
- airtrain/contrib/travel/__init__.py +35 -0
- airtrain/contrib/travel/agents.py +243 -0
- airtrain/contrib/travel/models.py +59 -0
- airtrain/core/__init__.py +7 -0
- airtrain/core/__pycache__/__init__.cpython-313.pyc +0 -0
- airtrain/core/__pycache__/schemas.cpython-313.pyc +0 -0
- airtrain/core/__pycache__/skills.cpython-313.pyc +0 -0
- airtrain/core/credentials.py +171 -0
- airtrain/core/schemas.py +237 -0
- airtrain/core/skills.py +269 -0
- airtrain/integrations/__init__.py +74 -0
- airtrain/integrations/anthropic/__init__.py +33 -0
- airtrain/integrations/anthropic/credentials.py +32 -0
- airtrain/integrations/anthropic/list_models.py +110 -0
- airtrain/integrations/anthropic/models_config.py +100 -0
- airtrain/integrations/anthropic/skills.py +155 -0
- airtrain/integrations/aws/__init__.py +6 -0
- airtrain/integrations/aws/credentials.py +36 -0
- airtrain/integrations/aws/skills.py +98 -0
- airtrain/integrations/cerebras/__init__.py +6 -0
- airtrain/integrations/cerebras/credentials.py +19 -0
- airtrain/integrations/cerebras/skills.py +127 -0
- airtrain/integrations/combined/__init__.py +21 -0
- airtrain/integrations/combined/groq_fireworks_skills.py +126 -0
- airtrain/integrations/combined/list_models_factory.py +210 -0
- airtrain/integrations/fireworks/__init__.py +21 -0
- airtrain/integrations/fireworks/completion_skills.py +147 -0
- airtrain/integrations/fireworks/conversation_manager.py +109 -0
- airtrain/integrations/fireworks/credentials.py +26 -0
- airtrain/integrations/fireworks/list_models.py +128 -0
- airtrain/integrations/fireworks/models.py +139 -0
- airtrain/integrations/fireworks/requests_skills.py +207 -0
- airtrain/integrations/fireworks/skills.py +181 -0
- airtrain/integrations/fireworks/structured_completion_skills.py +175 -0
- airtrain/integrations/fireworks/structured_requests_skills.py +291 -0
- airtrain/integrations/fireworks/structured_skills.py +102 -0
- airtrain/integrations/google/__init__.py +7 -0
- airtrain/integrations/google/credentials.py +58 -0
- airtrain/integrations/google/skills.py +122 -0
- airtrain/integrations/groq/__init__.py +23 -0
- airtrain/integrations/groq/credentials.py +24 -0
- airtrain/integrations/groq/models_config.py +162 -0
- airtrain/integrations/groq/skills.py +201 -0
- airtrain/integrations/ollama/__init__.py +6 -0
- airtrain/integrations/ollama/credentials.py +26 -0
- airtrain/integrations/ollama/skills.py +41 -0
- airtrain/integrations/openai/__init__.py +37 -0
- airtrain/integrations/openai/chinese_assistant.py +42 -0
- airtrain/integrations/openai/credentials.py +39 -0
- airtrain/integrations/openai/list_models.py +112 -0
- airtrain/integrations/openai/models_config.py +224 -0
- airtrain/integrations/openai/skills.py +342 -0
- airtrain/integrations/perplexity/__init__.py +49 -0
- airtrain/integrations/perplexity/credentials.py +43 -0
- airtrain/integrations/perplexity/list_models.py +112 -0
- airtrain/integrations/perplexity/models_config.py +128 -0
- airtrain/integrations/perplexity/skills.py +279 -0
- airtrain/integrations/sambanova/__init__.py +6 -0
- airtrain/integrations/sambanova/credentials.py +20 -0
- airtrain/integrations/sambanova/skills.py +129 -0
- airtrain/integrations/search/__init__.py +21 -0
- airtrain/integrations/search/exa/__init__.py +23 -0
- airtrain/integrations/search/exa/credentials.py +30 -0
- airtrain/integrations/search/exa/schemas.py +114 -0
- airtrain/integrations/search/exa/skills.py +115 -0
- airtrain/integrations/together/__init__.py +33 -0
- airtrain/integrations/together/audio_models_config.py +34 -0
- airtrain/integrations/together/credentials.py +22 -0
- airtrain/integrations/together/embedding_models_config.py +92 -0
- airtrain/integrations/together/image_models_config.py +69 -0
- airtrain/integrations/together/image_skill.py +143 -0
- airtrain/integrations/together/list_models.py +76 -0
- airtrain/integrations/together/models.py +95 -0
- airtrain/integrations/together/models_config.py +399 -0
- airtrain/integrations/together/rerank_models_config.py +43 -0
- airtrain/integrations/together/rerank_skill.py +49 -0
- airtrain/integrations/together/schemas.py +33 -0
- airtrain/integrations/together/skills.py +305 -0
- airtrain/integrations/together/vision_models_config.py +49 -0
- airtrain/telemetry/__init__.py +38 -0
- airtrain/telemetry/service.py +167 -0
- airtrain/telemetry/views.py +237 -0
- airtrain/tools/__init__.py +45 -0
- airtrain/tools/command.py +398 -0
- airtrain/tools/filesystem.py +166 -0
- airtrain/tools/network.py +111 -0
- airtrain/tools/registry.py +320 -0
- airtrain/tools/search.py +450 -0
- airtrain/tools/testing.py +135 -0
- airtrain-0.1.4.dist-info/METADATA +222 -0
- airtrain-0.1.4.dist-info/RECORD +108 -0
- {airtrain-0.1.2.dist-info → airtrain-0.1.4.dist-info}/WHEEL +1 -1
- airtrain-0.1.4.dist-info/entry_points.txt +2 -0
- airtrain-0.1.2.dist-info/METADATA +0 -106
- airtrain-0.1.2.dist-info/RECORD +0 -5
- {airtrain-0.1.2.dist-info → airtrain-0.1.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,175 @@
|
|
1
|
+
from typing import Any, Dict, Generator, List, Optional, Type, TypeVar
|
2
|
+
from pydantic import BaseModel, Field
|
3
|
+
import requests
|
4
|
+
import json
|
5
|
+
|
6
|
+
from airtrain.core.skills import Skill, ProcessingError
|
7
|
+
from airtrain.core.schemas import InputSchema, OutputSchema
|
8
|
+
from .credentials import FireworksCredentials
|
9
|
+
|
10
|
+
ResponseT = TypeVar("ResponseT")
|
11
|
+
|
12
|
+
|
13
|
+
class FireworksStructuredCompletionInput(InputSchema):
|
14
|
+
"""Schema for Fireworks AI structured completion input"""
|
15
|
+
|
16
|
+
prompt: str = Field(..., description="Input prompt for completion")
|
17
|
+
model: str = Field(
|
18
|
+
default="accounts/fireworks/models/deepseek-r1",
|
19
|
+
description="Fireworks AI model to use",
|
20
|
+
)
|
21
|
+
temperature: float = Field(
|
22
|
+
default=0.7, description="Temperature for response generation", ge=0, le=1
|
23
|
+
)
|
24
|
+
max_tokens: int = Field(default=131072, description="Maximum tokens in response")
|
25
|
+
response_model: Type[ResponseT]
|
26
|
+
stream: bool = Field(
|
27
|
+
default=False,
|
28
|
+
description="Whether to stream the response token by token",
|
29
|
+
)
|
30
|
+
|
31
|
+
class Config:
|
32
|
+
arbitrary_types_allowed = True
|
33
|
+
|
34
|
+
|
35
|
+
class FireworksStructuredCompletionOutput(OutputSchema):
|
36
|
+
"""Schema for Fireworks AI structured completion output"""
|
37
|
+
|
38
|
+
parsed_response: Any
|
39
|
+
used_model: str
|
40
|
+
usage: Dict[str, int]
|
41
|
+
tool_calls: Optional[List[Dict[str, Any]]] = Field(
|
42
|
+
default=None,
|
43
|
+
description=(
|
44
|
+
"Tool calls are not applicable for completions, "
|
45
|
+
"included for compatibility"
|
46
|
+
)
|
47
|
+
)
|
48
|
+
|
49
|
+
|
50
|
+
class FireworksStructuredCompletionSkill(
|
51
|
+
Skill[FireworksStructuredCompletionInput, FireworksStructuredCompletionOutput]
|
52
|
+
):
|
53
|
+
"""Skill for getting structured completion responses from Fireworks AI"""
|
54
|
+
|
55
|
+
input_schema = FireworksStructuredCompletionInput
|
56
|
+
output_schema = FireworksStructuredCompletionOutput
|
57
|
+
BASE_URL = "https://api.fireworks.ai/inference/v1/completions"
|
58
|
+
|
59
|
+
def __init__(self, credentials: Optional[FireworksCredentials] = None):
|
60
|
+
"""Initialize the skill with optional credentials"""
|
61
|
+
super().__init__()
|
62
|
+
self.credentials = credentials or FireworksCredentials.from_env()
|
63
|
+
self.headers = {
|
64
|
+
"Accept": "application/json",
|
65
|
+
"Content-Type": "application/json",
|
66
|
+
"Authorization": f"Bearer {self.credentials.fireworks_api_key.get_secret_value()}",
|
67
|
+
}
|
68
|
+
|
69
|
+
def _build_payload(
|
70
|
+
self, input_data: FireworksStructuredCompletionInput
|
71
|
+
) -> Dict[str, Any]:
|
72
|
+
"""Build the request payload."""
|
73
|
+
return {
|
74
|
+
"model": input_data.model,
|
75
|
+
"prompt": input_data.prompt,
|
76
|
+
"temperature": input_data.temperature,
|
77
|
+
"max_tokens": input_data.max_tokens,
|
78
|
+
"stream": input_data.stream,
|
79
|
+
"response_format": {
|
80
|
+
"type": "json_object",
|
81
|
+
"schema": {
|
82
|
+
**input_data.response_model.model_json_schema(),
|
83
|
+
"required": [
|
84
|
+
field
|
85
|
+
for field, _ in input_data.response_model.model_fields.items()
|
86
|
+
],
|
87
|
+
},
|
88
|
+
},
|
89
|
+
}
|
90
|
+
|
91
|
+
def process_stream(
|
92
|
+
self, input_data: FireworksStructuredCompletionInput
|
93
|
+
) -> Generator[Dict[str, Any], None, None]:
|
94
|
+
"""Process the input and stream the response."""
|
95
|
+
try:
|
96
|
+
payload = self._build_payload(input_data)
|
97
|
+
response = requests.post(
|
98
|
+
self.BASE_URL,
|
99
|
+
headers=self.headers,
|
100
|
+
data=json.dumps(payload),
|
101
|
+
stream=True,
|
102
|
+
)
|
103
|
+
response.raise_for_status()
|
104
|
+
|
105
|
+
json_buffer = []
|
106
|
+
for line in response.iter_lines():
|
107
|
+
if line:
|
108
|
+
try:
|
109
|
+
data = json.loads(line.decode("utf-8").removeprefix("data: "))
|
110
|
+
if data.get("choices") and data["choices"][0].get("text"):
|
111
|
+
content = data["choices"][0]["text"]
|
112
|
+
json_buffer.append(content)
|
113
|
+
yield {"chunk": content}
|
114
|
+
except json.JSONDecodeError:
|
115
|
+
continue
|
116
|
+
|
117
|
+
# Once complete, parse the full JSON
|
118
|
+
complete_json = "".join(json_buffer)
|
119
|
+
try:
|
120
|
+
parsed_response = input_data.response_model.model_validate_json(
|
121
|
+
complete_json
|
122
|
+
)
|
123
|
+
yield {"complete": parsed_response}
|
124
|
+
except Exception as e:
|
125
|
+
raise ProcessingError(f"Failed to parse JSON response: {str(e)}")
|
126
|
+
|
127
|
+
except Exception as e:
|
128
|
+
raise ProcessingError(f"Fireworks streaming request failed: {str(e)}")
|
129
|
+
|
130
|
+
def process(
|
131
|
+
self, input_data: FireworksStructuredCompletionInput
|
132
|
+
) -> FireworksStructuredCompletionOutput:
|
133
|
+
"""Process the input and return structured response."""
|
134
|
+
try:
|
135
|
+
if input_data.stream:
|
136
|
+
# For streaming, collect and parse the entire response
|
137
|
+
json_buffer = []
|
138
|
+
parsed_response = None
|
139
|
+
|
140
|
+
for chunk in self.process_stream(input_data):
|
141
|
+
if "chunk" in chunk:
|
142
|
+
json_buffer.append(chunk["chunk"])
|
143
|
+
elif "complete" in chunk:
|
144
|
+
parsed_response = chunk["complete"]
|
145
|
+
|
146
|
+
if parsed_response is None:
|
147
|
+
raise ProcessingError("Failed to parse streamed response")
|
148
|
+
|
149
|
+
return FireworksStructuredCompletionOutput(
|
150
|
+
parsed_response=parsed_response,
|
151
|
+
used_model=input_data.model,
|
152
|
+
usage={}, # Usage stats not available in streaming mode
|
153
|
+
)
|
154
|
+
else:
|
155
|
+
# For non-streaming, use regular request
|
156
|
+
payload = self._build_payload(input_data)
|
157
|
+
response = requests.post(
|
158
|
+
self.BASE_URL, headers=self.headers, data=json.dumps(payload)
|
159
|
+
)
|
160
|
+
response.raise_for_status()
|
161
|
+
data = response.json()
|
162
|
+
|
163
|
+
response_text = data["choices"][0]["text"]
|
164
|
+
parsed_response = input_data.response_model.model_validate_json(
|
165
|
+
response_text
|
166
|
+
)
|
167
|
+
|
168
|
+
return FireworksStructuredCompletionOutput(
|
169
|
+
parsed_response=parsed_response,
|
170
|
+
used_model=input_data.model,
|
171
|
+
usage=data["usage"],
|
172
|
+
)
|
173
|
+
|
174
|
+
except Exception as e:
|
175
|
+
raise ProcessingError(f"Fireworks structured completion failed: {str(e)}")
|
@@ -0,0 +1,291 @@
|
|
1
|
+
from typing import Type, TypeVar, Optional, List, Dict, Any, Generator, Union
|
2
|
+
from pydantic import BaseModel, Field, create_model
|
3
|
+
import requests
|
4
|
+
import json
|
5
|
+
from loguru import logger
|
6
|
+
import re
|
7
|
+
|
8
|
+
from airtrain.core.skills import Skill, ProcessingError
|
9
|
+
from airtrain.core.schemas import InputSchema, OutputSchema
|
10
|
+
from .credentials import FireworksCredentials
|
11
|
+
|
12
|
+
ResponseT = TypeVar("ResponseT", bound=BaseModel)
|
13
|
+
|
14
|
+
|
15
|
+
class FireworksStructuredRequestInput(InputSchema):
|
16
|
+
"""Schema for Fireworks AI structured output input using requests"""
|
17
|
+
|
18
|
+
user_input: str = Field(..., description="User's input text")
|
19
|
+
system_prompt: str = Field(
|
20
|
+
default="You are a helpful assistant that provides structured data.",
|
21
|
+
description="System prompt to guide the model's behavior",
|
22
|
+
)
|
23
|
+
conversation_history: List[Dict[str, Any]] = Field(
|
24
|
+
default_factory=list,
|
25
|
+
description="List of previous conversation messages",
|
26
|
+
)
|
27
|
+
model: str = Field(
|
28
|
+
default="accounts/fireworks/models/deepseek-r1",
|
29
|
+
description="Fireworks AI model to use",
|
30
|
+
)
|
31
|
+
temperature: float = Field(
|
32
|
+
default=0.7, description="Temperature for response generation", ge=0, le=1
|
33
|
+
)
|
34
|
+
max_tokens: int = Field(default=131072, description="Maximum tokens in response")
|
35
|
+
response_model: Type[ResponseT]
|
36
|
+
stream: bool = Field(
|
37
|
+
default=False, description="Whether to stream the response token by token"
|
38
|
+
)
|
39
|
+
tools: Optional[List[Dict[str, Any]]] = Field(
|
40
|
+
default=None,
|
41
|
+
description=(
|
42
|
+
"A list of tools the model may use. "
|
43
|
+
"Currently only functions supported."
|
44
|
+
),
|
45
|
+
)
|
46
|
+
tool_choice: Optional[Union[str, Dict[str, Any]]] = Field(
|
47
|
+
default=None,
|
48
|
+
description=(
|
49
|
+
"Controls which tool is called by the model. "
|
50
|
+
"'none', 'auto', or specific tool."
|
51
|
+
),
|
52
|
+
)
|
53
|
+
|
54
|
+
class Config:
|
55
|
+
arbitrary_types_allowed = True
|
56
|
+
|
57
|
+
|
58
|
+
class FireworksStructuredRequestOutput(OutputSchema):
|
59
|
+
"""Schema for Fireworks AI structured output"""
|
60
|
+
|
61
|
+
parsed_response: Any
|
62
|
+
used_model: str
|
63
|
+
usage: Dict[str, int]
|
64
|
+
reasoning: Optional[str] = None
|
65
|
+
tool_calls: Optional[List[Dict[str, Any]]] = Field(
|
66
|
+
default=None, description="Tool calls generated by the model"
|
67
|
+
)
|
68
|
+
|
69
|
+
|
70
|
+
class FireworksStructuredRequestSkill(
|
71
|
+
Skill[FireworksStructuredRequestInput, FireworksStructuredRequestOutput]
|
72
|
+
):
|
73
|
+
"""Skill for getting structured responses from Fireworks AI using requests"""
|
74
|
+
|
75
|
+
input_schema = FireworksStructuredRequestInput
|
76
|
+
output_schema = FireworksStructuredRequestOutput
|
77
|
+
BASE_URL = "https://api.fireworks.ai/inference/v1/chat/completions"
|
78
|
+
|
79
|
+
def __init__(self, credentials: Optional[FireworksCredentials] = None):
|
80
|
+
"""Initialize the skill with optional credentials"""
|
81
|
+
super().__init__()
|
82
|
+
self.credentials = credentials or FireworksCredentials.from_env()
|
83
|
+
self.headers = {
|
84
|
+
"Accept": "application/json",
|
85
|
+
"Content-Type": "application/json",
|
86
|
+
"Authorization": f"Bearer {self.credentials.fireworks_api_key.get_secret_value()}",
|
87
|
+
}
|
88
|
+
|
89
|
+
def _build_messages(
|
90
|
+
self, input_data: FireworksStructuredRequestInput
|
91
|
+
) -> List[Dict[str, Any]]:
|
92
|
+
"""Build messages list from input data including conversation history."""
|
93
|
+
messages = [{"role": "system", "content": input_data.system_prompt}]
|
94
|
+
|
95
|
+
if input_data.conversation_history:
|
96
|
+
messages.extend(input_data.conversation_history)
|
97
|
+
|
98
|
+
messages.append({"role": "user", "content": input_data.user_input})
|
99
|
+
return messages
|
100
|
+
|
101
|
+
def _build_payload(
|
102
|
+
self, input_data: FireworksStructuredRequestInput
|
103
|
+
) -> Dict[str, Any]:
|
104
|
+
"""Build the request payload."""
|
105
|
+
payload = {
|
106
|
+
"model": input_data.model,
|
107
|
+
"messages": self._build_messages(input_data),
|
108
|
+
"temperature": input_data.temperature,
|
109
|
+
"max_tokens": input_data.max_tokens,
|
110
|
+
"stream": input_data.stream,
|
111
|
+
"response_format": {"type": "json_object"},
|
112
|
+
}
|
113
|
+
|
114
|
+
# Add tool-related parameters if provided
|
115
|
+
if input_data.tools:
|
116
|
+
payload["tools"] = input_data.tools
|
117
|
+
|
118
|
+
if input_data.tool_choice:
|
119
|
+
payload["tool_choice"] = input_data.tool_choice
|
120
|
+
|
121
|
+
return payload
|
122
|
+
|
123
|
+
def process_stream(
|
124
|
+
self, input_data: FireworksStructuredRequestInput
|
125
|
+
) -> Generator[Dict[str, Any], None, None]:
|
126
|
+
"""Process the input and stream the response."""
|
127
|
+
try:
|
128
|
+
payload = self._build_payload(input_data)
|
129
|
+
response = requests.post(
|
130
|
+
self.BASE_URL,
|
131
|
+
headers=self.headers,
|
132
|
+
data=json.dumps(payload),
|
133
|
+
stream=True,
|
134
|
+
)
|
135
|
+
response.raise_for_status()
|
136
|
+
|
137
|
+
json_buffer = []
|
138
|
+
for line in response.iter_lines():
|
139
|
+
if line:
|
140
|
+
try:
|
141
|
+
data = json.loads(line.decode("utf-8").removeprefix("data: "))
|
142
|
+
if data["choices"][0]["delta"].get("content"):
|
143
|
+
content = data["choices"][0]["delta"]["content"]
|
144
|
+
json_buffer.append(content)
|
145
|
+
yield {"chunk": content}
|
146
|
+
except json.JSONDecodeError:
|
147
|
+
continue
|
148
|
+
|
149
|
+
# Once complete, parse the full response with think tags
|
150
|
+
if not json_buffer:
|
151
|
+
# If no data was collected, raise error
|
152
|
+
raise ProcessingError("No data received from Fireworks API")
|
153
|
+
|
154
|
+
complete_response = "".join(json_buffer)
|
155
|
+
reasoning, json_str = self._parse_response_content(complete_response)
|
156
|
+
|
157
|
+
try:
|
158
|
+
parsed_response = input_data.response_model.model_validate_json(
|
159
|
+
json_str
|
160
|
+
)
|
161
|
+
yield {"complete": parsed_response, "reasoning": reasoning}
|
162
|
+
except Exception as e:
|
163
|
+
raise ProcessingError(f"Failed to parse JSON response: {str(e)}")
|
164
|
+
|
165
|
+
except Exception as e:
|
166
|
+
raise ProcessingError(f"Fireworks streaming request failed: {str(e)}")
|
167
|
+
|
168
|
+
def _parse_response_content(self, content: str) -> tuple[Optional[str], str]:
|
169
|
+
"""Parse response content to extract reasoning and JSON."""
|
170
|
+
# Extract reasoning if present
|
171
|
+
reasoning_match = re.search(r"<think>(.*?)</think>", content, re.DOTALL)
|
172
|
+
reasoning = reasoning_match.group(1).strip() if reasoning_match else None
|
173
|
+
|
174
|
+
# Extract JSON
|
175
|
+
json_match = re.search(r"</think>\s*(\{.*\})", content, re.DOTALL)
|
176
|
+
json_str = json_match.group(1).strip() if json_match else content
|
177
|
+
|
178
|
+
return reasoning, json_str
|
179
|
+
|
180
|
+
def process(
|
181
|
+
self, input_data: FireworksStructuredRequestInput
|
182
|
+
) -> FireworksStructuredRequestOutput:
|
183
|
+
"""Process the input and return structured response."""
|
184
|
+
try:
|
185
|
+
if input_data.stream:
|
186
|
+
# For streaming, collect and parse the entire response
|
187
|
+
json_buffer = []
|
188
|
+
parsed_response = None
|
189
|
+
reasoning = None
|
190
|
+
|
191
|
+
for chunk in self.process_stream(input_data):
|
192
|
+
if "chunk" in chunk:
|
193
|
+
json_buffer.append(chunk["chunk"])
|
194
|
+
elif "complete" in chunk:
|
195
|
+
parsed_response = chunk["complete"]
|
196
|
+
reasoning = chunk.get("reasoning")
|
197
|
+
|
198
|
+
if parsed_response is None:
|
199
|
+
raise ProcessingError("Failed to parse streamed response")
|
200
|
+
|
201
|
+
# Make a non-streaming call to get tool calls if tools were provided
|
202
|
+
tool_calls = None
|
203
|
+
if input_data.tools:
|
204
|
+
# Create a non-streaming request to get tool calls
|
205
|
+
non_stream_payload = self._build_payload(input_data)
|
206
|
+
non_stream_payload["stream"] = False
|
207
|
+
|
208
|
+
response = requests.post(
|
209
|
+
self.BASE_URL,
|
210
|
+
headers=self.headers,
|
211
|
+
data=json.dumps(non_stream_payload),
|
212
|
+
)
|
213
|
+
response.raise_for_status()
|
214
|
+
result = response.json()
|
215
|
+
|
216
|
+
# Check for tool calls
|
217
|
+
if (result["choices"][0]["message"].get("tool_calls")):
|
218
|
+
tool_calls = [
|
219
|
+
{
|
220
|
+
"id": tool_call["id"],
|
221
|
+
"type": tool_call["type"],
|
222
|
+
"function": {
|
223
|
+
"name": tool_call["function"]["name"],
|
224
|
+
"arguments": tool_call["function"]["arguments"]
|
225
|
+
}
|
226
|
+
}
|
227
|
+
for tool_call in result["choices"][0]["message"]["tool_calls"]
|
228
|
+
]
|
229
|
+
|
230
|
+
return FireworksStructuredRequestOutput(
|
231
|
+
parsed_response=parsed_response,
|
232
|
+
used_model=input_data.model,
|
233
|
+
usage={"total_tokens": 0}, # Can't get usage stats from streaming
|
234
|
+
reasoning=reasoning,
|
235
|
+
tool_calls=tool_calls,
|
236
|
+
)
|
237
|
+
else:
|
238
|
+
# For non-streaming, use regular request
|
239
|
+
payload = self._build_payload(input_data)
|
240
|
+
payload["stream"] = False # Ensure it's not streaming
|
241
|
+
|
242
|
+
response = requests.post(
|
243
|
+
self.BASE_URL, headers=self.headers, data=json.dumps(payload)
|
244
|
+
)
|
245
|
+
response.raise_for_status()
|
246
|
+
result = response.json()
|
247
|
+
|
248
|
+
# Get the content from the response
|
249
|
+
if "choices" not in result or not result["choices"]:
|
250
|
+
raise ProcessingError("Invalid response format from Fireworks API")
|
251
|
+
|
252
|
+
content = result["choices"][0]["message"].get("content", "")
|
253
|
+
|
254
|
+
# Check for tool calls
|
255
|
+
tool_calls = None
|
256
|
+
if (result["choices"][0]["message"].get("tool_calls")):
|
257
|
+
tool_calls = [
|
258
|
+
{
|
259
|
+
"id": tool_call["id"],
|
260
|
+
"type": tool_call["type"],
|
261
|
+
"function": {
|
262
|
+
"name": tool_call["function"]["name"],
|
263
|
+
"arguments": tool_call["function"]["arguments"]
|
264
|
+
}
|
265
|
+
}
|
266
|
+
for tool_call in result["choices"][0]["message"]["tool_calls"]
|
267
|
+
]
|
268
|
+
|
269
|
+
# Parse the response content
|
270
|
+
reasoning, json_str = self._parse_response_content(content)
|
271
|
+
try:
|
272
|
+
parsed_response = input_data.response_model.model_validate_json(
|
273
|
+
json_str
|
274
|
+
)
|
275
|
+
except Exception as e:
|
276
|
+
raise ProcessingError(f"Failed to parse JSON response: {str(e)}")
|
277
|
+
|
278
|
+
return FireworksStructuredRequestOutput(
|
279
|
+
parsed_response=parsed_response,
|
280
|
+
used_model=input_data.model,
|
281
|
+
usage={
|
282
|
+
"total_tokens": result["usage"]["total_tokens"],
|
283
|
+
"prompt_tokens": result["usage"]["prompt_tokens"],
|
284
|
+
"completion_tokens": result["usage"]["completion_tokens"],
|
285
|
+
},
|
286
|
+
reasoning=reasoning,
|
287
|
+
tool_calls=tool_calls,
|
288
|
+
)
|
289
|
+
|
290
|
+
except Exception as e:
|
291
|
+
raise ProcessingError(f"Fireworks structured request failed: {str(e)}")
|
@@ -0,0 +1,102 @@
|
|
1
|
+
from typing import Type, TypeVar, Optional, List, Dict, Any
|
2
|
+
from pydantic import BaseModel, Field
|
3
|
+
from openai import OpenAI
|
4
|
+
from airtrain.core.skills import Skill, ProcessingError
|
5
|
+
from airtrain.core.schemas import InputSchema, OutputSchema
|
6
|
+
from .credentials import FireworksCredentials
|
7
|
+
import re
|
8
|
+
|
9
|
+
# Generic type variable for Pydantic response models
|
10
|
+
ResponseT = TypeVar("ResponseT", bound=BaseModel)
|
11
|
+
|
12
|
+
|
13
|
+
class FireworksParserInput(InputSchema):
|
14
|
+
"""Schema for Fireworks structured output input"""
|
15
|
+
|
16
|
+
user_input: str
|
17
|
+
system_prompt: str = "You are a helpful assistant that provides structured data."
|
18
|
+
model: str = "accounts/fireworks/models/deepseek-r1"
|
19
|
+
temperature: float = 0.7
|
20
|
+
max_tokens: Optional[int] = 131072
|
21
|
+
response_model: Type[ResponseT]
|
22
|
+
conversation_history: List[Dict[str, str]] = Field(
|
23
|
+
default_factory=list,
|
24
|
+
description="List of previous conversation messages in [{'role': 'user|assistant', 'content': 'message'}] format",
|
25
|
+
)
|
26
|
+
|
27
|
+
class Config:
|
28
|
+
arbitrary_types_allowed = True
|
29
|
+
|
30
|
+
|
31
|
+
class FireworksParserOutput(OutputSchema):
|
32
|
+
"""Schema for Fireworks structured output"""
|
33
|
+
|
34
|
+
parsed_response: BaseModel
|
35
|
+
used_model: str
|
36
|
+
tokens_used: int
|
37
|
+
reasoning: Optional[str] = None
|
38
|
+
|
39
|
+
|
40
|
+
class FireworksParserSkill(Skill[FireworksParserInput, FireworksParserOutput]):
|
41
|
+
"""Skill for getting structured responses from Fireworks"""
|
42
|
+
|
43
|
+
input_schema = FireworksParserInput
|
44
|
+
output_schema = FireworksParserOutput
|
45
|
+
|
46
|
+
def __init__(self, credentials: Optional[FireworksCredentials] = None):
|
47
|
+
"""Initialize the skill with optional credentials"""
|
48
|
+
super().__init__()
|
49
|
+
self.credentials = credentials or FireworksCredentials.from_env()
|
50
|
+
self.client = OpenAI(
|
51
|
+
base_url="https://api.fireworks.ai/inference/v1",
|
52
|
+
api_key=self.credentials.fireworks_api_key.get_secret_value(),
|
53
|
+
)
|
54
|
+
|
55
|
+
def process(self, input_data: FireworksParserInput) -> FireworksParserOutput:
|
56
|
+
try:
|
57
|
+
# Build messages list including conversation history
|
58
|
+
messages = [{"role": "system", "content": input_data.system_prompt}]
|
59
|
+
|
60
|
+
# Add conversation history if present
|
61
|
+
if input_data.conversation_history:
|
62
|
+
messages.extend(input_data.conversation_history)
|
63
|
+
|
64
|
+
# Add current user input
|
65
|
+
messages.append({"role": "user", "content": input_data.user_input})
|
66
|
+
|
67
|
+
# Make API call with JSON schema
|
68
|
+
completion = self.client.chat.completions.create(
|
69
|
+
model=input_data.model,
|
70
|
+
messages=messages,
|
71
|
+
response_format={
|
72
|
+
"type": "json_object",
|
73
|
+
"schema": input_data.response_model.model_json_schema(),
|
74
|
+
},
|
75
|
+
temperature=input_data.temperature,
|
76
|
+
max_tokens=input_data.max_tokens,
|
77
|
+
)
|
78
|
+
|
79
|
+
response_content = completion.choices[0].message.content
|
80
|
+
|
81
|
+
# Extract reasoning if present
|
82
|
+
reasoning_match = re.search(
|
83
|
+
r"<think>(.*?)</think>", response_content, re.DOTALL
|
84
|
+
)
|
85
|
+
reasoning = reasoning_match.group(1).strip() if reasoning_match else None
|
86
|
+
|
87
|
+
# Extract JSON
|
88
|
+
json_match = re.search(r"</think>\s*(\{.*\})", response_content, re.DOTALL)
|
89
|
+
json_str = json_match.group(1).strip() if json_match else response_content
|
90
|
+
|
91
|
+
# Parse the response into the specified model
|
92
|
+
parsed_response = input_data.response_model.parse_raw(json_str)
|
93
|
+
|
94
|
+
return FireworksParserOutput(
|
95
|
+
parsed_response=parsed_response,
|
96
|
+
used_model=completion.model,
|
97
|
+
tokens_used=completion.usage.total_tokens,
|
98
|
+
reasoning=reasoning,
|
99
|
+
)
|
100
|
+
|
101
|
+
except Exception as e:
|
102
|
+
raise ProcessingError(f"Fireworks parsing failed: {str(e)}")
|
@@ -0,0 +1,58 @@
|
|
1
|
+
from pydantic import Field, SecretStr
|
2
|
+
from airtrain.core.credentials import BaseCredentials, CredentialValidationError
|
3
|
+
import google.genai as genai
|
4
|
+
from google.cloud import storage
|
5
|
+
import os
|
6
|
+
|
7
|
+
# from google.cloud import storage
|
8
|
+
|
9
|
+
|
10
|
+
class GoogleCloudCredentials(BaseCredentials):
|
11
|
+
"""Google Cloud credentials"""
|
12
|
+
|
13
|
+
project_id: str = Field(..., description="Google Cloud Project ID")
|
14
|
+
service_account_key: SecretStr = Field(..., description="Service Account Key JSON")
|
15
|
+
|
16
|
+
_required_credentials = {"project_id", "service_account_key"}
|
17
|
+
|
18
|
+
async def validate_credentials(self) -> bool:
|
19
|
+
"""Validate Google Cloud credentials"""
|
20
|
+
try:
|
21
|
+
# Initialize with service account key
|
22
|
+
storage_client = storage.Client.from_service_account_info(
|
23
|
+
self.service_account_key.get_secret_value()
|
24
|
+
)
|
25
|
+
# Test API call
|
26
|
+
storage_client.list_buckets(max_results=1)
|
27
|
+
return True
|
28
|
+
except Exception as e:
|
29
|
+
raise CredentialValidationError(
|
30
|
+
f"Invalid Google Cloud credentials: {str(e)}"
|
31
|
+
)
|
32
|
+
|
33
|
+
|
34
|
+
class GeminiCredentials(BaseCredentials):
|
35
|
+
"""Gemini API credentials"""
|
36
|
+
|
37
|
+
gemini_api_key: SecretStr = Field(..., description="Gemini API Key")
|
38
|
+
|
39
|
+
_required_credentials = {"gemini_api_key"}
|
40
|
+
|
41
|
+
@classmethod
|
42
|
+
def from_env(cls) -> "GeminiCredentials":
|
43
|
+
"""Create credentials from environment variables"""
|
44
|
+
return cls(gemini_api_key=SecretStr(os.environ.get("GEMINI_API_KEY", "")))
|
45
|
+
|
46
|
+
async def validate_credentials(self) -> bool:
|
47
|
+
"""Validate Gemini API credentials"""
|
48
|
+
try:
|
49
|
+
# Configure Gemini with API key
|
50
|
+
genai.configure(api_key=self.gemini_api_key.get_secret_value())
|
51
|
+
|
52
|
+
# Test API call with a simple model
|
53
|
+
model = genai.GenerativeModel("gemini-1.5-flash")
|
54
|
+
response = model.generate_content("test")
|
55
|
+
|
56
|
+
return True
|
57
|
+
except Exception as e:
|
58
|
+
raise CredentialValidationError(f"Invalid Gemini credentials: {str(e)}")
|