airtrain 0.1.12__py3-none-any.whl → 0.1.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airtrain/__init__.py +1 -1
- airtrain/contrib/travel/__init__.py +5 -5
- airtrain/integrations/fireworks/__init__.py +11 -0
- airtrain/integrations/fireworks/credentials.py +18 -0
- airtrain/integrations/fireworks/models.py +27 -0
- airtrain/integrations/fireworks/skills.py +107 -0
- airtrain/integrations/openai/models_config.py +119 -0
- airtrain/integrations/together/audio_models_config.py +34 -0
- airtrain/integrations/together/credentials.py +3 -3
- airtrain/integrations/together/embedding_models_config.py +92 -0
- airtrain/integrations/together/image_models_config.py +69 -0
- airtrain/integrations/together/image_skill.py +171 -0
- airtrain/integrations/together/models.py +56 -0
- airtrain/integrations/together/models_config.py +277 -0
- airtrain/integrations/together/rerank_models_config.py +43 -0
- airtrain/integrations/together/rerank_skill.py +49 -0
- airtrain/integrations/together/schemas.py +33 -0
- airtrain/integrations/together/skills.py +87 -1
- airtrain/integrations/together/vision_models_config.py +49 -0
- {airtrain-0.1.12.dist-info → airtrain-0.1.14.dist-info}/METADATA +1 -1
- {airtrain-0.1.12.dist-info → airtrain-0.1.14.dist-info}/RECORD +23 -8
- {airtrain-0.1.12.dist-info → airtrain-0.1.14.dist-info}/WHEEL +0 -0
- {airtrain-0.1.12.dist-info → airtrain-0.1.14.dist-info}/top_level.txt +0 -0
@@ -1,8 +1,13 @@
|
|
1
|
-
from typing import Optional, Dict, Any
|
1
|
+
from typing import Optional, Dict, Any, List
|
2
2
|
from pydantic import Field
|
3
3
|
from airtrain.core.skills import Skill, ProcessingError
|
4
4
|
from airtrain.core.schemas import InputSchema, OutputSchema
|
5
5
|
from .credentials import TogetherAICredentials
|
6
|
+
from .models import TogetherAIImageInput, TogetherAIImageOutput, GeneratedImage
|
7
|
+
from pathlib import Path
|
8
|
+
import base64
|
9
|
+
import time
|
10
|
+
from together import Together
|
6
11
|
|
7
12
|
|
8
13
|
class TogetherAIInput(InputSchema):
|
@@ -41,3 +46,84 @@ class TogetherAIChatSkill(Skill[TogetherAIInput, TogetherAIOutput]):
|
|
41
46
|
|
42
47
|
def process(self, input_data: TogetherAIInput) -> TogetherAIOutput:
|
43
48
|
raise NotImplementedError("TogetherAIChatSkill is not implemented yet")
|
49
|
+
|
50
|
+
|
51
|
+
class TogetherAIImageSkill(Skill[TogetherAIImageInput, TogetherAIImageOutput]):
|
52
|
+
"""Skill for Together AI image generation"""
|
53
|
+
|
54
|
+
input_schema = TogetherAIImageInput
|
55
|
+
output_schema = TogetherAIImageOutput
|
56
|
+
|
57
|
+
def __init__(self, credentials: Optional[TogetherAICredentials] = None):
|
58
|
+
super().__init__()
|
59
|
+
self.credentials = credentials or TogetherAICredentials.from_env()
|
60
|
+
self.client = Together(
|
61
|
+
api_key=self.credentials.together_api_key.get_secret_value()
|
62
|
+
)
|
63
|
+
|
64
|
+
def process(self, input_data: TogetherAIImageInput) -> TogetherAIImageOutput:
|
65
|
+
try:
|
66
|
+
start_time = time.time()
|
67
|
+
|
68
|
+
# Generate images
|
69
|
+
response = self.client.images.generate(
|
70
|
+
prompt=input_data.prompt,
|
71
|
+
model=input_data.model,
|
72
|
+
steps=input_data.steps,
|
73
|
+
n=input_data.n,
|
74
|
+
size=input_data.size,
|
75
|
+
negative_prompt=input_data.negative_prompt,
|
76
|
+
seed=input_data.seed,
|
77
|
+
)
|
78
|
+
|
79
|
+
# Calculate total time
|
80
|
+
total_time = time.time() - start_time
|
81
|
+
|
82
|
+
# Convert response to our output format
|
83
|
+
generated_images = [
|
84
|
+
GeneratedImage(
|
85
|
+
b64_json=img.b64_json,
|
86
|
+
seed=getattr(img, "seed", None),
|
87
|
+
finish_reason=getattr(img, "finish_reason", None),
|
88
|
+
)
|
89
|
+
for img in response.data
|
90
|
+
]
|
91
|
+
|
92
|
+
return TogetherAIImageOutput(
|
93
|
+
images=generated_images,
|
94
|
+
model=input_data.model,
|
95
|
+
prompt=input_data.prompt,
|
96
|
+
total_time=total_time,
|
97
|
+
usage=getattr(response, "usage", {}),
|
98
|
+
)
|
99
|
+
|
100
|
+
except Exception as e:
|
101
|
+
raise ProcessingError(f"Together AI image generation failed: {str(e)}")
|
102
|
+
|
103
|
+
def save_images(
|
104
|
+
self, output: TogetherAIImageOutput, output_dir: Path
|
105
|
+
) -> List[Path]:
|
106
|
+
"""
|
107
|
+
Save generated images to disk
|
108
|
+
|
109
|
+
Args:
|
110
|
+
output (TogetherAIImageOutput): Generation output containing images
|
111
|
+
output_dir (Path): Directory to save images
|
112
|
+
|
113
|
+
Returns:
|
114
|
+
List[Path]: List of paths to saved images
|
115
|
+
"""
|
116
|
+
output_dir = Path(output_dir)
|
117
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
118
|
+
|
119
|
+
saved_paths = []
|
120
|
+
for i, img in enumerate(output.images):
|
121
|
+
output_path = output_dir / f"image_{i}.png"
|
122
|
+
image_data = base64.b64decode(img.b64_json)
|
123
|
+
|
124
|
+
with open(output_path, "wb") as f:
|
125
|
+
f.write(image_data)
|
126
|
+
|
127
|
+
saved_paths.append(output_path)
|
128
|
+
|
129
|
+
return saved_paths
|
@@ -0,0 +1,49 @@
|
|
1
|
+
from typing import Dict, NamedTuple
|
2
|
+
|
3
|
+
|
4
|
+
class VisionModelConfig(NamedTuple):
|
5
|
+
organization: str
|
6
|
+
display_name: str
|
7
|
+
context_length: int
|
8
|
+
|
9
|
+
|
10
|
+
TOGETHER_VISION_MODELS: Dict[str, VisionModelConfig] = {
|
11
|
+
"meta-llama/Llama-Vision-Free": VisionModelConfig(
|
12
|
+
organization="Meta",
|
13
|
+
display_name="(Free) Llama 3.2 11B Vision Instruct Turbo",
|
14
|
+
context_length=131072,
|
15
|
+
),
|
16
|
+
"meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo": VisionModelConfig(
|
17
|
+
organization="Meta",
|
18
|
+
display_name="Llama 3.2 11B Vision Instruct Turbo",
|
19
|
+
context_length=131072,
|
20
|
+
),
|
21
|
+
"meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo": VisionModelConfig(
|
22
|
+
organization="Meta",
|
23
|
+
display_name="Llama 3.2 90B Vision Instruct Turbo",
|
24
|
+
context_length=131072,
|
25
|
+
),
|
26
|
+
}
|
27
|
+
|
28
|
+
|
29
|
+
def get_vision_model_config(model_id: str) -> VisionModelConfig:
|
30
|
+
"""Get vision model configuration by model ID"""
|
31
|
+
if model_id not in TOGETHER_VISION_MODELS:
|
32
|
+
raise ValueError(f"Model {model_id} not found in Together AI vision models")
|
33
|
+
return TOGETHER_VISION_MODELS[model_id]
|
34
|
+
|
35
|
+
|
36
|
+
def list_vision_models_by_organization(
|
37
|
+
organization: str,
|
38
|
+
) -> Dict[str, VisionModelConfig]:
|
39
|
+
"""Get all vision models for a specific organization"""
|
40
|
+
return {
|
41
|
+
model_id: config
|
42
|
+
for model_id, config in TOGETHER_VISION_MODELS.items()
|
43
|
+
if config.organization.lower() == organization.lower()
|
44
|
+
}
|
45
|
+
|
46
|
+
|
47
|
+
def get_default_vision_model() -> str:
|
48
|
+
"""Get the default vision model ID"""
|
49
|
+
return "meta-llama/Llama-Vision-Free"
|
@@ -1,6 +1,6 @@
|
|
1
|
-
airtrain/__init__.py,sha256=
|
1
|
+
airtrain/__init__.py,sha256=FRI-wOLxVAu3ECtPJGQ8ZLjDSM1PgoiDGvn8ctNtS_8,2095
|
2
2
|
airtrain/contrib/__init__.py,sha256=pG-7mJ0pBMqp3Q86mIF9bo1PqoBOVSGlnEK1yY1U1ok,641
|
3
|
-
airtrain/contrib/travel/__init__.py,sha256=
|
3
|
+
airtrain/contrib/travel/__init__.py,sha256=clmBodw4nkTA-DsgjVGcXfJGPaWxIpCZDtdO-8RzL0M,811
|
4
4
|
airtrain/contrib/travel/agents.py,sha256=tpQtZ0WUiXBuhvZtc2JlEam5TuR5l-Tndi14YyImDBM,8975
|
5
5
|
airtrain/contrib/travel/models.py,sha256=E4Mtds5TINmXLXu65aTYqv6wwOh2KclJHZ2eXRrBqiY,1547
|
6
6
|
airtrain/core/__init__.py,sha256=9h7iKwTzZocCPc9bU6j8bA02BokteWIOcO1uaqGMcrk,254
|
@@ -17,6 +17,10 @@ airtrain/integrations/aws/skills.py,sha256=TQiMXeXRRcJ14fe8Xi7Uk20iS6_INbcznuLGt
|
|
17
17
|
airtrain/integrations/cerebras/__init__.py,sha256=zAD-qV38OzHhMCz1z-NvjjqcYEhURbm8RWTOKHNqbew,174
|
18
18
|
airtrain/integrations/cerebras/credentials.py,sha256=IFkn8LxMAaOpvEWXDpb94VQGtqcDxQ7rZHKH-tX4Nuw,884
|
19
19
|
airtrain/integrations/cerebras/skills.py,sha256=O9vwFzvv_tUOwFOVE8CszAQEac711eVYVUj_8dVMTpc,1596
|
20
|
+
airtrain/integrations/fireworks/__init__.py,sha256=9pJvP0u1FJbNtB0oHa09mHVJLctELf_c27LOYyDk2ZI,271
|
21
|
+
airtrain/integrations/fireworks/credentials.py,sha256=UpcwR9V5Hbk5sJbjFDJDbHMRqc90IQSqAvrtJCOvwEo,524
|
22
|
+
airtrain/integrations/fireworks/models.py,sha256=F-MddbLCLAsTjwRr1l6IpJxOegyY4pD7jN9ySPiypSo,593
|
23
|
+
airtrain/integrations/fireworks/skills.py,sha256=ZykowW8lMbTcZVJ0GO2Ut6E-u2-keXvE4F-_j-3JI4k,4074
|
20
24
|
airtrain/integrations/google/__init__.py,sha256=INZFNOcNebz3m-Ggk07ZjmX0kNHIbTe_St9gBlZBki8,176
|
21
25
|
airtrain/integrations/google/credentials.py,sha256=Mm4jNWF02rIf0_GuHLcUUPyLHC4NMRdF_iTCoVTQ0Bs,1033
|
22
26
|
airtrain/integrations/google/skills.py,sha256=uwmgetl5Ien7fLOA5HIZdqoL6AZnexFDyzfsrGuJ1RU,1606
|
@@ -29,14 +33,25 @@ airtrain/integrations/ollama/skills.py,sha256=M_Un8D5VJ5XtPEq9IClzqV3jCPBoFTSm2v
|
|
29
33
|
airtrain/integrations/openai/__init__.py,sha256=K-NY2_T1T6SEOgkpbUA55cWvK2nr2NOJgLCqmmtaCno,371
|
30
34
|
airtrain/integrations/openai/chinese_assistant.py,sha256=MMhv4NBOoEQ0O22ZZtP255rd5ajHC9l6FPWIjpqxBOA,1581
|
31
35
|
airtrain/integrations/openai/credentials.py,sha256=NfRyp1QgEtgm8cxt2-BOLq-6d0X-Pcm80NnfHM8p0FY,1470
|
36
|
+
airtrain/integrations/openai/models_config.py,sha256=bzosqqpDy2AJxu2vGdk2H4voqEGlv7LORR6fpJLhNic,3962
|
32
37
|
airtrain/integrations/openai/skills.py,sha256=Olg9-6f_p2XgkVwwcB9tvjAMApmM2EK81i8LP4qVVvs,7676
|
33
38
|
airtrain/integrations/sambanova/__init__.py,sha256=dp_263iOckM_J9pOEvyqpf3FrejD6-_x33r0edMCTe0,179
|
34
39
|
airtrain/integrations/sambanova/credentials.py,sha256=U36RAEIPNuwo-vTrt3U9kkkj2GfdqSclA1ttOYHxS-w,784
|
35
40
|
airtrain/integrations/sambanova/skills.py,sha256=Po1ur_QFwzVIugbkk2mt73WdXDz_Gr9ASlUc9Y12Kok,1614
|
36
41
|
airtrain/integrations/together/__init__.py,sha256=we4KXn_pUs6Dxo3QcB-t40BSRraQFdKg2nXw7yi2FjM,185
|
37
|
-
airtrain/integrations/together/
|
38
|
-
airtrain/integrations/together/
|
39
|
-
airtrain
|
40
|
-
airtrain
|
41
|
-
airtrain
|
42
|
-
airtrain
|
42
|
+
airtrain/integrations/together/audio_models_config.py,sha256=GtqfmKR1vJ5x4B3kScvEO3x4exvzwNP78vcGVTk_fBE,1004
|
43
|
+
airtrain/integrations/together/credentials.py,sha256=cYNhyIwgsxm8LfiFfT-omBvgV3mUP6SZeRSukyzzDlI,747
|
44
|
+
airtrain/integrations/together/embedding_models_config.py,sha256=F0ISAXCG_Pcnf-ojkvZwIXacXD8LaU8hQmGHCFzmlds,2927
|
45
|
+
airtrain/integrations/together/image_models_config.py,sha256=JlCozrphI9zE4uYpGfj4DCWSN6GZGyr84Tb1HmjNQ28,2455
|
46
|
+
airtrain/integrations/together/image_skill.py,sha256=FMO9-TRwfucLRlpvij9VpeVsrTiP9wux_je_pFz6OXs,6508
|
47
|
+
airtrain/integrations/together/models.py,sha256=ZW5xfEN9fU18aeltb-sB2O-Bnu5sLkDPZqvUtxgoH-U,2112
|
48
|
+
airtrain/integrations/together/models_config.py,sha256=XMKp0Oq1nWWnMMdNAZxkFXmJaURwWrwLE18kFXsMsRw,8829
|
49
|
+
airtrain/integrations/together/rerank_models_config.py,sha256=coCg0IOG2tU4L2uc2uPtPdoBwGjSc_zQwxENwdDuwHE,1188
|
50
|
+
airtrain/integrations/together/rerank_skill.py,sha256=gjH24hLWCweWKPyyfKZMG3K_g9gWzm80WgiJNjkA9eg,1894
|
51
|
+
airtrain/integrations/together/schemas.py,sha256=pBMrbX67oxPCr-sg4K8_Xqu1DWbaC4uLCloVSascROg,1210
|
52
|
+
airtrain/integrations/together/skills.py,sha256=UfLHnseZbA7R7q5dDco6mpV546Zfd3DTliZSrNkCL6Q,4518
|
53
|
+
airtrain/integrations/together/vision_models_config.py,sha256=m28HwYDk2Kup_J-a1FtynIa2ZVcbl37kltfoHnK8zxs,1544
|
54
|
+
airtrain-0.1.14.dist-info/METADATA,sha256=l4IPKLJ7Bf3gmZYSRPVEfz4oe1XGt_lWfvuZg68cNnE,4536
|
55
|
+
airtrain-0.1.14.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
|
56
|
+
airtrain-0.1.14.dist-info/top_level.txt,sha256=cFWW1vY6VMCb3AGVdz6jBDpZ65xxBRSqlsPyySxTkxY,9
|
57
|
+
airtrain-0.1.14.dist-info/RECORD,,
|
File without changes
|
File without changes
|