airtrain 0.1.26__py3-none-any.whl → 0.1.28__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airtrain/__init__.py +1 -1
- airtrain/integrations/fireworks/completion_skills.py +147 -0
- airtrain/integrations/fireworks/requests_skills.py +152 -0
- airtrain/integrations/fireworks/skills.py +53 -52
- airtrain/integrations/fireworks/structured_completion_skills.py +169 -0
- airtrain/integrations/fireworks/structured_requests_skills.py +209 -0
- airtrain/integrations/fireworks/structured_skills.py +102 -0
- airtrain/integrations/openai/skills.py +69 -85
- {airtrain-0.1.26.dist-info → airtrain-0.1.28.dist-info}/METADATA +16 -2
- {airtrain-0.1.26.dist-info → airtrain-0.1.28.dist-info}/RECORD +12 -7
- {airtrain-0.1.26.dist-info → airtrain-0.1.28.dist-info}/WHEEL +0 -0
- {airtrain-0.1.26.dist-info → airtrain-0.1.28.dist-info}/top_level.txt +0 -0
airtrain/__init__.py
CHANGED
@@ -0,0 +1,147 @@
|
|
1
|
+
from typing import List, Optional, Dict, Any, Generator, Union
|
2
|
+
from pydantic import Field
|
3
|
+
import requests
|
4
|
+
import json
|
5
|
+
from loguru import logger
|
6
|
+
|
7
|
+
from airtrain.core.skills import Skill, ProcessingError
|
8
|
+
from airtrain.core.schemas import InputSchema, OutputSchema
|
9
|
+
from .credentials import FireworksCredentials
|
10
|
+
|
11
|
+
|
12
|
+
class FireworksCompletionInput(InputSchema):
|
13
|
+
"""Schema for Fireworks AI completion input using requests"""
|
14
|
+
|
15
|
+
prompt: str = Field(..., description="Input prompt for completion")
|
16
|
+
model: str = Field(
|
17
|
+
default="accounts/fireworks/models/deepseek-r1",
|
18
|
+
description="Fireworks AI model to use",
|
19
|
+
)
|
20
|
+
max_tokens: int = Field(default=4096, description="Maximum tokens in response")
|
21
|
+
temperature: float = Field(
|
22
|
+
default=0.7, description="Temperature for response generation", ge=0, le=1
|
23
|
+
)
|
24
|
+
top_p: float = Field(
|
25
|
+
default=1.0, description="Top p sampling parameter", ge=0, le=1
|
26
|
+
)
|
27
|
+
top_k: int = Field(default=50, description="Top k sampling parameter", ge=0)
|
28
|
+
presence_penalty: float = Field(
|
29
|
+
default=0.0, description="Presence penalty", ge=-2.0, le=2.0
|
30
|
+
)
|
31
|
+
frequency_penalty: float = Field(
|
32
|
+
default=0.0, description="Frequency penalty", ge=-2.0, le=2.0
|
33
|
+
)
|
34
|
+
repetition_penalty: float = Field(
|
35
|
+
default=1.0, description="Repetition penalty", ge=0.0
|
36
|
+
)
|
37
|
+
stop: Optional[Union[str, List[str]]] = Field(
|
38
|
+
default=None, description="Stop sequences"
|
39
|
+
)
|
40
|
+
echo: bool = Field(default=False, description="Echo the prompt in the response")
|
41
|
+
stream: bool = Field(default=False, description="Whether to stream the response")
|
42
|
+
|
43
|
+
|
44
|
+
class FireworksCompletionOutput(OutputSchema):
|
45
|
+
"""Schema for Fireworks AI completion output"""
|
46
|
+
|
47
|
+
response: str
|
48
|
+
used_model: str
|
49
|
+
usage: Dict[str, int]
|
50
|
+
|
51
|
+
|
52
|
+
class FireworksCompletionSkill(
|
53
|
+
Skill[FireworksCompletionInput, FireworksCompletionOutput]
|
54
|
+
):
|
55
|
+
"""Skill for text completion using Fireworks AI"""
|
56
|
+
|
57
|
+
input_schema = FireworksCompletionInput
|
58
|
+
output_schema = FireworksCompletionOutput
|
59
|
+
BASE_URL = "https://api.fireworks.ai/inference/v1/completions"
|
60
|
+
|
61
|
+
def __init__(self, credentials: Optional[FireworksCredentials] = None):
|
62
|
+
"""Initialize the skill with optional credentials"""
|
63
|
+
super().__init__()
|
64
|
+
self.credentials = credentials or FireworksCredentials.from_env()
|
65
|
+
self.headers = {
|
66
|
+
"Accept": "application/json",
|
67
|
+
"Content-Type": "application/json",
|
68
|
+
"Authorization": f"Bearer {self.credentials.fireworks_api_key.get_secret_value()}",
|
69
|
+
}
|
70
|
+
|
71
|
+
def _build_payload(self, input_data: FireworksCompletionInput) -> Dict[str, Any]:
|
72
|
+
"""Build the request payload."""
|
73
|
+
payload = {
|
74
|
+
"model": input_data.model,
|
75
|
+
"prompt": input_data.prompt,
|
76
|
+
"max_tokens": input_data.max_tokens,
|
77
|
+
"temperature": input_data.temperature,
|
78
|
+
"top_p": input_data.top_p,
|
79
|
+
"top_k": input_data.top_k,
|
80
|
+
"presence_penalty": input_data.presence_penalty,
|
81
|
+
"frequency_penalty": input_data.frequency_penalty,
|
82
|
+
"repetition_penalty": input_data.repetition_penalty,
|
83
|
+
"echo": input_data.echo,
|
84
|
+
"stream": input_data.stream,
|
85
|
+
}
|
86
|
+
|
87
|
+
if input_data.stop:
|
88
|
+
payload["stop"] = input_data.stop
|
89
|
+
|
90
|
+
return payload
|
91
|
+
|
92
|
+
def process_stream(
|
93
|
+
self, input_data: FireworksCompletionInput
|
94
|
+
) -> Generator[str, None, None]:
|
95
|
+
"""Process the input and stream the response."""
|
96
|
+
try:
|
97
|
+
payload = self._build_payload(input_data)
|
98
|
+
response = requests.post(
|
99
|
+
self.BASE_URL,
|
100
|
+
headers=self.headers,
|
101
|
+
data=json.dumps(payload),
|
102
|
+
stream=True,
|
103
|
+
)
|
104
|
+
response.raise_for_status()
|
105
|
+
|
106
|
+
for line in response.iter_lines():
|
107
|
+
if line:
|
108
|
+
try:
|
109
|
+
data = json.loads(line.decode("utf-8").removeprefix("data: "))
|
110
|
+
if data.get("choices") and data["choices"][0].get("text"):
|
111
|
+
yield data["choices"][0]["text"]
|
112
|
+
except json.JSONDecodeError:
|
113
|
+
continue
|
114
|
+
|
115
|
+
except Exception as e:
|
116
|
+
raise ProcessingError(f"Fireworks completion streaming failed: {str(e)}")
|
117
|
+
|
118
|
+
def process(
|
119
|
+
self, input_data: FireworksCompletionInput
|
120
|
+
) -> FireworksCompletionOutput:
|
121
|
+
"""Process the input and return completion response."""
|
122
|
+
try:
|
123
|
+
if input_data.stream:
|
124
|
+
# For streaming, collect the entire response
|
125
|
+
response_chunks = []
|
126
|
+
for chunk in self.process_stream(input_data):
|
127
|
+
response_chunks.append(chunk)
|
128
|
+
response_text = "".join(response_chunks)
|
129
|
+
usage = {} # Usage stats not available in streaming mode
|
130
|
+
else:
|
131
|
+
# For non-streaming, use regular request
|
132
|
+
payload = self._build_payload(input_data)
|
133
|
+
response = requests.post(
|
134
|
+
self.BASE_URL, headers=self.headers, data=json.dumps(payload)
|
135
|
+
)
|
136
|
+
response.raise_for_status()
|
137
|
+
data = response.json()
|
138
|
+
|
139
|
+
response_text = data["choices"][0]["text"]
|
140
|
+
usage = data["usage"]
|
141
|
+
|
142
|
+
return FireworksCompletionOutput(
|
143
|
+
response=response_text, used_model=input_data.model, usage=usage
|
144
|
+
)
|
145
|
+
|
146
|
+
except Exception as e:
|
147
|
+
raise ProcessingError(f"Fireworks completion failed: {str(e)}")
|
@@ -0,0 +1,152 @@
|
|
1
|
+
from typing import List, Optional, Dict, Any, Generator
|
2
|
+
from pydantic import Field
|
3
|
+
import requests
|
4
|
+
import json
|
5
|
+
from loguru import logger
|
6
|
+
|
7
|
+
from airtrain.core.skills import Skill, ProcessingError
|
8
|
+
from airtrain.core.schemas import InputSchema, OutputSchema
|
9
|
+
from .credentials import FireworksCredentials
|
10
|
+
|
11
|
+
|
12
|
+
class FireworksRequestInput(InputSchema):
|
13
|
+
"""Schema for Fireworks AI chat input using requests"""
|
14
|
+
|
15
|
+
user_input: str = Field(..., description="User's input text")
|
16
|
+
system_prompt: str = Field(
|
17
|
+
default="You are a helpful assistant.",
|
18
|
+
description="System prompt to guide the model's behavior",
|
19
|
+
)
|
20
|
+
conversation_history: List[Dict[str, str]] = Field(
|
21
|
+
default_factory=list,
|
22
|
+
description="List of previous conversation messages",
|
23
|
+
)
|
24
|
+
model: str = Field(
|
25
|
+
default="accounts/fireworks/models/deepseek-r1",
|
26
|
+
description="Fireworks AI model to use",
|
27
|
+
)
|
28
|
+
temperature: float = Field(
|
29
|
+
default=0.7, description="Temperature for response generation", ge=0, le=1
|
30
|
+
)
|
31
|
+
max_tokens: int = Field(default=4096, description="Maximum tokens in response")
|
32
|
+
top_p: float = Field(
|
33
|
+
default=1.0, description="Top p sampling parameter", ge=0, le=1
|
34
|
+
)
|
35
|
+
top_k: int = Field(default=40, description="Top k sampling parameter", ge=0)
|
36
|
+
presence_penalty: float = Field(
|
37
|
+
default=0.0, description="Presence penalty", ge=-2.0, le=2.0
|
38
|
+
)
|
39
|
+
frequency_penalty: float = Field(
|
40
|
+
default=0.0, description="Frequency penalty", ge=-2.0, le=2.0
|
41
|
+
)
|
42
|
+
stream: bool = Field(
|
43
|
+
default=False,
|
44
|
+
description="Whether to stream the response",
|
45
|
+
)
|
46
|
+
|
47
|
+
|
48
|
+
class FireworksRequestOutput(OutputSchema):
|
49
|
+
"""Schema for Fireworks AI chat output"""
|
50
|
+
|
51
|
+
response: str
|
52
|
+
used_model: str
|
53
|
+
usage: Dict[str, int]
|
54
|
+
|
55
|
+
|
56
|
+
class FireworksRequestSkill(Skill[FireworksRequestInput, FireworksRequestOutput]):
|
57
|
+
"""Skill for interacting with Fireworks AI models using requests"""
|
58
|
+
|
59
|
+
input_schema = FireworksRequestInput
|
60
|
+
output_schema = FireworksRequestOutput
|
61
|
+
BASE_URL = "https://api.fireworks.ai/inference/v1/chat/completions"
|
62
|
+
|
63
|
+
def __init__(self, credentials: Optional[FireworksCredentials] = None):
|
64
|
+
"""Initialize the skill with optional credentials"""
|
65
|
+
super().__init__()
|
66
|
+
self.credentials = credentials or FireworksCredentials.from_env()
|
67
|
+
self.headers = {
|
68
|
+
"Accept": "application/json",
|
69
|
+
"Content-Type": "application/json",
|
70
|
+
"Authorization": f"Bearer {self.credentials.fireworks_api_key.get_secret_value()}",
|
71
|
+
}
|
72
|
+
|
73
|
+
def _build_messages(
|
74
|
+
self, input_data: FireworksRequestInput
|
75
|
+
) -> List[Dict[str, str]]:
|
76
|
+
"""Build messages list from input data including conversation history."""
|
77
|
+
messages = [{"role": "system", "content": input_data.system_prompt}]
|
78
|
+
|
79
|
+
if input_data.conversation_history:
|
80
|
+
messages.extend(input_data.conversation_history)
|
81
|
+
|
82
|
+
messages.append({"role": "user", "content": input_data.user_input})
|
83
|
+
return messages
|
84
|
+
|
85
|
+
def _build_payload(self, input_data: FireworksRequestInput) -> Dict[str, Any]:
|
86
|
+
"""Build the request payload."""
|
87
|
+
return {
|
88
|
+
"model": input_data.model,
|
89
|
+
"messages": self._build_messages(input_data),
|
90
|
+
"temperature": input_data.temperature,
|
91
|
+
"max_tokens": input_data.max_tokens,
|
92
|
+
"top_p": input_data.top_p,
|
93
|
+
"top_k": input_data.top_k,
|
94
|
+
"presence_penalty": input_data.presence_penalty,
|
95
|
+
"frequency_penalty": input_data.frequency_penalty,
|
96
|
+
"stream": input_data.stream,
|
97
|
+
}
|
98
|
+
|
99
|
+
def process_stream(
|
100
|
+
self, input_data: FireworksRequestInput
|
101
|
+
) -> Generator[str, None, None]:
|
102
|
+
"""Process the input and stream the response."""
|
103
|
+
try:
|
104
|
+
payload = self._build_payload(input_data)
|
105
|
+
response = requests.post(
|
106
|
+
self.BASE_URL,
|
107
|
+
headers=self.headers,
|
108
|
+
data=json.dumps(payload),
|
109
|
+
stream=True,
|
110
|
+
)
|
111
|
+
response.raise_for_status()
|
112
|
+
|
113
|
+
for line in response.iter_lines():
|
114
|
+
if line:
|
115
|
+
try:
|
116
|
+
data = json.loads(line.decode("utf-8").removeprefix("data: "))
|
117
|
+
if data["choices"][0]["delta"].get("content"):
|
118
|
+
yield data["choices"][0]["delta"]["content"]
|
119
|
+
except json.JSONDecodeError:
|
120
|
+
continue
|
121
|
+
|
122
|
+
except Exception as e:
|
123
|
+
raise ProcessingError(f"Fireworks streaming request failed: {str(e)}")
|
124
|
+
|
125
|
+
def process(self, input_data: FireworksRequestInput) -> FireworksRequestOutput:
|
126
|
+
"""Process the input and return the complete response."""
|
127
|
+
try:
|
128
|
+
if input_data.stream:
|
129
|
+
# For streaming, collect the entire response
|
130
|
+
response_chunks = []
|
131
|
+
for chunk in self.process_stream(input_data):
|
132
|
+
response_chunks.append(chunk)
|
133
|
+
response_text = "".join(response_chunks)
|
134
|
+
usage = {} # Usage stats not available in streaming mode
|
135
|
+
else:
|
136
|
+
# For non-streaming, use regular request
|
137
|
+
payload = self._build_payload(input_data)
|
138
|
+
response = requests.post(
|
139
|
+
self.BASE_URL, headers=self.headers, data=json.dumps(payload)
|
140
|
+
)
|
141
|
+
response.raise_for_status()
|
142
|
+
data = response.json()
|
143
|
+
|
144
|
+
response_text = data["choices"][0]["message"]["content"]
|
145
|
+
usage = data["usage"]
|
146
|
+
|
147
|
+
return FireworksRequestOutput(
|
148
|
+
response=response_text, used_model=input_data.model, usage=usage
|
149
|
+
)
|
150
|
+
|
151
|
+
except Exception as e:
|
152
|
+
raise ProcessingError(f"Fireworks request failed: {str(e)}")
|
@@ -1,7 +1,9 @@
|
|
1
|
-
from typing import List, Optional, Dict, Any
|
1
|
+
from typing import List, Optional, Dict, Any, Generator
|
2
2
|
from pydantic import Field
|
3
3
|
import requests
|
4
4
|
from loguru import logger
|
5
|
+
from openai import OpenAI
|
6
|
+
from openai.types.chat import ChatCompletionChunk
|
5
7
|
|
6
8
|
from airtrain.core.skills import Skill, ProcessingError
|
7
9
|
from airtrain.core.schemas import InputSchema, OutputSchema
|
@@ -34,10 +36,14 @@ class FireworksInput(InputSchema):
|
|
34
36
|
context_length_exceeded_behavior: str = Field(
|
35
37
|
default="truncate", description="Behavior when context length is exceeded"
|
36
38
|
)
|
39
|
+
stream: bool = Field(
|
40
|
+
default=False,
|
41
|
+
description="Whether to stream the response token by token",
|
42
|
+
)
|
37
43
|
|
38
44
|
|
39
45
|
class FireworksOutput(OutputSchema):
|
40
|
-
"""Schema for Fireworks AI output"""
|
46
|
+
"""Schema for Fireworks AI chat output"""
|
41
47
|
|
42
48
|
response: str = Field(..., description="Model's response text")
|
43
49
|
used_model: str = Field(..., description="Model used for generation")
|
@@ -54,76 +60,71 @@ class FireworksChatSkill(Skill[FireworksInput, FireworksOutput]):
|
|
54
60
|
"""Initialize the skill with optional credentials"""
|
55
61
|
super().__init__()
|
56
62
|
self.credentials = credentials or FireworksCredentials.from_env()
|
57
|
-
self.
|
63
|
+
self.client = OpenAI(
|
64
|
+
base_url="https://api.fireworks.ai/inference/v1",
|
65
|
+
api_key=self.credentials.fireworks_api_key.get_secret_value(),
|
66
|
+
)
|
58
67
|
|
59
68
|
def _build_messages(self, input_data: FireworksInput) -> List[Dict[str, str]]:
|
60
|
-
"""
|
61
|
-
Build messages list from input data including conversation history.
|
62
|
-
|
63
|
-
Args:
|
64
|
-
input_data: The input data containing system prompt, conversation history, and user input
|
65
|
-
|
66
|
-
Returns:
|
67
|
-
List[Dict[str, str]]: List of messages in the format required by Fireworks AI
|
68
|
-
"""
|
69
|
+
"""Build messages list from input data including conversation history."""
|
69
70
|
messages = [{"role": "system", "content": input_data.system_prompt}]
|
70
71
|
|
71
|
-
# Add conversation history if present
|
72
72
|
if input_data.conversation_history:
|
73
73
|
messages.extend(input_data.conversation_history)
|
74
74
|
|
75
|
-
# Add current user input
|
76
75
|
messages.append({"role": "user", "content": input_data.user_input})
|
77
|
-
|
78
76
|
return messages
|
79
77
|
|
80
|
-
def
|
81
|
-
"""Process the input
|
78
|
+
def process_stream(self, input_data: FireworksInput) -> Generator[str, None, None]:
|
79
|
+
"""Process the input and stream the response token by token."""
|
82
80
|
try:
|
83
|
-
logger.info(f"Processing request with model {input_data.model}")
|
84
|
-
|
85
|
-
# Build messages using the helper method
|
86
81
|
messages = self._build_messages(input_data)
|
87
82
|
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
"n": 1,
|
95
|
-
"response_format": {"type": "text"},
|
96
|
-
"stream": False,
|
97
|
-
}
|
98
|
-
|
99
|
-
if input_data.max_tokens:
|
100
|
-
payload["max_tokens"] = input_data.max_tokens
|
101
|
-
|
102
|
-
# Make API request
|
103
|
-
response = requests.post(
|
104
|
-
f"{self.base_url}/chat/completions",
|
105
|
-
json=payload,
|
106
|
-
headers={
|
107
|
-
"Authorization": f"Bearer {self.credentials.fireworks_api_key.get_secret_value()}",
|
108
|
-
"Content-Type": "application/json",
|
109
|
-
},
|
83
|
+
stream = self.client.chat.completions.create(
|
84
|
+
model=input_data.model,
|
85
|
+
messages=messages,
|
86
|
+
temperature=input_data.temperature,
|
87
|
+
max_tokens=input_data.max_tokens,
|
88
|
+
stream=True,
|
110
89
|
)
|
111
90
|
|
112
|
-
|
113
|
-
|
91
|
+
for chunk in stream:
|
92
|
+
if chunk.choices[0].delta.content is not None:
|
93
|
+
yield chunk.choices[0].delta.content
|
94
|
+
|
95
|
+
except Exception as e:
|
96
|
+
raise ProcessingError(f"Fireworks streaming failed: {str(e)}")
|
114
97
|
|
115
|
-
|
98
|
+
def process(self, input_data: FireworksInput) -> FireworksOutput:
|
99
|
+
"""Process the input and return the complete response."""
|
100
|
+
try:
|
101
|
+
if input_data.stream:
|
102
|
+
# For streaming, collect the entire response
|
103
|
+
response_chunks = []
|
104
|
+
for chunk in self.process_stream(input_data):
|
105
|
+
response_chunks.append(chunk)
|
106
|
+
response = "".join(response_chunks)
|
107
|
+
else:
|
108
|
+
# For non-streaming, use regular completion
|
109
|
+
messages = self._build_messages(input_data)
|
110
|
+
completion = self.client.chat.completions.create(
|
111
|
+
model=input_data.model,
|
112
|
+
messages=messages,
|
113
|
+
temperature=input_data.temperature,
|
114
|
+
max_tokens=input_data.max_tokens,
|
115
|
+
stream=False,
|
116
|
+
)
|
117
|
+
response = completion.choices[0].message.content
|
116
118
|
|
117
119
|
return FireworksOutput(
|
118
|
-
response=
|
119
|
-
used_model=
|
120
|
+
response=response,
|
121
|
+
used_model=input_data.model,
|
120
122
|
usage={
|
121
|
-
"
|
122
|
-
"
|
123
|
-
"
|
123
|
+
"total_tokens": completion.usage.total_tokens,
|
124
|
+
"prompt_tokens": completion.usage.prompt_tokens,
|
125
|
+
"completion_tokens": completion.usage.completion_tokens,
|
124
126
|
},
|
125
127
|
)
|
126
128
|
|
127
129
|
except Exception as e:
|
128
|
-
|
129
|
-
raise ProcessingError(f"Fireworks AI processing failed: {str(e)}")
|
130
|
+
raise ProcessingError(f"Fireworks chat failed: {str(e)}")
|
@@ -0,0 +1,169 @@
|
|
1
|
+
from typing import Type, TypeVar, Optional, List, Dict, Any, Generator
|
2
|
+
from pydantic import BaseModel, Field
|
3
|
+
import requests
|
4
|
+
import json
|
5
|
+
from loguru import logger
|
6
|
+
|
7
|
+
from airtrain.core.skills import Skill, ProcessingError
|
8
|
+
from airtrain.core.schemas import InputSchema, OutputSchema
|
9
|
+
from .credentials import FireworksCredentials
|
10
|
+
|
11
|
+
ResponseT = TypeVar("ResponseT", bound=BaseModel)
|
12
|
+
|
13
|
+
|
14
|
+
class FireworksStructuredCompletionInput(InputSchema):
|
15
|
+
"""Schema for Fireworks AI structured completion input"""
|
16
|
+
|
17
|
+
prompt: str = Field(..., description="Input prompt for completion")
|
18
|
+
model: str = Field(
|
19
|
+
default="accounts/fireworks/models/deepseek-r1",
|
20
|
+
description="Fireworks AI model to use",
|
21
|
+
)
|
22
|
+
temperature: float = Field(
|
23
|
+
default=0.7, description="Temperature for response generation", ge=0, le=1
|
24
|
+
)
|
25
|
+
max_tokens: int = Field(default=4096, description="Maximum tokens in response")
|
26
|
+
response_model: Type[ResponseT]
|
27
|
+
stream: bool = Field(
|
28
|
+
default=False,
|
29
|
+
description="Whether to stream the response",
|
30
|
+
)
|
31
|
+
|
32
|
+
class Config:
|
33
|
+
arbitrary_types_allowed = True
|
34
|
+
|
35
|
+
|
36
|
+
class FireworksStructuredCompletionOutput(OutputSchema):
|
37
|
+
"""Schema for Fireworks AI structured completion output"""
|
38
|
+
|
39
|
+
parsed_response: Any
|
40
|
+
used_model: str
|
41
|
+
usage: Dict[str, int]
|
42
|
+
|
43
|
+
|
44
|
+
class FireworksStructuredCompletionSkill(
|
45
|
+
Skill[FireworksStructuredCompletionInput, FireworksStructuredCompletionOutput]
|
46
|
+
):
|
47
|
+
"""Skill for getting structured completion responses from Fireworks AI"""
|
48
|
+
|
49
|
+
input_schema = FireworksStructuredCompletionInput
|
50
|
+
output_schema = FireworksStructuredCompletionOutput
|
51
|
+
BASE_URL = "https://api.fireworks.ai/inference/v1/completions"
|
52
|
+
|
53
|
+
def __init__(self, credentials: Optional[FireworksCredentials] = None):
|
54
|
+
"""Initialize the skill with optional credentials"""
|
55
|
+
super().__init__()
|
56
|
+
self.credentials = credentials or FireworksCredentials.from_env()
|
57
|
+
self.headers = {
|
58
|
+
"Accept": "application/json",
|
59
|
+
"Content-Type": "application/json",
|
60
|
+
"Authorization": f"Bearer {self.credentials.fireworks_api_key.get_secret_value()}",
|
61
|
+
}
|
62
|
+
|
63
|
+
def _build_payload(
|
64
|
+
self, input_data: FireworksStructuredCompletionInput
|
65
|
+
) -> Dict[str, Any]:
|
66
|
+
"""Build the request payload."""
|
67
|
+
return {
|
68
|
+
"model": input_data.model,
|
69
|
+
"prompt": input_data.prompt,
|
70
|
+
"temperature": input_data.temperature,
|
71
|
+
"max_tokens": input_data.max_tokens,
|
72
|
+
"stream": input_data.stream,
|
73
|
+
"response_format": {
|
74
|
+
"type": "json_object",
|
75
|
+
"schema": {
|
76
|
+
**input_data.response_model.model_json_schema(),
|
77
|
+
"required": [
|
78
|
+
field
|
79
|
+
for field, _ in input_data.response_model.model_fields.items()
|
80
|
+
],
|
81
|
+
},
|
82
|
+
},
|
83
|
+
}
|
84
|
+
|
85
|
+
def process_stream(
|
86
|
+
self, input_data: FireworksStructuredCompletionInput
|
87
|
+
) -> Generator[Dict[str, Any], None, None]:
|
88
|
+
"""Process the input and stream the response."""
|
89
|
+
try:
|
90
|
+
payload = self._build_payload(input_data)
|
91
|
+
response = requests.post(
|
92
|
+
self.BASE_URL,
|
93
|
+
headers=self.headers,
|
94
|
+
data=json.dumps(payload),
|
95
|
+
stream=True,
|
96
|
+
)
|
97
|
+
response.raise_for_status()
|
98
|
+
|
99
|
+
json_buffer = []
|
100
|
+
for line in response.iter_lines():
|
101
|
+
if line:
|
102
|
+
try:
|
103
|
+
data = json.loads(line.decode("utf-8").removeprefix("data: "))
|
104
|
+
if data.get("choices") and data["choices"][0].get("text"):
|
105
|
+
content = data["choices"][0]["text"]
|
106
|
+
json_buffer.append(content)
|
107
|
+
yield {"chunk": content}
|
108
|
+
except json.JSONDecodeError:
|
109
|
+
continue
|
110
|
+
|
111
|
+
# Once complete, parse the full JSON
|
112
|
+
complete_json = "".join(json_buffer)
|
113
|
+
try:
|
114
|
+
parsed_response = input_data.response_model.model_validate_json(
|
115
|
+
complete_json
|
116
|
+
)
|
117
|
+
yield {"complete": parsed_response}
|
118
|
+
except Exception as e:
|
119
|
+
raise ProcessingError(f"Failed to parse JSON response: {str(e)}")
|
120
|
+
|
121
|
+
except Exception as e:
|
122
|
+
raise ProcessingError(f"Fireworks streaming request failed: {str(e)}")
|
123
|
+
|
124
|
+
def process(
|
125
|
+
self, input_data: FireworksStructuredCompletionInput
|
126
|
+
) -> FireworksStructuredCompletionOutput:
|
127
|
+
"""Process the input and return structured response."""
|
128
|
+
try:
|
129
|
+
if input_data.stream:
|
130
|
+
# For streaming, collect and parse the entire response
|
131
|
+
json_buffer = []
|
132
|
+
parsed_response = None
|
133
|
+
|
134
|
+
for chunk in self.process_stream(input_data):
|
135
|
+
if "chunk" in chunk:
|
136
|
+
json_buffer.append(chunk["chunk"])
|
137
|
+
elif "complete" in chunk:
|
138
|
+
parsed_response = chunk["complete"]
|
139
|
+
|
140
|
+
if parsed_response is None:
|
141
|
+
raise ProcessingError("Failed to parse streamed response")
|
142
|
+
|
143
|
+
return FireworksStructuredCompletionOutput(
|
144
|
+
parsed_response=parsed_response,
|
145
|
+
used_model=input_data.model,
|
146
|
+
usage={}, # Usage stats not available in streaming mode
|
147
|
+
)
|
148
|
+
else:
|
149
|
+
# For non-streaming, use regular request
|
150
|
+
payload = self._build_payload(input_data)
|
151
|
+
response = requests.post(
|
152
|
+
self.BASE_URL, headers=self.headers, data=json.dumps(payload)
|
153
|
+
)
|
154
|
+
response.raise_for_status()
|
155
|
+
data = response.json()
|
156
|
+
|
157
|
+
response_text = data["choices"][0]["text"]
|
158
|
+
parsed_response = input_data.response_model.model_validate_json(
|
159
|
+
response_text
|
160
|
+
)
|
161
|
+
|
162
|
+
return FireworksStructuredCompletionOutput(
|
163
|
+
parsed_response=parsed_response,
|
164
|
+
used_model=input_data.model,
|
165
|
+
usage=data["usage"],
|
166
|
+
)
|
167
|
+
|
168
|
+
except Exception as e:
|
169
|
+
raise ProcessingError(f"Fireworks structured completion failed: {str(e)}")
|
@@ -0,0 +1,209 @@
|
|
1
|
+
from typing import Type, TypeVar, Optional, List, Dict, Any, Generator
|
2
|
+
from pydantic import BaseModel, Field
|
3
|
+
import requests
|
4
|
+
import json
|
5
|
+
from loguru import logger
|
6
|
+
import re
|
7
|
+
|
8
|
+
from airtrain.core.skills import Skill, ProcessingError
|
9
|
+
from airtrain.core.schemas import InputSchema, OutputSchema
|
10
|
+
from .credentials import FireworksCredentials
|
11
|
+
|
12
|
+
ResponseT = TypeVar("ResponseT", bound=BaseModel)
|
13
|
+
|
14
|
+
|
15
|
+
class FireworksStructuredRequestInput(InputSchema):
|
16
|
+
"""Schema for Fireworks AI structured output input using requests"""
|
17
|
+
|
18
|
+
user_input: str = Field(..., description="User's input text")
|
19
|
+
system_prompt: str = Field(
|
20
|
+
default="You are a helpful assistant that provides structured data.",
|
21
|
+
description="System prompt to guide the model's behavior",
|
22
|
+
)
|
23
|
+
conversation_history: List[Dict[str, str]] = Field(
|
24
|
+
default_factory=list,
|
25
|
+
description="List of previous conversation messages",
|
26
|
+
)
|
27
|
+
model: str = Field(
|
28
|
+
default="accounts/fireworks/models/deepseek-r1",
|
29
|
+
description="Fireworks AI model to use",
|
30
|
+
)
|
31
|
+
temperature: float = Field(
|
32
|
+
default=0.7, description="Temperature for response generation", ge=0, le=1
|
33
|
+
)
|
34
|
+
max_tokens: int = Field(default=4096, description="Maximum tokens in response")
|
35
|
+
response_model: Type[ResponseT]
|
36
|
+
stream: bool = Field(
|
37
|
+
default=False,
|
38
|
+
description="Whether to stream the response",
|
39
|
+
)
|
40
|
+
|
41
|
+
class Config:
|
42
|
+
arbitrary_types_allowed = True
|
43
|
+
|
44
|
+
|
45
|
+
class FireworksStructuredRequestOutput(OutputSchema):
|
46
|
+
"""Schema for Fireworks AI structured output"""
|
47
|
+
|
48
|
+
parsed_response: Any
|
49
|
+
used_model: str
|
50
|
+
usage: Dict[str, int]
|
51
|
+
reasoning: Optional[str] = None
|
52
|
+
|
53
|
+
|
54
|
+
class FireworksStructuredRequestSkill(
|
55
|
+
Skill[FireworksStructuredRequestInput, FireworksStructuredRequestOutput]
|
56
|
+
):
|
57
|
+
"""Skill for getting structured responses from Fireworks AI using requests"""
|
58
|
+
|
59
|
+
input_schema = FireworksStructuredRequestInput
|
60
|
+
output_schema = FireworksStructuredRequestOutput
|
61
|
+
BASE_URL = "https://api.fireworks.ai/inference/v1/chat/completions"
|
62
|
+
|
63
|
+
def __init__(self, credentials: Optional[FireworksCredentials] = None):
|
64
|
+
"""Initialize the skill with optional credentials"""
|
65
|
+
super().__init__()
|
66
|
+
self.credentials = credentials or FireworksCredentials.from_env()
|
67
|
+
self.headers = {
|
68
|
+
"Accept": "application/json",
|
69
|
+
"Content-Type": "application/json",
|
70
|
+
"Authorization": f"Bearer {self.credentials.fireworks_api_key.get_secret_value()}",
|
71
|
+
}
|
72
|
+
|
73
|
+
def _build_messages(
|
74
|
+
self, input_data: FireworksStructuredRequestInput
|
75
|
+
) -> List[Dict[str, str]]:
|
76
|
+
"""Build messages list from input data including conversation history."""
|
77
|
+
messages = [{"role": "system", "content": input_data.system_prompt}]
|
78
|
+
|
79
|
+
if input_data.conversation_history:
|
80
|
+
messages.extend(input_data.conversation_history)
|
81
|
+
|
82
|
+
messages.append({"role": "user", "content": input_data.user_input})
|
83
|
+
return messages
|
84
|
+
|
85
|
+
def _build_payload(
|
86
|
+
self, input_data: FireworksStructuredRequestInput
|
87
|
+
) -> Dict[str, Any]:
|
88
|
+
"""Build the request payload."""
|
89
|
+
return {
|
90
|
+
"model": input_data.model,
|
91
|
+
"messages": self._build_messages(input_data),
|
92
|
+
"temperature": input_data.temperature,
|
93
|
+
"max_tokens": input_data.max_tokens,
|
94
|
+
"stream": input_data.stream,
|
95
|
+
"response_format": {
|
96
|
+
"type": "json_object",
|
97
|
+
"schema": {
|
98
|
+
**input_data.response_model.model_json_schema(),
|
99
|
+
"required": [
|
100
|
+
field
|
101
|
+
for field, _ in input_data.response_model.model_fields.items()
|
102
|
+
],
|
103
|
+
},
|
104
|
+
},
|
105
|
+
}
|
106
|
+
|
107
|
+
def process_stream(
|
108
|
+
self, input_data: FireworksStructuredRequestInput
|
109
|
+
) -> Generator[Dict[str, Any], None, None]:
|
110
|
+
"""Process the input and stream the response."""
|
111
|
+
try:
|
112
|
+
payload = self._build_payload(input_data)
|
113
|
+
response = requests.post(
|
114
|
+
self.BASE_URL,
|
115
|
+
headers=self.headers,
|
116
|
+
data=json.dumps(payload),
|
117
|
+
stream=True,
|
118
|
+
)
|
119
|
+
response.raise_for_status()
|
120
|
+
|
121
|
+
json_buffer = []
|
122
|
+
for line in response.iter_lines():
|
123
|
+
if line:
|
124
|
+
try:
|
125
|
+
data = json.loads(line.decode("utf-8").removeprefix("data: "))
|
126
|
+
if data["choices"][0]["delta"].get("content"):
|
127
|
+
content = data["choices"][0]["delta"]["content"]
|
128
|
+
json_buffer.append(content)
|
129
|
+
yield {"chunk": content}
|
130
|
+
except json.JSONDecodeError:
|
131
|
+
continue
|
132
|
+
|
133
|
+
# Once complete, parse the full JSON
|
134
|
+
complete_json = "".join(json_buffer)
|
135
|
+
try:
|
136
|
+
parsed_response = input_data.response_model.model_validate_json(
|
137
|
+
complete_json
|
138
|
+
)
|
139
|
+
yield {"complete": parsed_response}
|
140
|
+
except Exception as e:
|
141
|
+
raise ProcessingError(f"Failed to parse JSON response: {str(e)}")
|
142
|
+
|
143
|
+
except Exception as e:
|
144
|
+
raise ProcessingError(f"Fireworks streaming request failed: {str(e)}")
|
145
|
+
|
146
|
+
def _parse_response_content(self, content: str) -> tuple[Optional[str], str]:
|
147
|
+
"""Parse response content to extract reasoning and JSON."""
|
148
|
+
# Extract reasoning if present
|
149
|
+
reasoning_match = re.search(r"<think>(.*?)</think>", content, re.DOTALL)
|
150
|
+
reasoning = reasoning_match.group(1).strip() if reasoning_match else None
|
151
|
+
|
152
|
+
# Extract JSON
|
153
|
+
json_match = re.search(r"</think>\s*(\{.*\})", content, re.DOTALL)
|
154
|
+
json_str = json_match.group(1).strip() if json_match else content
|
155
|
+
|
156
|
+
return reasoning, json_str
|
157
|
+
|
158
|
+
def process(
|
159
|
+
self, input_data: FireworksStructuredRequestInput
|
160
|
+
) -> FireworksStructuredRequestOutput:
|
161
|
+
"""Process the input and return structured response."""
|
162
|
+
try:
|
163
|
+
if input_data.stream:
|
164
|
+
# For streaming, collect and parse the entire response
|
165
|
+
json_buffer = []
|
166
|
+
parsed_response = None
|
167
|
+
|
168
|
+
for chunk in self.process_stream(input_data):
|
169
|
+
if "chunk" in chunk:
|
170
|
+
json_buffer.append(chunk["chunk"])
|
171
|
+
elif "complete" in chunk:
|
172
|
+
parsed_response = chunk["complete"]
|
173
|
+
|
174
|
+
if parsed_response is None:
|
175
|
+
raise ProcessingError("Failed to parse streamed response")
|
176
|
+
|
177
|
+
return FireworksStructuredRequestOutput(
|
178
|
+
parsed_response=parsed_response,
|
179
|
+
used_model=input_data.model,
|
180
|
+
usage={}, # Usage stats not available in streaming mode
|
181
|
+
)
|
182
|
+
else:
|
183
|
+
# For non-streaming, use regular request
|
184
|
+
payload = self._build_payload(input_data)
|
185
|
+
response = requests.post(
|
186
|
+
self.BASE_URL, headers=self.headers, data=json.dumps(payload)
|
187
|
+
)
|
188
|
+
response.raise_for_status()
|
189
|
+
data = response.json()
|
190
|
+
|
191
|
+
response_content = data["choices"][0]["message"]["content"]
|
192
|
+
|
193
|
+
# Parse the response content to extract reasoning and JSON
|
194
|
+
reasoning, json_str = self._parse_response_content(response_content)
|
195
|
+
|
196
|
+
# Parse the JSON string into the specified model
|
197
|
+
parsed_response = input_data.response_model.model_validate_json(
|
198
|
+
json_str
|
199
|
+
)
|
200
|
+
|
201
|
+
return FireworksStructuredRequestOutput(
|
202
|
+
parsed_response=parsed_response,
|
203
|
+
used_model=input_data.model,
|
204
|
+
usage=data["usage"],
|
205
|
+
reasoning=reasoning, # Add reasoning to output if present
|
206
|
+
)
|
207
|
+
|
208
|
+
except Exception as e:
|
209
|
+
raise ProcessingError(f"Fireworks structured request failed: {str(e)}")
|
@@ -0,0 +1,102 @@
|
|
1
|
+
from typing import Type, TypeVar, Optional, List, Dict, Any
|
2
|
+
from pydantic import BaseModel, Field
|
3
|
+
from openai import OpenAI
|
4
|
+
from airtrain.core.skills import Skill, ProcessingError
|
5
|
+
from airtrain.core.schemas import InputSchema, OutputSchema
|
6
|
+
from .credentials import FireworksCredentials
|
7
|
+
import re
|
8
|
+
|
9
|
+
# Generic type variable for Pydantic response models
|
10
|
+
ResponseT = TypeVar("ResponseT", bound=BaseModel)
|
11
|
+
|
12
|
+
|
13
|
+
class FireworksParserInput(InputSchema):
|
14
|
+
"""Schema for Fireworks structured output input"""
|
15
|
+
|
16
|
+
user_input: str
|
17
|
+
system_prompt: str = "You are a helpful assistant that provides structured data."
|
18
|
+
model: str = "accounts/fireworks/models/deepseek-r1"
|
19
|
+
temperature: float = 0.7
|
20
|
+
max_tokens: Optional[int] = None
|
21
|
+
response_model: Type[ResponseT]
|
22
|
+
conversation_history: List[Dict[str, str]] = Field(
|
23
|
+
default_factory=list,
|
24
|
+
description="List of previous conversation messages in [{'role': 'user|assistant', 'content': 'message'}] format",
|
25
|
+
)
|
26
|
+
|
27
|
+
class Config:
|
28
|
+
arbitrary_types_allowed = True
|
29
|
+
|
30
|
+
|
31
|
+
class FireworksParserOutput(OutputSchema):
|
32
|
+
"""Schema for Fireworks structured output"""
|
33
|
+
|
34
|
+
parsed_response: BaseModel
|
35
|
+
used_model: str
|
36
|
+
tokens_used: int
|
37
|
+
reasoning: Optional[str] = None
|
38
|
+
|
39
|
+
|
40
|
+
class FireworksParserSkill(Skill[FireworksParserInput, FireworksParserOutput]):
|
41
|
+
"""Skill for getting structured responses from Fireworks"""
|
42
|
+
|
43
|
+
input_schema = FireworksParserInput
|
44
|
+
output_schema = FireworksParserOutput
|
45
|
+
|
46
|
+
def __init__(self, credentials: Optional[FireworksCredentials] = None):
|
47
|
+
"""Initialize the skill with optional credentials"""
|
48
|
+
super().__init__()
|
49
|
+
self.credentials = credentials or FireworksCredentials.from_env()
|
50
|
+
self.client = OpenAI(
|
51
|
+
base_url="https://api.fireworks.ai/inference/v1",
|
52
|
+
api_key=self.credentials.fireworks_api_key.get_secret_value(),
|
53
|
+
)
|
54
|
+
|
55
|
+
def process(self, input_data: FireworksParserInput) -> FireworksParserOutput:
|
56
|
+
try:
|
57
|
+
# Build messages list including conversation history
|
58
|
+
messages = [{"role": "system", "content": input_data.system_prompt}]
|
59
|
+
|
60
|
+
# Add conversation history if present
|
61
|
+
if input_data.conversation_history:
|
62
|
+
messages.extend(input_data.conversation_history)
|
63
|
+
|
64
|
+
# Add current user input
|
65
|
+
messages.append({"role": "user", "content": input_data.user_input})
|
66
|
+
|
67
|
+
# Make API call with JSON schema
|
68
|
+
completion = self.client.chat.completions.create(
|
69
|
+
model=input_data.model,
|
70
|
+
messages=messages,
|
71
|
+
response_format={
|
72
|
+
"type": "json_object",
|
73
|
+
"schema": input_data.response_model.model_json_schema(),
|
74
|
+
},
|
75
|
+
temperature=input_data.temperature,
|
76
|
+
max_tokens=input_data.max_tokens,
|
77
|
+
)
|
78
|
+
|
79
|
+
response_content = completion.choices[0].message.content
|
80
|
+
|
81
|
+
# Extract reasoning if present
|
82
|
+
reasoning_match = re.search(
|
83
|
+
r"<think>(.*?)</think>", response_content, re.DOTALL
|
84
|
+
)
|
85
|
+
reasoning = reasoning_match.group(1).strip() if reasoning_match else None
|
86
|
+
|
87
|
+
# Extract JSON
|
88
|
+
json_match = re.search(r"</think>\s*(\{.*\})", response_content, re.DOTALL)
|
89
|
+
json_str = json_match.group(1).strip() if json_match else response_content
|
90
|
+
|
91
|
+
# Parse the response into the specified model
|
92
|
+
parsed_response = input_data.response_model.parse_raw(json_str)
|
93
|
+
|
94
|
+
return FireworksParserOutput(
|
95
|
+
parsed_response=parsed_response,
|
96
|
+
used_model=completion.model,
|
97
|
+
tokens_used=completion.usage.total_tokens,
|
98
|
+
reasoning=reasoning,
|
99
|
+
)
|
100
|
+
|
101
|
+
except Exception as e:
|
102
|
+
raise ProcessingError(f"Fireworks parsing failed: {str(e)}")
|
@@ -1,9 +1,10 @@
|
|
1
|
-
from typing import List, Optional, Dict, Any, TypeVar, Type
|
1
|
+
from typing import List, Optional, Dict, Any, TypeVar, Type, Generator
|
2
2
|
from pydantic import Field, BaseModel
|
3
3
|
from openai import OpenAI
|
4
4
|
import base64
|
5
5
|
from pathlib import Path
|
6
6
|
from loguru import logger
|
7
|
+
from openai.types.chat import ChatCompletionChunk
|
7
8
|
|
8
9
|
from airtrain.core.skills import Skill, ProcessingError
|
9
10
|
from airtrain.core.schemas import InputSchema, OutputSchema
|
@@ -18,40 +19,36 @@ class OpenAIInput(InputSchema):
|
|
18
19
|
default="You are a helpful assistant.",
|
19
20
|
description="System prompt to guide the model's behavior",
|
20
21
|
)
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
22
|
+
conversation_history: List[Dict[str, str]] = Field(
|
23
|
+
default_factory=list,
|
24
|
+
description="List of previous conversation messages in [{'role': 'user|assistant', 'content': 'message'}] format",
|
25
|
+
)
|
26
|
+
model: str = Field(
|
27
|
+
default="gpt-4o",
|
28
|
+
description="OpenAI model to use",
|
25
29
|
)
|
26
|
-
|
27
|
-
default=
|
28
|
-
description="Optional list of image paths to include in the message",
|
30
|
+
temperature: float = Field(
|
31
|
+
default=0.7, description="Temperature for response generation", ge=0, le=1
|
29
32
|
)
|
30
|
-
|
31
|
-
default=None,
|
32
|
-
description="Optional function definitions for function calling",
|
33
|
+
max_tokens: Optional[int] = Field(
|
34
|
+
default=None, description="Maximum tokens in response"
|
33
35
|
)
|
34
|
-
|
35
|
-
default=
|
36
|
-
description="
|
36
|
+
stream: bool = Field(
|
37
|
+
default=False,
|
38
|
+
description="Whether to stream the response token by token",
|
37
39
|
)
|
38
40
|
|
39
41
|
|
40
42
|
class OpenAIOutput(OutputSchema):
|
41
43
|
"""Schema for OpenAI chat output"""
|
42
44
|
|
43
|
-
response: str
|
44
|
-
used_model: str
|
45
|
-
usage: Dict[str,
|
46
|
-
default_factory=dict, description="Usage statistics from the API"
|
47
|
-
)
|
48
|
-
function_call: Optional[Dict[str, Any]] = Field(
|
49
|
-
default=None, description="Function call information if applicable"
|
50
|
-
)
|
45
|
+
response: str
|
46
|
+
used_model: str
|
47
|
+
usage: Dict[str, int]
|
51
48
|
|
52
49
|
|
53
50
|
class OpenAIChatSkill(Skill[OpenAIInput, OpenAIOutput]):
|
54
|
-
"""Skill for interacting with OpenAI
|
51
|
+
"""Skill for interacting with OpenAI models"""
|
55
52
|
|
56
53
|
input_schema = OpenAIInput
|
57
54
|
output_schema = OpenAIOutput
|
@@ -65,82 +62,69 @@ class OpenAIChatSkill(Skill[OpenAIInput, OpenAIOutput]):
|
|
65
62
|
organization=self.credentials.openai_organization_id,
|
66
63
|
)
|
67
64
|
|
68
|
-
def
|
69
|
-
"""
|
70
|
-
|
71
|
-
if not image_path.exists():
|
72
|
-
raise FileNotFoundError(f"Image file not found: {image_path}")
|
73
|
-
|
74
|
-
with open(image_path, "rb") as img_file:
|
75
|
-
encoded = base64.b64encode(img_file.read()).decode()
|
76
|
-
return {
|
77
|
-
"type": "image_url",
|
78
|
-
"image_url": {"url": f"data:image/jpeg;base64,{encoded}"},
|
79
|
-
}
|
80
|
-
except Exception as e:
|
81
|
-
logger.error(f"Failed to encode image {image_path}: {str(e)}")
|
82
|
-
raise ProcessingError(f"Image encoding failed: {str(e)}")
|
65
|
+
def _build_messages(self, input_data: OpenAIInput) -> List[Dict[str, str]]:
|
66
|
+
"""Build messages list from input data including conversation history."""
|
67
|
+
messages = [{"role": "system", "content": input_data.system_prompt}]
|
83
68
|
|
84
|
-
|
85
|
-
|
86
|
-
try:
|
87
|
-
logger.info(f"Processing request with model {input_data.model}")
|
69
|
+
if input_data.conversation_history:
|
70
|
+
messages.extend(input_data.conversation_history)
|
88
71
|
|
89
|
-
|
90
|
-
|
72
|
+
messages.append({"role": "user", "content": input_data.user_input})
|
73
|
+
return messages
|
91
74
|
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
if input_data.images:
|
97
|
-
logger.debug(f"Processing {len(input_data.images)} images")
|
98
|
-
for image_path in input_data.images:
|
99
|
-
content.append(self._encode_image(image_path))
|
100
|
-
|
101
|
-
# Prepare messages
|
102
|
-
messages = [
|
103
|
-
{"role": "system", "content": input_data.system_prompt},
|
104
|
-
{"role": "user", "content": content},
|
105
|
-
]
|
106
|
-
|
107
|
-
# Create completion parameters
|
108
|
-
params = {
|
109
|
-
"model": input_data.model,
|
110
|
-
"messages": messages,
|
111
|
-
"temperature": input_data.temperature,
|
112
|
-
"max_tokens": input_data.max_tokens,
|
113
|
-
}
|
75
|
+
def process_stream(self, input_data: OpenAIInput) -> Generator[str, None, None]:
|
76
|
+
"""Process the input and stream the response token by token."""
|
77
|
+
try:
|
78
|
+
messages = self._build_messages(input_data)
|
114
79
|
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
80
|
+
stream = self.client.chat.completions.create(
|
81
|
+
model=input_data.model,
|
82
|
+
messages=messages,
|
83
|
+
temperature=input_data.temperature,
|
84
|
+
max_tokens=input_data.max_tokens,
|
85
|
+
stream=True,
|
86
|
+
)
|
119
87
|
|
120
|
-
|
121
|
-
|
88
|
+
for chunk in stream:
|
89
|
+
if chunk.choices[0].delta.content is not None:
|
90
|
+
yield chunk.choices[0].delta.content
|
122
91
|
|
123
|
-
|
124
|
-
|
125
|
-
if response.choices[0].message.function_call:
|
126
|
-
function_call = response.choices[0].message.function_call.model_dump()
|
92
|
+
except Exception as e:
|
93
|
+
raise ProcessingError(f"OpenAI streaming failed: {str(e)}")
|
127
94
|
|
128
|
-
|
95
|
+
def process(self, input_data: OpenAIInput) -> OpenAIOutput:
|
96
|
+
"""Process the input and return the complete response."""
|
97
|
+
try:
|
98
|
+
if input_data.stream:
|
99
|
+
# For streaming, collect the entire response
|
100
|
+
response_chunks = []
|
101
|
+
for chunk in self.process_stream(input_data):
|
102
|
+
response_chunks.append(chunk)
|
103
|
+
response = "".join(response_chunks)
|
104
|
+
else:
|
105
|
+
# For non-streaming, use regular completion
|
106
|
+
messages = self._build_messages(input_data)
|
107
|
+
completion = self.client.chat.completions.create(
|
108
|
+
model=input_data.model,
|
109
|
+
messages=messages,
|
110
|
+
temperature=input_data.temperature,
|
111
|
+
max_tokens=input_data.max_tokens,
|
112
|
+
stream=False,
|
113
|
+
)
|
114
|
+
response = completion.choices[0].message.content
|
129
115
|
|
130
116
|
return OpenAIOutput(
|
131
|
-
response=response
|
132
|
-
used_model=
|
117
|
+
response=response,
|
118
|
+
used_model=input_data.model,
|
133
119
|
usage={
|
134
|
-
"
|
135
|
-
"
|
136
|
-
"
|
120
|
+
"total_tokens": completion.usage.total_tokens,
|
121
|
+
"prompt_tokens": completion.usage.prompt_tokens,
|
122
|
+
"completion_tokens": completion.usage.completion_tokens,
|
137
123
|
},
|
138
|
-
function_call=function_call,
|
139
124
|
)
|
140
125
|
|
141
126
|
except Exception as e:
|
142
|
-
|
143
|
-
raise ProcessingError(f"OpenAI processing failed: {str(e)}")
|
127
|
+
raise ProcessingError(f"OpenAI chat failed: {str(e)}")
|
144
128
|
|
145
129
|
|
146
130
|
ResponseT = TypeVar("ResponseT", bound=BaseModel)
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.2
|
2
2
|
Name: airtrain
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.28
|
4
4
|
Summary: A platform for building and deploying AI agents with structured skills
|
5
5
|
Home-page: https://github.com/rosaboyle/airtrain.dev
|
6
6
|
Author: Dheeraj Pai
|
@@ -171,7 +171,21 @@ This project is licensed under the MIT License - see the LICENSE file for detail
|
|
171
171
|
## Changelog
|
172
172
|
|
173
173
|
|
174
|
-
|
174
|
+
|
175
|
+
## 0.1.28
|
176
|
+
|
177
|
+
- Bug fix: reasoning to Fireworks structured output.
|
178
|
+
- Added reasoning to Fireworks structured output.
|
179
|
+
|
180
|
+
## 0.1.27
|
181
|
+
|
182
|
+
- Added structured completion skills for Fireworks AI
|
183
|
+
- Added Completion skills for Fireworks AI.
|
184
|
+
- Added Combination skill for Groq and Fireworks AI.
|
185
|
+
- Add completion streaming.
|
186
|
+
- Added strcutured output streaming for Fireworks AI.
|
187
|
+
|
188
|
+
## 0.1.23
|
175
189
|
|
176
190
|
- Added conversation support for Deepseek, Togehter AI, Fireworks AI, Gemini, Groq, Cerebras and Sambanova.
|
177
191
|
- Added Change Log
|
@@ -1,4 +1,4 @@
|
|
1
|
-
airtrain/__init__.py,sha256=
|
1
|
+
airtrain/__init__.py,sha256=ZUseQpkG-otzfnOZRjQZHpJYG6w2R-c-k872-UMaRqI,2099
|
2
2
|
airtrain/contrib/__init__.py,sha256=pG-7mJ0pBMqp3Q86mIF9bo1PqoBOVSGlnEK1yY1U1ok,641
|
3
3
|
airtrain/contrib/travel/__init__.py,sha256=clmBodw4nkTA-DsgjVGcXfJGPaWxIpCZDtdO-8RzL0M,811
|
4
4
|
airtrain/contrib/travel/agents.py,sha256=tpQtZ0WUiXBuhvZtc2JlEam5TuR5l-Tndi14YyImDBM,8975
|
@@ -18,10 +18,15 @@ airtrain/integrations/cerebras/__init__.py,sha256=zAD-qV38OzHhMCz1z-NvjjqcYEhURb
|
|
18
18
|
airtrain/integrations/cerebras/credentials.py,sha256=KDEH4r8FGT68L9p34MLZWK65wq_a703pqIF3ODaSbts,694
|
19
19
|
airtrain/integrations/cerebras/skills.py,sha256=Ksggq_s5wHWlf_xQIOO8MFoNTYV0cm9SHZ7GESOd2YE,3527
|
20
20
|
airtrain/integrations/fireworks/__init__.py,sha256=9pJvP0u1FJbNtB0oHa09mHVJLctELf_c27LOYyDk2ZI,271
|
21
|
+
airtrain/integrations/fireworks/completion_skills.py,sha256=G657xWd7izLOxXq7RmqdupBF4DHqXQgXuhQ-MW7mtqc,5613
|
21
22
|
airtrain/integrations/fireworks/conversation_manager.py,sha256=m6VEHijqpYEYawkKhuHtb8RQxw4kxGWFWdbSK6zGuro,3704
|
22
23
|
airtrain/integrations/fireworks/credentials.py,sha256=UpcwR9V5Hbk5sJbjFDJDbHMRqc90IQSqAvrtJCOvwEo,524
|
23
24
|
airtrain/integrations/fireworks/models.py,sha256=F-MddbLCLAsTjwRr1l6IpJxOegyY4pD7jN9ySPiypSo,593
|
24
|
-
airtrain/integrations/fireworks/
|
25
|
+
airtrain/integrations/fireworks/requests_skills.py,sha256=zPIR70l0KdSGmA5WyDEopFAdKHSPltAttJBzWyHu6Bk,5878
|
26
|
+
airtrain/integrations/fireworks/skills.py,sha256=OB4epD4CSTxExUCC1oMJ_8rHLOoftlxf0FUoIVrd4mA,5163
|
27
|
+
airtrain/integrations/fireworks/structured_completion_skills.py,sha256=IXG4gsZDSfuscrmKIHfnyHkBaCV7zlPInaWXb95iC5k,6428
|
28
|
+
airtrain/integrations/fireworks/structured_requests_skills.py,sha256=5FllptmuUewS1OX3Jf0vK6kQjMV4QJmisExbs0ElWPY,8150
|
29
|
+
airtrain/integrations/fireworks/structured_skills.py,sha256=BZaLqSOTC11QdZ4kDORS4JnwF_YXBAa-IiwQ5dJiHXw,3895
|
25
30
|
airtrain/integrations/google/__init__.py,sha256=ElwgcXfbg_gGMm6zbkMXCQPFKZUb-yTJk986o19A7Cs,214
|
26
31
|
airtrain/integrations/google/credentials.py,sha256=KSvWNqW8Mjr4MkysRvUqlrOSGdShNIe5u2OPO6vRrWY,2047
|
27
32
|
airtrain/integrations/google/skills.py,sha256=ytsoksCY4qbfRO9Brnxhc2694fAj0ytnHX20SXS_FOM,4547
|
@@ -35,7 +40,7 @@ airtrain/integrations/openai/__init__.py,sha256=K-NY2_T1T6SEOgkpbUA55cWvK2nr2NOJ
|
|
35
40
|
airtrain/integrations/openai/chinese_assistant.py,sha256=MMhv4NBOoEQ0O22ZZtP255rd5ajHC9l6FPWIjpqxBOA,1581
|
36
41
|
airtrain/integrations/openai/credentials.py,sha256=NfRyp1QgEtgm8cxt2-BOLq-6d0X-Pcm80NnfHM8p0FY,1470
|
37
42
|
airtrain/integrations/openai/models_config.py,sha256=bzosqqpDy2AJxu2vGdk2H4voqEGlv7LORR6fpJLhNic,3962
|
38
|
-
airtrain/integrations/openai/skills.py,sha256=
|
43
|
+
airtrain/integrations/openai/skills.py,sha256=SEWpwfWPsDEmf7IcSzWZuEzb13YT4gzQIHO8_O1bmN4,6936
|
39
44
|
airtrain/integrations/sambanova/__init__.py,sha256=dp_263iOckM_J9pOEvyqpf3FrejD6-_x33r0edMCTe0,179
|
40
45
|
airtrain/integrations/sambanova/credentials.py,sha256=JyN8sbMCoXuXAjim46aI3LTicBijoemS7Ao0rn4yBJU,824
|
41
46
|
airtrain/integrations/sambanova/skills.py,sha256=SDFY-ZzhOEIxQgTkQJzX9gN7UDqqnCBJdK7I2JydIoY,3625
|
@@ -52,7 +57,7 @@ airtrain/integrations/together/rerank_skill.py,sha256=gjH24hLWCweWKPyyfKZMG3K_g9
|
|
52
57
|
airtrain/integrations/together/schemas.py,sha256=pBMrbX67oxPCr-sg4K8_Xqu1DWbaC4uLCloVSascROg,1210
|
53
58
|
airtrain/integrations/together/skills.py,sha256=mUoHc2r5TYQi5iGzwz2aDuUeROGq7teCtNrOlNApef4,6276
|
54
59
|
airtrain/integrations/together/vision_models_config.py,sha256=m28HwYDk2Kup_J-a1FtynIa2ZVcbl37kltfoHnK8zxs,1544
|
55
|
-
airtrain-0.1.
|
56
|
-
airtrain-0.1.
|
57
|
-
airtrain-0.1.
|
58
|
-
airtrain-0.1.
|
60
|
+
airtrain-0.1.28.dist-info/METADATA,sha256=svDrTc0qzBYsuYKGR0VKg7AZSIEMQpUht1M2csmWGH0,5243
|
61
|
+
airtrain-0.1.28.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
|
62
|
+
airtrain-0.1.28.dist-info/top_level.txt,sha256=cFWW1vY6VMCb3AGVdz6jBDpZ65xxBRSqlsPyySxTkxY,9
|
63
|
+
airtrain-0.1.28.dist-info/RECORD,,
|
File without changes
|
File without changes
|