airtrain 0.1.3__py3-none-any.whl → 0.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airtrain/__init__.py +146 -6
- airtrain/__main__.py +4 -0
- airtrain/__pycache__/__init__.cpython-313.pyc +0 -0
- airtrain/agents/__init__.py +45 -0
- airtrain/agents/example_agent.py +348 -0
- airtrain/agents/groq_agent.py +289 -0
- airtrain/agents/memory.py +663 -0
- airtrain/agents/registry.py +465 -0
- airtrain/builder/__init__.py +3 -0
- airtrain/builder/agent_builder.py +122 -0
- airtrain/cli/__init__.py +0 -0
- airtrain/cli/builder.py +23 -0
- airtrain/cli/main.py +120 -0
- airtrain/contrib/__init__.py +29 -0
- airtrain/contrib/travel/__init__.py +35 -0
- airtrain/contrib/travel/agents.py +243 -0
- airtrain/contrib/travel/models.py +59 -0
- airtrain/core/__pycache__/__init__.cpython-313.pyc +0 -0
- airtrain/core/__pycache__/schemas.cpython-313.pyc +0 -0
- airtrain/core/__pycache__/skills.cpython-313.pyc +0 -0
- airtrain/core/credentials.py +62 -44
- airtrain/core/skills.py +102 -0
- airtrain/integrations/__init__.py +74 -0
- airtrain/integrations/anthropic/__init__.py +33 -0
- airtrain/integrations/anthropic/credentials.py +32 -0
- airtrain/integrations/anthropic/list_models.py +110 -0
- airtrain/integrations/anthropic/models_config.py +100 -0
- airtrain/integrations/anthropic/skills.py +155 -0
- airtrain/integrations/aws/__init__.py +6 -0
- airtrain/integrations/aws/credentials.py +36 -0
- airtrain/integrations/aws/skills.py +98 -0
- airtrain/integrations/cerebras/__init__.py +6 -0
- airtrain/integrations/cerebras/credentials.py +19 -0
- airtrain/integrations/cerebras/skills.py +127 -0
- airtrain/integrations/combined/__init__.py +21 -0
- airtrain/integrations/combined/groq_fireworks_skills.py +126 -0
- airtrain/integrations/combined/list_models_factory.py +210 -0
- airtrain/integrations/fireworks/__init__.py +21 -0
- airtrain/integrations/fireworks/completion_skills.py +147 -0
- airtrain/integrations/fireworks/conversation_manager.py +109 -0
- airtrain/integrations/fireworks/credentials.py +26 -0
- airtrain/integrations/fireworks/list_models.py +128 -0
- airtrain/integrations/fireworks/models.py +139 -0
- airtrain/integrations/fireworks/requests_skills.py +207 -0
- airtrain/integrations/fireworks/skills.py +181 -0
- airtrain/integrations/fireworks/structured_completion_skills.py +175 -0
- airtrain/integrations/fireworks/structured_requests_skills.py +291 -0
- airtrain/integrations/fireworks/structured_skills.py +102 -0
- airtrain/integrations/google/__init__.py +7 -0
- airtrain/integrations/google/credentials.py +58 -0
- airtrain/integrations/google/skills.py +122 -0
- airtrain/integrations/groq/__init__.py +23 -0
- airtrain/integrations/groq/credentials.py +24 -0
- airtrain/integrations/groq/models_config.py +162 -0
- airtrain/integrations/groq/skills.py +201 -0
- airtrain/integrations/ollama/__init__.py +6 -0
- airtrain/integrations/ollama/credentials.py +26 -0
- airtrain/integrations/ollama/skills.py +41 -0
- airtrain/integrations/openai/__init__.py +37 -0
- airtrain/integrations/openai/chinese_assistant.py +42 -0
- airtrain/integrations/openai/credentials.py +39 -0
- airtrain/integrations/openai/list_models.py +112 -0
- airtrain/integrations/openai/models_config.py +224 -0
- airtrain/integrations/openai/skills.py +342 -0
- airtrain/integrations/perplexity/__init__.py +49 -0
- airtrain/integrations/perplexity/credentials.py +43 -0
- airtrain/integrations/perplexity/list_models.py +112 -0
- airtrain/integrations/perplexity/models_config.py +128 -0
- airtrain/integrations/perplexity/skills.py +279 -0
- airtrain/integrations/sambanova/__init__.py +6 -0
- airtrain/integrations/sambanova/credentials.py +20 -0
- airtrain/integrations/sambanova/skills.py +129 -0
- airtrain/integrations/search/__init__.py +21 -0
- airtrain/integrations/search/exa/__init__.py +23 -0
- airtrain/integrations/search/exa/credentials.py +30 -0
- airtrain/integrations/search/exa/schemas.py +114 -0
- airtrain/integrations/search/exa/skills.py +115 -0
- airtrain/integrations/together/__init__.py +33 -0
- airtrain/integrations/together/audio_models_config.py +34 -0
- airtrain/integrations/together/credentials.py +22 -0
- airtrain/integrations/together/embedding_models_config.py +92 -0
- airtrain/integrations/together/image_models_config.py +69 -0
- airtrain/integrations/together/image_skill.py +143 -0
- airtrain/integrations/together/list_models.py +76 -0
- airtrain/integrations/together/models.py +95 -0
- airtrain/integrations/together/models_config.py +399 -0
- airtrain/integrations/together/rerank_models_config.py +43 -0
- airtrain/integrations/together/rerank_skill.py +49 -0
- airtrain/integrations/together/schemas.py +33 -0
- airtrain/integrations/together/skills.py +305 -0
- airtrain/integrations/together/vision_models_config.py +49 -0
- airtrain/telemetry/__init__.py +38 -0
- airtrain/telemetry/service.py +167 -0
- airtrain/telemetry/views.py +237 -0
- airtrain/tools/__init__.py +45 -0
- airtrain/tools/command.py +398 -0
- airtrain/tools/filesystem.py +166 -0
- airtrain/tools/network.py +111 -0
- airtrain/tools/registry.py +320 -0
- airtrain/tools/search.py +450 -0
- airtrain/tools/testing.py +135 -0
- airtrain-0.1.4.dist-info/METADATA +222 -0
- airtrain-0.1.4.dist-info/RECORD +108 -0
- {airtrain-0.1.3.dist-info → airtrain-0.1.4.dist-info}/WHEEL +1 -1
- airtrain-0.1.4.dist-info/entry_points.txt +2 -0
- airtrain-0.1.3.dist-info/METADATA +0 -106
- airtrain-0.1.3.dist-info/RECORD +0 -9
- {airtrain-0.1.3.dist-info → airtrain-0.1.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,147 @@
|
|
1
|
+
from typing import List, Optional, Dict, Any, Generator, Union
|
2
|
+
from pydantic import Field
|
3
|
+
import requests
|
4
|
+
import json
|
5
|
+
from loguru import logger
|
6
|
+
|
7
|
+
from airtrain.core.skills import Skill, ProcessingError
|
8
|
+
from airtrain.core.schemas import InputSchema, OutputSchema
|
9
|
+
from .credentials import FireworksCredentials
|
10
|
+
|
11
|
+
|
12
|
+
class FireworksCompletionInput(InputSchema):
|
13
|
+
"""Schema for Fireworks AI completion input using requests"""
|
14
|
+
|
15
|
+
prompt: str = Field(..., description="Input prompt for completion")
|
16
|
+
model: str = Field(
|
17
|
+
default="accounts/fireworks/models/deepseek-r1",
|
18
|
+
description="Fireworks AI model to use",
|
19
|
+
)
|
20
|
+
max_tokens: int = Field(default=131072, description="Maximum tokens in response")
|
21
|
+
temperature: float = Field(
|
22
|
+
default=0.7, description="Temperature for response generation", ge=0, le=1
|
23
|
+
)
|
24
|
+
top_p: float = Field(
|
25
|
+
default=1.0, description="Top p sampling parameter", ge=0, le=1
|
26
|
+
)
|
27
|
+
top_k: int = Field(default=50, description="Top k sampling parameter", ge=0)
|
28
|
+
presence_penalty: float = Field(
|
29
|
+
default=0.0, description="Presence penalty", ge=-2.0, le=2.0
|
30
|
+
)
|
31
|
+
frequency_penalty: float = Field(
|
32
|
+
default=0.0, description="Frequency penalty", ge=-2.0, le=2.0
|
33
|
+
)
|
34
|
+
repetition_penalty: float = Field(
|
35
|
+
default=1.0, description="Repetition penalty", ge=0.0
|
36
|
+
)
|
37
|
+
stop: Optional[Union[str, List[str]]] = Field(
|
38
|
+
default=None, description="Stop sequences"
|
39
|
+
)
|
40
|
+
echo: bool = Field(default=False, description="Echo the prompt in the response")
|
41
|
+
stream: bool = Field(default=False, description="Whether to stream the response")
|
42
|
+
|
43
|
+
|
44
|
+
class FireworksCompletionOutput(OutputSchema):
|
45
|
+
"""Schema for Fireworks AI completion output"""
|
46
|
+
|
47
|
+
response: str
|
48
|
+
used_model: str
|
49
|
+
usage: Dict[str, int]
|
50
|
+
|
51
|
+
|
52
|
+
class FireworksCompletionSkill(
|
53
|
+
Skill[FireworksCompletionInput, FireworksCompletionOutput]
|
54
|
+
):
|
55
|
+
"""Skill for text completion using Fireworks AI"""
|
56
|
+
|
57
|
+
input_schema = FireworksCompletionInput
|
58
|
+
output_schema = FireworksCompletionOutput
|
59
|
+
BASE_URL = "https://api.fireworks.ai/inference/v1/completions"
|
60
|
+
|
61
|
+
def __init__(self, credentials: Optional[FireworksCredentials] = None):
|
62
|
+
"""Initialize the skill with optional credentials"""
|
63
|
+
super().__init__()
|
64
|
+
self.credentials = credentials or FireworksCredentials.from_env()
|
65
|
+
self.headers = {
|
66
|
+
"Accept": "application/json",
|
67
|
+
"Content-Type": "application/json",
|
68
|
+
"Authorization": f"Bearer {self.credentials.fireworks_api_key.get_secret_value()}",
|
69
|
+
}
|
70
|
+
|
71
|
+
def _build_payload(self, input_data: FireworksCompletionInput) -> Dict[str, Any]:
|
72
|
+
"""Build the request payload."""
|
73
|
+
payload = {
|
74
|
+
"model": input_data.model,
|
75
|
+
"prompt": input_data.prompt,
|
76
|
+
"max_tokens": input_data.max_tokens,
|
77
|
+
"temperature": input_data.temperature,
|
78
|
+
"top_p": input_data.top_p,
|
79
|
+
"top_k": input_data.top_k,
|
80
|
+
"presence_penalty": input_data.presence_penalty,
|
81
|
+
"frequency_penalty": input_data.frequency_penalty,
|
82
|
+
"repetition_penalty": input_data.repetition_penalty,
|
83
|
+
"echo": input_data.echo,
|
84
|
+
"stream": input_data.stream,
|
85
|
+
}
|
86
|
+
|
87
|
+
if input_data.stop:
|
88
|
+
payload["stop"] = input_data.stop
|
89
|
+
|
90
|
+
return payload
|
91
|
+
|
92
|
+
def process_stream(
|
93
|
+
self, input_data: FireworksCompletionInput
|
94
|
+
) -> Generator[str, None, None]:
|
95
|
+
"""Process the input and stream the response."""
|
96
|
+
try:
|
97
|
+
payload = self._build_payload(input_data)
|
98
|
+
response = requests.post(
|
99
|
+
self.BASE_URL,
|
100
|
+
headers=self.headers,
|
101
|
+
data=json.dumps(payload),
|
102
|
+
stream=True,
|
103
|
+
)
|
104
|
+
response.raise_for_status()
|
105
|
+
|
106
|
+
for line in response.iter_lines():
|
107
|
+
if line:
|
108
|
+
try:
|
109
|
+
data = json.loads(line.decode("utf-8").removeprefix("data: "))
|
110
|
+
if data.get("choices") and data["choices"][0].get("text"):
|
111
|
+
yield data["choices"][0]["text"]
|
112
|
+
except json.JSONDecodeError:
|
113
|
+
continue
|
114
|
+
|
115
|
+
except Exception as e:
|
116
|
+
raise ProcessingError(f"Fireworks completion streaming failed: {str(e)}")
|
117
|
+
|
118
|
+
def process(
|
119
|
+
self, input_data: FireworksCompletionInput
|
120
|
+
) -> FireworksCompletionOutput:
|
121
|
+
"""Process the input and return completion response."""
|
122
|
+
try:
|
123
|
+
if input_data.stream:
|
124
|
+
# For streaming, collect the entire response
|
125
|
+
response_chunks = []
|
126
|
+
for chunk in self.process_stream(input_data):
|
127
|
+
response_chunks.append(chunk)
|
128
|
+
response_text = "".join(response_chunks)
|
129
|
+
usage = {} # Usage stats not available in streaming mode
|
130
|
+
else:
|
131
|
+
# For non-streaming, use regular request
|
132
|
+
payload = self._build_payload(input_data)
|
133
|
+
response = requests.post(
|
134
|
+
self.BASE_URL, headers=self.headers, data=json.dumps(payload)
|
135
|
+
)
|
136
|
+
response.raise_for_status()
|
137
|
+
data = response.json()
|
138
|
+
|
139
|
+
response_text = data["choices"][0]["text"]
|
140
|
+
usage = data["usage"]
|
141
|
+
|
142
|
+
return FireworksCompletionOutput(
|
143
|
+
response=response_text, used_model=input_data.model, usage=usage
|
144
|
+
)
|
145
|
+
|
146
|
+
except Exception as e:
|
147
|
+
raise ProcessingError(f"Fireworks completion failed: {str(e)}")
|
@@ -0,0 +1,109 @@
|
|
1
|
+
from typing import List, Dict, Optional
|
2
|
+
from pydantic import BaseModel, Field
|
3
|
+
from .skills import FireworksChatSkill, FireworksInput, FireworksOutput
|
4
|
+
|
5
|
+
# TODO: Test this thing.
|
6
|
+
|
7
|
+
|
8
|
+
class ConversationState(BaseModel):
|
9
|
+
"""Model to track conversation state"""
|
10
|
+
|
11
|
+
messages: List[Dict[str, str]] = Field(
|
12
|
+
default_factory=list, description="List of conversation messages"
|
13
|
+
)
|
14
|
+
system_prompt: str = Field(
|
15
|
+
default="You are a helpful assistant.",
|
16
|
+
description="System prompt for the conversation",
|
17
|
+
)
|
18
|
+
model: str = Field(
|
19
|
+
default="accounts/fireworks/models/deepseek-r1",
|
20
|
+
description="Model being used for the conversation",
|
21
|
+
)
|
22
|
+
temperature: float = Field(default=0.7, description="Temperature setting")
|
23
|
+
max_tokens: Optional[int] = Field(default=131072, description="Max tokens setting")
|
24
|
+
|
25
|
+
|
26
|
+
class FireworksConversationManager:
|
27
|
+
"""Manager for handling conversation state with Fireworks AI"""
|
28
|
+
|
29
|
+
def __init__(
|
30
|
+
self,
|
31
|
+
skill: Optional[FireworksChatSkill] = None,
|
32
|
+
system_prompt: str = "You are a helpful assistant.",
|
33
|
+
model: str = "accounts/fireworks/models/deepseek-r1",
|
34
|
+
temperature: float = 0.7,
|
35
|
+
max_tokens: Optional[int] = None,
|
36
|
+
):
|
37
|
+
"""
|
38
|
+
Initialize conversation manager.
|
39
|
+
|
40
|
+
Args:
|
41
|
+
skill: FireworksChatSkill instance (creates new one if None)
|
42
|
+
system_prompt: Initial system prompt
|
43
|
+
model: Model to use
|
44
|
+
temperature: Temperature setting
|
45
|
+
max_tokens: Max tokens setting
|
46
|
+
"""
|
47
|
+
self.skill = skill or FireworksChatSkill()
|
48
|
+
self.state = ConversationState(
|
49
|
+
system_prompt=system_prompt,
|
50
|
+
model=model,
|
51
|
+
temperature=temperature,
|
52
|
+
max_tokens=max_tokens,
|
53
|
+
)
|
54
|
+
|
55
|
+
def send_message(self, user_input: str) -> FireworksOutput:
|
56
|
+
"""
|
57
|
+
Send a message and get response while maintaining conversation history.
|
58
|
+
|
59
|
+
Args:
|
60
|
+
user_input: User's message
|
61
|
+
|
62
|
+
Returns:
|
63
|
+
FireworksOutput: Model's response
|
64
|
+
"""
|
65
|
+
# Create input with current conversation state
|
66
|
+
input_data = FireworksInput(
|
67
|
+
user_input=user_input,
|
68
|
+
system_prompt=self.state.system_prompt,
|
69
|
+
conversation_history=self.state.messages,
|
70
|
+
model=self.state.model,
|
71
|
+
temperature=self.state.temperature,
|
72
|
+
max_tokens=self.state.max_tokens,
|
73
|
+
)
|
74
|
+
|
75
|
+
# Get response
|
76
|
+
result = self.skill.process(input_data)
|
77
|
+
|
78
|
+
# Update conversation history
|
79
|
+
self.state.messages.extend(
|
80
|
+
[
|
81
|
+
{"role": "user", "content": user_input},
|
82
|
+
{"role": "assistant", "content": result.response},
|
83
|
+
]
|
84
|
+
)
|
85
|
+
|
86
|
+
return result
|
87
|
+
|
88
|
+
def reset_conversation(self) -> None:
|
89
|
+
"""Reset the conversation history while maintaining other settings"""
|
90
|
+
self.state.messages = []
|
91
|
+
|
92
|
+
def get_conversation_history(self) -> List[Dict[str, str]]:
|
93
|
+
"""Get the current conversation history"""
|
94
|
+
return self.state.messages.copy()
|
95
|
+
|
96
|
+
def update_system_prompt(self, new_prompt: str) -> None:
|
97
|
+
"""Update the system prompt for future messages"""
|
98
|
+
self.state.system_prompt = new_prompt
|
99
|
+
|
100
|
+
def save_state(self, file_path: str) -> None:
|
101
|
+
"""Save conversation state to a file"""
|
102
|
+
with open(file_path, "w") as f:
|
103
|
+
f.write(self.state.model_dump_json(indent=2))
|
104
|
+
|
105
|
+
def load_state(self, file_path: str) -> None:
|
106
|
+
"""Load conversation state from a file"""
|
107
|
+
with open(file_path, "r") as f:
|
108
|
+
data = f.read()
|
109
|
+
self.state = ConversationState.model_validate_json(data)
|
@@ -0,0 +1,26 @@
|
|
1
|
+
from pydantic import SecretStr, BaseModel, Field
|
2
|
+
from typing import Optional
|
3
|
+
import os
|
4
|
+
|
5
|
+
|
6
|
+
class FireworksCredentials(BaseModel):
|
7
|
+
"""Credentials for Fireworks AI API"""
|
8
|
+
|
9
|
+
fireworks_api_key: SecretStr = Field(..., min_length=1)
|
10
|
+
|
11
|
+
def __repr__(self) -> str:
|
12
|
+
"""Return a string representation of the credentials."""
|
13
|
+
return f"FireworksCredentials(fireworks_api_key=SecretStr('**********'))"
|
14
|
+
|
15
|
+
def __str__(self) -> str:
|
16
|
+
"""Return a string representation of the credentials."""
|
17
|
+
return self.__repr__()
|
18
|
+
|
19
|
+
@classmethod
|
20
|
+
def from_env(cls) -> "FireworksCredentials":
|
21
|
+
"""Create credentials from environment variables"""
|
22
|
+
api_key = os.getenv("FIREWORKS_API_KEY")
|
23
|
+
if not api_key:
|
24
|
+
raise ValueError("FIREWORKS_API_KEY environment variable not set")
|
25
|
+
|
26
|
+
return cls(fireworks_api_key=api_key)
|
@@ -0,0 +1,128 @@
|
|
1
|
+
from typing import Optional, List
|
2
|
+
import requests
|
3
|
+
from pydantic import Field
|
4
|
+
|
5
|
+
from airtrain.core.skills import Skill, ProcessingError
|
6
|
+
from airtrain.core.schemas import InputSchema, OutputSchema
|
7
|
+
from .credentials import FireworksCredentials
|
8
|
+
from .models import FireworksModel
|
9
|
+
|
10
|
+
|
11
|
+
class FireworksListModelsInput(InputSchema):
|
12
|
+
"""Schema for Fireworks AI list models input"""
|
13
|
+
|
14
|
+
account_id: str = Field(..., description="The Account Id")
|
15
|
+
page_size: Optional[int] = Field(
|
16
|
+
default=50,
|
17
|
+
description=(
|
18
|
+
"The maximum number of models to return. The maximum page_size is 200, "
|
19
|
+
"values above 200 will be coerced to 200."
|
20
|
+
),
|
21
|
+
le=200
|
22
|
+
)
|
23
|
+
page_token: Optional[str] = Field(
|
24
|
+
default=None,
|
25
|
+
description=(
|
26
|
+
"A page token, received from a previous ListModels call. Provide this "
|
27
|
+
"to retrieve the subsequent page. When paginating, all other parameters "
|
28
|
+
"provided to ListModels must match the call that provided the page token."
|
29
|
+
)
|
30
|
+
)
|
31
|
+
filter: Optional[str] = Field(
|
32
|
+
default=None,
|
33
|
+
description=(
|
34
|
+
"Only model satisfying the provided filter (if specified) will be "
|
35
|
+
"returned. See https://google.aip.dev/160 for the filter grammar."
|
36
|
+
)
|
37
|
+
)
|
38
|
+
order_by: Optional[str] = Field(
|
39
|
+
default=None,
|
40
|
+
description=(
|
41
|
+
"A comma-separated list of fields to order by. e.g. \"foo,bar\" "
|
42
|
+
"The default sort order is ascending. To specify a descending order for a "
|
43
|
+
"field, append a \" desc\" suffix. e.g. \"foo desc,bar\" "
|
44
|
+
"Subfields are specified with a \".\" character. e.g. \"foo.bar\" "
|
45
|
+
"If not specified, the default order is by \"name\"."
|
46
|
+
)
|
47
|
+
)
|
48
|
+
|
49
|
+
|
50
|
+
class FireworksListModelsOutput(OutputSchema):
|
51
|
+
"""Schema for Fireworks AI list models output"""
|
52
|
+
|
53
|
+
models: List[FireworksModel] = Field(
|
54
|
+
default_factory=list,
|
55
|
+
description="List of Fireworks models"
|
56
|
+
)
|
57
|
+
next_page_token: Optional[str] = Field(
|
58
|
+
default=None,
|
59
|
+
description="Token for retrieving the next page of results"
|
60
|
+
)
|
61
|
+
total_size: Optional[int] = Field(
|
62
|
+
default=None,
|
63
|
+
description="Total number of models available"
|
64
|
+
)
|
65
|
+
|
66
|
+
|
67
|
+
class FireworksListModelsSkill(
|
68
|
+
Skill[FireworksListModelsInput, FireworksListModelsOutput]
|
69
|
+
):
|
70
|
+
"""Skill for listing Fireworks AI models"""
|
71
|
+
|
72
|
+
input_schema = FireworksListModelsInput
|
73
|
+
output_schema = FireworksListModelsOutput
|
74
|
+
|
75
|
+
def __init__(self, credentials: Optional[FireworksCredentials] = None):
|
76
|
+
"""Initialize the skill with optional credentials"""
|
77
|
+
super().__init__()
|
78
|
+
self.credentials = credentials or FireworksCredentials.from_env()
|
79
|
+
self.base_url = "https://api.fireworks.ai/v1"
|
80
|
+
|
81
|
+
def process(
|
82
|
+
self, input_data: FireworksListModelsInput
|
83
|
+
) -> FireworksListModelsOutput:
|
84
|
+
"""Process the input and return a list of models."""
|
85
|
+
try:
|
86
|
+
# Build the URL
|
87
|
+
url = f"{self.base_url}/accounts/{input_data.account_id}/models"
|
88
|
+
|
89
|
+
# Prepare query parameters
|
90
|
+
params = {}
|
91
|
+
if input_data.page_size:
|
92
|
+
params["pageSize"] = input_data.page_size
|
93
|
+
if input_data.page_token:
|
94
|
+
params["pageToken"] = input_data.page_token
|
95
|
+
if input_data.filter:
|
96
|
+
params["filter"] = input_data.filter
|
97
|
+
if input_data.order_by:
|
98
|
+
params["orderBy"] = input_data.order_by
|
99
|
+
|
100
|
+
# Make the request
|
101
|
+
headers = {
|
102
|
+
"Authorization": (
|
103
|
+
f"Bearer {self.credentials.fireworks_api_key.get_secret_value()}"
|
104
|
+
)
|
105
|
+
}
|
106
|
+
|
107
|
+
response = requests.get(url, headers=headers, params=params)
|
108
|
+
response.raise_for_status()
|
109
|
+
|
110
|
+
# Parse the response
|
111
|
+
result = response.json()
|
112
|
+
|
113
|
+
# Convert the models to FireworksModel objects
|
114
|
+
models = []
|
115
|
+
for model_data in result.get("models", []):
|
116
|
+
models.append(FireworksModel(**model_data))
|
117
|
+
|
118
|
+
# Return the output
|
119
|
+
return FireworksListModelsOutput(
|
120
|
+
models=models,
|
121
|
+
next_page_token=result.get("nextPageToken"),
|
122
|
+
total_size=result.get("totalSize")
|
123
|
+
)
|
124
|
+
|
125
|
+
except requests.RequestException as e:
|
126
|
+
raise ProcessingError(f"Failed to list Fireworks models: {str(e)}")
|
127
|
+
except Exception as e:
|
128
|
+
raise ProcessingError(f"Error listing Fireworks models: {str(e)}")
|
@@ -0,0 +1,139 @@
|
|
1
|
+
from typing import List, Optional, Dict, Any
|
2
|
+
from pydantic import Field, BaseModel
|
3
|
+
|
4
|
+
|
5
|
+
class FireworksMessage(BaseModel):
|
6
|
+
"""Schema for Fireworks chat message"""
|
7
|
+
|
8
|
+
content: str
|
9
|
+
role: str = Field(..., pattern="^(system|user|assistant)$")
|
10
|
+
|
11
|
+
|
12
|
+
class FireworksUsage(BaseModel):
|
13
|
+
"""Schema for Fireworks API usage statistics"""
|
14
|
+
|
15
|
+
prompt_tokens: int
|
16
|
+
completion_tokens: int
|
17
|
+
total_tokens: int
|
18
|
+
|
19
|
+
|
20
|
+
class FireworksResponse(BaseModel):
|
21
|
+
"""Schema for Fireworks API response"""
|
22
|
+
|
23
|
+
id: str
|
24
|
+
choices: List[Dict[str, Any]]
|
25
|
+
created: int
|
26
|
+
model: str
|
27
|
+
usage: FireworksUsage
|
28
|
+
|
29
|
+
|
30
|
+
class FireworksModelStatus(BaseModel):
|
31
|
+
"""Schema for Fireworks model status"""
|
32
|
+
# This would be filled with actual fields from the API response
|
33
|
+
|
34
|
+
|
35
|
+
class FireworksModelBaseDetails(BaseModel):
|
36
|
+
"""Schema for Fireworks base model details"""
|
37
|
+
# This would be filled with actual fields from the API response
|
38
|
+
|
39
|
+
|
40
|
+
class FireworksPeftDetails(BaseModel):
|
41
|
+
"""Schema for Fireworks PEFT details"""
|
42
|
+
# This would be filled with actual fields from the API response
|
43
|
+
|
44
|
+
|
45
|
+
class FireworksConversationConfig(BaseModel):
|
46
|
+
"""Schema for Fireworks conversation configuration"""
|
47
|
+
# This would be filled with actual fields from the API response
|
48
|
+
|
49
|
+
|
50
|
+
class FireworksModelDeployedRef(BaseModel):
|
51
|
+
"""Schema for Fireworks deployed model reference"""
|
52
|
+
# This would be filled with actual fields from the API response
|
53
|
+
|
54
|
+
|
55
|
+
class FireworksDeprecationDate(BaseModel):
|
56
|
+
"""Schema for Fireworks deprecation date"""
|
57
|
+
# This would be filled with actual fields from the API response
|
58
|
+
|
59
|
+
|
60
|
+
class FireworksModel(BaseModel):
|
61
|
+
"""Schema for a Fireworks model"""
|
62
|
+
|
63
|
+
name: str
|
64
|
+
display_name: Optional[str] = None
|
65
|
+
description: Optional[str] = None
|
66
|
+
create_time: Optional[str] = None
|
67
|
+
created_by: Optional[str] = None
|
68
|
+
state: Optional[str] = None
|
69
|
+
status: Optional[Dict[str, Any]] = None
|
70
|
+
kind: Optional[str] = None
|
71
|
+
github_url: Optional[str] = None
|
72
|
+
hugging_face_url: Optional[str] = None
|
73
|
+
base_model_details: Optional[Dict[str, Any]] = None
|
74
|
+
peft_details: Optional[Dict[str, Any]] = None
|
75
|
+
teft_details: Optional[Dict[str, Any]] = None
|
76
|
+
public: Optional[bool] = None
|
77
|
+
conversation_config: Optional[Dict[str, Any]] = None
|
78
|
+
context_length: Optional[int] = None
|
79
|
+
supports_image_input: Optional[bool] = None
|
80
|
+
supports_tools: Optional[bool] = None
|
81
|
+
imported_from: Optional[str] = None
|
82
|
+
fine_tuning_job: Optional[str] = None
|
83
|
+
default_draft_model: Optional[str] = None
|
84
|
+
default_draft_token_count: Optional[int] = None
|
85
|
+
precisions: Optional[List[str]] = None
|
86
|
+
deployed_model_refs: Optional[List[Dict[str, Any]]] = None
|
87
|
+
cluster: Optional[str] = None
|
88
|
+
deprecation_date: Optional[Dict[str, Any]] = None
|
89
|
+
calibrated: Optional[bool] = None
|
90
|
+
tunable: Optional[bool] = None
|
91
|
+
supports_lora: Optional[bool] = None
|
92
|
+
use_hf_apply_chat_template: Optional[bool] = None
|
93
|
+
|
94
|
+
|
95
|
+
class ListModelsInput(BaseModel):
|
96
|
+
"""Schema for listing Fireworks models input"""
|
97
|
+
|
98
|
+
account_id: str = Field(..., description="The Account Id")
|
99
|
+
page_size: Optional[int] = Field(
|
100
|
+
default=50,
|
101
|
+
description=(
|
102
|
+
"The maximum number of models to return. The maximum page_size is 200, "
|
103
|
+
"values above 200 will be coerced to 200."
|
104
|
+
),
|
105
|
+
le=200
|
106
|
+
)
|
107
|
+
page_token: Optional[str] = Field(
|
108
|
+
default=None,
|
109
|
+
description=(
|
110
|
+
"A page token, received from a previous ListModels call. Provide this "
|
111
|
+
"to retrieve the subsequent page. When paginating, all other parameters "
|
112
|
+
"provided to ListModels must match the call that provided the page token."
|
113
|
+
)
|
114
|
+
)
|
115
|
+
filter: Optional[str] = Field(
|
116
|
+
default=None,
|
117
|
+
description=(
|
118
|
+
"Only model satisfying the provided filter (if specified) will be "
|
119
|
+
"returned. See https://google.aip.dev/160 for the filter grammar."
|
120
|
+
)
|
121
|
+
)
|
122
|
+
order_by: Optional[str] = Field(
|
123
|
+
default=None,
|
124
|
+
description=(
|
125
|
+
"A comma-separated list of fields to order by. e.g. \"foo,bar\" "
|
126
|
+
"The default sort order is ascending. To specify a descending order for a "
|
127
|
+
"field, append a \" desc\" suffix. e.g. \"foo desc,bar\" "
|
128
|
+
"Subfields are specified with a \".\" character. e.g. \"foo.bar\" "
|
129
|
+
"If not specified, the default order is by \"name\"."
|
130
|
+
)
|
131
|
+
)
|
132
|
+
|
133
|
+
|
134
|
+
class ListModelsOutput(BaseModel):
|
135
|
+
"""Schema for listing Fireworks models output"""
|
136
|
+
|
137
|
+
models: List[FireworksModel]
|
138
|
+
next_page_token: Optional[str] = None
|
139
|
+
total_size: Optional[int] = None
|