airtrain 0.1.12__py3-none-any.whl → 0.1.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airtrain/__init__.py +1 -1
- airtrain/contrib/travel/__init__.py +5 -5
- airtrain/integrations/fireworks/__init__.py +11 -0
- airtrain/integrations/fireworks/credentials.py +18 -0
- airtrain/integrations/fireworks/models.py +27 -0
- airtrain/integrations/fireworks/skills.py +107 -0
- airtrain/integrations/openai/models_config.py +119 -0
- airtrain/integrations/together/audio_models_config.py +34 -0
- airtrain/integrations/together/credentials.py +3 -3
- airtrain/integrations/together/embedding_models_config.py +92 -0
- airtrain/integrations/together/image_models_config.py +69 -0
- airtrain/integrations/together/image_skill.py +171 -0
- airtrain/integrations/together/models.py +56 -0
- airtrain/integrations/together/models_config.py +277 -0
- airtrain/integrations/together/rerank_models_config.py +43 -0
- airtrain/integrations/together/rerank_skill.py +49 -0
- airtrain/integrations/together/schemas.py +33 -0
- airtrain/integrations/together/skills.py +87 -1
- airtrain/integrations/together/vision_models_config.py +49 -0
- {airtrain-0.1.12.dist-info → airtrain-0.1.14.dist-info}/METADATA +1 -1
- {airtrain-0.1.12.dist-info → airtrain-0.1.14.dist-info}/RECORD +23 -8
- {airtrain-0.1.12.dist-info → airtrain-0.1.14.dist-info}/WHEEL +0 -0
- {airtrain-0.1.12.dist-info → airtrain-0.1.14.dist-info}/top_level.txt +0 -0
airtrain/__init__.py
CHANGED
@@ -4,9 +4,9 @@ from .agents import (
|
|
4
4
|
TravelAgentBase,
|
5
5
|
ClothingAgent,
|
6
6
|
HikingAgent,
|
7
|
-
|
8
|
-
|
9
|
-
|
7
|
+
InternetAgent,
|
8
|
+
FoodAgent,
|
9
|
+
PersonalizedAgent,
|
10
10
|
)
|
11
11
|
from .models import (
|
12
12
|
ClothingRecommendation,
|
@@ -14,8 +14,8 @@ from .models import (
|
|
14
14
|
InternetAvailability,
|
15
15
|
FoodOption,
|
16
16
|
)
|
17
|
-
from .
|
18
|
-
from .
|
17
|
+
from .agentlib.verification_agent import UserVerificationAgent
|
18
|
+
from .modellib.verification import UserTravelInfo, TravelCompanion, HealthCondition
|
19
19
|
|
20
20
|
__all__ = [
|
21
21
|
"TravelAgentBase",
|
@@ -0,0 +1,11 @@
|
|
1
|
+
"""Fireworks AI integration module"""
|
2
|
+
|
3
|
+
from .credentials import FireworksCredentials
|
4
|
+
from .skills import FireworksChatSkill, FireworksInput, FireworksOutput
|
5
|
+
|
6
|
+
__all__ = [
|
7
|
+
"FireworksCredentials",
|
8
|
+
"FireworksChatSkill",
|
9
|
+
"FireworksInput",
|
10
|
+
"FireworksOutput",
|
11
|
+
]
|
@@ -0,0 +1,18 @@
|
|
1
|
+
from pydantic import SecretStr, BaseModel
|
2
|
+
from typing import Optional
|
3
|
+
import os
|
4
|
+
|
5
|
+
|
6
|
+
class FireworksCredentials(BaseModel):
|
7
|
+
"""Credentials for Fireworks AI API"""
|
8
|
+
|
9
|
+
fireworks_api_key: SecretStr
|
10
|
+
|
11
|
+
@classmethod
|
12
|
+
def from_env(cls) -> "FireworksCredentials":
|
13
|
+
"""Create credentials from environment variables"""
|
14
|
+
api_key = os.getenv("FIREWORKS_API_KEY")
|
15
|
+
if not api_key:
|
16
|
+
raise ValueError("FIREWORKS_API_KEY environment variable not set")
|
17
|
+
|
18
|
+
return cls(fireworks_api_key=api_key)
|
@@ -0,0 +1,27 @@
|
|
1
|
+
from typing import List, Optional, Dict, Any
|
2
|
+
from pydantic import Field, BaseModel
|
3
|
+
|
4
|
+
|
5
|
+
class FireworksMessage(BaseModel):
|
6
|
+
"""Schema for Fireworks chat message"""
|
7
|
+
|
8
|
+
content: str
|
9
|
+
role: str = Field(..., pattern="^(system|user|assistant)$")
|
10
|
+
|
11
|
+
|
12
|
+
class FireworksUsage(BaseModel):
|
13
|
+
"""Schema for Fireworks API usage statistics"""
|
14
|
+
|
15
|
+
prompt_tokens: int
|
16
|
+
completion_tokens: int
|
17
|
+
total_tokens: int
|
18
|
+
|
19
|
+
|
20
|
+
class FireworksResponse(BaseModel):
|
21
|
+
"""Schema for Fireworks API response"""
|
22
|
+
|
23
|
+
id: str
|
24
|
+
choices: List[Dict[str, Any]]
|
25
|
+
created: int
|
26
|
+
model: str
|
27
|
+
usage: FireworksUsage
|
@@ -0,0 +1,107 @@
|
|
1
|
+
from typing import List, Optional, Dict, Any
|
2
|
+
from pydantic import Field
|
3
|
+
import requests
|
4
|
+
from loguru import logger
|
5
|
+
|
6
|
+
from airtrain.core.skills import Skill, ProcessingError
|
7
|
+
from airtrain.core.schemas import InputSchema, OutputSchema
|
8
|
+
from .credentials import FireworksCredentials
|
9
|
+
from .models import FireworksMessage, FireworksResponse
|
10
|
+
|
11
|
+
|
12
|
+
class FireworksInput(InputSchema):
|
13
|
+
"""Schema for Fireworks AI chat input"""
|
14
|
+
|
15
|
+
user_input: str = Field(..., description="User's input text")
|
16
|
+
system_prompt: str = Field(
|
17
|
+
default="You are a helpful assistant.",
|
18
|
+
description="System prompt to guide the model's behavior",
|
19
|
+
)
|
20
|
+
model: str = Field(
|
21
|
+
default="accounts/fireworks/models/deepseek-r1",
|
22
|
+
description="Fireworks AI model to use",
|
23
|
+
)
|
24
|
+
temperature: float = Field(
|
25
|
+
default=0.7, description="Temperature for response generation", ge=0, le=1
|
26
|
+
)
|
27
|
+
max_tokens: Optional[int] = Field(
|
28
|
+
default=None, description="Maximum tokens in response"
|
29
|
+
)
|
30
|
+
context_length_exceeded_behavior: str = Field(
|
31
|
+
default="truncate", description="Behavior when context length is exceeded"
|
32
|
+
)
|
33
|
+
|
34
|
+
|
35
|
+
class FireworksOutput(OutputSchema):
|
36
|
+
"""Schema for Fireworks AI output"""
|
37
|
+
|
38
|
+
response: str = Field(..., description="Model's response text")
|
39
|
+
used_model: str = Field(..., description="Model used for generation")
|
40
|
+
usage: Dict[str, int] = Field(default_factory=dict, description="Usage statistics")
|
41
|
+
|
42
|
+
|
43
|
+
class FireworksChatSkill(Skill[FireworksInput, FireworksOutput]):
|
44
|
+
"""Skill for interacting with Fireworks AI models"""
|
45
|
+
|
46
|
+
input_schema = FireworksInput
|
47
|
+
output_schema = FireworksOutput
|
48
|
+
|
49
|
+
def __init__(self, credentials: Optional[FireworksCredentials] = None):
|
50
|
+
"""Initialize the skill with optional credentials"""
|
51
|
+
super().__init__()
|
52
|
+
self.credentials = credentials or FireworksCredentials.from_env()
|
53
|
+
self.base_url = "https://api.fireworks.ai/inference/v1"
|
54
|
+
|
55
|
+
def process(self, input_data: FireworksInput) -> FireworksOutput:
|
56
|
+
"""Process the input using Fireworks AI API"""
|
57
|
+
try:
|
58
|
+
logger.info(f"Processing request with model {input_data.model}")
|
59
|
+
|
60
|
+
# Prepare messages
|
61
|
+
messages = [
|
62
|
+
{"role": "system", "content": input_data.system_prompt},
|
63
|
+
{"role": "user", "content": input_data.user_input},
|
64
|
+
]
|
65
|
+
|
66
|
+
# Prepare request payload
|
67
|
+
payload = {
|
68
|
+
"messages": messages,
|
69
|
+
"model": input_data.model,
|
70
|
+
"context_length_exceeded_behavior": input_data.context_length_exceeded_behavior,
|
71
|
+
"temperature": input_data.temperature,
|
72
|
+
"n": 1,
|
73
|
+
"response_format": {"type": "text"},
|
74
|
+
"stream": False,
|
75
|
+
}
|
76
|
+
|
77
|
+
if input_data.max_tokens:
|
78
|
+
payload["max_tokens"] = input_data.max_tokens
|
79
|
+
|
80
|
+
# Make API request
|
81
|
+
response = requests.post(
|
82
|
+
f"{self.base_url}/chat/completions",
|
83
|
+
json=payload,
|
84
|
+
headers={
|
85
|
+
"Authorization": f"Bearer {self.credentials.fireworks_api_key.get_secret_value()}",
|
86
|
+
"Content-Type": "application/json",
|
87
|
+
},
|
88
|
+
)
|
89
|
+
|
90
|
+
response.raise_for_status()
|
91
|
+
response_data = FireworksResponse(**response.json())
|
92
|
+
|
93
|
+
logger.success("Successfully processed Fireworks AI request")
|
94
|
+
|
95
|
+
return FireworksOutput(
|
96
|
+
response=response_data.choices[0]["message"]["content"],
|
97
|
+
used_model=response_data.model,
|
98
|
+
usage={
|
99
|
+
"prompt_tokens": response_data.usage.prompt_tokens,
|
100
|
+
"completion_tokens": response_data.usage.completion_tokens,
|
101
|
+
"total_tokens": response_data.usage.total_tokens,
|
102
|
+
},
|
103
|
+
)
|
104
|
+
|
105
|
+
except Exception as e:
|
106
|
+
logger.exception(f"Fireworks AI processing failed: {str(e)}")
|
107
|
+
raise ProcessingError(f"Fireworks AI processing failed: {str(e)}")
|
@@ -0,0 +1,119 @@
|
|
1
|
+
from typing import Dict, NamedTuple, Optional
|
2
|
+
from decimal import Decimal
|
3
|
+
|
4
|
+
|
5
|
+
class OpenAIModelConfig(NamedTuple):
|
6
|
+
display_name: str
|
7
|
+
base_model: str
|
8
|
+
input_price: Decimal
|
9
|
+
cached_input_price: Optional[Decimal]
|
10
|
+
output_price: Decimal
|
11
|
+
|
12
|
+
|
13
|
+
OPENAI_MODELS: Dict[str, OpenAIModelConfig] = {
|
14
|
+
"gpt-4o": OpenAIModelConfig(
|
15
|
+
display_name="GPT-4 Optimized",
|
16
|
+
base_model="gpt-4o",
|
17
|
+
input_price=Decimal("2.50"),
|
18
|
+
cached_input_price=Decimal("1.25"),
|
19
|
+
output_price=Decimal("10.00"),
|
20
|
+
),
|
21
|
+
"gpt-4o-2024-08-06": OpenAIModelConfig(
|
22
|
+
display_name="GPT-4 Optimized (2024-08-06)",
|
23
|
+
base_model="gpt-4o",
|
24
|
+
input_price=Decimal("2.50"),
|
25
|
+
cached_input_price=Decimal("1.25"),
|
26
|
+
output_price=Decimal("10.00"),
|
27
|
+
),
|
28
|
+
"gpt-4o-2024-05-13": OpenAIModelConfig(
|
29
|
+
display_name="GPT-4 Optimized (2024-05-13)",
|
30
|
+
base_model="gpt-4o",
|
31
|
+
input_price=Decimal("5.00"),
|
32
|
+
cached_input_price=None,
|
33
|
+
output_price=Decimal("15.00"),
|
34
|
+
),
|
35
|
+
"gpt-4o-audio-preview-2024-12-17": OpenAIModelConfig(
|
36
|
+
display_name="GPT-4 Optimized Audio Preview",
|
37
|
+
base_model="gpt-4o-audio-preview",
|
38
|
+
input_price=Decimal("2.50"),
|
39
|
+
cached_input_price=None,
|
40
|
+
output_price=Decimal("10.00"),
|
41
|
+
),
|
42
|
+
"gpt-4o-realtime-preview-2024-12-17": OpenAIModelConfig(
|
43
|
+
display_name="GPT-4 Optimized Realtime Preview",
|
44
|
+
base_model="gpt-4o-realtime-preview",
|
45
|
+
input_price=Decimal("5.00"),
|
46
|
+
cached_input_price=Decimal("2.50"),
|
47
|
+
output_price=Decimal("20.00"),
|
48
|
+
),
|
49
|
+
"gpt-4o-mini-2024-07-18": OpenAIModelConfig(
|
50
|
+
display_name="GPT-4 Optimized Mini",
|
51
|
+
base_model="gpt-4o-mini",
|
52
|
+
input_price=Decimal("0.15"),
|
53
|
+
cached_input_price=Decimal("0.075"),
|
54
|
+
output_price=Decimal("0.60"),
|
55
|
+
),
|
56
|
+
"gpt-4o-mini-audio-preview-2024-12-17": OpenAIModelConfig(
|
57
|
+
display_name="GPT-4 Optimized Mini Audio Preview",
|
58
|
+
base_model="gpt-4o-mini-audio-preview",
|
59
|
+
input_price=Decimal("0.15"),
|
60
|
+
cached_input_price=None,
|
61
|
+
output_price=Decimal("0.60"),
|
62
|
+
),
|
63
|
+
"gpt-4o-mini-realtime-preview-2024-12-17": OpenAIModelConfig(
|
64
|
+
display_name="GPT-4 Optimized Mini Realtime Preview",
|
65
|
+
base_model="gpt-4o-mini-realtime-preview",
|
66
|
+
input_price=Decimal("0.60"),
|
67
|
+
cached_input_price=Decimal("0.30"),
|
68
|
+
output_price=Decimal("2.40"),
|
69
|
+
),
|
70
|
+
"o1-2024-12-17": OpenAIModelConfig(
|
71
|
+
display_name="O1",
|
72
|
+
base_model="o1",
|
73
|
+
input_price=Decimal("15.00"),
|
74
|
+
cached_input_price=Decimal("7.50"),
|
75
|
+
output_price=Decimal("60.00"),
|
76
|
+
),
|
77
|
+
"o3-mini-2025-01-31": OpenAIModelConfig(
|
78
|
+
display_name="O3 Mini",
|
79
|
+
base_model="o3-mini",
|
80
|
+
input_price=Decimal("1.10"),
|
81
|
+
cached_input_price=Decimal("0.55"),
|
82
|
+
output_price=Decimal("4.40"),
|
83
|
+
),
|
84
|
+
"o1-mini-2024-09-12": OpenAIModelConfig(
|
85
|
+
display_name="O1 Mini",
|
86
|
+
base_model="o1-mini",
|
87
|
+
input_price=Decimal("1.10"),
|
88
|
+
cached_input_price=Decimal("0.55"),
|
89
|
+
output_price=Decimal("4.40"),
|
90
|
+
),
|
91
|
+
}
|
92
|
+
|
93
|
+
|
94
|
+
def get_model_config(model_id: str) -> OpenAIModelConfig:
|
95
|
+
"""Get model configuration by model ID"""
|
96
|
+
if model_id not in OPENAI_MODELS:
|
97
|
+
raise ValueError(f"Model {model_id} not found in OpenAI models")
|
98
|
+
return OPENAI_MODELS[model_id]
|
99
|
+
|
100
|
+
|
101
|
+
def get_default_model() -> str:
|
102
|
+
"""Get the default model ID"""
|
103
|
+
return "gpt-4o"
|
104
|
+
|
105
|
+
|
106
|
+
def calculate_cost(
|
107
|
+
model_id: str, input_tokens: int, output_tokens: int, use_cached: bool = False
|
108
|
+
) -> Decimal:
|
109
|
+
"""Calculate cost for token usage"""
|
110
|
+
config = get_model_config(model_id)
|
111
|
+
input_price = (
|
112
|
+
config.cached_input_price
|
113
|
+
if (use_cached and config.cached_input_price is not None)
|
114
|
+
else config.input_price
|
115
|
+
)
|
116
|
+
return (
|
117
|
+
input_price * Decimal(str(input_tokens))
|
118
|
+
+ config.output_price * Decimal(str(output_tokens))
|
119
|
+
) / Decimal("1000")
|
@@ -0,0 +1,34 @@
|
|
1
|
+
from typing import Dict, NamedTuple
|
2
|
+
|
3
|
+
|
4
|
+
class AudioModelConfig(NamedTuple):
|
5
|
+
organization: str
|
6
|
+
display_name: str
|
7
|
+
|
8
|
+
|
9
|
+
TOGETHER_AUDIO_MODELS: Dict[str, AudioModelConfig] = {
|
10
|
+
"Cartesia/Sonic": AudioModelConfig(
|
11
|
+
organization="Cartesia", display_name="Cartesia/Sonic"
|
12
|
+
)
|
13
|
+
}
|
14
|
+
|
15
|
+
|
16
|
+
def get_audio_model_config(model_id: str) -> AudioModelConfig:
|
17
|
+
"""Get audio model configuration by model ID"""
|
18
|
+
if model_id not in TOGETHER_AUDIO_MODELS:
|
19
|
+
raise ValueError(f"Model {model_id} not found in Together AI audio models")
|
20
|
+
return TOGETHER_AUDIO_MODELS[model_id]
|
21
|
+
|
22
|
+
|
23
|
+
def list_audio_models_by_organization(organization: str) -> Dict[str, AudioModelConfig]:
|
24
|
+
"""Get all audio models for a specific organization"""
|
25
|
+
return {
|
26
|
+
model_id: config
|
27
|
+
for model_id, config in TOGETHER_AUDIO_MODELS.items()
|
28
|
+
if config.organization.lower() == organization.lower()
|
29
|
+
}
|
30
|
+
|
31
|
+
|
32
|
+
def get_default_audio_model() -> str:
|
33
|
+
"""Get the default audio model ID"""
|
34
|
+
return "Cartesia/Sonic"
|
@@ -6,14 +6,14 @@ import together
|
|
6
6
|
class TogetherAICredentials(BaseCredentials):
|
7
7
|
"""Together AI credentials"""
|
8
8
|
|
9
|
-
|
9
|
+
together_api_key: SecretStr = Field(..., description="Together AI API key")
|
10
10
|
|
11
|
-
_required_credentials = {"
|
11
|
+
_required_credentials = {"together_api_key"}
|
12
12
|
|
13
13
|
async def validate_credentials(self) -> bool:
|
14
14
|
"""Validate Together AI credentials"""
|
15
15
|
try:
|
16
|
-
together.api_key = self.
|
16
|
+
together.api_key = self.together_api_key.get_secret_value()
|
17
17
|
await together.Models.list()
|
18
18
|
return True
|
19
19
|
except Exception as e:
|
@@ -0,0 +1,92 @@
|
|
1
|
+
from typing import Dict, NamedTuple
|
2
|
+
|
3
|
+
|
4
|
+
class EmbeddingModelConfig(NamedTuple):
|
5
|
+
organization: str
|
6
|
+
display_name: str
|
7
|
+
model_size: str
|
8
|
+
embedding_dimension: int
|
9
|
+
context_window: int
|
10
|
+
|
11
|
+
|
12
|
+
TOGETHER_EMBEDDING_MODELS: Dict[str, EmbeddingModelConfig] = {
|
13
|
+
"togethercomputer/m2-bert-80M-2k-retrieval": EmbeddingModelConfig(
|
14
|
+
organization="Together",
|
15
|
+
display_name="M2-BERT-80M-2K-Retrieval",
|
16
|
+
model_size="80M",
|
17
|
+
embedding_dimension=768,
|
18
|
+
context_window=2048,
|
19
|
+
),
|
20
|
+
"togethercomputer/m2-bert-80M-8k-retrieval": EmbeddingModelConfig(
|
21
|
+
organization="Together",
|
22
|
+
display_name="M2-BERT-80M-8K-Retrieval",
|
23
|
+
model_size="80M",
|
24
|
+
embedding_dimension=768,
|
25
|
+
context_window=8192,
|
26
|
+
),
|
27
|
+
"togethercomputer/m2-bert-80M-32k-retrieval": EmbeddingModelConfig(
|
28
|
+
organization="Together",
|
29
|
+
display_name="M2-BERT-80M-32K-Retrieval",
|
30
|
+
model_size="80M",
|
31
|
+
embedding_dimension=768,
|
32
|
+
context_window=32768,
|
33
|
+
),
|
34
|
+
"WhereIsAI/UAE-Large-V1": EmbeddingModelConfig(
|
35
|
+
organization="WhereIsAI",
|
36
|
+
display_name="UAE-Large-v1",
|
37
|
+
model_size="326M",
|
38
|
+
embedding_dimension=1024,
|
39
|
+
context_window=512,
|
40
|
+
),
|
41
|
+
"BAAI/bge-large-en-v1.5": EmbeddingModelConfig(
|
42
|
+
organization="BAAI",
|
43
|
+
display_name="BGE-Large-EN-v1.5",
|
44
|
+
model_size="326M",
|
45
|
+
embedding_dimension=1024,
|
46
|
+
context_window=512,
|
47
|
+
),
|
48
|
+
"BAAI/bge-base-en-v1.5": EmbeddingModelConfig(
|
49
|
+
organization="BAAI",
|
50
|
+
display_name="BGE-Base-EN-v1.5",
|
51
|
+
model_size="102M",
|
52
|
+
embedding_dimension=768,
|
53
|
+
context_window=512,
|
54
|
+
),
|
55
|
+
"sentence-transformers/msmarco-bert-base-dot-v5": EmbeddingModelConfig(
|
56
|
+
organization="sentence-transformers",
|
57
|
+
display_name="Sentence-BERT",
|
58
|
+
model_size="110M",
|
59
|
+
embedding_dimension=768,
|
60
|
+
context_window=512,
|
61
|
+
),
|
62
|
+
"bert-base-uncased": EmbeddingModelConfig(
|
63
|
+
organization="Hugging Face",
|
64
|
+
display_name="BERT",
|
65
|
+
model_size="110M",
|
66
|
+
embedding_dimension=768,
|
67
|
+
context_window=512,
|
68
|
+
),
|
69
|
+
}
|
70
|
+
|
71
|
+
|
72
|
+
def get_embedding_model_config(model_id: str) -> EmbeddingModelConfig:
|
73
|
+
"""Get embedding model configuration by model ID"""
|
74
|
+
if model_id not in TOGETHER_EMBEDDING_MODELS:
|
75
|
+
raise ValueError(f"Model {model_id} not found in Together AI embedding models")
|
76
|
+
return TOGETHER_EMBEDDING_MODELS[model_id]
|
77
|
+
|
78
|
+
|
79
|
+
def list_embedding_models_by_organization(
|
80
|
+
organization: str,
|
81
|
+
) -> Dict[str, EmbeddingModelConfig]:
|
82
|
+
"""Get all embedding models for a specific organization"""
|
83
|
+
return {
|
84
|
+
model_id: config
|
85
|
+
for model_id, config in TOGETHER_EMBEDDING_MODELS.items()
|
86
|
+
if config.organization.lower() == organization.lower()
|
87
|
+
}
|
88
|
+
|
89
|
+
|
90
|
+
def get_default_embedding_model() -> str:
|
91
|
+
"""Get the default embedding model ID"""
|
92
|
+
return "togethercomputer/m2-bert-80M-32k-retrieval"
|
@@ -0,0 +1,69 @@
|
|
1
|
+
from typing import Dict, NamedTuple, Optional
|
2
|
+
|
3
|
+
|
4
|
+
class ImageModelConfig(NamedTuple):
|
5
|
+
organization: str
|
6
|
+
display_name: str
|
7
|
+
default_steps: Optional[int]
|
8
|
+
|
9
|
+
|
10
|
+
TOGETHER_IMAGE_MODELS: Dict[str, ImageModelConfig] = {
|
11
|
+
"black-forest-labs/FLUX.1-schnell-Free": ImageModelConfig(
|
12
|
+
organization="Black Forest Labs",
|
13
|
+
display_name="Flux.1 [schnell] (free)",
|
14
|
+
default_steps=None,
|
15
|
+
),
|
16
|
+
"black-forest-labs/FLUX.1-schnell": ImageModelConfig(
|
17
|
+
organization="Black Forest Labs",
|
18
|
+
display_name="Flux.1 [schnell] (Turbo)",
|
19
|
+
default_steps=4,
|
20
|
+
),
|
21
|
+
"black-forest-labs/FLUX.1-dev": ImageModelConfig(
|
22
|
+
organization="Black Forest Labs", display_name="Flux.1 Dev", default_steps=28
|
23
|
+
),
|
24
|
+
"black-forest-labs/FLUX.1-canny": ImageModelConfig(
|
25
|
+
organization="Black Forest Labs", display_name="Flux.1 Canny", default_steps=28
|
26
|
+
),
|
27
|
+
"black-forest-labs/FLUX.1-depth": ImageModelConfig(
|
28
|
+
organization="Black Forest Labs", display_name="Flux.1 Depth", default_steps=28
|
29
|
+
),
|
30
|
+
"black-forest-labs/FLUX.1-redux": ImageModelConfig(
|
31
|
+
organization="Black Forest Labs", display_name="Flux.1 Redux", default_steps=28
|
32
|
+
),
|
33
|
+
"black-forest-labs/FLUX.1.1-pro": ImageModelConfig(
|
34
|
+
organization="Black Forest Labs",
|
35
|
+
display_name="Flux1.1 [pro]",
|
36
|
+
default_steps=None,
|
37
|
+
),
|
38
|
+
"black-forest-labs/FLUX.1-pro": ImageModelConfig(
|
39
|
+
organization="Black Forest Labs",
|
40
|
+
display_name="Flux.1 [pro]",
|
41
|
+
default_steps=None,
|
42
|
+
),
|
43
|
+
"stabilityai/stable-diffusion-xl-base-1.0": ImageModelConfig(
|
44
|
+
organization="Stability AI",
|
45
|
+
display_name="Stable Diffusion XL 1.0",
|
46
|
+
default_steps=None,
|
47
|
+
),
|
48
|
+
}
|
49
|
+
|
50
|
+
|
51
|
+
def get_image_model_config(model_id: str) -> ImageModelConfig:
|
52
|
+
"""Get image model configuration by model ID"""
|
53
|
+
if model_id not in TOGETHER_IMAGE_MODELS:
|
54
|
+
raise ValueError(f"Model {model_id} not found in Together AI image models")
|
55
|
+
return TOGETHER_IMAGE_MODELS[model_id]
|
56
|
+
|
57
|
+
|
58
|
+
def list_image_models_by_organization(organization: str) -> Dict[str, ImageModelConfig]:
|
59
|
+
"""Get all image models for a specific organization"""
|
60
|
+
return {
|
61
|
+
model_id: config
|
62
|
+
for model_id, config in TOGETHER_IMAGE_MODELS.items()
|
63
|
+
if config.organization.lower() == organization.lower()
|
64
|
+
}
|
65
|
+
|
66
|
+
|
67
|
+
def get_default_image_model() -> str:
|
68
|
+
"""Get the default image model ID"""
|
69
|
+
return "black-forest-labs/FLUX.1-schnell-Free"
|