airtrain 0.1.3__py3-none-any.whl → 0.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. airtrain/__init__.py +146 -6
  2. airtrain/__main__.py +4 -0
  3. airtrain/__pycache__/__init__.cpython-313.pyc +0 -0
  4. airtrain/agents/__init__.py +45 -0
  5. airtrain/agents/example_agent.py +348 -0
  6. airtrain/agents/groq_agent.py +289 -0
  7. airtrain/agents/memory.py +663 -0
  8. airtrain/agents/registry.py +465 -0
  9. airtrain/builder/__init__.py +3 -0
  10. airtrain/builder/agent_builder.py +122 -0
  11. airtrain/cli/__init__.py +0 -0
  12. airtrain/cli/builder.py +23 -0
  13. airtrain/cli/main.py +120 -0
  14. airtrain/contrib/__init__.py +29 -0
  15. airtrain/contrib/travel/__init__.py +35 -0
  16. airtrain/contrib/travel/agents.py +243 -0
  17. airtrain/contrib/travel/models.py +59 -0
  18. airtrain/core/__pycache__/__init__.cpython-313.pyc +0 -0
  19. airtrain/core/__pycache__/schemas.cpython-313.pyc +0 -0
  20. airtrain/core/__pycache__/skills.cpython-313.pyc +0 -0
  21. airtrain/core/credentials.py +62 -44
  22. airtrain/core/skills.py +102 -0
  23. airtrain/integrations/__init__.py +74 -0
  24. airtrain/integrations/anthropic/__init__.py +33 -0
  25. airtrain/integrations/anthropic/credentials.py +32 -0
  26. airtrain/integrations/anthropic/list_models.py +110 -0
  27. airtrain/integrations/anthropic/models_config.py +100 -0
  28. airtrain/integrations/anthropic/skills.py +155 -0
  29. airtrain/integrations/aws/__init__.py +6 -0
  30. airtrain/integrations/aws/credentials.py +36 -0
  31. airtrain/integrations/aws/skills.py +98 -0
  32. airtrain/integrations/cerebras/__init__.py +6 -0
  33. airtrain/integrations/cerebras/credentials.py +19 -0
  34. airtrain/integrations/cerebras/skills.py +127 -0
  35. airtrain/integrations/combined/__init__.py +21 -0
  36. airtrain/integrations/combined/groq_fireworks_skills.py +126 -0
  37. airtrain/integrations/combined/list_models_factory.py +210 -0
  38. airtrain/integrations/fireworks/__init__.py +21 -0
  39. airtrain/integrations/fireworks/completion_skills.py +147 -0
  40. airtrain/integrations/fireworks/conversation_manager.py +109 -0
  41. airtrain/integrations/fireworks/credentials.py +26 -0
  42. airtrain/integrations/fireworks/list_models.py +128 -0
  43. airtrain/integrations/fireworks/models.py +139 -0
  44. airtrain/integrations/fireworks/requests_skills.py +207 -0
  45. airtrain/integrations/fireworks/skills.py +181 -0
  46. airtrain/integrations/fireworks/structured_completion_skills.py +175 -0
  47. airtrain/integrations/fireworks/structured_requests_skills.py +291 -0
  48. airtrain/integrations/fireworks/structured_skills.py +102 -0
  49. airtrain/integrations/google/__init__.py +7 -0
  50. airtrain/integrations/google/credentials.py +58 -0
  51. airtrain/integrations/google/skills.py +122 -0
  52. airtrain/integrations/groq/__init__.py +23 -0
  53. airtrain/integrations/groq/credentials.py +24 -0
  54. airtrain/integrations/groq/models_config.py +162 -0
  55. airtrain/integrations/groq/skills.py +201 -0
  56. airtrain/integrations/ollama/__init__.py +6 -0
  57. airtrain/integrations/ollama/credentials.py +26 -0
  58. airtrain/integrations/ollama/skills.py +41 -0
  59. airtrain/integrations/openai/__init__.py +37 -0
  60. airtrain/integrations/openai/chinese_assistant.py +42 -0
  61. airtrain/integrations/openai/credentials.py +39 -0
  62. airtrain/integrations/openai/list_models.py +112 -0
  63. airtrain/integrations/openai/models_config.py +224 -0
  64. airtrain/integrations/openai/skills.py +342 -0
  65. airtrain/integrations/perplexity/__init__.py +49 -0
  66. airtrain/integrations/perplexity/credentials.py +43 -0
  67. airtrain/integrations/perplexity/list_models.py +112 -0
  68. airtrain/integrations/perplexity/models_config.py +128 -0
  69. airtrain/integrations/perplexity/skills.py +279 -0
  70. airtrain/integrations/sambanova/__init__.py +6 -0
  71. airtrain/integrations/sambanova/credentials.py +20 -0
  72. airtrain/integrations/sambanova/skills.py +129 -0
  73. airtrain/integrations/search/__init__.py +21 -0
  74. airtrain/integrations/search/exa/__init__.py +23 -0
  75. airtrain/integrations/search/exa/credentials.py +30 -0
  76. airtrain/integrations/search/exa/schemas.py +114 -0
  77. airtrain/integrations/search/exa/skills.py +115 -0
  78. airtrain/integrations/together/__init__.py +33 -0
  79. airtrain/integrations/together/audio_models_config.py +34 -0
  80. airtrain/integrations/together/credentials.py +22 -0
  81. airtrain/integrations/together/embedding_models_config.py +92 -0
  82. airtrain/integrations/together/image_models_config.py +69 -0
  83. airtrain/integrations/together/image_skill.py +143 -0
  84. airtrain/integrations/together/list_models.py +76 -0
  85. airtrain/integrations/together/models.py +95 -0
  86. airtrain/integrations/together/models_config.py +399 -0
  87. airtrain/integrations/together/rerank_models_config.py +43 -0
  88. airtrain/integrations/together/rerank_skill.py +49 -0
  89. airtrain/integrations/together/schemas.py +33 -0
  90. airtrain/integrations/together/skills.py +305 -0
  91. airtrain/integrations/together/vision_models_config.py +49 -0
  92. airtrain/telemetry/__init__.py +38 -0
  93. airtrain/telemetry/service.py +167 -0
  94. airtrain/telemetry/views.py +237 -0
  95. airtrain/tools/__init__.py +45 -0
  96. airtrain/tools/command.py +398 -0
  97. airtrain/tools/filesystem.py +166 -0
  98. airtrain/tools/network.py +111 -0
  99. airtrain/tools/registry.py +320 -0
  100. airtrain/tools/search.py +450 -0
  101. airtrain/tools/testing.py +135 -0
  102. airtrain-0.1.4.dist-info/METADATA +222 -0
  103. airtrain-0.1.4.dist-info/RECORD +108 -0
  104. {airtrain-0.1.3.dist-info → airtrain-0.1.4.dist-info}/WHEEL +1 -1
  105. airtrain-0.1.4.dist-info/entry_points.txt +2 -0
  106. airtrain-0.1.3.dist-info/METADATA +0 -106
  107. airtrain-0.1.3.dist-info/RECORD +0 -9
  108. {airtrain-0.1.3.dist-info → airtrain-0.1.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,207 @@
1
+ from typing import List, Optional, Dict, Any, Generator, AsyncGenerator
2
+ from pydantic import Field
3
+ import requests
4
+ import json
5
+ from loguru import logger
6
+ import aiohttp
7
+
8
+ from airtrain.core.skills import Skill, ProcessingError
9
+ from airtrain.core.schemas import InputSchema, OutputSchema
10
+ from .credentials import FireworksCredentials
11
+
12
+
13
+ class FireworksRequestInput(InputSchema):
14
+ """Schema for Fireworks AI chat input using requests"""
15
+
16
+ user_input: str = Field(..., description="User's input text")
17
+ system_prompt: str = Field(
18
+ default="You are a helpful assistant.",
19
+ description="System prompt to guide the model's behavior",
20
+ )
21
+ conversation_history: List[Dict[str, str]] = Field(
22
+ default_factory=list,
23
+ description="List of previous conversation messages",
24
+ )
25
+ model: str = Field(
26
+ default="accounts/fireworks/models/deepseek-r1",
27
+ description="Fireworks AI model to use",
28
+ )
29
+ temperature: float = Field(
30
+ default=0.7, description="Temperature for response generation", ge=0, le=1
31
+ )
32
+ max_tokens: int = Field(default=131072, description="Maximum tokens in response")
33
+ top_p: float = Field(
34
+ default=1.0, description="Top p sampling parameter", ge=0, le=1
35
+ )
36
+ top_k: int = Field(default=40, description="Top k sampling parameter", ge=0)
37
+ presence_penalty: float = Field(
38
+ default=0.0, description="Presence penalty", ge=-2.0, le=2.0
39
+ )
40
+ frequency_penalty: float = Field(
41
+ default=0.0, description="Frequency penalty", ge=-2.0, le=2.0
42
+ )
43
+ stream: bool = Field(
44
+ default=False,
45
+ description="Whether to stream the response",
46
+ )
47
+
48
+
49
+ class FireworksRequestOutput(OutputSchema):
50
+ """Schema for Fireworks AI chat output"""
51
+
52
+ response: str
53
+ used_model: str
54
+ usage: Dict[str, int]
55
+
56
+
57
+ class FireworksRequestSkill(Skill[FireworksRequestInput, FireworksRequestOutput]):
58
+ """Skill for interacting with Fireworks AI models using requests"""
59
+
60
+ input_schema = FireworksRequestInput
61
+ output_schema = FireworksRequestOutput
62
+ BASE_URL = "https://api.fireworks.ai/inference/v1/chat/completions"
63
+
64
+ def __init__(self, credentials: Optional[FireworksCredentials] = None):
65
+ """Initialize the skill with optional credentials"""
66
+ super().__init__()
67
+ self.credentials = credentials or FireworksCredentials.from_env()
68
+ self.headers = {
69
+ "Accept": "application/json",
70
+ "Content-Type": "application/json",
71
+ "Authorization": f"Bearer {self.credentials.fireworks_api_key.get_secret_value()}",
72
+ }
73
+ self.stream_headers = {
74
+ "Accept": "text/event-stream",
75
+ "Content-Type": "application/json",
76
+ "Authorization": f"Bearer {self.credentials.fireworks_api_key.get_secret_value()}",
77
+ }
78
+
79
+ def _build_messages(
80
+ self, input_data: FireworksRequestInput
81
+ ) -> List[Dict[str, str]]:
82
+ """Build messages list from input data including conversation history."""
83
+ messages = [{"role": "system", "content": input_data.system_prompt}]
84
+
85
+ if input_data.conversation_history:
86
+ messages.extend(input_data.conversation_history)
87
+
88
+ messages.append({"role": "user", "content": input_data.user_input})
89
+ return messages
90
+
91
+ def _build_payload(self, input_data: FireworksRequestInput) -> Dict[str, Any]:
92
+ """Build the request payload."""
93
+ return {
94
+ "model": input_data.model,
95
+ "messages": self._build_messages(input_data),
96
+ "temperature": input_data.temperature,
97
+ "max_tokens": input_data.max_tokens,
98
+ "top_p": input_data.top_p,
99
+ "top_k": input_data.top_k,
100
+ "presence_penalty": input_data.presence_penalty,
101
+ "frequency_penalty": input_data.frequency_penalty,
102
+ "stream": input_data.stream,
103
+ }
104
+
105
+ def process_stream(
106
+ self, input_data: FireworksRequestInput
107
+ ) -> Generator[str, None, None]:
108
+ """Process the input and stream the response."""
109
+ try:
110
+ payload = self._build_payload(input_data)
111
+ response = requests.post(
112
+ self.BASE_URL,
113
+ headers=self.headers,
114
+ data=json.dumps(payload),
115
+ stream=True,
116
+ )
117
+ response.raise_for_status()
118
+
119
+ for line in response.iter_lines():
120
+ if line:
121
+ try:
122
+ data = json.loads(line.decode("utf-8").removeprefix("data: "))
123
+ if data["choices"][0]["delta"].get("content"):
124
+ yield data["choices"][0]["delta"]["content"]
125
+ except json.JSONDecodeError:
126
+ continue
127
+
128
+ except Exception as e:
129
+ raise ProcessingError(f"Fireworks streaming request failed: {str(e)}")
130
+
131
+ def process(self, input_data: FireworksRequestInput) -> FireworksRequestOutput:
132
+ """Process the input and return the complete response."""
133
+ try:
134
+ if input_data.stream:
135
+ # For streaming, collect the entire response
136
+ response_chunks = []
137
+ for chunk in self.process_stream(input_data):
138
+ response_chunks.append(chunk)
139
+ response_text = "".join(response_chunks)
140
+ usage = {} # Usage stats not available in streaming mode
141
+ else:
142
+ # For non-streaming, use regular request
143
+ payload = self._build_payload(input_data)
144
+ response = requests.post(
145
+ self.BASE_URL, headers=self.headers, data=json.dumps(payload)
146
+ )
147
+ response.raise_for_status()
148
+ data = response.json()
149
+
150
+ response_text = data["choices"][0]["message"]["content"]
151
+ usage = data["usage"]
152
+
153
+ return FireworksRequestOutput(
154
+ response=response_text, used_model=input_data.model, usage=usage
155
+ )
156
+
157
+ except Exception as e:
158
+ raise ProcessingError(f"Fireworks request failed: {str(e)}")
159
+
160
+ async def process_async(
161
+ self, input_data: FireworksRequestInput
162
+ ) -> FireworksRequestOutput:
163
+ """Async version of process method using aiohttp"""
164
+ try:
165
+ async with aiohttp.ClientSession() as session:
166
+ payload = self._build_payload(input_data)
167
+ async with session.post(
168
+ self.BASE_URL, headers=self.headers, json=payload
169
+ ) as response:
170
+ response.raise_for_status()
171
+ data = await response.json()
172
+
173
+ return FireworksRequestOutput(
174
+ response=data["choices"][0]["message"]["content"],
175
+ used_model=input_data.model,
176
+ usage=data.get("usage", {}),
177
+ )
178
+
179
+ except Exception as e:
180
+ raise ProcessingError(f"Async Fireworks request failed: {str(e)}")
181
+
182
+ async def process_stream_async(
183
+ self, input_data: FireworksRequestInput
184
+ ) -> AsyncGenerator[str, None]:
185
+ """Async version of stream processor using aiohttp"""
186
+ try:
187
+ async with aiohttp.ClientSession() as session:
188
+ payload = self._build_payload(input_data)
189
+ async with session.post(
190
+ self.BASE_URL, headers=self.stream_headers, json=payload
191
+ ) as response:
192
+ response.raise_for_status()
193
+
194
+ async for line in response.content:
195
+ if line.startswith(b"data: "):
196
+ chunk = json.loads(line[6:].strip())
197
+ if "choices" in chunk:
198
+ content = (
199
+ chunk["choices"][0]
200
+ .get("delta", {})
201
+ .get("content", "")
202
+ )
203
+ if content:
204
+ yield content
205
+
206
+ except Exception as e:
207
+ raise ProcessingError(f"Async Fireworks streaming failed: {str(e)}")
@@ -0,0 +1,181 @@
1
+ from typing import List, Optional, Dict, Any, Generator, Union
2
+ from pydantic import Field
3
+ from openai import OpenAI
4
+
5
+ from airtrain.core.skills import Skill, ProcessingError
6
+ from airtrain.core.schemas import InputSchema, OutputSchema
7
+ from .credentials import FireworksCredentials
8
+
9
+
10
+ class FireworksInput(InputSchema):
11
+ """Schema for Fireworks AI chat input"""
12
+
13
+ user_input: str = Field(..., description="User's input text")
14
+ system_prompt: str = Field(
15
+ default="You are a helpful assistant.",
16
+ description="System prompt to guide the model's behavior",
17
+ )
18
+ conversation_history: List[Dict[str, Any]] = Field(
19
+ default_factory=list,
20
+ description="List of previous conversation messages",
21
+ )
22
+ model: str = Field(
23
+ default="accounts/fireworks/models/deepseek-r1",
24
+ description="Fireworks AI model to use",
25
+ )
26
+ temperature: float = Field(
27
+ default=0.7, description="Temperature for response generation", ge=0, le=1
28
+ )
29
+ max_tokens: Optional[int] = Field(
30
+ default=131072, description="Maximum tokens in response"
31
+ )
32
+ context_length_exceeded_behavior: str = Field(
33
+ default="truncate", description="Behavior when context length is exceeded"
34
+ )
35
+ stream: bool = Field(
36
+ default=False,
37
+ description="Whether to stream the response token by token",
38
+ )
39
+ tools: Optional[List[Dict[str, Any]]] = Field(
40
+ default=None,
41
+ description=(
42
+ "A list of tools the model may use. "
43
+ "Currently only functions supported."
44
+ ),
45
+ )
46
+ tool_choice: Optional[Union[str, Dict[str, Any]]] = Field(
47
+ default=None,
48
+ description=(
49
+ "Controls which tool is called by the model. "
50
+ "'none', 'auto', or specific tool."
51
+ ),
52
+ )
53
+
54
+
55
+ class FireworksOutput(OutputSchema):
56
+ """Schema for Fireworks AI chat output"""
57
+
58
+ response: str = Field(..., description="Model's response text")
59
+ used_model: str = Field(..., description="Model used for generation")
60
+ usage: Dict[str, int] = Field(default_factory=dict, description="Usage statistics")
61
+ tool_calls: Optional[List[Dict[str, Any]]] = Field(
62
+ default=None, description="Tool calls generated by the model"
63
+ )
64
+
65
+
66
+ class FireworksChatSkill(Skill[FireworksInput, FireworksOutput]):
67
+ """Skill for interacting with Fireworks AI models"""
68
+
69
+ input_schema = FireworksInput
70
+ output_schema = FireworksOutput
71
+
72
+ def __init__(self, credentials: Optional[FireworksCredentials] = None):
73
+ """Initialize the skill with optional credentials"""
74
+ super().__init__()
75
+ self.credentials = credentials or FireworksCredentials.from_env()
76
+ self.client = OpenAI(
77
+ base_url="https://api.fireworks.ai/inference/v1",
78
+ api_key=self.credentials.fireworks_api_key.get_secret_value(),
79
+ )
80
+
81
+ def _build_messages(self, input_data: FireworksInput) -> List[Dict[str, Any]]:
82
+ """Build messages list from input data including conversation history."""
83
+ messages = [{"role": "system", "content": input_data.system_prompt}]
84
+
85
+ if input_data.conversation_history:
86
+ messages.extend(input_data.conversation_history)
87
+
88
+ messages.append({"role": "user", "content": input_data.user_input})
89
+ return messages
90
+
91
+ def process_stream(self, input_data: FireworksInput) -> Generator[str, None, None]:
92
+ """Process the input and stream the response token by token."""
93
+ try:
94
+ messages = self._build_messages(input_data)
95
+
96
+ stream = self.client.chat.completions.create(
97
+ model=input_data.model,
98
+ messages=messages,
99
+ temperature=input_data.temperature,
100
+ max_tokens=input_data.max_tokens,
101
+ stream=True,
102
+ )
103
+
104
+ for chunk in stream:
105
+ if chunk.choices[0].delta.content is not None:
106
+ yield chunk.choices[0].delta.content
107
+
108
+ except Exception as e:
109
+ raise ProcessingError(f"Fireworks streaming failed: {str(e)}")
110
+
111
+ def process(self, input_data: FireworksInput) -> FireworksOutput:
112
+ """Process the input and return the complete response."""
113
+ try:
114
+ if input_data.stream:
115
+ # For streaming, collect the entire response
116
+ response_chunks = []
117
+ for chunk in self.process_stream(input_data):
118
+ response_chunks.append(chunk)
119
+ response = "".join(response_chunks)
120
+
121
+ # Create completion object for usage stats
122
+ messages = self._build_messages(input_data)
123
+ completion = self.client.chat.completions.create(
124
+ model=input_data.model,
125
+ messages=messages,
126
+ temperature=input_data.temperature,
127
+ max_tokens=input_data.max_tokens,
128
+ stream=False,
129
+ )
130
+ else:
131
+ # For non-streaming, use regular completion
132
+ messages = self._build_messages(input_data)
133
+
134
+ # Prepare API call parameters
135
+ api_params = {
136
+ "model": input_data.model,
137
+ "messages": messages,
138
+ "temperature": input_data.temperature,
139
+ "max_tokens": input_data.max_tokens,
140
+ "stream": False,
141
+ }
142
+
143
+ # Add tools and tool_choice if provided
144
+ if input_data.tools:
145
+ api_params["tools"] = input_data.tools
146
+
147
+ if input_data.tool_choice:
148
+ api_params["tool_choice"] = input_data.tool_choice
149
+
150
+ completion = self.client.chat.completions.create(**api_params)
151
+ response = completion.choices[0].message.content or ""
152
+
153
+ # Check for tool calls in the response
154
+ tool_calls = None
155
+ if (hasattr(completion.choices[0].message, "tool_calls") and
156
+ completion.choices[0].message.tool_calls):
157
+ tool_calls = [
158
+ {
159
+ "id": tool_call.id,
160
+ "type": tool_call.type,
161
+ "function": {
162
+ "name": tool_call.function.name,
163
+ "arguments": tool_call.function.arguments
164
+ }
165
+ }
166
+ for tool_call in completion.choices[0].message.tool_calls
167
+ ]
168
+
169
+ return FireworksOutput(
170
+ response=response,
171
+ used_model=input_data.model,
172
+ usage={
173
+ "total_tokens": completion.usage.total_tokens,
174
+ "prompt_tokens": completion.usage.prompt_tokens,
175
+ "completion_tokens": completion.usage.completion_tokens,
176
+ },
177
+ tool_calls=tool_calls
178
+ )
179
+
180
+ except Exception as e:
181
+ raise ProcessingError(f"Fireworks chat failed: {str(e)}")
@@ -0,0 +1,175 @@
1
+ from typing import Any, Dict, Generator, List, Optional, Type, TypeVar
2
+ from pydantic import BaseModel, Field
3
+ import requests
4
+ import json
5
+
6
+ from airtrain.core.skills import Skill, ProcessingError
7
+ from airtrain.core.schemas import InputSchema, OutputSchema
8
+ from .credentials import FireworksCredentials
9
+
10
+ ResponseT = TypeVar("ResponseT")
11
+
12
+
13
+ class FireworksStructuredCompletionInput(InputSchema):
14
+ """Schema for Fireworks AI structured completion input"""
15
+
16
+ prompt: str = Field(..., description="Input prompt for completion")
17
+ model: str = Field(
18
+ default="accounts/fireworks/models/deepseek-r1",
19
+ description="Fireworks AI model to use",
20
+ )
21
+ temperature: float = Field(
22
+ default=0.7, description="Temperature for response generation", ge=0, le=1
23
+ )
24
+ max_tokens: int = Field(default=131072, description="Maximum tokens in response")
25
+ response_model: Type[ResponseT]
26
+ stream: bool = Field(
27
+ default=False,
28
+ description="Whether to stream the response token by token",
29
+ )
30
+
31
+ class Config:
32
+ arbitrary_types_allowed = True
33
+
34
+
35
+ class FireworksStructuredCompletionOutput(OutputSchema):
36
+ """Schema for Fireworks AI structured completion output"""
37
+
38
+ parsed_response: Any
39
+ used_model: str
40
+ usage: Dict[str, int]
41
+ tool_calls: Optional[List[Dict[str, Any]]] = Field(
42
+ default=None,
43
+ description=(
44
+ "Tool calls are not applicable for completions, "
45
+ "included for compatibility"
46
+ )
47
+ )
48
+
49
+
50
+ class FireworksStructuredCompletionSkill(
51
+ Skill[FireworksStructuredCompletionInput, FireworksStructuredCompletionOutput]
52
+ ):
53
+ """Skill for getting structured completion responses from Fireworks AI"""
54
+
55
+ input_schema = FireworksStructuredCompletionInput
56
+ output_schema = FireworksStructuredCompletionOutput
57
+ BASE_URL = "https://api.fireworks.ai/inference/v1/completions"
58
+
59
+ def __init__(self, credentials: Optional[FireworksCredentials] = None):
60
+ """Initialize the skill with optional credentials"""
61
+ super().__init__()
62
+ self.credentials = credentials or FireworksCredentials.from_env()
63
+ self.headers = {
64
+ "Accept": "application/json",
65
+ "Content-Type": "application/json",
66
+ "Authorization": f"Bearer {self.credentials.fireworks_api_key.get_secret_value()}",
67
+ }
68
+
69
+ def _build_payload(
70
+ self, input_data: FireworksStructuredCompletionInput
71
+ ) -> Dict[str, Any]:
72
+ """Build the request payload."""
73
+ return {
74
+ "model": input_data.model,
75
+ "prompt": input_data.prompt,
76
+ "temperature": input_data.temperature,
77
+ "max_tokens": input_data.max_tokens,
78
+ "stream": input_data.stream,
79
+ "response_format": {
80
+ "type": "json_object",
81
+ "schema": {
82
+ **input_data.response_model.model_json_schema(),
83
+ "required": [
84
+ field
85
+ for field, _ in input_data.response_model.model_fields.items()
86
+ ],
87
+ },
88
+ },
89
+ }
90
+
91
+ def process_stream(
92
+ self, input_data: FireworksStructuredCompletionInput
93
+ ) -> Generator[Dict[str, Any], None, None]:
94
+ """Process the input and stream the response."""
95
+ try:
96
+ payload = self._build_payload(input_data)
97
+ response = requests.post(
98
+ self.BASE_URL,
99
+ headers=self.headers,
100
+ data=json.dumps(payload),
101
+ stream=True,
102
+ )
103
+ response.raise_for_status()
104
+
105
+ json_buffer = []
106
+ for line in response.iter_lines():
107
+ if line:
108
+ try:
109
+ data = json.loads(line.decode("utf-8").removeprefix("data: "))
110
+ if data.get("choices") and data["choices"][0].get("text"):
111
+ content = data["choices"][0]["text"]
112
+ json_buffer.append(content)
113
+ yield {"chunk": content}
114
+ except json.JSONDecodeError:
115
+ continue
116
+
117
+ # Once complete, parse the full JSON
118
+ complete_json = "".join(json_buffer)
119
+ try:
120
+ parsed_response = input_data.response_model.model_validate_json(
121
+ complete_json
122
+ )
123
+ yield {"complete": parsed_response}
124
+ except Exception as e:
125
+ raise ProcessingError(f"Failed to parse JSON response: {str(e)}")
126
+
127
+ except Exception as e:
128
+ raise ProcessingError(f"Fireworks streaming request failed: {str(e)}")
129
+
130
+ def process(
131
+ self, input_data: FireworksStructuredCompletionInput
132
+ ) -> FireworksStructuredCompletionOutput:
133
+ """Process the input and return structured response."""
134
+ try:
135
+ if input_data.stream:
136
+ # For streaming, collect and parse the entire response
137
+ json_buffer = []
138
+ parsed_response = None
139
+
140
+ for chunk in self.process_stream(input_data):
141
+ if "chunk" in chunk:
142
+ json_buffer.append(chunk["chunk"])
143
+ elif "complete" in chunk:
144
+ parsed_response = chunk["complete"]
145
+
146
+ if parsed_response is None:
147
+ raise ProcessingError("Failed to parse streamed response")
148
+
149
+ return FireworksStructuredCompletionOutput(
150
+ parsed_response=parsed_response,
151
+ used_model=input_data.model,
152
+ usage={}, # Usage stats not available in streaming mode
153
+ )
154
+ else:
155
+ # For non-streaming, use regular request
156
+ payload = self._build_payload(input_data)
157
+ response = requests.post(
158
+ self.BASE_URL, headers=self.headers, data=json.dumps(payload)
159
+ )
160
+ response.raise_for_status()
161
+ data = response.json()
162
+
163
+ response_text = data["choices"][0]["text"]
164
+ parsed_response = input_data.response_model.model_validate_json(
165
+ response_text
166
+ )
167
+
168
+ return FireworksStructuredCompletionOutput(
169
+ parsed_response=parsed_response,
170
+ used_model=input_data.model,
171
+ usage=data["usage"],
172
+ )
173
+
174
+ except Exception as e:
175
+ raise ProcessingError(f"Fireworks structured completion failed: {str(e)}")