airtrain 0.1.13__py3-none-any.whl → 0.1.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airtrain/__init__.py +1 -1
- airtrain/integrations/fireworks/__init__.py +11 -0
- airtrain/integrations/fireworks/credentials.py +18 -0
- airtrain/integrations/fireworks/models.py +27 -0
- airtrain/integrations/fireworks/skills.py +107 -0
- {airtrain-0.1.13.dist-info → airtrain-0.1.14.dist-info}/METADATA +1 -1
- {airtrain-0.1.13.dist-info → airtrain-0.1.14.dist-info}/RECORD +9 -5
- {airtrain-0.1.13.dist-info → airtrain-0.1.14.dist-info}/WHEEL +0 -0
- {airtrain-0.1.13.dist-info → airtrain-0.1.14.dist-info}/top_level.txt +0 -0
airtrain/__init__.py
CHANGED
@@ -0,0 +1,11 @@
|
|
1
|
+
"""Fireworks AI integration module"""
|
2
|
+
|
3
|
+
from .credentials import FireworksCredentials
|
4
|
+
from .skills import FireworksChatSkill, FireworksInput, FireworksOutput
|
5
|
+
|
6
|
+
__all__ = [
|
7
|
+
"FireworksCredentials",
|
8
|
+
"FireworksChatSkill",
|
9
|
+
"FireworksInput",
|
10
|
+
"FireworksOutput",
|
11
|
+
]
|
@@ -0,0 +1,18 @@
|
|
1
|
+
from pydantic import SecretStr, BaseModel
|
2
|
+
from typing import Optional
|
3
|
+
import os
|
4
|
+
|
5
|
+
|
6
|
+
class FireworksCredentials(BaseModel):
|
7
|
+
"""Credentials for Fireworks AI API"""
|
8
|
+
|
9
|
+
fireworks_api_key: SecretStr
|
10
|
+
|
11
|
+
@classmethod
|
12
|
+
def from_env(cls) -> "FireworksCredentials":
|
13
|
+
"""Create credentials from environment variables"""
|
14
|
+
api_key = os.getenv("FIREWORKS_API_KEY")
|
15
|
+
if not api_key:
|
16
|
+
raise ValueError("FIREWORKS_API_KEY environment variable not set")
|
17
|
+
|
18
|
+
return cls(fireworks_api_key=api_key)
|
@@ -0,0 +1,27 @@
|
|
1
|
+
from typing import List, Optional, Dict, Any
|
2
|
+
from pydantic import Field, BaseModel
|
3
|
+
|
4
|
+
|
5
|
+
class FireworksMessage(BaseModel):
|
6
|
+
"""Schema for Fireworks chat message"""
|
7
|
+
|
8
|
+
content: str
|
9
|
+
role: str = Field(..., pattern="^(system|user|assistant)$")
|
10
|
+
|
11
|
+
|
12
|
+
class FireworksUsage(BaseModel):
|
13
|
+
"""Schema for Fireworks API usage statistics"""
|
14
|
+
|
15
|
+
prompt_tokens: int
|
16
|
+
completion_tokens: int
|
17
|
+
total_tokens: int
|
18
|
+
|
19
|
+
|
20
|
+
class FireworksResponse(BaseModel):
|
21
|
+
"""Schema for Fireworks API response"""
|
22
|
+
|
23
|
+
id: str
|
24
|
+
choices: List[Dict[str, Any]]
|
25
|
+
created: int
|
26
|
+
model: str
|
27
|
+
usage: FireworksUsage
|
@@ -0,0 +1,107 @@
|
|
1
|
+
from typing import List, Optional, Dict, Any
|
2
|
+
from pydantic import Field
|
3
|
+
import requests
|
4
|
+
from loguru import logger
|
5
|
+
|
6
|
+
from airtrain.core.skills import Skill, ProcessingError
|
7
|
+
from airtrain.core.schemas import InputSchema, OutputSchema
|
8
|
+
from .credentials import FireworksCredentials
|
9
|
+
from .models import FireworksMessage, FireworksResponse
|
10
|
+
|
11
|
+
|
12
|
+
class FireworksInput(InputSchema):
|
13
|
+
"""Schema for Fireworks AI chat input"""
|
14
|
+
|
15
|
+
user_input: str = Field(..., description="User's input text")
|
16
|
+
system_prompt: str = Field(
|
17
|
+
default="You are a helpful assistant.",
|
18
|
+
description="System prompt to guide the model's behavior",
|
19
|
+
)
|
20
|
+
model: str = Field(
|
21
|
+
default="accounts/fireworks/models/deepseek-r1",
|
22
|
+
description="Fireworks AI model to use",
|
23
|
+
)
|
24
|
+
temperature: float = Field(
|
25
|
+
default=0.7, description="Temperature for response generation", ge=0, le=1
|
26
|
+
)
|
27
|
+
max_tokens: Optional[int] = Field(
|
28
|
+
default=None, description="Maximum tokens in response"
|
29
|
+
)
|
30
|
+
context_length_exceeded_behavior: str = Field(
|
31
|
+
default="truncate", description="Behavior when context length is exceeded"
|
32
|
+
)
|
33
|
+
|
34
|
+
|
35
|
+
class FireworksOutput(OutputSchema):
|
36
|
+
"""Schema for Fireworks AI output"""
|
37
|
+
|
38
|
+
response: str = Field(..., description="Model's response text")
|
39
|
+
used_model: str = Field(..., description="Model used for generation")
|
40
|
+
usage: Dict[str, int] = Field(default_factory=dict, description="Usage statistics")
|
41
|
+
|
42
|
+
|
43
|
+
class FireworksChatSkill(Skill[FireworksInput, FireworksOutput]):
|
44
|
+
"""Skill for interacting with Fireworks AI models"""
|
45
|
+
|
46
|
+
input_schema = FireworksInput
|
47
|
+
output_schema = FireworksOutput
|
48
|
+
|
49
|
+
def __init__(self, credentials: Optional[FireworksCredentials] = None):
|
50
|
+
"""Initialize the skill with optional credentials"""
|
51
|
+
super().__init__()
|
52
|
+
self.credentials = credentials or FireworksCredentials.from_env()
|
53
|
+
self.base_url = "https://api.fireworks.ai/inference/v1"
|
54
|
+
|
55
|
+
def process(self, input_data: FireworksInput) -> FireworksOutput:
|
56
|
+
"""Process the input using Fireworks AI API"""
|
57
|
+
try:
|
58
|
+
logger.info(f"Processing request with model {input_data.model}")
|
59
|
+
|
60
|
+
# Prepare messages
|
61
|
+
messages = [
|
62
|
+
{"role": "system", "content": input_data.system_prompt},
|
63
|
+
{"role": "user", "content": input_data.user_input},
|
64
|
+
]
|
65
|
+
|
66
|
+
# Prepare request payload
|
67
|
+
payload = {
|
68
|
+
"messages": messages,
|
69
|
+
"model": input_data.model,
|
70
|
+
"context_length_exceeded_behavior": input_data.context_length_exceeded_behavior,
|
71
|
+
"temperature": input_data.temperature,
|
72
|
+
"n": 1,
|
73
|
+
"response_format": {"type": "text"},
|
74
|
+
"stream": False,
|
75
|
+
}
|
76
|
+
|
77
|
+
if input_data.max_tokens:
|
78
|
+
payload["max_tokens"] = input_data.max_tokens
|
79
|
+
|
80
|
+
# Make API request
|
81
|
+
response = requests.post(
|
82
|
+
f"{self.base_url}/chat/completions",
|
83
|
+
json=payload,
|
84
|
+
headers={
|
85
|
+
"Authorization": f"Bearer {self.credentials.fireworks_api_key.get_secret_value()}",
|
86
|
+
"Content-Type": "application/json",
|
87
|
+
},
|
88
|
+
)
|
89
|
+
|
90
|
+
response.raise_for_status()
|
91
|
+
response_data = FireworksResponse(**response.json())
|
92
|
+
|
93
|
+
logger.success("Successfully processed Fireworks AI request")
|
94
|
+
|
95
|
+
return FireworksOutput(
|
96
|
+
response=response_data.choices[0]["message"]["content"],
|
97
|
+
used_model=response_data.model,
|
98
|
+
usage={
|
99
|
+
"prompt_tokens": response_data.usage.prompt_tokens,
|
100
|
+
"completion_tokens": response_data.usage.completion_tokens,
|
101
|
+
"total_tokens": response_data.usage.total_tokens,
|
102
|
+
},
|
103
|
+
)
|
104
|
+
|
105
|
+
except Exception as e:
|
106
|
+
logger.exception(f"Fireworks AI processing failed: {str(e)}")
|
107
|
+
raise ProcessingError(f"Fireworks AI processing failed: {str(e)}")
|
@@ -1,4 +1,4 @@
|
|
1
|
-
airtrain/__init__.py,sha256=
|
1
|
+
airtrain/__init__.py,sha256=FRI-wOLxVAu3ECtPJGQ8ZLjDSM1PgoiDGvn8ctNtS_8,2095
|
2
2
|
airtrain/contrib/__init__.py,sha256=pG-7mJ0pBMqp3Q86mIF9bo1PqoBOVSGlnEK1yY1U1ok,641
|
3
3
|
airtrain/contrib/travel/__init__.py,sha256=clmBodw4nkTA-DsgjVGcXfJGPaWxIpCZDtdO-8RzL0M,811
|
4
4
|
airtrain/contrib/travel/agents.py,sha256=tpQtZ0WUiXBuhvZtc2JlEam5TuR5l-Tndi14YyImDBM,8975
|
@@ -17,6 +17,10 @@ airtrain/integrations/aws/skills.py,sha256=TQiMXeXRRcJ14fe8Xi7Uk20iS6_INbcznuLGt
|
|
17
17
|
airtrain/integrations/cerebras/__init__.py,sha256=zAD-qV38OzHhMCz1z-NvjjqcYEhURbm8RWTOKHNqbew,174
|
18
18
|
airtrain/integrations/cerebras/credentials.py,sha256=IFkn8LxMAaOpvEWXDpb94VQGtqcDxQ7rZHKH-tX4Nuw,884
|
19
19
|
airtrain/integrations/cerebras/skills.py,sha256=O9vwFzvv_tUOwFOVE8CszAQEac711eVYVUj_8dVMTpc,1596
|
20
|
+
airtrain/integrations/fireworks/__init__.py,sha256=9pJvP0u1FJbNtB0oHa09mHVJLctELf_c27LOYyDk2ZI,271
|
21
|
+
airtrain/integrations/fireworks/credentials.py,sha256=UpcwR9V5Hbk5sJbjFDJDbHMRqc90IQSqAvrtJCOvwEo,524
|
22
|
+
airtrain/integrations/fireworks/models.py,sha256=F-MddbLCLAsTjwRr1l6IpJxOegyY4pD7jN9ySPiypSo,593
|
23
|
+
airtrain/integrations/fireworks/skills.py,sha256=ZykowW8lMbTcZVJ0GO2Ut6E-u2-keXvE4F-_j-3JI4k,4074
|
20
24
|
airtrain/integrations/google/__init__.py,sha256=INZFNOcNebz3m-Ggk07ZjmX0kNHIbTe_St9gBlZBki8,176
|
21
25
|
airtrain/integrations/google/credentials.py,sha256=Mm4jNWF02rIf0_GuHLcUUPyLHC4NMRdF_iTCoVTQ0Bs,1033
|
22
26
|
airtrain/integrations/google/skills.py,sha256=uwmgetl5Ien7fLOA5HIZdqoL6AZnexFDyzfsrGuJ1RU,1606
|
@@ -47,7 +51,7 @@ airtrain/integrations/together/rerank_skill.py,sha256=gjH24hLWCweWKPyyfKZMG3K_g9
|
|
47
51
|
airtrain/integrations/together/schemas.py,sha256=pBMrbX67oxPCr-sg4K8_Xqu1DWbaC4uLCloVSascROg,1210
|
48
52
|
airtrain/integrations/together/skills.py,sha256=UfLHnseZbA7R7q5dDco6mpV546Zfd3DTliZSrNkCL6Q,4518
|
49
53
|
airtrain/integrations/together/vision_models_config.py,sha256=m28HwYDk2Kup_J-a1FtynIa2ZVcbl37kltfoHnK8zxs,1544
|
50
|
-
airtrain-0.1.
|
51
|
-
airtrain-0.1.
|
52
|
-
airtrain-0.1.
|
53
|
-
airtrain-0.1.
|
54
|
+
airtrain-0.1.14.dist-info/METADATA,sha256=l4IPKLJ7Bf3gmZYSRPVEfz4oe1XGt_lWfvuZg68cNnE,4536
|
55
|
+
airtrain-0.1.14.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
|
56
|
+
airtrain-0.1.14.dist-info/top_level.txt,sha256=cFWW1vY6VMCb3AGVdz6jBDpZ65xxBRSqlsPyySxTkxY,9
|
57
|
+
airtrain-0.1.14.dist-info/RECORD,,
|
File without changes
|
File without changes
|