airtrain 0.1.2__py3-none-any.whl → 0.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airtrain/__init__.py +148 -2
- airtrain/__main__.py +4 -0
- airtrain/__pycache__/__init__.cpython-313.pyc +0 -0
- airtrain/agents/__init__.py +45 -0
- airtrain/agents/example_agent.py +348 -0
- airtrain/agents/groq_agent.py +289 -0
- airtrain/agents/memory.py +663 -0
- airtrain/agents/registry.py +465 -0
- airtrain/builder/__init__.py +3 -0
- airtrain/builder/agent_builder.py +122 -0
- airtrain/cli/__init__.py +0 -0
- airtrain/cli/builder.py +23 -0
- airtrain/cli/main.py +120 -0
- airtrain/contrib/__init__.py +29 -0
- airtrain/contrib/travel/__init__.py +35 -0
- airtrain/contrib/travel/agents.py +243 -0
- airtrain/contrib/travel/models.py +59 -0
- airtrain/core/__init__.py +7 -0
- airtrain/core/__pycache__/__init__.cpython-313.pyc +0 -0
- airtrain/core/__pycache__/schemas.cpython-313.pyc +0 -0
- airtrain/core/__pycache__/skills.cpython-313.pyc +0 -0
- airtrain/core/credentials.py +171 -0
- airtrain/core/schemas.py +237 -0
- airtrain/core/skills.py +269 -0
- airtrain/integrations/__init__.py +74 -0
- airtrain/integrations/anthropic/__init__.py +33 -0
- airtrain/integrations/anthropic/credentials.py +32 -0
- airtrain/integrations/anthropic/list_models.py +110 -0
- airtrain/integrations/anthropic/models_config.py +100 -0
- airtrain/integrations/anthropic/skills.py +155 -0
- airtrain/integrations/aws/__init__.py +6 -0
- airtrain/integrations/aws/credentials.py +36 -0
- airtrain/integrations/aws/skills.py +98 -0
- airtrain/integrations/cerebras/__init__.py +6 -0
- airtrain/integrations/cerebras/credentials.py +19 -0
- airtrain/integrations/cerebras/skills.py +127 -0
- airtrain/integrations/combined/__init__.py +21 -0
- airtrain/integrations/combined/groq_fireworks_skills.py +126 -0
- airtrain/integrations/combined/list_models_factory.py +210 -0
- airtrain/integrations/fireworks/__init__.py +21 -0
- airtrain/integrations/fireworks/completion_skills.py +147 -0
- airtrain/integrations/fireworks/conversation_manager.py +109 -0
- airtrain/integrations/fireworks/credentials.py +26 -0
- airtrain/integrations/fireworks/list_models.py +128 -0
- airtrain/integrations/fireworks/models.py +139 -0
- airtrain/integrations/fireworks/requests_skills.py +207 -0
- airtrain/integrations/fireworks/skills.py +181 -0
- airtrain/integrations/fireworks/structured_completion_skills.py +175 -0
- airtrain/integrations/fireworks/structured_requests_skills.py +291 -0
- airtrain/integrations/fireworks/structured_skills.py +102 -0
- airtrain/integrations/google/__init__.py +7 -0
- airtrain/integrations/google/credentials.py +58 -0
- airtrain/integrations/google/skills.py +122 -0
- airtrain/integrations/groq/__init__.py +23 -0
- airtrain/integrations/groq/credentials.py +24 -0
- airtrain/integrations/groq/models_config.py +162 -0
- airtrain/integrations/groq/skills.py +201 -0
- airtrain/integrations/ollama/__init__.py +6 -0
- airtrain/integrations/ollama/credentials.py +26 -0
- airtrain/integrations/ollama/skills.py +41 -0
- airtrain/integrations/openai/__init__.py +37 -0
- airtrain/integrations/openai/chinese_assistant.py +42 -0
- airtrain/integrations/openai/credentials.py +39 -0
- airtrain/integrations/openai/list_models.py +112 -0
- airtrain/integrations/openai/models_config.py +224 -0
- airtrain/integrations/openai/skills.py +342 -0
- airtrain/integrations/perplexity/__init__.py +49 -0
- airtrain/integrations/perplexity/credentials.py +43 -0
- airtrain/integrations/perplexity/list_models.py +112 -0
- airtrain/integrations/perplexity/models_config.py +128 -0
- airtrain/integrations/perplexity/skills.py +279 -0
- airtrain/integrations/sambanova/__init__.py +6 -0
- airtrain/integrations/sambanova/credentials.py +20 -0
- airtrain/integrations/sambanova/skills.py +129 -0
- airtrain/integrations/search/__init__.py +21 -0
- airtrain/integrations/search/exa/__init__.py +23 -0
- airtrain/integrations/search/exa/credentials.py +30 -0
- airtrain/integrations/search/exa/schemas.py +114 -0
- airtrain/integrations/search/exa/skills.py +115 -0
- airtrain/integrations/together/__init__.py +33 -0
- airtrain/integrations/together/audio_models_config.py +34 -0
- airtrain/integrations/together/credentials.py +22 -0
- airtrain/integrations/together/embedding_models_config.py +92 -0
- airtrain/integrations/together/image_models_config.py +69 -0
- airtrain/integrations/together/image_skill.py +143 -0
- airtrain/integrations/together/list_models.py +76 -0
- airtrain/integrations/together/models.py +95 -0
- airtrain/integrations/together/models_config.py +399 -0
- airtrain/integrations/together/rerank_models_config.py +43 -0
- airtrain/integrations/together/rerank_skill.py +49 -0
- airtrain/integrations/together/schemas.py +33 -0
- airtrain/integrations/together/skills.py +305 -0
- airtrain/integrations/together/vision_models_config.py +49 -0
- airtrain/telemetry/__init__.py +38 -0
- airtrain/telemetry/service.py +167 -0
- airtrain/telemetry/views.py +237 -0
- airtrain/tools/__init__.py +45 -0
- airtrain/tools/command.py +398 -0
- airtrain/tools/filesystem.py +166 -0
- airtrain/tools/network.py +111 -0
- airtrain/tools/registry.py +320 -0
- airtrain/tools/search.py +450 -0
- airtrain/tools/testing.py +135 -0
- airtrain-0.1.4.dist-info/METADATA +222 -0
- airtrain-0.1.4.dist-info/RECORD +108 -0
- {airtrain-0.1.2.dist-info → airtrain-0.1.4.dist-info}/WHEEL +1 -1
- airtrain-0.1.4.dist-info/entry_points.txt +2 -0
- airtrain-0.1.2.dist-info/METADATA +0 -106
- airtrain-0.1.2.dist-info/RECORD +0 -5
- {airtrain-0.1.2.dist-info → airtrain-0.1.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,126 @@
|
|
1
|
+
from typing import Optional, Dict, Any, List
|
2
|
+
from pydantic import Field
|
3
|
+
import requests
|
4
|
+
from groq import Groq
|
5
|
+
|
6
|
+
from airtrain.core.skills import Skill, ProcessingError
|
7
|
+
from airtrain.core.schemas import InputSchema, OutputSchema
|
8
|
+
from airtrain.integrations.fireworks.completion_skills import (
|
9
|
+
FireworksCompletionSkill,
|
10
|
+
FireworksCompletionInput,
|
11
|
+
)
|
12
|
+
|
13
|
+
|
14
|
+
class GroqFireworksInput(InputSchema):
|
15
|
+
"""Schema for combined Groq and Fireworks input"""
|
16
|
+
|
17
|
+
user_input: str = Field(..., description="User's input text")
|
18
|
+
groq_model: str = Field(
|
19
|
+
default="mixtral-8x7b-32768", description="Groq model to use"
|
20
|
+
)
|
21
|
+
fireworks_model: str = Field(
|
22
|
+
default="accounts/fireworks/models/deepseek-r1",
|
23
|
+
description="Fireworks model to use",
|
24
|
+
)
|
25
|
+
temperature: float = Field(
|
26
|
+
default=0.7, description="Temperature for response generation"
|
27
|
+
)
|
28
|
+
max_tokens: int = Field(default=131072, description="Maximum tokens in response")
|
29
|
+
|
30
|
+
|
31
|
+
class GroqFireworksOutput(OutputSchema):
|
32
|
+
"""Schema for combined Groq and Fireworks output"""
|
33
|
+
|
34
|
+
combined_response: str
|
35
|
+
groq_response: str
|
36
|
+
fireworks_response: str
|
37
|
+
used_models: Dict[str, str]
|
38
|
+
usage: Dict[str, Dict[str, int]]
|
39
|
+
|
40
|
+
|
41
|
+
class GroqFireworksSkill(Skill[GroqFireworksInput, GroqFireworksOutput]):
|
42
|
+
"""Skill combining Groq and Fireworks responses"""
|
43
|
+
|
44
|
+
input_schema = GroqFireworksInput
|
45
|
+
output_schema = GroqFireworksOutput
|
46
|
+
|
47
|
+
def __init__(
|
48
|
+
self,
|
49
|
+
groq_api_key: Optional[str] = None,
|
50
|
+
fireworks_skill: Optional[FireworksCompletionSkill] = None,
|
51
|
+
):
|
52
|
+
"""Initialize the skill with optional API keys"""
|
53
|
+
super().__init__()
|
54
|
+
self.groq_client = Groq(api_key=groq_api_key)
|
55
|
+
self.fireworks_skill = fireworks_skill or FireworksCompletionSkill()
|
56
|
+
|
57
|
+
def _get_groq_response(self, input_data: GroqFireworksInput) -> Dict[str, Any]:
|
58
|
+
"""Get response from Groq"""
|
59
|
+
try:
|
60
|
+
completion = self.groq_client.chat.completions.create(
|
61
|
+
model=input_data.groq_model,
|
62
|
+
messages=[{"role": "user", "content": input_data.user_input}],
|
63
|
+
temperature=input_data.temperature,
|
64
|
+
max_tokens=input_data.max_tokens,
|
65
|
+
)
|
66
|
+
return {
|
67
|
+
"response": completion.choices[0].message.content,
|
68
|
+
"usage": completion.usage.model_dump(),
|
69
|
+
}
|
70
|
+
except Exception as e:
|
71
|
+
raise ProcessingError(f"Groq request failed: {str(e)}")
|
72
|
+
|
73
|
+
def _get_fireworks_response(
|
74
|
+
self, groq_response: str, input_data: GroqFireworksInput
|
75
|
+
) -> Dict[str, Any]:
|
76
|
+
"""Get response from Fireworks"""
|
77
|
+
try:
|
78
|
+
formatted_prompt = (
|
79
|
+
f"<USER>{input_data.user_input}</USER>\n<ASSISTANT>{groq_response}"
|
80
|
+
)
|
81
|
+
|
82
|
+
fireworks_input = FireworksCompletionInput(
|
83
|
+
prompt=formatted_prompt,
|
84
|
+
model=input_data.fireworks_model,
|
85
|
+
temperature=input_data.temperature,
|
86
|
+
max_tokens=input_data.max_tokens,
|
87
|
+
)
|
88
|
+
|
89
|
+
result = self.fireworks_skill.process(fireworks_input)
|
90
|
+
return {"response": result.response, "usage": result.usage}
|
91
|
+
except Exception as e:
|
92
|
+
raise ProcessingError(f"Fireworks request failed: {str(e)}")
|
93
|
+
|
94
|
+
def process(self, input_data: GroqFireworksInput) -> GroqFireworksOutput:
|
95
|
+
"""Process the input using both Groq and Fireworks"""
|
96
|
+
try:
|
97
|
+
# Get Groq response
|
98
|
+
groq_result = self._get_groq_response(input_data)
|
99
|
+
|
100
|
+
# Get Fireworks response
|
101
|
+
fireworks_result = self._get_fireworks_response(
|
102
|
+
groq_result["response"], input_data
|
103
|
+
)
|
104
|
+
|
105
|
+
# Combine responses in the required format
|
106
|
+
combined_response = (
|
107
|
+
f"<USER>{input_data.user_input}</USER>\n"
|
108
|
+
f"<ASSISTANT>{groq_result['response']} {fireworks_result['response']}"
|
109
|
+
)
|
110
|
+
|
111
|
+
return GroqFireworksOutput(
|
112
|
+
combined_response=combined_response,
|
113
|
+
groq_response=groq_result["response"],
|
114
|
+
fireworks_response=fireworks_result["response"],
|
115
|
+
used_models={
|
116
|
+
"groq": input_data.groq_model,
|
117
|
+
"fireworks": input_data.fireworks_model,
|
118
|
+
},
|
119
|
+
usage={
|
120
|
+
"groq": groq_result["usage"],
|
121
|
+
"fireworks": fireworks_result["usage"],
|
122
|
+
},
|
123
|
+
)
|
124
|
+
|
125
|
+
except Exception as e:
|
126
|
+
raise ProcessingError(f"Combined processing failed: {str(e)}")
|
@@ -0,0 +1,210 @@
|
|
1
|
+
from typing import Optional, Dict, Any, List
|
2
|
+
from pydantic import Field
|
3
|
+
|
4
|
+
from airtrain.core.skills import Skill, ProcessingError
|
5
|
+
from airtrain.core.schemas import InputSchema, OutputSchema
|
6
|
+
from airtrain.core.credentials import BaseCredentials
|
7
|
+
|
8
|
+
# Import existing list models skills
|
9
|
+
from airtrain.integrations.openai.list_models import OpenAIListModelsSkill
|
10
|
+
from airtrain.integrations.anthropic.list_models import AnthropicListModelsSkill
|
11
|
+
from airtrain.integrations.together.list_models import TogetherListModelsSkill
|
12
|
+
from airtrain.integrations.fireworks.list_models import FireworksListModelsSkill
|
13
|
+
|
14
|
+
# Import credentials
|
15
|
+
from airtrain.integrations.groq.credentials import GroqCredentials
|
16
|
+
from airtrain.integrations.cerebras.credentials import CerebrasCredentials
|
17
|
+
from airtrain.integrations.sambanova.credentials import SambanovaCredentials
|
18
|
+
from airtrain.integrations.perplexity.credentials import PerplexityCredentials
|
19
|
+
|
20
|
+
# Import Perplexity list models
|
21
|
+
from airtrain.integrations.perplexity.list_models import PerplexityListModelsSkill
|
22
|
+
|
23
|
+
|
24
|
+
# Generic list models input schema
|
25
|
+
class GenericListModelsInput(InputSchema):
|
26
|
+
"""Generic schema for listing models from any provider"""
|
27
|
+
|
28
|
+
api_models_only: bool = Field(
|
29
|
+
default=False,
|
30
|
+
description=(
|
31
|
+
"If True, fetch models from the API only. If False, use local config."
|
32
|
+
),
|
33
|
+
)
|
34
|
+
|
35
|
+
class Config:
|
36
|
+
arbitrary_types_allowed = True
|
37
|
+
extra = "allow"
|
38
|
+
|
39
|
+
|
40
|
+
# Generic list models output schema
|
41
|
+
class GenericListModelsOutput(OutputSchema):
|
42
|
+
"""Generic schema for list models output from any provider"""
|
43
|
+
|
44
|
+
models: List[Dict[str, Any]] = Field(
|
45
|
+
default_factory=list, description="List of models"
|
46
|
+
)
|
47
|
+
provider: str = Field(..., description="Provider name")
|
48
|
+
|
49
|
+
|
50
|
+
# Base class for stub implementations
|
51
|
+
class BaseListModelsSkill(Skill[GenericListModelsInput, GenericListModelsOutput]):
|
52
|
+
"""Base skill for listing models"""
|
53
|
+
|
54
|
+
input_schema = GenericListModelsInput
|
55
|
+
output_schema = GenericListModelsOutput
|
56
|
+
|
57
|
+
def __init__(self, provider: str, credentials: Optional[BaseCredentials] = None):
|
58
|
+
"""Initialize the skill with provider name and optional credentials"""
|
59
|
+
super().__init__()
|
60
|
+
self.provider = provider
|
61
|
+
self.credentials = credentials
|
62
|
+
|
63
|
+
def get_models(self) -> List[Dict[str, Any]]:
|
64
|
+
"""Return list of models. To be implemented by subclasses."""
|
65
|
+
raise NotImplementedError("Subclasses must implement get_models()")
|
66
|
+
|
67
|
+
def process(self, input_data: GenericListModelsInput) -> GenericListModelsOutput:
|
68
|
+
"""Process the input and return a list of models."""
|
69
|
+
try:
|
70
|
+
models = self.get_models()
|
71
|
+
return GenericListModelsOutput(models=models, provider=self.provider)
|
72
|
+
except Exception as e:
|
73
|
+
raise ProcessingError(f"Failed to list {self.provider} models: {str(e)}")
|
74
|
+
|
75
|
+
|
76
|
+
# Groq implementation
|
77
|
+
class GroqListModelsSkill(BaseListModelsSkill):
|
78
|
+
"""Skill for listing Groq models"""
|
79
|
+
|
80
|
+
def __init__(self, credentials: Optional[GroqCredentials] = None):
|
81
|
+
"""Initialize the skill with optional credentials"""
|
82
|
+
super().__init__(provider="groq", credentials=credentials)
|
83
|
+
|
84
|
+
def get_models(self) -> List[Dict[str, Any]]:
|
85
|
+
"""Return list of Groq models."""
|
86
|
+
# Default Groq models from trmx_agent config
|
87
|
+
models = [
|
88
|
+
{
|
89
|
+
"id": "llama-3.3-70b-versatile",
|
90
|
+
"display_name": "Llama 3.3 70B Versatile (Tool Use)",
|
91
|
+
},
|
92
|
+
{
|
93
|
+
"id": "llama-3.1-8b-instant",
|
94
|
+
"display_name": "Llama 3.1 8B Instant (Tool Use)",
|
95
|
+
},
|
96
|
+
{
|
97
|
+
"id": "mixtral-8x7b-32768",
|
98
|
+
"display_name": "Mixtral 8x7B (32K) (Tool Use)",
|
99
|
+
},
|
100
|
+
{"id": "gemma2-9b-it", "display_name": "Gemma 2 9B IT (Tool Use)"},
|
101
|
+
{"id": "qwen-qwq-32b", "display_name": "Qwen QWQ 32B (Tool Use)"},
|
102
|
+
{
|
103
|
+
"id": "qwen-2.5-coder-32b",
|
104
|
+
"display_name": "Qwen 2.5 Coder 32B (Tool Use)",
|
105
|
+
},
|
106
|
+
{"id": "qwen-2.5-32b", "display_name": "Qwen 2.5 32B (Tool Use)"},
|
107
|
+
{
|
108
|
+
"id": "deepseek-r1-distill-qwen-32b",
|
109
|
+
"display_name": "DeepSeek R1 Distill Qwen 32B (Tool Use)",
|
110
|
+
},
|
111
|
+
{
|
112
|
+
"id": "deepseek-r1-distill-llama-70b",
|
113
|
+
"display_name": "DeepSeek R1 Distill Llama 70B (Tool Use)",
|
114
|
+
},
|
115
|
+
]
|
116
|
+
return models
|
117
|
+
|
118
|
+
|
119
|
+
# Cerebras implementation
|
120
|
+
class CerebrasListModelsSkill(BaseListModelsSkill):
|
121
|
+
"""Skill for listing Cerebras models"""
|
122
|
+
|
123
|
+
def __init__(self, credentials: Optional[CerebrasCredentials] = None):
|
124
|
+
"""Initialize the skill with optional credentials"""
|
125
|
+
super().__init__(provider="cerebras", credentials=credentials)
|
126
|
+
|
127
|
+
def get_models(self) -> List[Dict[str, Any]]:
|
128
|
+
"""Return list of Cerebras models."""
|
129
|
+
# Default Cerebras models from trmx_agent config
|
130
|
+
models = [
|
131
|
+
{
|
132
|
+
"id": "cerebras/Cerebras-GPT-13B-v0.1",
|
133
|
+
"display_name": "Cerebras GPT 13B v0.1",
|
134
|
+
},
|
135
|
+
{
|
136
|
+
"id": "cerebras/Cerebras-GPT-111M-v0.9",
|
137
|
+
"display_name": "Cerebras GPT 111M v0.9",
|
138
|
+
},
|
139
|
+
{
|
140
|
+
"id": "cerebras/Cerebras-GPT-590M-v0.7",
|
141
|
+
"display_name": "Cerebras GPT 590M v0.7",
|
142
|
+
},
|
143
|
+
]
|
144
|
+
return models
|
145
|
+
|
146
|
+
|
147
|
+
# Sambanova implementation
|
148
|
+
class SambanovaListModelsSkill(BaseListModelsSkill):
|
149
|
+
"""Skill for listing Sambanova models"""
|
150
|
+
|
151
|
+
def __init__(self, credentials: Optional[SambanovaCredentials] = None):
|
152
|
+
"""Initialize the skill with optional credentials"""
|
153
|
+
super().__init__(provider="sambanova", credentials=credentials)
|
154
|
+
|
155
|
+
def get_models(self) -> List[Dict[str, Any]]:
|
156
|
+
"""Return list of Sambanova models."""
|
157
|
+
# Limited Sambanova model information
|
158
|
+
models = [
|
159
|
+
{"id": "sambanova/samba-1", "display_name": "Samba-1"},
|
160
|
+
{"id": "sambanova/samba-2", "display_name": "Samba-2"},
|
161
|
+
]
|
162
|
+
return models
|
163
|
+
|
164
|
+
|
165
|
+
# Factory class
|
166
|
+
class ListModelsSkillFactory:
|
167
|
+
"""Factory for creating list models skills for different providers"""
|
168
|
+
|
169
|
+
# Map provider names to their corresponding list models skills
|
170
|
+
_PROVIDER_MAP = {
|
171
|
+
"openai": OpenAIListModelsSkill,
|
172
|
+
"anthropic": AnthropicListModelsSkill,
|
173
|
+
"together": TogetherListModelsSkill,
|
174
|
+
"fireworks": FireworksListModelsSkill,
|
175
|
+
"groq": GroqListModelsSkill,
|
176
|
+
"cerebras": CerebrasListModelsSkill,
|
177
|
+
"sambanova": SambanovaListModelsSkill,
|
178
|
+
"perplexity": PerplexityListModelsSkill,
|
179
|
+
}
|
180
|
+
|
181
|
+
@classmethod
|
182
|
+
def get_skill(cls, provider: str, credentials=None):
|
183
|
+
"""Return a list models skill for the specified provider
|
184
|
+
|
185
|
+
Args:
|
186
|
+
provider (str): The provider name (case-insensitive)
|
187
|
+
credentials: Optional credentials for the provider
|
188
|
+
|
189
|
+
Returns:
|
190
|
+
A ListModelsSkill instance for the specified provider
|
191
|
+
|
192
|
+
Raises:
|
193
|
+
ValueError: If the provider is not supported
|
194
|
+
"""
|
195
|
+
provider = provider.lower()
|
196
|
+
|
197
|
+
if provider not in cls._PROVIDER_MAP:
|
198
|
+
supported = ", ".join(cls.get_supported_providers())
|
199
|
+
raise ValueError(
|
200
|
+
f"Unsupported provider: {provider}. "
|
201
|
+
f"Supported providers are: {supported}"
|
202
|
+
)
|
203
|
+
|
204
|
+
skill_class = cls._PROVIDER_MAP[provider]
|
205
|
+
return skill_class(credentials=credentials)
|
206
|
+
|
207
|
+
@classmethod
|
208
|
+
def get_supported_providers(cls):
|
209
|
+
"""Return a list of supported provider names"""
|
210
|
+
return list(cls._PROVIDER_MAP.keys())
|
@@ -0,0 +1,21 @@
|
|
1
|
+
"""Fireworks AI integration module"""
|
2
|
+
|
3
|
+
from .credentials import FireworksCredentials
|
4
|
+
from .skills import FireworksChatSkill, FireworksInput, FireworksOutput
|
5
|
+
from .list_models import (
|
6
|
+
FireworksListModelsSkill,
|
7
|
+
FireworksListModelsInput,
|
8
|
+
FireworksListModelsOutput,
|
9
|
+
)
|
10
|
+
from .models import FireworksModel
|
11
|
+
|
12
|
+
__all__ = [
|
13
|
+
"FireworksCredentials",
|
14
|
+
"FireworksChatSkill",
|
15
|
+
"FireworksInput",
|
16
|
+
"FireworksOutput",
|
17
|
+
"FireworksListModelsSkill",
|
18
|
+
"FireworksListModelsInput",
|
19
|
+
"FireworksListModelsOutput",
|
20
|
+
"FireworksModel",
|
21
|
+
]
|
@@ -0,0 +1,147 @@
|
|
1
|
+
from typing import List, Optional, Dict, Any, Generator, Union
|
2
|
+
from pydantic import Field
|
3
|
+
import requests
|
4
|
+
import json
|
5
|
+
from loguru import logger
|
6
|
+
|
7
|
+
from airtrain.core.skills import Skill, ProcessingError
|
8
|
+
from airtrain.core.schemas import InputSchema, OutputSchema
|
9
|
+
from .credentials import FireworksCredentials
|
10
|
+
|
11
|
+
|
12
|
+
class FireworksCompletionInput(InputSchema):
|
13
|
+
"""Schema for Fireworks AI completion input using requests"""
|
14
|
+
|
15
|
+
prompt: str = Field(..., description="Input prompt for completion")
|
16
|
+
model: str = Field(
|
17
|
+
default="accounts/fireworks/models/deepseek-r1",
|
18
|
+
description="Fireworks AI model to use",
|
19
|
+
)
|
20
|
+
max_tokens: int = Field(default=131072, description="Maximum tokens in response")
|
21
|
+
temperature: float = Field(
|
22
|
+
default=0.7, description="Temperature for response generation", ge=0, le=1
|
23
|
+
)
|
24
|
+
top_p: float = Field(
|
25
|
+
default=1.0, description="Top p sampling parameter", ge=0, le=1
|
26
|
+
)
|
27
|
+
top_k: int = Field(default=50, description="Top k sampling parameter", ge=0)
|
28
|
+
presence_penalty: float = Field(
|
29
|
+
default=0.0, description="Presence penalty", ge=-2.0, le=2.0
|
30
|
+
)
|
31
|
+
frequency_penalty: float = Field(
|
32
|
+
default=0.0, description="Frequency penalty", ge=-2.0, le=2.0
|
33
|
+
)
|
34
|
+
repetition_penalty: float = Field(
|
35
|
+
default=1.0, description="Repetition penalty", ge=0.0
|
36
|
+
)
|
37
|
+
stop: Optional[Union[str, List[str]]] = Field(
|
38
|
+
default=None, description="Stop sequences"
|
39
|
+
)
|
40
|
+
echo: bool = Field(default=False, description="Echo the prompt in the response")
|
41
|
+
stream: bool = Field(default=False, description="Whether to stream the response")
|
42
|
+
|
43
|
+
|
44
|
+
class FireworksCompletionOutput(OutputSchema):
|
45
|
+
"""Schema for Fireworks AI completion output"""
|
46
|
+
|
47
|
+
response: str
|
48
|
+
used_model: str
|
49
|
+
usage: Dict[str, int]
|
50
|
+
|
51
|
+
|
52
|
+
class FireworksCompletionSkill(
|
53
|
+
Skill[FireworksCompletionInput, FireworksCompletionOutput]
|
54
|
+
):
|
55
|
+
"""Skill for text completion using Fireworks AI"""
|
56
|
+
|
57
|
+
input_schema = FireworksCompletionInput
|
58
|
+
output_schema = FireworksCompletionOutput
|
59
|
+
BASE_URL = "https://api.fireworks.ai/inference/v1/completions"
|
60
|
+
|
61
|
+
def __init__(self, credentials: Optional[FireworksCredentials] = None):
|
62
|
+
"""Initialize the skill with optional credentials"""
|
63
|
+
super().__init__()
|
64
|
+
self.credentials = credentials or FireworksCredentials.from_env()
|
65
|
+
self.headers = {
|
66
|
+
"Accept": "application/json",
|
67
|
+
"Content-Type": "application/json",
|
68
|
+
"Authorization": f"Bearer {self.credentials.fireworks_api_key.get_secret_value()}",
|
69
|
+
}
|
70
|
+
|
71
|
+
def _build_payload(self, input_data: FireworksCompletionInput) -> Dict[str, Any]:
|
72
|
+
"""Build the request payload."""
|
73
|
+
payload = {
|
74
|
+
"model": input_data.model,
|
75
|
+
"prompt": input_data.prompt,
|
76
|
+
"max_tokens": input_data.max_tokens,
|
77
|
+
"temperature": input_data.temperature,
|
78
|
+
"top_p": input_data.top_p,
|
79
|
+
"top_k": input_data.top_k,
|
80
|
+
"presence_penalty": input_data.presence_penalty,
|
81
|
+
"frequency_penalty": input_data.frequency_penalty,
|
82
|
+
"repetition_penalty": input_data.repetition_penalty,
|
83
|
+
"echo": input_data.echo,
|
84
|
+
"stream": input_data.stream,
|
85
|
+
}
|
86
|
+
|
87
|
+
if input_data.stop:
|
88
|
+
payload["stop"] = input_data.stop
|
89
|
+
|
90
|
+
return payload
|
91
|
+
|
92
|
+
def process_stream(
|
93
|
+
self, input_data: FireworksCompletionInput
|
94
|
+
) -> Generator[str, None, None]:
|
95
|
+
"""Process the input and stream the response."""
|
96
|
+
try:
|
97
|
+
payload = self._build_payload(input_data)
|
98
|
+
response = requests.post(
|
99
|
+
self.BASE_URL,
|
100
|
+
headers=self.headers,
|
101
|
+
data=json.dumps(payload),
|
102
|
+
stream=True,
|
103
|
+
)
|
104
|
+
response.raise_for_status()
|
105
|
+
|
106
|
+
for line in response.iter_lines():
|
107
|
+
if line:
|
108
|
+
try:
|
109
|
+
data = json.loads(line.decode("utf-8").removeprefix("data: "))
|
110
|
+
if data.get("choices") and data["choices"][0].get("text"):
|
111
|
+
yield data["choices"][0]["text"]
|
112
|
+
except json.JSONDecodeError:
|
113
|
+
continue
|
114
|
+
|
115
|
+
except Exception as e:
|
116
|
+
raise ProcessingError(f"Fireworks completion streaming failed: {str(e)}")
|
117
|
+
|
118
|
+
def process(
|
119
|
+
self, input_data: FireworksCompletionInput
|
120
|
+
) -> FireworksCompletionOutput:
|
121
|
+
"""Process the input and return completion response."""
|
122
|
+
try:
|
123
|
+
if input_data.stream:
|
124
|
+
# For streaming, collect the entire response
|
125
|
+
response_chunks = []
|
126
|
+
for chunk in self.process_stream(input_data):
|
127
|
+
response_chunks.append(chunk)
|
128
|
+
response_text = "".join(response_chunks)
|
129
|
+
usage = {} # Usage stats not available in streaming mode
|
130
|
+
else:
|
131
|
+
# For non-streaming, use regular request
|
132
|
+
payload = self._build_payload(input_data)
|
133
|
+
response = requests.post(
|
134
|
+
self.BASE_URL, headers=self.headers, data=json.dumps(payload)
|
135
|
+
)
|
136
|
+
response.raise_for_status()
|
137
|
+
data = response.json()
|
138
|
+
|
139
|
+
response_text = data["choices"][0]["text"]
|
140
|
+
usage = data["usage"]
|
141
|
+
|
142
|
+
return FireworksCompletionOutput(
|
143
|
+
response=response_text, used_model=input_data.model, usage=usage
|
144
|
+
)
|
145
|
+
|
146
|
+
except Exception as e:
|
147
|
+
raise ProcessingError(f"Fireworks completion failed: {str(e)}")
|
@@ -0,0 +1,109 @@
|
|
1
|
+
from typing import List, Dict, Optional
|
2
|
+
from pydantic import BaseModel, Field
|
3
|
+
from .skills import FireworksChatSkill, FireworksInput, FireworksOutput
|
4
|
+
|
5
|
+
# TODO: Test this thing.
|
6
|
+
|
7
|
+
|
8
|
+
class ConversationState(BaseModel):
|
9
|
+
"""Model to track conversation state"""
|
10
|
+
|
11
|
+
messages: List[Dict[str, str]] = Field(
|
12
|
+
default_factory=list, description="List of conversation messages"
|
13
|
+
)
|
14
|
+
system_prompt: str = Field(
|
15
|
+
default="You are a helpful assistant.",
|
16
|
+
description="System prompt for the conversation",
|
17
|
+
)
|
18
|
+
model: str = Field(
|
19
|
+
default="accounts/fireworks/models/deepseek-r1",
|
20
|
+
description="Model being used for the conversation",
|
21
|
+
)
|
22
|
+
temperature: float = Field(default=0.7, description="Temperature setting")
|
23
|
+
max_tokens: Optional[int] = Field(default=131072, description="Max tokens setting")
|
24
|
+
|
25
|
+
|
26
|
+
class FireworksConversationManager:
|
27
|
+
"""Manager for handling conversation state with Fireworks AI"""
|
28
|
+
|
29
|
+
def __init__(
|
30
|
+
self,
|
31
|
+
skill: Optional[FireworksChatSkill] = None,
|
32
|
+
system_prompt: str = "You are a helpful assistant.",
|
33
|
+
model: str = "accounts/fireworks/models/deepseek-r1",
|
34
|
+
temperature: float = 0.7,
|
35
|
+
max_tokens: Optional[int] = None,
|
36
|
+
):
|
37
|
+
"""
|
38
|
+
Initialize conversation manager.
|
39
|
+
|
40
|
+
Args:
|
41
|
+
skill: FireworksChatSkill instance (creates new one if None)
|
42
|
+
system_prompt: Initial system prompt
|
43
|
+
model: Model to use
|
44
|
+
temperature: Temperature setting
|
45
|
+
max_tokens: Max tokens setting
|
46
|
+
"""
|
47
|
+
self.skill = skill or FireworksChatSkill()
|
48
|
+
self.state = ConversationState(
|
49
|
+
system_prompt=system_prompt,
|
50
|
+
model=model,
|
51
|
+
temperature=temperature,
|
52
|
+
max_tokens=max_tokens,
|
53
|
+
)
|
54
|
+
|
55
|
+
def send_message(self, user_input: str) -> FireworksOutput:
|
56
|
+
"""
|
57
|
+
Send a message and get response while maintaining conversation history.
|
58
|
+
|
59
|
+
Args:
|
60
|
+
user_input: User's message
|
61
|
+
|
62
|
+
Returns:
|
63
|
+
FireworksOutput: Model's response
|
64
|
+
"""
|
65
|
+
# Create input with current conversation state
|
66
|
+
input_data = FireworksInput(
|
67
|
+
user_input=user_input,
|
68
|
+
system_prompt=self.state.system_prompt,
|
69
|
+
conversation_history=self.state.messages,
|
70
|
+
model=self.state.model,
|
71
|
+
temperature=self.state.temperature,
|
72
|
+
max_tokens=self.state.max_tokens,
|
73
|
+
)
|
74
|
+
|
75
|
+
# Get response
|
76
|
+
result = self.skill.process(input_data)
|
77
|
+
|
78
|
+
# Update conversation history
|
79
|
+
self.state.messages.extend(
|
80
|
+
[
|
81
|
+
{"role": "user", "content": user_input},
|
82
|
+
{"role": "assistant", "content": result.response},
|
83
|
+
]
|
84
|
+
)
|
85
|
+
|
86
|
+
return result
|
87
|
+
|
88
|
+
def reset_conversation(self) -> None:
|
89
|
+
"""Reset the conversation history while maintaining other settings"""
|
90
|
+
self.state.messages = []
|
91
|
+
|
92
|
+
def get_conversation_history(self) -> List[Dict[str, str]]:
|
93
|
+
"""Get the current conversation history"""
|
94
|
+
return self.state.messages.copy()
|
95
|
+
|
96
|
+
def update_system_prompt(self, new_prompt: str) -> None:
|
97
|
+
"""Update the system prompt for future messages"""
|
98
|
+
self.state.system_prompt = new_prompt
|
99
|
+
|
100
|
+
def save_state(self, file_path: str) -> None:
|
101
|
+
"""Save conversation state to a file"""
|
102
|
+
with open(file_path, "w") as f:
|
103
|
+
f.write(self.state.model_dump_json(indent=2))
|
104
|
+
|
105
|
+
def load_state(self, file_path: str) -> None:
|
106
|
+
"""Load conversation state from a file"""
|
107
|
+
with open(file_path, "r") as f:
|
108
|
+
data = f.read()
|
109
|
+
self.state = ConversationState.model_validate_json(data)
|
@@ -0,0 +1,26 @@
|
|
1
|
+
from pydantic import SecretStr, BaseModel, Field
|
2
|
+
from typing import Optional
|
3
|
+
import os
|
4
|
+
|
5
|
+
|
6
|
+
class FireworksCredentials(BaseModel):
|
7
|
+
"""Credentials for Fireworks AI API"""
|
8
|
+
|
9
|
+
fireworks_api_key: SecretStr = Field(..., min_length=1)
|
10
|
+
|
11
|
+
def __repr__(self) -> str:
|
12
|
+
"""Return a string representation of the credentials."""
|
13
|
+
return f"FireworksCredentials(fireworks_api_key=SecretStr('**********'))"
|
14
|
+
|
15
|
+
def __str__(self) -> str:
|
16
|
+
"""Return a string representation of the credentials."""
|
17
|
+
return self.__repr__()
|
18
|
+
|
19
|
+
@classmethod
|
20
|
+
def from_env(cls) -> "FireworksCredentials":
|
21
|
+
"""Create credentials from environment variables"""
|
22
|
+
api_key = os.getenv("FIREWORKS_API_KEY")
|
23
|
+
if not api_key:
|
24
|
+
raise ValueError("FIREWORKS_API_KEY environment variable not set")
|
25
|
+
|
26
|
+
return cls(fireworks_api_key=api_key)
|