airtrain 0.1.2__py3-none-any.whl → 0.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (110) hide show
  1. airtrain/__init__.py +148 -2
  2. airtrain/__main__.py +4 -0
  3. airtrain/__pycache__/__init__.cpython-313.pyc +0 -0
  4. airtrain/agents/__init__.py +45 -0
  5. airtrain/agents/example_agent.py +348 -0
  6. airtrain/agents/groq_agent.py +289 -0
  7. airtrain/agents/memory.py +663 -0
  8. airtrain/agents/registry.py +465 -0
  9. airtrain/builder/__init__.py +3 -0
  10. airtrain/builder/agent_builder.py +122 -0
  11. airtrain/cli/__init__.py +0 -0
  12. airtrain/cli/builder.py +23 -0
  13. airtrain/cli/main.py +120 -0
  14. airtrain/contrib/__init__.py +29 -0
  15. airtrain/contrib/travel/__init__.py +35 -0
  16. airtrain/contrib/travel/agents.py +243 -0
  17. airtrain/contrib/travel/models.py +59 -0
  18. airtrain/core/__init__.py +7 -0
  19. airtrain/core/__pycache__/__init__.cpython-313.pyc +0 -0
  20. airtrain/core/__pycache__/schemas.cpython-313.pyc +0 -0
  21. airtrain/core/__pycache__/skills.cpython-313.pyc +0 -0
  22. airtrain/core/credentials.py +171 -0
  23. airtrain/core/schemas.py +237 -0
  24. airtrain/core/skills.py +269 -0
  25. airtrain/integrations/__init__.py +74 -0
  26. airtrain/integrations/anthropic/__init__.py +33 -0
  27. airtrain/integrations/anthropic/credentials.py +32 -0
  28. airtrain/integrations/anthropic/list_models.py +110 -0
  29. airtrain/integrations/anthropic/models_config.py +100 -0
  30. airtrain/integrations/anthropic/skills.py +155 -0
  31. airtrain/integrations/aws/__init__.py +6 -0
  32. airtrain/integrations/aws/credentials.py +36 -0
  33. airtrain/integrations/aws/skills.py +98 -0
  34. airtrain/integrations/cerebras/__init__.py +6 -0
  35. airtrain/integrations/cerebras/credentials.py +19 -0
  36. airtrain/integrations/cerebras/skills.py +127 -0
  37. airtrain/integrations/combined/__init__.py +21 -0
  38. airtrain/integrations/combined/groq_fireworks_skills.py +126 -0
  39. airtrain/integrations/combined/list_models_factory.py +210 -0
  40. airtrain/integrations/fireworks/__init__.py +21 -0
  41. airtrain/integrations/fireworks/completion_skills.py +147 -0
  42. airtrain/integrations/fireworks/conversation_manager.py +109 -0
  43. airtrain/integrations/fireworks/credentials.py +26 -0
  44. airtrain/integrations/fireworks/list_models.py +128 -0
  45. airtrain/integrations/fireworks/models.py +139 -0
  46. airtrain/integrations/fireworks/requests_skills.py +207 -0
  47. airtrain/integrations/fireworks/skills.py +181 -0
  48. airtrain/integrations/fireworks/structured_completion_skills.py +175 -0
  49. airtrain/integrations/fireworks/structured_requests_skills.py +291 -0
  50. airtrain/integrations/fireworks/structured_skills.py +102 -0
  51. airtrain/integrations/google/__init__.py +7 -0
  52. airtrain/integrations/google/credentials.py +58 -0
  53. airtrain/integrations/google/skills.py +122 -0
  54. airtrain/integrations/groq/__init__.py +23 -0
  55. airtrain/integrations/groq/credentials.py +24 -0
  56. airtrain/integrations/groq/models_config.py +162 -0
  57. airtrain/integrations/groq/skills.py +201 -0
  58. airtrain/integrations/ollama/__init__.py +6 -0
  59. airtrain/integrations/ollama/credentials.py +26 -0
  60. airtrain/integrations/ollama/skills.py +41 -0
  61. airtrain/integrations/openai/__init__.py +37 -0
  62. airtrain/integrations/openai/chinese_assistant.py +42 -0
  63. airtrain/integrations/openai/credentials.py +39 -0
  64. airtrain/integrations/openai/list_models.py +112 -0
  65. airtrain/integrations/openai/models_config.py +224 -0
  66. airtrain/integrations/openai/skills.py +342 -0
  67. airtrain/integrations/perplexity/__init__.py +49 -0
  68. airtrain/integrations/perplexity/credentials.py +43 -0
  69. airtrain/integrations/perplexity/list_models.py +112 -0
  70. airtrain/integrations/perplexity/models_config.py +128 -0
  71. airtrain/integrations/perplexity/skills.py +279 -0
  72. airtrain/integrations/sambanova/__init__.py +6 -0
  73. airtrain/integrations/sambanova/credentials.py +20 -0
  74. airtrain/integrations/sambanova/skills.py +129 -0
  75. airtrain/integrations/search/__init__.py +21 -0
  76. airtrain/integrations/search/exa/__init__.py +23 -0
  77. airtrain/integrations/search/exa/credentials.py +30 -0
  78. airtrain/integrations/search/exa/schemas.py +114 -0
  79. airtrain/integrations/search/exa/skills.py +115 -0
  80. airtrain/integrations/together/__init__.py +33 -0
  81. airtrain/integrations/together/audio_models_config.py +34 -0
  82. airtrain/integrations/together/credentials.py +22 -0
  83. airtrain/integrations/together/embedding_models_config.py +92 -0
  84. airtrain/integrations/together/image_models_config.py +69 -0
  85. airtrain/integrations/together/image_skill.py +143 -0
  86. airtrain/integrations/together/list_models.py +76 -0
  87. airtrain/integrations/together/models.py +95 -0
  88. airtrain/integrations/together/models_config.py +399 -0
  89. airtrain/integrations/together/rerank_models_config.py +43 -0
  90. airtrain/integrations/together/rerank_skill.py +49 -0
  91. airtrain/integrations/together/schemas.py +33 -0
  92. airtrain/integrations/together/skills.py +305 -0
  93. airtrain/integrations/together/vision_models_config.py +49 -0
  94. airtrain/telemetry/__init__.py +38 -0
  95. airtrain/telemetry/service.py +167 -0
  96. airtrain/telemetry/views.py +237 -0
  97. airtrain/tools/__init__.py +45 -0
  98. airtrain/tools/command.py +398 -0
  99. airtrain/tools/filesystem.py +166 -0
  100. airtrain/tools/network.py +111 -0
  101. airtrain/tools/registry.py +320 -0
  102. airtrain/tools/search.py +450 -0
  103. airtrain/tools/testing.py +135 -0
  104. airtrain-0.1.4.dist-info/METADATA +222 -0
  105. airtrain-0.1.4.dist-info/RECORD +108 -0
  106. {airtrain-0.1.2.dist-info → airtrain-0.1.4.dist-info}/WHEEL +1 -1
  107. airtrain-0.1.4.dist-info/entry_points.txt +2 -0
  108. airtrain-0.1.2.dist-info/METADATA +0 -106
  109. airtrain-0.1.2.dist-info/RECORD +0 -5
  110. {airtrain-0.1.2.dist-info → airtrain-0.1.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,126 @@
1
+ from typing import Optional, Dict, Any, List
2
+ from pydantic import Field
3
+ import requests
4
+ from groq import Groq
5
+
6
+ from airtrain.core.skills import Skill, ProcessingError
7
+ from airtrain.core.schemas import InputSchema, OutputSchema
8
+ from airtrain.integrations.fireworks.completion_skills import (
9
+ FireworksCompletionSkill,
10
+ FireworksCompletionInput,
11
+ )
12
+
13
+
14
+ class GroqFireworksInput(InputSchema):
15
+ """Schema for combined Groq and Fireworks input"""
16
+
17
+ user_input: str = Field(..., description="User's input text")
18
+ groq_model: str = Field(
19
+ default="mixtral-8x7b-32768", description="Groq model to use"
20
+ )
21
+ fireworks_model: str = Field(
22
+ default="accounts/fireworks/models/deepseek-r1",
23
+ description="Fireworks model to use",
24
+ )
25
+ temperature: float = Field(
26
+ default=0.7, description="Temperature for response generation"
27
+ )
28
+ max_tokens: int = Field(default=131072, description="Maximum tokens in response")
29
+
30
+
31
+ class GroqFireworksOutput(OutputSchema):
32
+ """Schema for combined Groq and Fireworks output"""
33
+
34
+ combined_response: str
35
+ groq_response: str
36
+ fireworks_response: str
37
+ used_models: Dict[str, str]
38
+ usage: Dict[str, Dict[str, int]]
39
+
40
+
41
+ class GroqFireworksSkill(Skill[GroqFireworksInput, GroqFireworksOutput]):
42
+ """Skill combining Groq and Fireworks responses"""
43
+
44
+ input_schema = GroqFireworksInput
45
+ output_schema = GroqFireworksOutput
46
+
47
+ def __init__(
48
+ self,
49
+ groq_api_key: Optional[str] = None,
50
+ fireworks_skill: Optional[FireworksCompletionSkill] = None,
51
+ ):
52
+ """Initialize the skill with optional API keys"""
53
+ super().__init__()
54
+ self.groq_client = Groq(api_key=groq_api_key)
55
+ self.fireworks_skill = fireworks_skill or FireworksCompletionSkill()
56
+
57
+ def _get_groq_response(self, input_data: GroqFireworksInput) -> Dict[str, Any]:
58
+ """Get response from Groq"""
59
+ try:
60
+ completion = self.groq_client.chat.completions.create(
61
+ model=input_data.groq_model,
62
+ messages=[{"role": "user", "content": input_data.user_input}],
63
+ temperature=input_data.temperature,
64
+ max_tokens=input_data.max_tokens,
65
+ )
66
+ return {
67
+ "response": completion.choices[0].message.content,
68
+ "usage": completion.usage.model_dump(),
69
+ }
70
+ except Exception as e:
71
+ raise ProcessingError(f"Groq request failed: {str(e)}")
72
+
73
+ def _get_fireworks_response(
74
+ self, groq_response: str, input_data: GroqFireworksInput
75
+ ) -> Dict[str, Any]:
76
+ """Get response from Fireworks"""
77
+ try:
78
+ formatted_prompt = (
79
+ f"<USER>{input_data.user_input}</USER>\n<ASSISTANT>{groq_response}"
80
+ )
81
+
82
+ fireworks_input = FireworksCompletionInput(
83
+ prompt=formatted_prompt,
84
+ model=input_data.fireworks_model,
85
+ temperature=input_data.temperature,
86
+ max_tokens=input_data.max_tokens,
87
+ )
88
+
89
+ result = self.fireworks_skill.process(fireworks_input)
90
+ return {"response": result.response, "usage": result.usage}
91
+ except Exception as e:
92
+ raise ProcessingError(f"Fireworks request failed: {str(e)}")
93
+
94
+ def process(self, input_data: GroqFireworksInput) -> GroqFireworksOutput:
95
+ """Process the input using both Groq and Fireworks"""
96
+ try:
97
+ # Get Groq response
98
+ groq_result = self._get_groq_response(input_data)
99
+
100
+ # Get Fireworks response
101
+ fireworks_result = self._get_fireworks_response(
102
+ groq_result["response"], input_data
103
+ )
104
+
105
+ # Combine responses in the required format
106
+ combined_response = (
107
+ f"<USER>{input_data.user_input}</USER>\n"
108
+ f"<ASSISTANT>{groq_result['response']} {fireworks_result['response']}"
109
+ )
110
+
111
+ return GroqFireworksOutput(
112
+ combined_response=combined_response,
113
+ groq_response=groq_result["response"],
114
+ fireworks_response=fireworks_result["response"],
115
+ used_models={
116
+ "groq": input_data.groq_model,
117
+ "fireworks": input_data.fireworks_model,
118
+ },
119
+ usage={
120
+ "groq": groq_result["usage"],
121
+ "fireworks": fireworks_result["usage"],
122
+ },
123
+ )
124
+
125
+ except Exception as e:
126
+ raise ProcessingError(f"Combined processing failed: {str(e)}")
@@ -0,0 +1,210 @@
1
+ from typing import Optional, Dict, Any, List
2
+ from pydantic import Field
3
+
4
+ from airtrain.core.skills import Skill, ProcessingError
5
+ from airtrain.core.schemas import InputSchema, OutputSchema
6
+ from airtrain.core.credentials import BaseCredentials
7
+
8
+ # Import existing list models skills
9
+ from airtrain.integrations.openai.list_models import OpenAIListModelsSkill
10
+ from airtrain.integrations.anthropic.list_models import AnthropicListModelsSkill
11
+ from airtrain.integrations.together.list_models import TogetherListModelsSkill
12
+ from airtrain.integrations.fireworks.list_models import FireworksListModelsSkill
13
+
14
+ # Import credentials
15
+ from airtrain.integrations.groq.credentials import GroqCredentials
16
+ from airtrain.integrations.cerebras.credentials import CerebrasCredentials
17
+ from airtrain.integrations.sambanova.credentials import SambanovaCredentials
18
+ from airtrain.integrations.perplexity.credentials import PerplexityCredentials
19
+
20
+ # Import Perplexity list models
21
+ from airtrain.integrations.perplexity.list_models import PerplexityListModelsSkill
22
+
23
+
24
+ # Generic list models input schema
25
+ class GenericListModelsInput(InputSchema):
26
+ """Generic schema for listing models from any provider"""
27
+
28
+ api_models_only: bool = Field(
29
+ default=False,
30
+ description=(
31
+ "If True, fetch models from the API only. If False, use local config."
32
+ ),
33
+ )
34
+
35
+ class Config:
36
+ arbitrary_types_allowed = True
37
+ extra = "allow"
38
+
39
+
40
+ # Generic list models output schema
41
+ class GenericListModelsOutput(OutputSchema):
42
+ """Generic schema for list models output from any provider"""
43
+
44
+ models: List[Dict[str, Any]] = Field(
45
+ default_factory=list, description="List of models"
46
+ )
47
+ provider: str = Field(..., description="Provider name")
48
+
49
+
50
+ # Base class for stub implementations
51
+ class BaseListModelsSkill(Skill[GenericListModelsInput, GenericListModelsOutput]):
52
+ """Base skill for listing models"""
53
+
54
+ input_schema = GenericListModelsInput
55
+ output_schema = GenericListModelsOutput
56
+
57
+ def __init__(self, provider: str, credentials: Optional[BaseCredentials] = None):
58
+ """Initialize the skill with provider name and optional credentials"""
59
+ super().__init__()
60
+ self.provider = provider
61
+ self.credentials = credentials
62
+
63
+ def get_models(self) -> List[Dict[str, Any]]:
64
+ """Return list of models. To be implemented by subclasses."""
65
+ raise NotImplementedError("Subclasses must implement get_models()")
66
+
67
+ def process(self, input_data: GenericListModelsInput) -> GenericListModelsOutput:
68
+ """Process the input and return a list of models."""
69
+ try:
70
+ models = self.get_models()
71
+ return GenericListModelsOutput(models=models, provider=self.provider)
72
+ except Exception as e:
73
+ raise ProcessingError(f"Failed to list {self.provider} models: {str(e)}")
74
+
75
+
76
+ # Groq implementation
77
+ class GroqListModelsSkill(BaseListModelsSkill):
78
+ """Skill for listing Groq models"""
79
+
80
+ def __init__(self, credentials: Optional[GroqCredentials] = None):
81
+ """Initialize the skill with optional credentials"""
82
+ super().__init__(provider="groq", credentials=credentials)
83
+
84
+ def get_models(self) -> List[Dict[str, Any]]:
85
+ """Return list of Groq models."""
86
+ # Default Groq models from trmx_agent config
87
+ models = [
88
+ {
89
+ "id": "llama-3.3-70b-versatile",
90
+ "display_name": "Llama 3.3 70B Versatile (Tool Use)",
91
+ },
92
+ {
93
+ "id": "llama-3.1-8b-instant",
94
+ "display_name": "Llama 3.1 8B Instant (Tool Use)",
95
+ },
96
+ {
97
+ "id": "mixtral-8x7b-32768",
98
+ "display_name": "Mixtral 8x7B (32K) (Tool Use)",
99
+ },
100
+ {"id": "gemma2-9b-it", "display_name": "Gemma 2 9B IT (Tool Use)"},
101
+ {"id": "qwen-qwq-32b", "display_name": "Qwen QWQ 32B (Tool Use)"},
102
+ {
103
+ "id": "qwen-2.5-coder-32b",
104
+ "display_name": "Qwen 2.5 Coder 32B (Tool Use)",
105
+ },
106
+ {"id": "qwen-2.5-32b", "display_name": "Qwen 2.5 32B (Tool Use)"},
107
+ {
108
+ "id": "deepseek-r1-distill-qwen-32b",
109
+ "display_name": "DeepSeek R1 Distill Qwen 32B (Tool Use)",
110
+ },
111
+ {
112
+ "id": "deepseek-r1-distill-llama-70b",
113
+ "display_name": "DeepSeek R1 Distill Llama 70B (Tool Use)",
114
+ },
115
+ ]
116
+ return models
117
+
118
+
119
+ # Cerebras implementation
120
+ class CerebrasListModelsSkill(BaseListModelsSkill):
121
+ """Skill for listing Cerebras models"""
122
+
123
+ def __init__(self, credentials: Optional[CerebrasCredentials] = None):
124
+ """Initialize the skill with optional credentials"""
125
+ super().__init__(provider="cerebras", credentials=credentials)
126
+
127
+ def get_models(self) -> List[Dict[str, Any]]:
128
+ """Return list of Cerebras models."""
129
+ # Default Cerebras models from trmx_agent config
130
+ models = [
131
+ {
132
+ "id": "cerebras/Cerebras-GPT-13B-v0.1",
133
+ "display_name": "Cerebras GPT 13B v0.1",
134
+ },
135
+ {
136
+ "id": "cerebras/Cerebras-GPT-111M-v0.9",
137
+ "display_name": "Cerebras GPT 111M v0.9",
138
+ },
139
+ {
140
+ "id": "cerebras/Cerebras-GPT-590M-v0.7",
141
+ "display_name": "Cerebras GPT 590M v0.7",
142
+ },
143
+ ]
144
+ return models
145
+
146
+
147
+ # Sambanova implementation
148
+ class SambanovaListModelsSkill(BaseListModelsSkill):
149
+ """Skill for listing Sambanova models"""
150
+
151
+ def __init__(self, credentials: Optional[SambanovaCredentials] = None):
152
+ """Initialize the skill with optional credentials"""
153
+ super().__init__(provider="sambanova", credentials=credentials)
154
+
155
+ def get_models(self) -> List[Dict[str, Any]]:
156
+ """Return list of Sambanova models."""
157
+ # Limited Sambanova model information
158
+ models = [
159
+ {"id": "sambanova/samba-1", "display_name": "Samba-1"},
160
+ {"id": "sambanova/samba-2", "display_name": "Samba-2"},
161
+ ]
162
+ return models
163
+
164
+
165
+ # Factory class
166
+ class ListModelsSkillFactory:
167
+ """Factory for creating list models skills for different providers"""
168
+
169
+ # Map provider names to their corresponding list models skills
170
+ _PROVIDER_MAP = {
171
+ "openai": OpenAIListModelsSkill,
172
+ "anthropic": AnthropicListModelsSkill,
173
+ "together": TogetherListModelsSkill,
174
+ "fireworks": FireworksListModelsSkill,
175
+ "groq": GroqListModelsSkill,
176
+ "cerebras": CerebrasListModelsSkill,
177
+ "sambanova": SambanovaListModelsSkill,
178
+ "perplexity": PerplexityListModelsSkill,
179
+ }
180
+
181
+ @classmethod
182
+ def get_skill(cls, provider: str, credentials=None):
183
+ """Return a list models skill for the specified provider
184
+
185
+ Args:
186
+ provider (str): The provider name (case-insensitive)
187
+ credentials: Optional credentials for the provider
188
+
189
+ Returns:
190
+ A ListModelsSkill instance for the specified provider
191
+
192
+ Raises:
193
+ ValueError: If the provider is not supported
194
+ """
195
+ provider = provider.lower()
196
+
197
+ if provider not in cls._PROVIDER_MAP:
198
+ supported = ", ".join(cls.get_supported_providers())
199
+ raise ValueError(
200
+ f"Unsupported provider: {provider}. "
201
+ f"Supported providers are: {supported}"
202
+ )
203
+
204
+ skill_class = cls._PROVIDER_MAP[provider]
205
+ return skill_class(credentials=credentials)
206
+
207
+ @classmethod
208
+ def get_supported_providers(cls):
209
+ """Return a list of supported provider names"""
210
+ return list(cls._PROVIDER_MAP.keys())
@@ -0,0 +1,21 @@
1
+ """Fireworks AI integration module"""
2
+
3
+ from .credentials import FireworksCredentials
4
+ from .skills import FireworksChatSkill, FireworksInput, FireworksOutput
5
+ from .list_models import (
6
+ FireworksListModelsSkill,
7
+ FireworksListModelsInput,
8
+ FireworksListModelsOutput,
9
+ )
10
+ from .models import FireworksModel
11
+
12
+ __all__ = [
13
+ "FireworksCredentials",
14
+ "FireworksChatSkill",
15
+ "FireworksInput",
16
+ "FireworksOutput",
17
+ "FireworksListModelsSkill",
18
+ "FireworksListModelsInput",
19
+ "FireworksListModelsOutput",
20
+ "FireworksModel",
21
+ ]
@@ -0,0 +1,147 @@
1
+ from typing import List, Optional, Dict, Any, Generator, Union
2
+ from pydantic import Field
3
+ import requests
4
+ import json
5
+ from loguru import logger
6
+
7
+ from airtrain.core.skills import Skill, ProcessingError
8
+ from airtrain.core.schemas import InputSchema, OutputSchema
9
+ from .credentials import FireworksCredentials
10
+
11
+
12
+ class FireworksCompletionInput(InputSchema):
13
+ """Schema for Fireworks AI completion input using requests"""
14
+
15
+ prompt: str = Field(..., description="Input prompt for completion")
16
+ model: str = Field(
17
+ default="accounts/fireworks/models/deepseek-r1",
18
+ description="Fireworks AI model to use",
19
+ )
20
+ max_tokens: int = Field(default=131072, description="Maximum tokens in response")
21
+ temperature: float = Field(
22
+ default=0.7, description="Temperature for response generation", ge=0, le=1
23
+ )
24
+ top_p: float = Field(
25
+ default=1.0, description="Top p sampling parameter", ge=0, le=1
26
+ )
27
+ top_k: int = Field(default=50, description="Top k sampling parameter", ge=0)
28
+ presence_penalty: float = Field(
29
+ default=0.0, description="Presence penalty", ge=-2.0, le=2.0
30
+ )
31
+ frequency_penalty: float = Field(
32
+ default=0.0, description="Frequency penalty", ge=-2.0, le=2.0
33
+ )
34
+ repetition_penalty: float = Field(
35
+ default=1.0, description="Repetition penalty", ge=0.0
36
+ )
37
+ stop: Optional[Union[str, List[str]]] = Field(
38
+ default=None, description="Stop sequences"
39
+ )
40
+ echo: bool = Field(default=False, description="Echo the prompt in the response")
41
+ stream: bool = Field(default=False, description="Whether to stream the response")
42
+
43
+
44
+ class FireworksCompletionOutput(OutputSchema):
45
+ """Schema for Fireworks AI completion output"""
46
+
47
+ response: str
48
+ used_model: str
49
+ usage: Dict[str, int]
50
+
51
+
52
+ class FireworksCompletionSkill(
53
+ Skill[FireworksCompletionInput, FireworksCompletionOutput]
54
+ ):
55
+ """Skill for text completion using Fireworks AI"""
56
+
57
+ input_schema = FireworksCompletionInput
58
+ output_schema = FireworksCompletionOutput
59
+ BASE_URL = "https://api.fireworks.ai/inference/v1/completions"
60
+
61
+ def __init__(self, credentials: Optional[FireworksCredentials] = None):
62
+ """Initialize the skill with optional credentials"""
63
+ super().__init__()
64
+ self.credentials = credentials or FireworksCredentials.from_env()
65
+ self.headers = {
66
+ "Accept": "application/json",
67
+ "Content-Type": "application/json",
68
+ "Authorization": f"Bearer {self.credentials.fireworks_api_key.get_secret_value()}",
69
+ }
70
+
71
+ def _build_payload(self, input_data: FireworksCompletionInput) -> Dict[str, Any]:
72
+ """Build the request payload."""
73
+ payload = {
74
+ "model": input_data.model,
75
+ "prompt": input_data.prompt,
76
+ "max_tokens": input_data.max_tokens,
77
+ "temperature": input_data.temperature,
78
+ "top_p": input_data.top_p,
79
+ "top_k": input_data.top_k,
80
+ "presence_penalty": input_data.presence_penalty,
81
+ "frequency_penalty": input_data.frequency_penalty,
82
+ "repetition_penalty": input_data.repetition_penalty,
83
+ "echo": input_data.echo,
84
+ "stream": input_data.stream,
85
+ }
86
+
87
+ if input_data.stop:
88
+ payload["stop"] = input_data.stop
89
+
90
+ return payload
91
+
92
+ def process_stream(
93
+ self, input_data: FireworksCompletionInput
94
+ ) -> Generator[str, None, None]:
95
+ """Process the input and stream the response."""
96
+ try:
97
+ payload = self._build_payload(input_data)
98
+ response = requests.post(
99
+ self.BASE_URL,
100
+ headers=self.headers,
101
+ data=json.dumps(payload),
102
+ stream=True,
103
+ )
104
+ response.raise_for_status()
105
+
106
+ for line in response.iter_lines():
107
+ if line:
108
+ try:
109
+ data = json.loads(line.decode("utf-8").removeprefix("data: "))
110
+ if data.get("choices") and data["choices"][0].get("text"):
111
+ yield data["choices"][0]["text"]
112
+ except json.JSONDecodeError:
113
+ continue
114
+
115
+ except Exception as e:
116
+ raise ProcessingError(f"Fireworks completion streaming failed: {str(e)}")
117
+
118
+ def process(
119
+ self, input_data: FireworksCompletionInput
120
+ ) -> FireworksCompletionOutput:
121
+ """Process the input and return completion response."""
122
+ try:
123
+ if input_data.stream:
124
+ # For streaming, collect the entire response
125
+ response_chunks = []
126
+ for chunk in self.process_stream(input_data):
127
+ response_chunks.append(chunk)
128
+ response_text = "".join(response_chunks)
129
+ usage = {} # Usage stats not available in streaming mode
130
+ else:
131
+ # For non-streaming, use regular request
132
+ payload = self._build_payload(input_data)
133
+ response = requests.post(
134
+ self.BASE_URL, headers=self.headers, data=json.dumps(payload)
135
+ )
136
+ response.raise_for_status()
137
+ data = response.json()
138
+
139
+ response_text = data["choices"][0]["text"]
140
+ usage = data["usage"]
141
+
142
+ return FireworksCompletionOutput(
143
+ response=response_text, used_model=input_data.model, usage=usage
144
+ )
145
+
146
+ except Exception as e:
147
+ raise ProcessingError(f"Fireworks completion failed: {str(e)}")
@@ -0,0 +1,109 @@
1
+ from typing import List, Dict, Optional
2
+ from pydantic import BaseModel, Field
3
+ from .skills import FireworksChatSkill, FireworksInput, FireworksOutput
4
+
5
+ # TODO: Test this thing.
6
+
7
+
8
+ class ConversationState(BaseModel):
9
+ """Model to track conversation state"""
10
+
11
+ messages: List[Dict[str, str]] = Field(
12
+ default_factory=list, description="List of conversation messages"
13
+ )
14
+ system_prompt: str = Field(
15
+ default="You are a helpful assistant.",
16
+ description="System prompt for the conversation",
17
+ )
18
+ model: str = Field(
19
+ default="accounts/fireworks/models/deepseek-r1",
20
+ description="Model being used for the conversation",
21
+ )
22
+ temperature: float = Field(default=0.7, description="Temperature setting")
23
+ max_tokens: Optional[int] = Field(default=131072, description="Max tokens setting")
24
+
25
+
26
+ class FireworksConversationManager:
27
+ """Manager for handling conversation state with Fireworks AI"""
28
+
29
+ def __init__(
30
+ self,
31
+ skill: Optional[FireworksChatSkill] = None,
32
+ system_prompt: str = "You are a helpful assistant.",
33
+ model: str = "accounts/fireworks/models/deepseek-r1",
34
+ temperature: float = 0.7,
35
+ max_tokens: Optional[int] = None,
36
+ ):
37
+ """
38
+ Initialize conversation manager.
39
+
40
+ Args:
41
+ skill: FireworksChatSkill instance (creates new one if None)
42
+ system_prompt: Initial system prompt
43
+ model: Model to use
44
+ temperature: Temperature setting
45
+ max_tokens: Max tokens setting
46
+ """
47
+ self.skill = skill or FireworksChatSkill()
48
+ self.state = ConversationState(
49
+ system_prompt=system_prompt,
50
+ model=model,
51
+ temperature=temperature,
52
+ max_tokens=max_tokens,
53
+ )
54
+
55
+ def send_message(self, user_input: str) -> FireworksOutput:
56
+ """
57
+ Send a message and get response while maintaining conversation history.
58
+
59
+ Args:
60
+ user_input: User's message
61
+
62
+ Returns:
63
+ FireworksOutput: Model's response
64
+ """
65
+ # Create input with current conversation state
66
+ input_data = FireworksInput(
67
+ user_input=user_input,
68
+ system_prompt=self.state.system_prompt,
69
+ conversation_history=self.state.messages,
70
+ model=self.state.model,
71
+ temperature=self.state.temperature,
72
+ max_tokens=self.state.max_tokens,
73
+ )
74
+
75
+ # Get response
76
+ result = self.skill.process(input_data)
77
+
78
+ # Update conversation history
79
+ self.state.messages.extend(
80
+ [
81
+ {"role": "user", "content": user_input},
82
+ {"role": "assistant", "content": result.response},
83
+ ]
84
+ )
85
+
86
+ return result
87
+
88
+ def reset_conversation(self) -> None:
89
+ """Reset the conversation history while maintaining other settings"""
90
+ self.state.messages = []
91
+
92
+ def get_conversation_history(self) -> List[Dict[str, str]]:
93
+ """Get the current conversation history"""
94
+ return self.state.messages.copy()
95
+
96
+ def update_system_prompt(self, new_prompt: str) -> None:
97
+ """Update the system prompt for future messages"""
98
+ self.state.system_prompt = new_prompt
99
+
100
+ def save_state(self, file_path: str) -> None:
101
+ """Save conversation state to a file"""
102
+ with open(file_path, "w") as f:
103
+ f.write(self.state.model_dump_json(indent=2))
104
+
105
+ def load_state(self, file_path: str) -> None:
106
+ """Load conversation state from a file"""
107
+ with open(file_path, "r") as f:
108
+ data = f.read()
109
+ self.state = ConversationState.model_validate_json(data)
@@ -0,0 +1,26 @@
1
+ from pydantic import SecretStr, BaseModel, Field
2
+ from typing import Optional
3
+ import os
4
+
5
+
6
+ class FireworksCredentials(BaseModel):
7
+ """Credentials for Fireworks AI API"""
8
+
9
+ fireworks_api_key: SecretStr = Field(..., min_length=1)
10
+
11
+ def __repr__(self) -> str:
12
+ """Return a string representation of the credentials."""
13
+ return f"FireworksCredentials(fireworks_api_key=SecretStr('**********'))"
14
+
15
+ def __str__(self) -> str:
16
+ """Return a string representation of the credentials."""
17
+ return self.__repr__()
18
+
19
+ @classmethod
20
+ def from_env(cls) -> "FireworksCredentials":
21
+ """Create credentials from environment variables"""
22
+ api_key = os.getenv("FIREWORKS_API_KEY")
23
+ if not api_key:
24
+ raise ValueError("FIREWORKS_API_KEY environment variable not set")
25
+
26
+ return cls(fireworks_api_key=api_key)