airtrain 0.1.2__py3-none-any.whl → 0.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (110) hide show
  1. airtrain/__init__.py +148 -2
  2. airtrain/__main__.py +4 -0
  3. airtrain/__pycache__/__init__.cpython-313.pyc +0 -0
  4. airtrain/agents/__init__.py +45 -0
  5. airtrain/agents/example_agent.py +348 -0
  6. airtrain/agents/groq_agent.py +289 -0
  7. airtrain/agents/memory.py +663 -0
  8. airtrain/agents/registry.py +465 -0
  9. airtrain/builder/__init__.py +3 -0
  10. airtrain/builder/agent_builder.py +122 -0
  11. airtrain/cli/__init__.py +0 -0
  12. airtrain/cli/builder.py +23 -0
  13. airtrain/cli/main.py +120 -0
  14. airtrain/contrib/__init__.py +29 -0
  15. airtrain/contrib/travel/__init__.py +35 -0
  16. airtrain/contrib/travel/agents.py +243 -0
  17. airtrain/contrib/travel/models.py +59 -0
  18. airtrain/core/__init__.py +7 -0
  19. airtrain/core/__pycache__/__init__.cpython-313.pyc +0 -0
  20. airtrain/core/__pycache__/schemas.cpython-313.pyc +0 -0
  21. airtrain/core/__pycache__/skills.cpython-313.pyc +0 -0
  22. airtrain/core/credentials.py +171 -0
  23. airtrain/core/schemas.py +237 -0
  24. airtrain/core/skills.py +269 -0
  25. airtrain/integrations/__init__.py +74 -0
  26. airtrain/integrations/anthropic/__init__.py +33 -0
  27. airtrain/integrations/anthropic/credentials.py +32 -0
  28. airtrain/integrations/anthropic/list_models.py +110 -0
  29. airtrain/integrations/anthropic/models_config.py +100 -0
  30. airtrain/integrations/anthropic/skills.py +155 -0
  31. airtrain/integrations/aws/__init__.py +6 -0
  32. airtrain/integrations/aws/credentials.py +36 -0
  33. airtrain/integrations/aws/skills.py +98 -0
  34. airtrain/integrations/cerebras/__init__.py +6 -0
  35. airtrain/integrations/cerebras/credentials.py +19 -0
  36. airtrain/integrations/cerebras/skills.py +127 -0
  37. airtrain/integrations/combined/__init__.py +21 -0
  38. airtrain/integrations/combined/groq_fireworks_skills.py +126 -0
  39. airtrain/integrations/combined/list_models_factory.py +210 -0
  40. airtrain/integrations/fireworks/__init__.py +21 -0
  41. airtrain/integrations/fireworks/completion_skills.py +147 -0
  42. airtrain/integrations/fireworks/conversation_manager.py +109 -0
  43. airtrain/integrations/fireworks/credentials.py +26 -0
  44. airtrain/integrations/fireworks/list_models.py +128 -0
  45. airtrain/integrations/fireworks/models.py +139 -0
  46. airtrain/integrations/fireworks/requests_skills.py +207 -0
  47. airtrain/integrations/fireworks/skills.py +181 -0
  48. airtrain/integrations/fireworks/structured_completion_skills.py +175 -0
  49. airtrain/integrations/fireworks/structured_requests_skills.py +291 -0
  50. airtrain/integrations/fireworks/structured_skills.py +102 -0
  51. airtrain/integrations/google/__init__.py +7 -0
  52. airtrain/integrations/google/credentials.py +58 -0
  53. airtrain/integrations/google/skills.py +122 -0
  54. airtrain/integrations/groq/__init__.py +23 -0
  55. airtrain/integrations/groq/credentials.py +24 -0
  56. airtrain/integrations/groq/models_config.py +162 -0
  57. airtrain/integrations/groq/skills.py +201 -0
  58. airtrain/integrations/ollama/__init__.py +6 -0
  59. airtrain/integrations/ollama/credentials.py +26 -0
  60. airtrain/integrations/ollama/skills.py +41 -0
  61. airtrain/integrations/openai/__init__.py +37 -0
  62. airtrain/integrations/openai/chinese_assistant.py +42 -0
  63. airtrain/integrations/openai/credentials.py +39 -0
  64. airtrain/integrations/openai/list_models.py +112 -0
  65. airtrain/integrations/openai/models_config.py +224 -0
  66. airtrain/integrations/openai/skills.py +342 -0
  67. airtrain/integrations/perplexity/__init__.py +49 -0
  68. airtrain/integrations/perplexity/credentials.py +43 -0
  69. airtrain/integrations/perplexity/list_models.py +112 -0
  70. airtrain/integrations/perplexity/models_config.py +128 -0
  71. airtrain/integrations/perplexity/skills.py +279 -0
  72. airtrain/integrations/sambanova/__init__.py +6 -0
  73. airtrain/integrations/sambanova/credentials.py +20 -0
  74. airtrain/integrations/sambanova/skills.py +129 -0
  75. airtrain/integrations/search/__init__.py +21 -0
  76. airtrain/integrations/search/exa/__init__.py +23 -0
  77. airtrain/integrations/search/exa/credentials.py +30 -0
  78. airtrain/integrations/search/exa/schemas.py +114 -0
  79. airtrain/integrations/search/exa/skills.py +115 -0
  80. airtrain/integrations/together/__init__.py +33 -0
  81. airtrain/integrations/together/audio_models_config.py +34 -0
  82. airtrain/integrations/together/credentials.py +22 -0
  83. airtrain/integrations/together/embedding_models_config.py +92 -0
  84. airtrain/integrations/together/image_models_config.py +69 -0
  85. airtrain/integrations/together/image_skill.py +143 -0
  86. airtrain/integrations/together/list_models.py +76 -0
  87. airtrain/integrations/together/models.py +95 -0
  88. airtrain/integrations/together/models_config.py +399 -0
  89. airtrain/integrations/together/rerank_models_config.py +43 -0
  90. airtrain/integrations/together/rerank_skill.py +49 -0
  91. airtrain/integrations/together/schemas.py +33 -0
  92. airtrain/integrations/together/skills.py +305 -0
  93. airtrain/integrations/together/vision_models_config.py +49 -0
  94. airtrain/telemetry/__init__.py +38 -0
  95. airtrain/telemetry/service.py +167 -0
  96. airtrain/telemetry/views.py +237 -0
  97. airtrain/tools/__init__.py +45 -0
  98. airtrain/tools/command.py +398 -0
  99. airtrain/tools/filesystem.py +166 -0
  100. airtrain/tools/network.py +111 -0
  101. airtrain/tools/registry.py +320 -0
  102. airtrain/tools/search.py +450 -0
  103. airtrain/tools/testing.py +135 -0
  104. airtrain-0.1.4.dist-info/METADATA +222 -0
  105. airtrain-0.1.4.dist-info/RECORD +108 -0
  106. {airtrain-0.1.2.dist-info → airtrain-0.1.4.dist-info}/WHEEL +1 -1
  107. airtrain-0.1.4.dist-info/entry_points.txt +2 -0
  108. airtrain-0.1.2.dist-info/METADATA +0 -106
  109. airtrain-0.1.2.dist-info/RECORD +0 -5
  110. {airtrain-0.1.2.dist-info → airtrain-0.1.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,128 @@
1
+ from typing import Optional, List
2
+ import requests
3
+ from pydantic import Field
4
+
5
+ from airtrain.core.skills import Skill, ProcessingError
6
+ from airtrain.core.schemas import InputSchema, OutputSchema
7
+ from .credentials import FireworksCredentials
8
+ from .models import FireworksModel
9
+
10
+
11
+ class FireworksListModelsInput(InputSchema):
12
+ """Schema for Fireworks AI list models input"""
13
+
14
+ account_id: str = Field(..., description="The Account Id")
15
+ page_size: Optional[int] = Field(
16
+ default=50,
17
+ description=(
18
+ "The maximum number of models to return. The maximum page_size is 200, "
19
+ "values above 200 will be coerced to 200."
20
+ ),
21
+ le=200
22
+ )
23
+ page_token: Optional[str] = Field(
24
+ default=None,
25
+ description=(
26
+ "A page token, received from a previous ListModels call. Provide this "
27
+ "to retrieve the subsequent page. When paginating, all other parameters "
28
+ "provided to ListModels must match the call that provided the page token."
29
+ )
30
+ )
31
+ filter: Optional[str] = Field(
32
+ default=None,
33
+ description=(
34
+ "Only model satisfying the provided filter (if specified) will be "
35
+ "returned. See https://google.aip.dev/160 for the filter grammar."
36
+ )
37
+ )
38
+ order_by: Optional[str] = Field(
39
+ default=None,
40
+ description=(
41
+ "A comma-separated list of fields to order by. e.g. \"foo,bar\" "
42
+ "The default sort order is ascending. To specify a descending order for a "
43
+ "field, append a \" desc\" suffix. e.g. \"foo desc,bar\" "
44
+ "Subfields are specified with a \".\" character. e.g. \"foo.bar\" "
45
+ "If not specified, the default order is by \"name\"."
46
+ )
47
+ )
48
+
49
+
50
+ class FireworksListModelsOutput(OutputSchema):
51
+ """Schema for Fireworks AI list models output"""
52
+
53
+ models: List[FireworksModel] = Field(
54
+ default_factory=list,
55
+ description="List of Fireworks models"
56
+ )
57
+ next_page_token: Optional[str] = Field(
58
+ default=None,
59
+ description="Token for retrieving the next page of results"
60
+ )
61
+ total_size: Optional[int] = Field(
62
+ default=None,
63
+ description="Total number of models available"
64
+ )
65
+
66
+
67
+ class FireworksListModelsSkill(
68
+ Skill[FireworksListModelsInput, FireworksListModelsOutput]
69
+ ):
70
+ """Skill for listing Fireworks AI models"""
71
+
72
+ input_schema = FireworksListModelsInput
73
+ output_schema = FireworksListModelsOutput
74
+
75
+ def __init__(self, credentials: Optional[FireworksCredentials] = None):
76
+ """Initialize the skill with optional credentials"""
77
+ super().__init__()
78
+ self.credentials = credentials or FireworksCredentials.from_env()
79
+ self.base_url = "https://api.fireworks.ai/v1"
80
+
81
+ def process(
82
+ self, input_data: FireworksListModelsInput
83
+ ) -> FireworksListModelsOutput:
84
+ """Process the input and return a list of models."""
85
+ try:
86
+ # Build the URL
87
+ url = f"{self.base_url}/accounts/{input_data.account_id}/models"
88
+
89
+ # Prepare query parameters
90
+ params = {}
91
+ if input_data.page_size:
92
+ params["pageSize"] = input_data.page_size
93
+ if input_data.page_token:
94
+ params["pageToken"] = input_data.page_token
95
+ if input_data.filter:
96
+ params["filter"] = input_data.filter
97
+ if input_data.order_by:
98
+ params["orderBy"] = input_data.order_by
99
+
100
+ # Make the request
101
+ headers = {
102
+ "Authorization": (
103
+ f"Bearer {self.credentials.fireworks_api_key.get_secret_value()}"
104
+ )
105
+ }
106
+
107
+ response = requests.get(url, headers=headers, params=params)
108
+ response.raise_for_status()
109
+
110
+ # Parse the response
111
+ result = response.json()
112
+
113
+ # Convert the models to FireworksModel objects
114
+ models = []
115
+ for model_data in result.get("models", []):
116
+ models.append(FireworksModel(**model_data))
117
+
118
+ # Return the output
119
+ return FireworksListModelsOutput(
120
+ models=models,
121
+ next_page_token=result.get("nextPageToken"),
122
+ total_size=result.get("totalSize")
123
+ )
124
+
125
+ except requests.RequestException as e:
126
+ raise ProcessingError(f"Failed to list Fireworks models: {str(e)}")
127
+ except Exception as e:
128
+ raise ProcessingError(f"Error listing Fireworks models: {str(e)}")
@@ -0,0 +1,139 @@
1
+ from typing import List, Optional, Dict, Any
2
+ from pydantic import Field, BaseModel
3
+
4
+
5
+ class FireworksMessage(BaseModel):
6
+ """Schema for Fireworks chat message"""
7
+
8
+ content: str
9
+ role: str = Field(..., pattern="^(system|user|assistant)$")
10
+
11
+
12
+ class FireworksUsage(BaseModel):
13
+ """Schema for Fireworks API usage statistics"""
14
+
15
+ prompt_tokens: int
16
+ completion_tokens: int
17
+ total_tokens: int
18
+
19
+
20
+ class FireworksResponse(BaseModel):
21
+ """Schema for Fireworks API response"""
22
+
23
+ id: str
24
+ choices: List[Dict[str, Any]]
25
+ created: int
26
+ model: str
27
+ usage: FireworksUsage
28
+
29
+
30
+ class FireworksModelStatus(BaseModel):
31
+ """Schema for Fireworks model status"""
32
+ # This would be filled with actual fields from the API response
33
+
34
+
35
+ class FireworksModelBaseDetails(BaseModel):
36
+ """Schema for Fireworks base model details"""
37
+ # This would be filled with actual fields from the API response
38
+
39
+
40
+ class FireworksPeftDetails(BaseModel):
41
+ """Schema for Fireworks PEFT details"""
42
+ # This would be filled with actual fields from the API response
43
+
44
+
45
+ class FireworksConversationConfig(BaseModel):
46
+ """Schema for Fireworks conversation configuration"""
47
+ # This would be filled with actual fields from the API response
48
+
49
+
50
+ class FireworksModelDeployedRef(BaseModel):
51
+ """Schema for Fireworks deployed model reference"""
52
+ # This would be filled with actual fields from the API response
53
+
54
+
55
+ class FireworksDeprecationDate(BaseModel):
56
+ """Schema for Fireworks deprecation date"""
57
+ # This would be filled with actual fields from the API response
58
+
59
+
60
+ class FireworksModel(BaseModel):
61
+ """Schema for a Fireworks model"""
62
+
63
+ name: str
64
+ display_name: Optional[str] = None
65
+ description: Optional[str] = None
66
+ create_time: Optional[str] = None
67
+ created_by: Optional[str] = None
68
+ state: Optional[str] = None
69
+ status: Optional[Dict[str, Any]] = None
70
+ kind: Optional[str] = None
71
+ github_url: Optional[str] = None
72
+ hugging_face_url: Optional[str] = None
73
+ base_model_details: Optional[Dict[str, Any]] = None
74
+ peft_details: Optional[Dict[str, Any]] = None
75
+ teft_details: Optional[Dict[str, Any]] = None
76
+ public: Optional[bool] = None
77
+ conversation_config: Optional[Dict[str, Any]] = None
78
+ context_length: Optional[int] = None
79
+ supports_image_input: Optional[bool] = None
80
+ supports_tools: Optional[bool] = None
81
+ imported_from: Optional[str] = None
82
+ fine_tuning_job: Optional[str] = None
83
+ default_draft_model: Optional[str] = None
84
+ default_draft_token_count: Optional[int] = None
85
+ precisions: Optional[List[str]] = None
86
+ deployed_model_refs: Optional[List[Dict[str, Any]]] = None
87
+ cluster: Optional[str] = None
88
+ deprecation_date: Optional[Dict[str, Any]] = None
89
+ calibrated: Optional[bool] = None
90
+ tunable: Optional[bool] = None
91
+ supports_lora: Optional[bool] = None
92
+ use_hf_apply_chat_template: Optional[bool] = None
93
+
94
+
95
+ class ListModelsInput(BaseModel):
96
+ """Schema for listing Fireworks models input"""
97
+
98
+ account_id: str = Field(..., description="The Account Id")
99
+ page_size: Optional[int] = Field(
100
+ default=50,
101
+ description=(
102
+ "The maximum number of models to return. The maximum page_size is 200, "
103
+ "values above 200 will be coerced to 200."
104
+ ),
105
+ le=200
106
+ )
107
+ page_token: Optional[str] = Field(
108
+ default=None,
109
+ description=(
110
+ "A page token, received from a previous ListModels call. Provide this "
111
+ "to retrieve the subsequent page. When paginating, all other parameters "
112
+ "provided to ListModels must match the call that provided the page token."
113
+ )
114
+ )
115
+ filter: Optional[str] = Field(
116
+ default=None,
117
+ description=(
118
+ "Only model satisfying the provided filter (if specified) will be "
119
+ "returned. See https://google.aip.dev/160 for the filter grammar."
120
+ )
121
+ )
122
+ order_by: Optional[str] = Field(
123
+ default=None,
124
+ description=(
125
+ "A comma-separated list of fields to order by. e.g. \"foo,bar\" "
126
+ "The default sort order is ascending. To specify a descending order for a "
127
+ "field, append a \" desc\" suffix. e.g. \"foo desc,bar\" "
128
+ "Subfields are specified with a \".\" character. e.g. \"foo.bar\" "
129
+ "If not specified, the default order is by \"name\"."
130
+ )
131
+ )
132
+
133
+
134
+ class ListModelsOutput(BaseModel):
135
+ """Schema for listing Fireworks models output"""
136
+
137
+ models: List[FireworksModel]
138
+ next_page_token: Optional[str] = None
139
+ total_size: Optional[int] = None
@@ -0,0 +1,207 @@
1
+ from typing import List, Optional, Dict, Any, Generator, AsyncGenerator
2
+ from pydantic import Field
3
+ import requests
4
+ import json
5
+ from loguru import logger
6
+ import aiohttp
7
+
8
+ from airtrain.core.skills import Skill, ProcessingError
9
+ from airtrain.core.schemas import InputSchema, OutputSchema
10
+ from .credentials import FireworksCredentials
11
+
12
+
13
+ class FireworksRequestInput(InputSchema):
14
+ """Schema for Fireworks AI chat input using requests"""
15
+
16
+ user_input: str = Field(..., description="User's input text")
17
+ system_prompt: str = Field(
18
+ default="You are a helpful assistant.",
19
+ description="System prompt to guide the model's behavior",
20
+ )
21
+ conversation_history: List[Dict[str, str]] = Field(
22
+ default_factory=list,
23
+ description="List of previous conversation messages",
24
+ )
25
+ model: str = Field(
26
+ default="accounts/fireworks/models/deepseek-r1",
27
+ description="Fireworks AI model to use",
28
+ )
29
+ temperature: float = Field(
30
+ default=0.7, description="Temperature for response generation", ge=0, le=1
31
+ )
32
+ max_tokens: int = Field(default=131072, description="Maximum tokens in response")
33
+ top_p: float = Field(
34
+ default=1.0, description="Top p sampling parameter", ge=0, le=1
35
+ )
36
+ top_k: int = Field(default=40, description="Top k sampling parameter", ge=0)
37
+ presence_penalty: float = Field(
38
+ default=0.0, description="Presence penalty", ge=-2.0, le=2.0
39
+ )
40
+ frequency_penalty: float = Field(
41
+ default=0.0, description="Frequency penalty", ge=-2.0, le=2.0
42
+ )
43
+ stream: bool = Field(
44
+ default=False,
45
+ description="Whether to stream the response",
46
+ )
47
+
48
+
49
+ class FireworksRequestOutput(OutputSchema):
50
+ """Schema for Fireworks AI chat output"""
51
+
52
+ response: str
53
+ used_model: str
54
+ usage: Dict[str, int]
55
+
56
+
57
+ class FireworksRequestSkill(Skill[FireworksRequestInput, FireworksRequestOutput]):
58
+ """Skill for interacting with Fireworks AI models using requests"""
59
+
60
+ input_schema = FireworksRequestInput
61
+ output_schema = FireworksRequestOutput
62
+ BASE_URL = "https://api.fireworks.ai/inference/v1/chat/completions"
63
+
64
+ def __init__(self, credentials: Optional[FireworksCredentials] = None):
65
+ """Initialize the skill with optional credentials"""
66
+ super().__init__()
67
+ self.credentials = credentials or FireworksCredentials.from_env()
68
+ self.headers = {
69
+ "Accept": "application/json",
70
+ "Content-Type": "application/json",
71
+ "Authorization": f"Bearer {self.credentials.fireworks_api_key.get_secret_value()}",
72
+ }
73
+ self.stream_headers = {
74
+ "Accept": "text/event-stream",
75
+ "Content-Type": "application/json",
76
+ "Authorization": f"Bearer {self.credentials.fireworks_api_key.get_secret_value()}",
77
+ }
78
+
79
+ def _build_messages(
80
+ self, input_data: FireworksRequestInput
81
+ ) -> List[Dict[str, str]]:
82
+ """Build messages list from input data including conversation history."""
83
+ messages = [{"role": "system", "content": input_data.system_prompt}]
84
+
85
+ if input_data.conversation_history:
86
+ messages.extend(input_data.conversation_history)
87
+
88
+ messages.append({"role": "user", "content": input_data.user_input})
89
+ return messages
90
+
91
+ def _build_payload(self, input_data: FireworksRequestInput) -> Dict[str, Any]:
92
+ """Build the request payload."""
93
+ return {
94
+ "model": input_data.model,
95
+ "messages": self._build_messages(input_data),
96
+ "temperature": input_data.temperature,
97
+ "max_tokens": input_data.max_tokens,
98
+ "top_p": input_data.top_p,
99
+ "top_k": input_data.top_k,
100
+ "presence_penalty": input_data.presence_penalty,
101
+ "frequency_penalty": input_data.frequency_penalty,
102
+ "stream": input_data.stream,
103
+ }
104
+
105
+ def process_stream(
106
+ self, input_data: FireworksRequestInput
107
+ ) -> Generator[str, None, None]:
108
+ """Process the input and stream the response."""
109
+ try:
110
+ payload = self._build_payload(input_data)
111
+ response = requests.post(
112
+ self.BASE_URL,
113
+ headers=self.headers,
114
+ data=json.dumps(payload),
115
+ stream=True,
116
+ )
117
+ response.raise_for_status()
118
+
119
+ for line in response.iter_lines():
120
+ if line:
121
+ try:
122
+ data = json.loads(line.decode("utf-8").removeprefix("data: "))
123
+ if data["choices"][0]["delta"].get("content"):
124
+ yield data["choices"][0]["delta"]["content"]
125
+ except json.JSONDecodeError:
126
+ continue
127
+
128
+ except Exception as e:
129
+ raise ProcessingError(f"Fireworks streaming request failed: {str(e)}")
130
+
131
+ def process(self, input_data: FireworksRequestInput) -> FireworksRequestOutput:
132
+ """Process the input and return the complete response."""
133
+ try:
134
+ if input_data.stream:
135
+ # For streaming, collect the entire response
136
+ response_chunks = []
137
+ for chunk in self.process_stream(input_data):
138
+ response_chunks.append(chunk)
139
+ response_text = "".join(response_chunks)
140
+ usage = {} # Usage stats not available in streaming mode
141
+ else:
142
+ # For non-streaming, use regular request
143
+ payload = self._build_payload(input_data)
144
+ response = requests.post(
145
+ self.BASE_URL, headers=self.headers, data=json.dumps(payload)
146
+ )
147
+ response.raise_for_status()
148
+ data = response.json()
149
+
150
+ response_text = data["choices"][0]["message"]["content"]
151
+ usage = data["usage"]
152
+
153
+ return FireworksRequestOutput(
154
+ response=response_text, used_model=input_data.model, usage=usage
155
+ )
156
+
157
+ except Exception as e:
158
+ raise ProcessingError(f"Fireworks request failed: {str(e)}")
159
+
160
+ async def process_async(
161
+ self, input_data: FireworksRequestInput
162
+ ) -> FireworksRequestOutput:
163
+ """Async version of process method using aiohttp"""
164
+ try:
165
+ async with aiohttp.ClientSession() as session:
166
+ payload = self._build_payload(input_data)
167
+ async with session.post(
168
+ self.BASE_URL, headers=self.headers, json=payload
169
+ ) as response:
170
+ response.raise_for_status()
171
+ data = await response.json()
172
+
173
+ return FireworksRequestOutput(
174
+ response=data["choices"][0]["message"]["content"],
175
+ used_model=input_data.model,
176
+ usage=data.get("usage", {}),
177
+ )
178
+
179
+ except Exception as e:
180
+ raise ProcessingError(f"Async Fireworks request failed: {str(e)}")
181
+
182
+ async def process_stream_async(
183
+ self, input_data: FireworksRequestInput
184
+ ) -> AsyncGenerator[str, None]:
185
+ """Async version of stream processor using aiohttp"""
186
+ try:
187
+ async with aiohttp.ClientSession() as session:
188
+ payload = self._build_payload(input_data)
189
+ async with session.post(
190
+ self.BASE_URL, headers=self.stream_headers, json=payload
191
+ ) as response:
192
+ response.raise_for_status()
193
+
194
+ async for line in response.content:
195
+ if line.startswith(b"data: "):
196
+ chunk = json.loads(line[6:].strip())
197
+ if "choices" in chunk:
198
+ content = (
199
+ chunk["choices"][0]
200
+ .get("delta", {})
201
+ .get("content", "")
202
+ )
203
+ if content:
204
+ yield content
205
+
206
+ except Exception as e:
207
+ raise ProcessingError(f"Async Fireworks streaming failed: {str(e)}")
@@ -0,0 +1,181 @@
1
+ from typing import List, Optional, Dict, Any, Generator, Union
2
+ from pydantic import Field
3
+ from openai import OpenAI
4
+
5
+ from airtrain.core.skills import Skill, ProcessingError
6
+ from airtrain.core.schemas import InputSchema, OutputSchema
7
+ from .credentials import FireworksCredentials
8
+
9
+
10
+ class FireworksInput(InputSchema):
11
+ """Schema for Fireworks AI chat input"""
12
+
13
+ user_input: str = Field(..., description="User's input text")
14
+ system_prompt: str = Field(
15
+ default="You are a helpful assistant.",
16
+ description="System prompt to guide the model's behavior",
17
+ )
18
+ conversation_history: List[Dict[str, Any]] = Field(
19
+ default_factory=list,
20
+ description="List of previous conversation messages",
21
+ )
22
+ model: str = Field(
23
+ default="accounts/fireworks/models/deepseek-r1",
24
+ description="Fireworks AI model to use",
25
+ )
26
+ temperature: float = Field(
27
+ default=0.7, description="Temperature for response generation", ge=0, le=1
28
+ )
29
+ max_tokens: Optional[int] = Field(
30
+ default=131072, description="Maximum tokens in response"
31
+ )
32
+ context_length_exceeded_behavior: str = Field(
33
+ default="truncate", description="Behavior when context length is exceeded"
34
+ )
35
+ stream: bool = Field(
36
+ default=False,
37
+ description="Whether to stream the response token by token",
38
+ )
39
+ tools: Optional[List[Dict[str, Any]]] = Field(
40
+ default=None,
41
+ description=(
42
+ "A list of tools the model may use. "
43
+ "Currently only functions supported."
44
+ ),
45
+ )
46
+ tool_choice: Optional[Union[str, Dict[str, Any]]] = Field(
47
+ default=None,
48
+ description=(
49
+ "Controls which tool is called by the model. "
50
+ "'none', 'auto', or specific tool."
51
+ ),
52
+ )
53
+
54
+
55
+ class FireworksOutput(OutputSchema):
56
+ """Schema for Fireworks AI chat output"""
57
+
58
+ response: str = Field(..., description="Model's response text")
59
+ used_model: str = Field(..., description="Model used for generation")
60
+ usage: Dict[str, int] = Field(default_factory=dict, description="Usage statistics")
61
+ tool_calls: Optional[List[Dict[str, Any]]] = Field(
62
+ default=None, description="Tool calls generated by the model"
63
+ )
64
+
65
+
66
+ class FireworksChatSkill(Skill[FireworksInput, FireworksOutput]):
67
+ """Skill for interacting with Fireworks AI models"""
68
+
69
+ input_schema = FireworksInput
70
+ output_schema = FireworksOutput
71
+
72
+ def __init__(self, credentials: Optional[FireworksCredentials] = None):
73
+ """Initialize the skill with optional credentials"""
74
+ super().__init__()
75
+ self.credentials = credentials or FireworksCredentials.from_env()
76
+ self.client = OpenAI(
77
+ base_url="https://api.fireworks.ai/inference/v1",
78
+ api_key=self.credentials.fireworks_api_key.get_secret_value(),
79
+ )
80
+
81
+ def _build_messages(self, input_data: FireworksInput) -> List[Dict[str, Any]]:
82
+ """Build messages list from input data including conversation history."""
83
+ messages = [{"role": "system", "content": input_data.system_prompt}]
84
+
85
+ if input_data.conversation_history:
86
+ messages.extend(input_data.conversation_history)
87
+
88
+ messages.append({"role": "user", "content": input_data.user_input})
89
+ return messages
90
+
91
+ def process_stream(self, input_data: FireworksInput) -> Generator[str, None, None]:
92
+ """Process the input and stream the response token by token."""
93
+ try:
94
+ messages = self._build_messages(input_data)
95
+
96
+ stream = self.client.chat.completions.create(
97
+ model=input_data.model,
98
+ messages=messages,
99
+ temperature=input_data.temperature,
100
+ max_tokens=input_data.max_tokens,
101
+ stream=True,
102
+ )
103
+
104
+ for chunk in stream:
105
+ if chunk.choices[0].delta.content is not None:
106
+ yield chunk.choices[0].delta.content
107
+
108
+ except Exception as e:
109
+ raise ProcessingError(f"Fireworks streaming failed: {str(e)}")
110
+
111
+ def process(self, input_data: FireworksInput) -> FireworksOutput:
112
+ """Process the input and return the complete response."""
113
+ try:
114
+ if input_data.stream:
115
+ # For streaming, collect the entire response
116
+ response_chunks = []
117
+ for chunk in self.process_stream(input_data):
118
+ response_chunks.append(chunk)
119
+ response = "".join(response_chunks)
120
+
121
+ # Create completion object for usage stats
122
+ messages = self._build_messages(input_data)
123
+ completion = self.client.chat.completions.create(
124
+ model=input_data.model,
125
+ messages=messages,
126
+ temperature=input_data.temperature,
127
+ max_tokens=input_data.max_tokens,
128
+ stream=False,
129
+ )
130
+ else:
131
+ # For non-streaming, use regular completion
132
+ messages = self._build_messages(input_data)
133
+
134
+ # Prepare API call parameters
135
+ api_params = {
136
+ "model": input_data.model,
137
+ "messages": messages,
138
+ "temperature": input_data.temperature,
139
+ "max_tokens": input_data.max_tokens,
140
+ "stream": False,
141
+ }
142
+
143
+ # Add tools and tool_choice if provided
144
+ if input_data.tools:
145
+ api_params["tools"] = input_data.tools
146
+
147
+ if input_data.tool_choice:
148
+ api_params["tool_choice"] = input_data.tool_choice
149
+
150
+ completion = self.client.chat.completions.create(**api_params)
151
+ response = completion.choices[0].message.content or ""
152
+
153
+ # Check for tool calls in the response
154
+ tool_calls = None
155
+ if (hasattr(completion.choices[0].message, "tool_calls") and
156
+ completion.choices[0].message.tool_calls):
157
+ tool_calls = [
158
+ {
159
+ "id": tool_call.id,
160
+ "type": tool_call.type,
161
+ "function": {
162
+ "name": tool_call.function.name,
163
+ "arguments": tool_call.function.arguments
164
+ }
165
+ }
166
+ for tool_call in completion.choices[0].message.tool_calls
167
+ ]
168
+
169
+ return FireworksOutput(
170
+ response=response,
171
+ used_model=input_data.model,
172
+ usage={
173
+ "total_tokens": completion.usage.total_tokens,
174
+ "prompt_tokens": completion.usage.prompt_tokens,
175
+ "completion_tokens": completion.usage.completion_tokens,
176
+ },
177
+ tool_calls=tool_calls
178
+ )
179
+
180
+ except Exception as e:
181
+ raise ProcessingError(f"Fireworks chat failed: {str(e)}")