airtrain 0.1.2__py3-none-any.whl → 0.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (110) hide show
  1. airtrain/__init__.py +148 -2
  2. airtrain/__main__.py +4 -0
  3. airtrain/__pycache__/__init__.cpython-313.pyc +0 -0
  4. airtrain/agents/__init__.py +45 -0
  5. airtrain/agents/example_agent.py +348 -0
  6. airtrain/agents/groq_agent.py +289 -0
  7. airtrain/agents/memory.py +663 -0
  8. airtrain/agents/registry.py +465 -0
  9. airtrain/builder/__init__.py +3 -0
  10. airtrain/builder/agent_builder.py +122 -0
  11. airtrain/cli/__init__.py +0 -0
  12. airtrain/cli/builder.py +23 -0
  13. airtrain/cli/main.py +120 -0
  14. airtrain/contrib/__init__.py +29 -0
  15. airtrain/contrib/travel/__init__.py +35 -0
  16. airtrain/contrib/travel/agents.py +243 -0
  17. airtrain/contrib/travel/models.py +59 -0
  18. airtrain/core/__init__.py +7 -0
  19. airtrain/core/__pycache__/__init__.cpython-313.pyc +0 -0
  20. airtrain/core/__pycache__/schemas.cpython-313.pyc +0 -0
  21. airtrain/core/__pycache__/skills.cpython-313.pyc +0 -0
  22. airtrain/core/credentials.py +171 -0
  23. airtrain/core/schemas.py +237 -0
  24. airtrain/core/skills.py +269 -0
  25. airtrain/integrations/__init__.py +74 -0
  26. airtrain/integrations/anthropic/__init__.py +33 -0
  27. airtrain/integrations/anthropic/credentials.py +32 -0
  28. airtrain/integrations/anthropic/list_models.py +110 -0
  29. airtrain/integrations/anthropic/models_config.py +100 -0
  30. airtrain/integrations/anthropic/skills.py +155 -0
  31. airtrain/integrations/aws/__init__.py +6 -0
  32. airtrain/integrations/aws/credentials.py +36 -0
  33. airtrain/integrations/aws/skills.py +98 -0
  34. airtrain/integrations/cerebras/__init__.py +6 -0
  35. airtrain/integrations/cerebras/credentials.py +19 -0
  36. airtrain/integrations/cerebras/skills.py +127 -0
  37. airtrain/integrations/combined/__init__.py +21 -0
  38. airtrain/integrations/combined/groq_fireworks_skills.py +126 -0
  39. airtrain/integrations/combined/list_models_factory.py +210 -0
  40. airtrain/integrations/fireworks/__init__.py +21 -0
  41. airtrain/integrations/fireworks/completion_skills.py +147 -0
  42. airtrain/integrations/fireworks/conversation_manager.py +109 -0
  43. airtrain/integrations/fireworks/credentials.py +26 -0
  44. airtrain/integrations/fireworks/list_models.py +128 -0
  45. airtrain/integrations/fireworks/models.py +139 -0
  46. airtrain/integrations/fireworks/requests_skills.py +207 -0
  47. airtrain/integrations/fireworks/skills.py +181 -0
  48. airtrain/integrations/fireworks/structured_completion_skills.py +175 -0
  49. airtrain/integrations/fireworks/structured_requests_skills.py +291 -0
  50. airtrain/integrations/fireworks/structured_skills.py +102 -0
  51. airtrain/integrations/google/__init__.py +7 -0
  52. airtrain/integrations/google/credentials.py +58 -0
  53. airtrain/integrations/google/skills.py +122 -0
  54. airtrain/integrations/groq/__init__.py +23 -0
  55. airtrain/integrations/groq/credentials.py +24 -0
  56. airtrain/integrations/groq/models_config.py +162 -0
  57. airtrain/integrations/groq/skills.py +201 -0
  58. airtrain/integrations/ollama/__init__.py +6 -0
  59. airtrain/integrations/ollama/credentials.py +26 -0
  60. airtrain/integrations/ollama/skills.py +41 -0
  61. airtrain/integrations/openai/__init__.py +37 -0
  62. airtrain/integrations/openai/chinese_assistant.py +42 -0
  63. airtrain/integrations/openai/credentials.py +39 -0
  64. airtrain/integrations/openai/list_models.py +112 -0
  65. airtrain/integrations/openai/models_config.py +224 -0
  66. airtrain/integrations/openai/skills.py +342 -0
  67. airtrain/integrations/perplexity/__init__.py +49 -0
  68. airtrain/integrations/perplexity/credentials.py +43 -0
  69. airtrain/integrations/perplexity/list_models.py +112 -0
  70. airtrain/integrations/perplexity/models_config.py +128 -0
  71. airtrain/integrations/perplexity/skills.py +279 -0
  72. airtrain/integrations/sambanova/__init__.py +6 -0
  73. airtrain/integrations/sambanova/credentials.py +20 -0
  74. airtrain/integrations/sambanova/skills.py +129 -0
  75. airtrain/integrations/search/__init__.py +21 -0
  76. airtrain/integrations/search/exa/__init__.py +23 -0
  77. airtrain/integrations/search/exa/credentials.py +30 -0
  78. airtrain/integrations/search/exa/schemas.py +114 -0
  79. airtrain/integrations/search/exa/skills.py +115 -0
  80. airtrain/integrations/together/__init__.py +33 -0
  81. airtrain/integrations/together/audio_models_config.py +34 -0
  82. airtrain/integrations/together/credentials.py +22 -0
  83. airtrain/integrations/together/embedding_models_config.py +92 -0
  84. airtrain/integrations/together/image_models_config.py +69 -0
  85. airtrain/integrations/together/image_skill.py +143 -0
  86. airtrain/integrations/together/list_models.py +76 -0
  87. airtrain/integrations/together/models.py +95 -0
  88. airtrain/integrations/together/models_config.py +399 -0
  89. airtrain/integrations/together/rerank_models_config.py +43 -0
  90. airtrain/integrations/together/rerank_skill.py +49 -0
  91. airtrain/integrations/together/schemas.py +33 -0
  92. airtrain/integrations/together/skills.py +305 -0
  93. airtrain/integrations/together/vision_models_config.py +49 -0
  94. airtrain/telemetry/__init__.py +38 -0
  95. airtrain/telemetry/service.py +167 -0
  96. airtrain/telemetry/views.py +237 -0
  97. airtrain/tools/__init__.py +45 -0
  98. airtrain/tools/command.py +398 -0
  99. airtrain/tools/filesystem.py +166 -0
  100. airtrain/tools/network.py +111 -0
  101. airtrain/tools/registry.py +320 -0
  102. airtrain/tools/search.py +450 -0
  103. airtrain/tools/testing.py +135 -0
  104. airtrain-0.1.4.dist-info/METADATA +222 -0
  105. airtrain-0.1.4.dist-info/RECORD +108 -0
  106. {airtrain-0.1.2.dist-info → airtrain-0.1.4.dist-info}/WHEEL +1 -1
  107. airtrain-0.1.4.dist-info/entry_points.txt +2 -0
  108. airtrain-0.1.2.dist-info/METADATA +0 -106
  109. airtrain-0.1.2.dist-info/RECORD +0 -5
  110. {airtrain-0.1.2.dist-info → airtrain-0.1.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,175 @@
1
+ from typing import Any, Dict, Generator, List, Optional, Type, TypeVar
2
+ from pydantic import BaseModel, Field
3
+ import requests
4
+ import json
5
+
6
+ from airtrain.core.skills import Skill, ProcessingError
7
+ from airtrain.core.schemas import InputSchema, OutputSchema
8
+ from .credentials import FireworksCredentials
9
+
10
+ ResponseT = TypeVar("ResponseT")
11
+
12
+
13
+ class FireworksStructuredCompletionInput(InputSchema):
14
+ """Schema for Fireworks AI structured completion input"""
15
+
16
+ prompt: str = Field(..., description="Input prompt for completion")
17
+ model: str = Field(
18
+ default="accounts/fireworks/models/deepseek-r1",
19
+ description="Fireworks AI model to use",
20
+ )
21
+ temperature: float = Field(
22
+ default=0.7, description="Temperature for response generation", ge=0, le=1
23
+ )
24
+ max_tokens: int = Field(default=131072, description="Maximum tokens in response")
25
+ response_model: Type[ResponseT]
26
+ stream: bool = Field(
27
+ default=False,
28
+ description="Whether to stream the response token by token",
29
+ )
30
+
31
+ class Config:
32
+ arbitrary_types_allowed = True
33
+
34
+
35
+ class FireworksStructuredCompletionOutput(OutputSchema):
36
+ """Schema for Fireworks AI structured completion output"""
37
+
38
+ parsed_response: Any
39
+ used_model: str
40
+ usage: Dict[str, int]
41
+ tool_calls: Optional[List[Dict[str, Any]]] = Field(
42
+ default=None,
43
+ description=(
44
+ "Tool calls are not applicable for completions, "
45
+ "included for compatibility"
46
+ )
47
+ )
48
+
49
+
50
+ class FireworksStructuredCompletionSkill(
51
+ Skill[FireworksStructuredCompletionInput, FireworksStructuredCompletionOutput]
52
+ ):
53
+ """Skill for getting structured completion responses from Fireworks AI"""
54
+
55
+ input_schema = FireworksStructuredCompletionInput
56
+ output_schema = FireworksStructuredCompletionOutput
57
+ BASE_URL = "https://api.fireworks.ai/inference/v1/completions"
58
+
59
+ def __init__(self, credentials: Optional[FireworksCredentials] = None):
60
+ """Initialize the skill with optional credentials"""
61
+ super().__init__()
62
+ self.credentials = credentials or FireworksCredentials.from_env()
63
+ self.headers = {
64
+ "Accept": "application/json",
65
+ "Content-Type": "application/json",
66
+ "Authorization": f"Bearer {self.credentials.fireworks_api_key.get_secret_value()}",
67
+ }
68
+
69
+ def _build_payload(
70
+ self, input_data: FireworksStructuredCompletionInput
71
+ ) -> Dict[str, Any]:
72
+ """Build the request payload."""
73
+ return {
74
+ "model": input_data.model,
75
+ "prompt": input_data.prompt,
76
+ "temperature": input_data.temperature,
77
+ "max_tokens": input_data.max_tokens,
78
+ "stream": input_data.stream,
79
+ "response_format": {
80
+ "type": "json_object",
81
+ "schema": {
82
+ **input_data.response_model.model_json_schema(),
83
+ "required": [
84
+ field
85
+ for field, _ in input_data.response_model.model_fields.items()
86
+ ],
87
+ },
88
+ },
89
+ }
90
+
91
+ def process_stream(
92
+ self, input_data: FireworksStructuredCompletionInput
93
+ ) -> Generator[Dict[str, Any], None, None]:
94
+ """Process the input and stream the response."""
95
+ try:
96
+ payload = self._build_payload(input_data)
97
+ response = requests.post(
98
+ self.BASE_URL,
99
+ headers=self.headers,
100
+ data=json.dumps(payload),
101
+ stream=True,
102
+ )
103
+ response.raise_for_status()
104
+
105
+ json_buffer = []
106
+ for line in response.iter_lines():
107
+ if line:
108
+ try:
109
+ data = json.loads(line.decode("utf-8").removeprefix("data: "))
110
+ if data.get("choices") and data["choices"][0].get("text"):
111
+ content = data["choices"][0]["text"]
112
+ json_buffer.append(content)
113
+ yield {"chunk": content}
114
+ except json.JSONDecodeError:
115
+ continue
116
+
117
+ # Once complete, parse the full JSON
118
+ complete_json = "".join(json_buffer)
119
+ try:
120
+ parsed_response = input_data.response_model.model_validate_json(
121
+ complete_json
122
+ )
123
+ yield {"complete": parsed_response}
124
+ except Exception as e:
125
+ raise ProcessingError(f"Failed to parse JSON response: {str(e)}")
126
+
127
+ except Exception as e:
128
+ raise ProcessingError(f"Fireworks streaming request failed: {str(e)}")
129
+
130
+ def process(
131
+ self, input_data: FireworksStructuredCompletionInput
132
+ ) -> FireworksStructuredCompletionOutput:
133
+ """Process the input and return structured response."""
134
+ try:
135
+ if input_data.stream:
136
+ # For streaming, collect and parse the entire response
137
+ json_buffer = []
138
+ parsed_response = None
139
+
140
+ for chunk in self.process_stream(input_data):
141
+ if "chunk" in chunk:
142
+ json_buffer.append(chunk["chunk"])
143
+ elif "complete" in chunk:
144
+ parsed_response = chunk["complete"]
145
+
146
+ if parsed_response is None:
147
+ raise ProcessingError("Failed to parse streamed response")
148
+
149
+ return FireworksStructuredCompletionOutput(
150
+ parsed_response=parsed_response,
151
+ used_model=input_data.model,
152
+ usage={}, # Usage stats not available in streaming mode
153
+ )
154
+ else:
155
+ # For non-streaming, use regular request
156
+ payload = self._build_payload(input_data)
157
+ response = requests.post(
158
+ self.BASE_URL, headers=self.headers, data=json.dumps(payload)
159
+ )
160
+ response.raise_for_status()
161
+ data = response.json()
162
+
163
+ response_text = data["choices"][0]["text"]
164
+ parsed_response = input_data.response_model.model_validate_json(
165
+ response_text
166
+ )
167
+
168
+ return FireworksStructuredCompletionOutput(
169
+ parsed_response=parsed_response,
170
+ used_model=input_data.model,
171
+ usage=data["usage"],
172
+ )
173
+
174
+ except Exception as e:
175
+ raise ProcessingError(f"Fireworks structured completion failed: {str(e)}")
@@ -0,0 +1,291 @@
1
+ from typing import Type, TypeVar, Optional, List, Dict, Any, Generator, Union
2
+ from pydantic import BaseModel, Field, create_model
3
+ import requests
4
+ import json
5
+ from loguru import logger
6
+ import re
7
+
8
+ from airtrain.core.skills import Skill, ProcessingError
9
+ from airtrain.core.schemas import InputSchema, OutputSchema
10
+ from .credentials import FireworksCredentials
11
+
12
+ ResponseT = TypeVar("ResponseT", bound=BaseModel)
13
+
14
+
15
+ class FireworksStructuredRequestInput(InputSchema):
16
+ """Schema for Fireworks AI structured output input using requests"""
17
+
18
+ user_input: str = Field(..., description="User's input text")
19
+ system_prompt: str = Field(
20
+ default="You are a helpful assistant that provides structured data.",
21
+ description="System prompt to guide the model's behavior",
22
+ )
23
+ conversation_history: List[Dict[str, Any]] = Field(
24
+ default_factory=list,
25
+ description="List of previous conversation messages",
26
+ )
27
+ model: str = Field(
28
+ default="accounts/fireworks/models/deepseek-r1",
29
+ description="Fireworks AI model to use",
30
+ )
31
+ temperature: float = Field(
32
+ default=0.7, description="Temperature for response generation", ge=0, le=1
33
+ )
34
+ max_tokens: int = Field(default=131072, description="Maximum tokens in response")
35
+ response_model: Type[ResponseT]
36
+ stream: bool = Field(
37
+ default=False, description="Whether to stream the response token by token"
38
+ )
39
+ tools: Optional[List[Dict[str, Any]]] = Field(
40
+ default=None,
41
+ description=(
42
+ "A list of tools the model may use. "
43
+ "Currently only functions supported."
44
+ ),
45
+ )
46
+ tool_choice: Optional[Union[str, Dict[str, Any]]] = Field(
47
+ default=None,
48
+ description=(
49
+ "Controls which tool is called by the model. "
50
+ "'none', 'auto', or specific tool."
51
+ ),
52
+ )
53
+
54
+ class Config:
55
+ arbitrary_types_allowed = True
56
+
57
+
58
+ class FireworksStructuredRequestOutput(OutputSchema):
59
+ """Schema for Fireworks AI structured output"""
60
+
61
+ parsed_response: Any
62
+ used_model: str
63
+ usage: Dict[str, int]
64
+ reasoning: Optional[str] = None
65
+ tool_calls: Optional[List[Dict[str, Any]]] = Field(
66
+ default=None, description="Tool calls generated by the model"
67
+ )
68
+
69
+
70
+ class FireworksStructuredRequestSkill(
71
+ Skill[FireworksStructuredRequestInput, FireworksStructuredRequestOutput]
72
+ ):
73
+ """Skill for getting structured responses from Fireworks AI using requests"""
74
+
75
+ input_schema = FireworksStructuredRequestInput
76
+ output_schema = FireworksStructuredRequestOutput
77
+ BASE_URL = "https://api.fireworks.ai/inference/v1/chat/completions"
78
+
79
+ def __init__(self, credentials: Optional[FireworksCredentials] = None):
80
+ """Initialize the skill with optional credentials"""
81
+ super().__init__()
82
+ self.credentials = credentials or FireworksCredentials.from_env()
83
+ self.headers = {
84
+ "Accept": "application/json",
85
+ "Content-Type": "application/json",
86
+ "Authorization": f"Bearer {self.credentials.fireworks_api_key.get_secret_value()}",
87
+ }
88
+
89
+ def _build_messages(
90
+ self, input_data: FireworksStructuredRequestInput
91
+ ) -> List[Dict[str, Any]]:
92
+ """Build messages list from input data including conversation history."""
93
+ messages = [{"role": "system", "content": input_data.system_prompt}]
94
+
95
+ if input_data.conversation_history:
96
+ messages.extend(input_data.conversation_history)
97
+
98
+ messages.append({"role": "user", "content": input_data.user_input})
99
+ return messages
100
+
101
+ def _build_payload(
102
+ self, input_data: FireworksStructuredRequestInput
103
+ ) -> Dict[str, Any]:
104
+ """Build the request payload."""
105
+ payload = {
106
+ "model": input_data.model,
107
+ "messages": self._build_messages(input_data),
108
+ "temperature": input_data.temperature,
109
+ "max_tokens": input_data.max_tokens,
110
+ "stream": input_data.stream,
111
+ "response_format": {"type": "json_object"},
112
+ }
113
+
114
+ # Add tool-related parameters if provided
115
+ if input_data.tools:
116
+ payload["tools"] = input_data.tools
117
+
118
+ if input_data.tool_choice:
119
+ payload["tool_choice"] = input_data.tool_choice
120
+
121
+ return payload
122
+
123
+ def process_stream(
124
+ self, input_data: FireworksStructuredRequestInput
125
+ ) -> Generator[Dict[str, Any], None, None]:
126
+ """Process the input and stream the response."""
127
+ try:
128
+ payload = self._build_payload(input_data)
129
+ response = requests.post(
130
+ self.BASE_URL,
131
+ headers=self.headers,
132
+ data=json.dumps(payload),
133
+ stream=True,
134
+ )
135
+ response.raise_for_status()
136
+
137
+ json_buffer = []
138
+ for line in response.iter_lines():
139
+ if line:
140
+ try:
141
+ data = json.loads(line.decode("utf-8").removeprefix("data: "))
142
+ if data["choices"][0]["delta"].get("content"):
143
+ content = data["choices"][0]["delta"]["content"]
144
+ json_buffer.append(content)
145
+ yield {"chunk": content}
146
+ except json.JSONDecodeError:
147
+ continue
148
+
149
+ # Once complete, parse the full response with think tags
150
+ if not json_buffer:
151
+ # If no data was collected, raise error
152
+ raise ProcessingError("No data received from Fireworks API")
153
+
154
+ complete_response = "".join(json_buffer)
155
+ reasoning, json_str = self._parse_response_content(complete_response)
156
+
157
+ try:
158
+ parsed_response = input_data.response_model.model_validate_json(
159
+ json_str
160
+ )
161
+ yield {"complete": parsed_response, "reasoning": reasoning}
162
+ except Exception as e:
163
+ raise ProcessingError(f"Failed to parse JSON response: {str(e)}")
164
+
165
+ except Exception as e:
166
+ raise ProcessingError(f"Fireworks streaming request failed: {str(e)}")
167
+
168
+ def _parse_response_content(self, content: str) -> tuple[Optional[str], str]:
169
+ """Parse response content to extract reasoning and JSON."""
170
+ # Extract reasoning if present
171
+ reasoning_match = re.search(r"<think>(.*?)</think>", content, re.DOTALL)
172
+ reasoning = reasoning_match.group(1).strip() if reasoning_match else None
173
+
174
+ # Extract JSON
175
+ json_match = re.search(r"</think>\s*(\{.*\})", content, re.DOTALL)
176
+ json_str = json_match.group(1).strip() if json_match else content
177
+
178
+ return reasoning, json_str
179
+
180
+ def process(
181
+ self, input_data: FireworksStructuredRequestInput
182
+ ) -> FireworksStructuredRequestOutput:
183
+ """Process the input and return structured response."""
184
+ try:
185
+ if input_data.stream:
186
+ # For streaming, collect and parse the entire response
187
+ json_buffer = []
188
+ parsed_response = None
189
+ reasoning = None
190
+
191
+ for chunk in self.process_stream(input_data):
192
+ if "chunk" in chunk:
193
+ json_buffer.append(chunk["chunk"])
194
+ elif "complete" in chunk:
195
+ parsed_response = chunk["complete"]
196
+ reasoning = chunk.get("reasoning")
197
+
198
+ if parsed_response is None:
199
+ raise ProcessingError("Failed to parse streamed response")
200
+
201
+ # Make a non-streaming call to get tool calls if tools were provided
202
+ tool_calls = None
203
+ if input_data.tools:
204
+ # Create a non-streaming request to get tool calls
205
+ non_stream_payload = self._build_payload(input_data)
206
+ non_stream_payload["stream"] = False
207
+
208
+ response = requests.post(
209
+ self.BASE_URL,
210
+ headers=self.headers,
211
+ data=json.dumps(non_stream_payload),
212
+ )
213
+ response.raise_for_status()
214
+ result = response.json()
215
+
216
+ # Check for tool calls
217
+ if (result["choices"][0]["message"].get("tool_calls")):
218
+ tool_calls = [
219
+ {
220
+ "id": tool_call["id"],
221
+ "type": tool_call["type"],
222
+ "function": {
223
+ "name": tool_call["function"]["name"],
224
+ "arguments": tool_call["function"]["arguments"]
225
+ }
226
+ }
227
+ for tool_call in result["choices"][0]["message"]["tool_calls"]
228
+ ]
229
+
230
+ return FireworksStructuredRequestOutput(
231
+ parsed_response=parsed_response,
232
+ used_model=input_data.model,
233
+ usage={"total_tokens": 0}, # Can't get usage stats from streaming
234
+ reasoning=reasoning,
235
+ tool_calls=tool_calls,
236
+ )
237
+ else:
238
+ # For non-streaming, use regular request
239
+ payload = self._build_payload(input_data)
240
+ payload["stream"] = False # Ensure it's not streaming
241
+
242
+ response = requests.post(
243
+ self.BASE_URL, headers=self.headers, data=json.dumps(payload)
244
+ )
245
+ response.raise_for_status()
246
+ result = response.json()
247
+
248
+ # Get the content from the response
249
+ if "choices" not in result or not result["choices"]:
250
+ raise ProcessingError("Invalid response format from Fireworks API")
251
+
252
+ content = result["choices"][0]["message"].get("content", "")
253
+
254
+ # Check for tool calls
255
+ tool_calls = None
256
+ if (result["choices"][0]["message"].get("tool_calls")):
257
+ tool_calls = [
258
+ {
259
+ "id": tool_call["id"],
260
+ "type": tool_call["type"],
261
+ "function": {
262
+ "name": tool_call["function"]["name"],
263
+ "arguments": tool_call["function"]["arguments"]
264
+ }
265
+ }
266
+ for tool_call in result["choices"][0]["message"]["tool_calls"]
267
+ ]
268
+
269
+ # Parse the response content
270
+ reasoning, json_str = self._parse_response_content(content)
271
+ try:
272
+ parsed_response = input_data.response_model.model_validate_json(
273
+ json_str
274
+ )
275
+ except Exception as e:
276
+ raise ProcessingError(f"Failed to parse JSON response: {str(e)}")
277
+
278
+ return FireworksStructuredRequestOutput(
279
+ parsed_response=parsed_response,
280
+ used_model=input_data.model,
281
+ usage={
282
+ "total_tokens": result["usage"]["total_tokens"],
283
+ "prompt_tokens": result["usage"]["prompt_tokens"],
284
+ "completion_tokens": result["usage"]["completion_tokens"],
285
+ },
286
+ reasoning=reasoning,
287
+ tool_calls=tool_calls,
288
+ )
289
+
290
+ except Exception as e:
291
+ raise ProcessingError(f"Fireworks structured request failed: {str(e)}")
@@ -0,0 +1,102 @@
1
+ from typing import Type, TypeVar, Optional, List, Dict, Any
2
+ from pydantic import BaseModel, Field
3
+ from openai import OpenAI
4
+ from airtrain.core.skills import Skill, ProcessingError
5
+ from airtrain.core.schemas import InputSchema, OutputSchema
6
+ from .credentials import FireworksCredentials
7
+ import re
8
+
9
+ # Generic type variable for Pydantic response models
10
+ ResponseT = TypeVar("ResponseT", bound=BaseModel)
11
+
12
+
13
+ class FireworksParserInput(InputSchema):
14
+ """Schema for Fireworks structured output input"""
15
+
16
+ user_input: str
17
+ system_prompt: str = "You are a helpful assistant that provides structured data."
18
+ model: str = "accounts/fireworks/models/deepseek-r1"
19
+ temperature: float = 0.7
20
+ max_tokens: Optional[int] = 131072
21
+ response_model: Type[ResponseT]
22
+ conversation_history: List[Dict[str, str]] = Field(
23
+ default_factory=list,
24
+ description="List of previous conversation messages in [{'role': 'user|assistant', 'content': 'message'}] format",
25
+ )
26
+
27
+ class Config:
28
+ arbitrary_types_allowed = True
29
+
30
+
31
+ class FireworksParserOutput(OutputSchema):
32
+ """Schema for Fireworks structured output"""
33
+
34
+ parsed_response: BaseModel
35
+ used_model: str
36
+ tokens_used: int
37
+ reasoning: Optional[str] = None
38
+
39
+
40
+ class FireworksParserSkill(Skill[FireworksParserInput, FireworksParserOutput]):
41
+ """Skill for getting structured responses from Fireworks"""
42
+
43
+ input_schema = FireworksParserInput
44
+ output_schema = FireworksParserOutput
45
+
46
+ def __init__(self, credentials: Optional[FireworksCredentials] = None):
47
+ """Initialize the skill with optional credentials"""
48
+ super().__init__()
49
+ self.credentials = credentials or FireworksCredentials.from_env()
50
+ self.client = OpenAI(
51
+ base_url="https://api.fireworks.ai/inference/v1",
52
+ api_key=self.credentials.fireworks_api_key.get_secret_value(),
53
+ )
54
+
55
+ def process(self, input_data: FireworksParserInput) -> FireworksParserOutput:
56
+ try:
57
+ # Build messages list including conversation history
58
+ messages = [{"role": "system", "content": input_data.system_prompt}]
59
+
60
+ # Add conversation history if present
61
+ if input_data.conversation_history:
62
+ messages.extend(input_data.conversation_history)
63
+
64
+ # Add current user input
65
+ messages.append({"role": "user", "content": input_data.user_input})
66
+
67
+ # Make API call with JSON schema
68
+ completion = self.client.chat.completions.create(
69
+ model=input_data.model,
70
+ messages=messages,
71
+ response_format={
72
+ "type": "json_object",
73
+ "schema": input_data.response_model.model_json_schema(),
74
+ },
75
+ temperature=input_data.temperature,
76
+ max_tokens=input_data.max_tokens,
77
+ )
78
+
79
+ response_content = completion.choices[0].message.content
80
+
81
+ # Extract reasoning if present
82
+ reasoning_match = re.search(
83
+ r"<think>(.*?)</think>", response_content, re.DOTALL
84
+ )
85
+ reasoning = reasoning_match.group(1).strip() if reasoning_match else None
86
+
87
+ # Extract JSON
88
+ json_match = re.search(r"</think>\s*(\{.*\})", response_content, re.DOTALL)
89
+ json_str = json_match.group(1).strip() if json_match else response_content
90
+
91
+ # Parse the response into the specified model
92
+ parsed_response = input_data.response_model.parse_raw(json_str)
93
+
94
+ return FireworksParserOutput(
95
+ parsed_response=parsed_response,
96
+ used_model=completion.model,
97
+ tokens_used=completion.usage.total_tokens,
98
+ reasoning=reasoning,
99
+ )
100
+
101
+ except Exception as e:
102
+ raise ProcessingError(f"Fireworks parsing failed: {str(e)}")
@@ -0,0 +1,7 @@
1
+ """Google Cloud integration module"""
2
+
3
+ from .credentials import GoogleCloudCredentials
4
+ from .skills import GoogleChatSkill
5
+ # from .skills import VertexAISkill
6
+
7
+ __all__ = ["GoogleCloudCredentials", "VertexAISkill"]
@@ -0,0 +1,58 @@
1
+ from pydantic import Field, SecretStr
2
+ from airtrain.core.credentials import BaseCredentials, CredentialValidationError
3
+ import google.genai as genai
4
+ from google.cloud import storage
5
+ import os
6
+
7
+ # from google.cloud import storage
8
+
9
+
10
+ class GoogleCloudCredentials(BaseCredentials):
11
+ """Google Cloud credentials"""
12
+
13
+ project_id: str = Field(..., description="Google Cloud Project ID")
14
+ service_account_key: SecretStr = Field(..., description="Service Account Key JSON")
15
+
16
+ _required_credentials = {"project_id", "service_account_key"}
17
+
18
+ async def validate_credentials(self) -> bool:
19
+ """Validate Google Cloud credentials"""
20
+ try:
21
+ # Initialize with service account key
22
+ storage_client = storage.Client.from_service_account_info(
23
+ self.service_account_key.get_secret_value()
24
+ )
25
+ # Test API call
26
+ storage_client.list_buckets(max_results=1)
27
+ return True
28
+ except Exception as e:
29
+ raise CredentialValidationError(
30
+ f"Invalid Google Cloud credentials: {str(e)}"
31
+ )
32
+
33
+
34
+ class GeminiCredentials(BaseCredentials):
35
+ """Gemini API credentials"""
36
+
37
+ gemini_api_key: SecretStr = Field(..., description="Gemini API Key")
38
+
39
+ _required_credentials = {"gemini_api_key"}
40
+
41
+ @classmethod
42
+ def from_env(cls) -> "GeminiCredentials":
43
+ """Create credentials from environment variables"""
44
+ return cls(gemini_api_key=SecretStr(os.environ.get("GEMINI_API_KEY", "")))
45
+
46
+ async def validate_credentials(self) -> bool:
47
+ """Validate Gemini API credentials"""
48
+ try:
49
+ # Configure Gemini with API key
50
+ genai.configure(api_key=self.gemini_api_key.get_secret_value())
51
+
52
+ # Test API call with a simple model
53
+ model = genai.GenerativeModel("gemini-1.5-flash")
54
+ response = model.generate_content("test")
55
+
56
+ return True
57
+ except Exception as e:
58
+ raise CredentialValidationError(f"Invalid Gemini credentials: {str(e)}")