airtrain 0.1.2__py3-none-any.whl → 0.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airtrain/__init__.py +148 -2
- airtrain/__main__.py +4 -0
- airtrain/__pycache__/__init__.cpython-313.pyc +0 -0
- airtrain/agents/__init__.py +45 -0
- airtrain/agents/example_agent.py +348 -0
- airtrain/agents/groq_agent.py +289 -0
- airtrain/agents/memory.py +663 -0
- airtrain/agents/registry.py +465 -0
- airtrain/builder/__init__.py +3 -0
- airtrain/builder/agent_builder.py +122 -0
- airtrain/cli/__init__.py +0 -0
- airtrain/cli/builder.py +23 -0
- airtrain/cli/main.py +120 -0
- airtrain/contrib/__init__.py +29 -0
- airtrain/contrib/travel/__init__.py +35 -0
- airtrain/contrib/travel/agents.py +243 -0
- airtrain/contrib/travel/models.py +59 -0
- airtrain/core/__init__.py +7 -0
- airtrain/core/__pycache__/__init__.cpython-313.pyc +0 -0
- airtrain/core/__pycache__/schemas.cpython-313.pyc +0 -0
- airtrain/core/__pycache__/skills.cpython-313.pyc +0 -0
- airtrain/core/credentials.py +171 -0
- airtrain/core/schemas.py +237 -0
- airtrain/core/skills.py +269 -0
- airtrain/integrations/__init__.py +74 -0
- airtrain/integrations/anthropic/__init__.py +33 -0
- airtrain/integrations/anthropic/credentials.py +32 -0
- airtrain/integrations/anthropic/list_models.py +110 -0
- airtrain/integrations/anthropic/models_config.py +100 -0
- airtrain/integrations/anthropic/skills.py +155 -0
- airtrain/integrations/aws/__init__.py +6 -0
- airtrain/integrations/aws/credentials.py +36 -0
- airtrain/integrations/aws/skills.py +98 -0
- airtrain/integrations/cerebras/__init__.py +6 -0
- airtrain/integrations/cerebras/credentials.py +19 -0
- airtrain/integrations/cerebras/skills.py +127 -0
- airtrain/integrations/combined/__init__.py +21 -0
- airtrain/integrations/combined/groq_fireworks_skills.py +126 -0
- airtrain/integrations/combined/list_models_factory.py +210 -0
- airtrain/integrations/fireworks/__init__.py +21 -0
- airtrain/integrations/fireworks/completion_skills.py +147 -0
- airtrain/integrations/fireworks/conversation_manager.py +109 -0
- airtrain/integrations/fireworks/credentials.py +26 -0
- airtrain/integrations/fireworks/list_models.py +128 -0
- airtrain/integrations/fireworks/models.py +139 -0
- airtrain/integrations/fireworks/requests_skills.py +207 -0
- airtrain/integrations/fireworks/skills.py +181 -0
- airtrain/integrations/fireworks/structured_completion_skills.py +175 -0
- airtrain/integrations/fireworks/structured_requests_skills.py +291 -0
- airtrain/integrations/fireworks/structured_skills.py +102 -0
- airtrain/integrations/google/__init__.py +7 -0
- airtrain/integrations/google/credentials.py +58 -0
- airtrain/integrations/google/skills.py +122 -0
- airtrain/integrations/groq/__init__.py +23 -0
- airtrain/integrations/groq/credentials.py +24 -0
- airtrain/integrations/groq/models_config.py +162 -0
- airtrain/integrations/groq/skills.py +201 -0
- airtrain/integrations/ollama/__init__.py +6 -0
- airtrain/integrations/ollama/credentials.py +26 -0
- airtrain/integrations/ollama/skills.py +41 -0
- airtrain/integrations/openai/__init__.py +37 -0
- airtrain/integrations/openai/chinese_assistant.py +42 -0
- airtrain/integrations/openai/credentials.py +39 -0
- airtrain/integrations/openai/list_models.py +112 -0
- airtrain/integrations/openai/models_config.py +224 -0
- airtrain/integrations/openai/skills.py +342 -0
- airtrain/integrations/perplexity/__init__.py +49 -0
- airtrain/integrations/perplexity/credentials.py +43 -0
- airtrain/integrations/perplexity/list_models.py +112 -0
- airtrain/integrations/perplexity/models_config.py +128 -0
- airtrain/integrations/perplexity/skills.py +279 -0
- airtrain/integrations/sambanova/__init__.py +6 -0
- airtrain/integrations/sambanova/credentials.py +20 -0
- airtrain/integrations/sambanova/skills.py +129 -0
- airtrain/integrations/search/__init__.py +21 -0
- airtrain/integrations/search/exa/__init__.py +23 -0
- airtrain/integrations/search/exa/credentials.py +30 -0
- airtrain/integrations/search/exa/schemas.py +114 -0
- airtrain/integrations/search/exa/skills.py +115 -0
- airtrain/integrations/together/__init__.py +33 -0
- airtrain/integrations/together/audio_models_config.py +34 -0
- airtrain/integrations/together/credentials.py +22 -0
- airtrain/integrations/together/embedding_models_config.py +92 -0
- airtrain/integrations/together/image_models_config.py +69 -0
- airtrain/integrations/together/image_skill.py +143 -0
- airtrain/integrations/together/list_models.py +76 -0
- airtrain/integrations/together/models.py +95 -0
- airtrain/integrations/together/models_config.py +399 -0
- airtrain/integrations/together/rerank_models_config.py +43 -0
- airtrain/integrations/together/rerank_skill.py +49 -0
- airtrain/integrations/together/schemas.py +33 -0
- airtrain/integrations/together/skills.py +305 -0
- airtrain/integrations/together/vision_models_config.py +49 -0
- airtrain/telemetry/__init__.py +38 -0
- airtrain/telemetry/service.py +167 -0
- airtrain/telemetry/views.py +237 -0
- airtrain/tools/__init__.py +45 -0
- airtrain/tools/command.py +398 -0
- airtrain/tools/filesystem.py +166 -0
- airtrain/tools/network.py +111 -0
- airtrain/tools/registry.py +320 -0
- airtrain/tools/search.py +450 -0
- airtrain/tools/testing.py +135 -0
- airtrain-0.1.4.dist-info/METADATA +222 -0
- airtrain-0.1.4.dist-info/RECORD +108 -0
- {airtrain-0.1.2.dist-info → airtrain-0.1.4.dist-info}/WHEEL +1 -1
- airtrain-0.1.4.dist-info/entry_points.txt +2 -0
- airtrain-0.1.2.dist-info/METADATA +0 -106
- airtrain-0.1.2.dist-info/RECORD +0 -5
- {airtrain-0.1.2.dist-info → airtrain-0.1.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,115 @@
|
|
1
|
+
"""
|
2
|
+
Skills for Exa Search API.
|
3
|
+
|
4
|
+
This module provides skills for using the Exa search API.
|
5
|
+
"""
|
6
|
+
|
7
|
+
import json
|
8
|
+
import logging
|
9
|
+
import httpx
|
10
|
+
from typing import Optional, Dict, Any, List, cast
|
11
|
+
|
12
|
+
from pydantic import ValidationError
|
13
|
+
|
14
|
+
from airtrain.core.skills import Skill
|
15
|
+
from airtrain.core.errors import ProcessingError
|
16
|
+
from .credentials import ExaCredentials
|
17
|
+
from .schemas import ExaSearchInputSchema, ExaSearchOutputSchema, ExaSearchResult
|
18
|
+
|
19
|
+
|
20
|
+
logger = logging.getLogger(__name__)
|
21
|
+
|
22
|
+
|
23
|
+
class ExaSearchSkill(Skill[ExaSearchInputSchema, ExaSearchOutputSchema]):
|
24
|
+
"""Skill for searching the web using the Exa search API."""
|
25
|
+
|
26
|
+
input_schema = ExaSearchInputSchema
|
27
|
+
output_schema = ExaSearchOutputSchema
|
28
|
+
|
29
|
+
EXA_API_ENDPOINT = "https://api.exa.ai/search"
|
30
|
+
|
31
|
+
def __init__(
|
32
|
+
self,
|
33
|
+
credentials: ExaCredentials,
|
34
|
+
timeout: float = 60.0,
|
35
|
+
max_retries: int = 3,
|
36
|
+
**kwargs,
|
37
|
+
):
|
38
|
+
"""
|
39
|
+
Initialize the Exa search skill.
|
40
|
+
|
41
|
+
Args:
|
42
|
+
credentials: Credentials for accessing the Exa API
|
43
|
+
timeout: Timeout for API requests in seconds
|
44
|
+
max_retries: Maximum number of retries for failed requests
|
45
|
+
"""
|
46
|
+
super().__init__(**kwargs)
|
47
|
+
self.credentials = credentials
|
48
|
+
self.timeout = timeout
|
49
|
+
self.max_retries = max_retries
|
50
|
+
|
51
|
+
async def process(self, input_data: ExaSearchInputSchema) -> ExaSearchOutputSchema:
|
52
|
+
"""
|
53
|
+
Process a search request using the Exa API.
|
54
|
+
|
55
|
+
Args:
|
56
|
+
input_data: Search input parameters
|
57
|
+
|
58
|
+
Returns:
|
59
|
+
Search results from Exa
|
60
|
+
|
61
|
+
Raises:
|
62
|
+
ProcessingError: If there's an issue with the API request
|
63
|
+
"""
|
64
|
+
try:
|
65
|
+
# Prepare request payload
|
66
|
+
payload = input_data.model_dump(exclude_none=True)
|
67
|
+
|
68
|
+
# Build request headers
|
69
|
+
headers = {
|
70
|
+
"content-type": "application/json",
|
71
|
+
"Authorization": f"Bearer {self.credentials.api_key.get_secret_value()}",
|
72
|
+
}
|
73
|
+
|
74
|
+
# Make the API request
|
75
|
+
async with httpx.AsyncClient() as client:
|
76
|
+
response = await client.post(
|
77
|
+
self.EXA_API_ENDPOINT,
|
78
|
+
headers=headers,
|
79
|
+
json=payload,
|
80
|
+
timeout=self.timeout,
|
81
|
+
)
|
82
|
+
|
83
|
+
# Check for successful response
|
84
|
+
if response.status_code == 200:
|
85
|
+
result_data = response.json()
|
86
|
+
|
87
|
+
# Construct the output schema
|
88
|
+
output = ExaSearchOutputSchema(
|
89
|
+
results=result_data.get("results", []),
|
90
|
+
query=input_data.query,
|
91
|
+
autopromptString=result_data.get("autopromptString"),
|
92
|
+
costDollars=result_data.get("costDollars"),
|
93
|
+
)
|
94
|
+
|
95
|
+
return output
|
96
|
+
else:
|
97
|
+
# Handle error responses
|
98
|
+
error_message = f"Exa API returned status code {response.status_code}: {response.text}"
|
99
|
+
logger.error(error_message)
|
100
|
+
raise ProcessingError(error_message)
|
101
|
+
|
102
|
+
except httpx.TimeoutException:
|
103
|
+
error_message = f"Timeout while querying Exa API (timeout={self.timeout}s)"
|
104
|
+
logger.error(error_message)
|
105
|
+
raise ProcessingError(error_message)
|
106
|
+
|
107
|
+
except ValidationError as e:
|
108
|
+
error_message = f"Schema validation error: {str(e)}"
|
109
|
+
logger.error(error_message)
|
110
|
+
raise ProcessingError(error_message)
|
111
|
+
|
112
|
+
except Exception as e:
|
113
|
+
error_message = f"Unexpected error while querying Exa API: {str(e)}"
|
114
|
+
logger.error(error_message)
|
115
|
+
raise ProcessingError(error_message)
|
@@ -0,0 +1,33 @@
|
|
1
|
+
"""Together AI integration module"""
|
2
|
+
|
3
|
+
from .credentials import TogetherAICredentials
|
4
|
+
from .skills import TogetherAIChatSkill, TogetherAIInput, TogetherAIOutput
|
5
|
+
from .models_config import (
|
6
|
+
get_model_config_with_capabilities,
|
7
|
+
get_max_completion_tokens,
|
8
|
+
supports_tool_use,
|
9
|
+
supports_json_mode,
|
10
|
+
TOGETHER_MODELS_CONFIG,
|
11
|
+
)
|
12
|
+
from .list_models import (
|
13
|
+
TogetherListModelsSkill,
|
14
|
+
TogetherListModelsInput,
|
15
|
+
TogetherListModelsOutput,
|
16
|
+
)
|
17
|
+
from .models import TogetherModel
|
18
|
+
|
19
|
+
__all__ = [
|
20
|
+
"TogetherAICredentials",
|
21
|
+
"TogetherAIChatSkill",
|
22
|
+
"TogetherAIInput",
|
23
|
+
"TogetherAIOutput",
|
24
|
+
"TogetherListModelsSkill",
|
25
|
+
"TogetherListModelsInput",
|
26
|
+
"TogetherListModelsOutput",
|
27
|
+
"TogetherModel",
|
28
|
+
"get_model_config_with_capabilities",
|
29
|
+
"get_max_completion_tokens",
|
30
|
+
"supports_tool_use",
|
31
|
+
"supports_json_mode",
|
32
|
+
"TOGETHER_MODELS_CONFIG",
|
33
|
+
]
|
@@ -0,0 +1,34 @@
|
|
1
|
+
from typing import Dict, NamedTuple
|
2
|
+
|
3
|
+
|
4
|
+
class AudioModelConfig(NamedTuple):
|
5
|
+
organization: str
|
6
|
+
display_name: str
|
7
|
+
|
8
|
+
|
9
|
+
TOGETHER_AUDIO_MODELS: Dict[str, AudioModelConfig] = {
|
10
|
+
"Cartesia/Sonic": AudioModelConfig(
|
11
|
+
organization="Cartesia", display_name="Cartesia/Sonic"
|
12
|
+
)
|
13
|
+
}
|
14
|
+
|
15
|
+
|
16
|
+
def get_audio_model_config(model_id: str) -> AudioModelConfig:
|
17
|
+
"""Get audio model configuration by model ID"""
|
18
|
+
if model_id not in TOGETHER_AUDIO_MODELS:
|
19
|
+
raise ValueError(f"Model {model_id} not found in Together AI audio models")
|
20
|
+
return TOGETHER_AUDIO_MODELS[model_id]
|
21
|
+
|
22
|
+
|
23
|
+
def list_audio_models_by_organization(organization: str) -> Dict[str, AudioModelConfig]:
|
24
|
+
"""Get all audio models for a specific organization"""
|
25
|
+
return {
|
26
|
+
model_id: config
|
27
|
+
for model_id, config in TOGETHER_AUDIO_MODELS.items()
|
28
|
+
if config.organization.lower() == organization.lower()
|
29
|
+
}
|
30
|
+
|
31
|
+
|
32
|
+
def get_default_audio_model() -> str:
|
33
|
+
"""Get the default audio model ID"""
|
34
|
+
return "Cartesia/Sonic"
|
@@ -0,0 +1,22 @@
|
|
1
|
+
from pydantic import Field, SecretStr
|
2
|
+
from airtrain.core.credentials import BaseCredentials, CredentialValidationError
|
3
|
+
import together
|
4
|
+
|
5
|
+
|
6
|
+
class TogetherAICredentials(BaseCredentials):
|
7
|
+
"""Together AI credentials"""
|
8
|
+
|
9
|
+
together_api_key: SecretStr = Field(..., description="Together AI API key")
|
10
|
+
|
11
|
+
_required_credentials = {"together_api_key"}
|
12
|
+
|
13
|
+
async def validate_credentials(self) -> bool:
|
14
|
+
"""Validate Together AI credentials"""
|
15
|
+
try:
|
16
|
+
together.api_key = self.together_api_key.get_secret_value()
|
17
|
+
await together.Models.list()
|
18
|
+
return True
|
19
|
+
except Exception as e:
|
20
|
+
raise CredentialValidationError(
|
21
|
+
f"Invalid Together AI credentials: {str(e)}"
|
22
|
+
)
|
@@ -0,0 +1,92 @@
|
|
1
|
+
from typing import Dict, NamedTuple
|
2
|
+
|
3
|
+
|
4
|
+
class EmbeddingModelConfig(NamedTuple):
|
5
|
+
organization: str
|
6
|
+
display_name: str
|
7
|
+
model_size: str
|
8
|
+
embedding_dimension: int
|
9
|
+
context_window: int
|
10
|
+
|
11
|
+
|
12
|
+
TOGETHER_EMBEDDING_MODELS: Dict[str, EmbeddingModelConfig] = {
|
13
|
+
"togethercomputer/m2-bert-80M-2k-retrieval": EmbeddingModelConfig(
|
14
|
+
organization="Together",
|
15
|
+
display_name="M2-BERT-80M-2K-Retrieval",
|
16
|
+
model_size="80M",
|
17
|
+
embedding_dimension=768,
|
18
|
+
context_window=2048,
|
19
|
+
),
|
20
|
+
"togethercomputer/m2-bert-80M-8k-retrieval": EmbeddingModelConfig(
|
21
|
+
organization="Together",
|
22
|
+
display_name="M2-BERT-80M-8K-Retrieval",
|
23
|
+
model_size="80M",
|
24
|
+
embedding_dimension=768,
|
25
|
+
context_window=8192,
|
26
|
+
),
|
27
|
+
"togethercomputer/m2-bert-80M-32k-retrieval": EmbeddingModelConfig(
|
28
|
+
organization="Together",
|
29
|
+
display_name="M2-BERT-80M-32K-Retrieval",
|
30
|
+
model_size="80M",
|
31
|
+
embedding_dimension=768,
|
32
|
+
context_window=32768,
|
33
|
+
),
|
34
|
+
"WhereIsAI/UAE-Large-V1": EmbeddingModelConfig(
|
35
|
+
organization="WhereIsAI",
|
36
|
+
display_name="UAE-Large-v1",
|
37
|
+
model_size="326M",
|
38
|
+
embedding_dimension=1024,
|
39
|
+
context_window=512,
|
40
|
+
),
|
41
|
+
"BAAI/bge-large-en-v1.5": EmbeddingModelConfig(
|
42
|
+
organization="BAAI",
|
43
|
+
display_name="BGE-Large-EN-v1.5",
|
44
|
+
model_size="326M",
|
45
|
+
embedding_dimension=1024,
|
46
|
+
context_window=512,
|
47
|
+
),
|
48
|
+
"BAAI/bge-base-en-v1.5": EmbeddingModelConfig(
|
49
|
+
organization="BAAI",
|
50
|
+
display_name="BGE-Base-EN-v1.5",
|
51
|
+
model_size="102M",
|
52
|
+
embedding_dimension=768,
|
53
|
+
context_window=512,
|
54
|
+
),
|
55
|
+
"sentence-transformers/msmarco-bert-base-dot-v5": EmbeddingModelConfig(
|
56
|
+
organization="sentence-transformers",
|
57
|
+
display_name="Sentence-BERT",
|
58
|
+
model_size="110M",
|
59
|
+
embedding_dimension=768,
|
60
|
+
context_window=512,
|
61
|
+
),
|
62
|
+
"bert-base-uncased": EmbeddingModelConfig(
|
63
|
+
organization="Hugging Face",
|
64
|
+
display_name="BERT",
|
65
|
+
model_size="110M",
|
66
|
+
embedding_dimension=768,
|
67
|
+
context_window=512,
|
68
|
+
),
|
69
|
+
}
|
70
|
+
|
71
|
+
|
72
|
+
def get_embedding_model_config(model_id: str) -> EmbeddingModelConfig:
|
73
|
+
"""Get embedding model configuration by model ID"""
|
74
|
+
if model_id not in TOGETHER_EMBEDDING_MODELS:
|
75
|
+
raise ValueError(f"Model {model_id} not found in Together AI embedding models")
|
76
|
+
return TOGETHER_EMBEDDING_MODELS[model_id]
|
77
|
+
|
78
|
+
|
79
|
+
def list_embedding_models_by_organization(
|
80
|
+
organization: str,
|
81
|
+
) -> Dict[str, EmbeddingModelConfig]:
|
82
|
+
"""Get all embedding models for a specific organization"""
|
83
|
+
return {
|
84
|
+
model_id: config
|
85
|
+
for model_id, config in TOGETHER_EMBEDDING_MODELS.items()
|
86
|
+
if config.organization.lower() == organization.lower()
|
87
|
+
}
|
88
|
+
|
89
|
+
|
90
|
+
def get_default_embedding_model() -> str:
|
91
|
+
"""Get the default embedding model ID"""
|
92
|
+
return "togethercomputer/m2-bert-80M-32k-retrieval"
|
@@ -0,0 +1,69 @@
|
|
1
|
+
from typing import Dict, NamedTuple, Optional
|
2
|
+
|
3
|
+
|
4
|
+
class ImageModelConfig(NamedTuple):
|
5
|
+
organization: str
|
6
|
+
display_name: str
|
7
|
+
default_steps: Optional[int]
|
8
|
+
|
9
|
+
|
10
|
+
TOGETHER_IMAGE_MODELS: Dict[str, ImageModelConfig] = {
|
11
|
+
"black-forest-labs/FLUX.1-schnell-Free": ImageModelConfig(
|
12
|
+
organization="Black Forest Labs",
|
13
|
+
display_name="Flux.1 [schnell] (free)",
|
14
|
+
default_steps=None,
|
15
|
+
),
|
16
|
+
"black-forest-labs/FLUX.1-schnell": ImageModelConfig(
|
17
|
+
organization="Black Forest Labs",
|
18
|
+
display_name="Flux.1 [schnell] (Turbo)",
|
19
|
+
default_steps=4,
|
20
|
+
),
|
21
|
+
"black-forest-labs/FLUX.1-dev": ImageModelConfig(
|
22
|
+
organization="Black Forest Labs", display_name="Flux.1 Dev", default_steps=28
|
23
|
+
),
|
24
|
+
"black-forest-labs/FLUX.1-canny": ImageModelConfig(
|
25
|
+
organization="Black Forest Labs", display_name="Flux.1 Canny", default_steps=28
|
26
|
+
),
|
27
|
+
"black-forest-labs/FLUX.1-depth": ImageModelConfig(
|
28
|
+
organization="Black Forest Labs", display_name="Flux.1 Depth", default_steps=28
|
29
|
+
),
|
30
|
+
"black-forest-labs/FLUX.1-redux": ImageModelConfig(
|
31
|
+
organization="Black Forest Labs", display_name="Flux.1 Redux", default_steps=28
|
32
|
+
),
|
33
|
+
"black-forest-labs/FLUX.1.1-pro": ImageModelConfig(
|
34
|
+
organization="Black Forest Labs",
|
35
|
+
display_name="Flux1.1 [pro]",
|
36
|
+
default_steps=None,
|
37
|
+
),
|
38
|
+
"black-forest-labs/FLUX.1-pro": ImageModelConfig(
|
39
|
+
organization="Black Forest Labs",
|
40
|
+
display_name="Flux.1 [pro]",
|
41
|
+
default_steps=None,
|
42
|
+
),
|
43
|
+
"stabilityai/stable-diffusion-xl-base-1.0": ImageModelConfig(
|
44
|
+
organization="Stability AI",
|
45
|
+
display_name="Stable Diffusion XL 1.0",
|
46
|
+
default_steps=None,
|
47
|
+
),
|
48
|
+
}
|
49
|
+
|
50
|
+
|
51
|
+
def get_image_model_config(model_id: str) -> ImageModelConfig:
|
52
|
+
"""Get image model configuration by model ID"""
|
53
|
+
if model_id not in TOGETHER_IMAGE_MODELS:
|
54
|
+
raise ValueError(f"Model {model_id} not found in Together AI image models")
|
55
|
+
return TOGETHER_IMAGE_MODELS[model_id]
|
56
|
+
|
57
|
+
|
58
|
+
def list_image_models_by_organization(organization: str) -> Dict[str, ImageModelConfig]:
|
59
|
+
"""Get all image models for a specific organization"""
|
60
|
+
return {
|
61
|
+
model_id: config
|
62
|
+
for model_id, config in TOGETHER_IMAGE_MODELS.items()
|
63
|
+
if config.organization.lower() == organization.lower()
|
64
|
+
}
|
65
|
+
|
66
|
+
|
67
|
+
def get_default_image_model() -> str:
|
68
|
+
"""Get the default image model ID"""
|
69
|
+
return "black-forest-labs/FLUX.1-schnell-Free"
|
@@ -0,0 +1,143 @@
|
|
1
|
+
from typing import Optional, List
|
2
|
+
from pathlib import Path
|
3
|
+
from pydantic import Field
|
4
|
+
from together import Together
|
5
|
+
import base64
|
6
|
+
import time
|
7
|
+
|
8
|
+
from airtrain.core.skills import Skill, ProcessingError
|
9
|
+
from airtrain.core.schemas import InputSchema, OutputSchema
|
10
|
+
from .credentials import TogetherAICredentials
|
11
|
+
from .image_models_config import get_image_model_config, get_default_image_model
|
12
|
+
|
13
|
+
|
14
|
+
class TogetherAIImageInput(InputSchema):
|
15
|
+
"""Schema for Together AI image generation input"""
|
16
|
+
|
17
|
+
prompt: str = Field(..., description="Text prompt for image generation")
|
18
|
+
model: str = Field(
|
19
|
+
default=get_default_image_model(), description="Together AI image model to use"
|
20
|
+
)
|
21
|
+
steps: int = Field(default=10, description="Number of inference steps", ge=1, le=50)
|
22
|
+
n: int = Field(default=1, description="Number of images to generate", ge=1, le=4)
|
23
|
+
size: str = Field(
|
24
|
+
default="1024x1024", description="Image size in format WIDTHxHEIGHT"
|
25
|
+
)
|
26
|
+
negative_prompt: Optional[str] = Field(
|
27
|
+
default=None, description="Things to exclude from the generation"
|
28
|
+
)
|
29
|
+
seed: Optional[int] = Field(
|
30
|
+
default=None, description="Random seed for reproducibility"
|
31
|
+
)
|
32
|
+
|
33
|
+
|
34
|
+
class GeneratedImage(OutputSchema):
|
35
|
+
"""Individual generated image data"""
|
36
|
+
|
37
|
+
b64_json: Optional[str] = Field(None, description="Base64 encoded image data")
|
38
|
+
url: str = Field(..., description="URL of the generated image")
|
39
|
+
seed: Optional[int] = Field(None, description="Seed used for this image")
|
40
|
+
finish_reason: Optional[str] = Field(
|
41
|
+
None, description="Reason for finishing generation"
|
42
|
+
)
|
43
|
+
|
44
|
+
|
45
|
+
class TogetherAIImageOutput(OutputSchema):
|
46
|
+
"""Schema for Together AI image generation output"""
|
47
|
+
|
48
|
+
images: List[GeneratedImage] = Field(..., description="List of generated images")
|
49
|
+
model: str = Field(..., description="Model used for generation")
|
50
|
+
prompt: str = Field(..., description="Original prompt used")
|
51
|
+
total_time: float = Field(..., description="Time taken for generation in seconds")
|
52
|
+
usage: dict = Field(default_factory=dict, description="Usage statistics")
|
53
|
+
|
54
|
+
|
55
|
+
class TogetherAIImageSkill(Skill[TogetherAIImageInput, TogetherAIImageOutput]):
|
56
|
+
"""Skill for generating images using Together AI"""
|
57
|
+
|
58
|
+
input_schema = TogetherAIImageInput
|
59
|
+
output_schema = TogetherAIImageOutput
|
60
|
+
|
61
|
+
def __init__(self, credentials: Optional[TogetherAICredentials] = None):
|
62
|
+
"""Initialize the skill with optional credentials"""
|
63
|
+
super().__init__()
|
64
|
+
self.credentials = credentials or TogetherAICredentials.from_env()
|
65
|
+
self.client = Together(
|
66
|
+
api_key=self.credentials.together_api_key.get_secret_value()
|
67
|
+
)
|
68
|
+
|
69
|
+
def process(self, input_data: TogetherAIImageInput) -> TogetherAIImageOutput:
|
70
|
+
try:
|
71
|
+
# Validate the model exists in our config
|
72
|
+
get_image_model_config(input_data.model)
|
73
|
+
|
74
|
+
start_time = time.time()
|
75
|
+
|
76
|
+
# Generate images
|
77
|
+
response = self.client.images.generate(
|
78
|
+
prompt=input_data.prompt,
|
79
|
+
model=input_data.model,
|
80
|
+
steps=input_data.steps,
|
81
|
+
n=input_data.n,
|
82
|
+
size=input_data.size,
|
83
|
+
negative_prompt=input_data.negative_prompt,
|
84
|
+
seed=input_data.seed,
|
85
|
+
)
|
86
|
+
|
87
|
+
# Calculate total time
|
88
|
+
total_time = time.time() - start_time
|
89
|
+
|
90
|
+
# Convert response to our output format
|
91
|
+
generated_images = []
|
92
|
+
for img in response.data:
|
93
|
+
if not hasattr(img, "url"):
|
94
|
+
raise ProcessingError(
|
95
|
+
f"No URL found in API response. Response structure: {dir(img)}"
|
96
|
+
)
|
97
|
+
|
98
|
+
generated_images.append(
|
99
|
+
GeneratedImage(
|
100
|
+
url=img.url,
|
101
|
+
seed=getattr(img, "seed", None),
|
102
|
+
finish_reason=getattr(img, "finish_reason", None),
|
103
|
+
)
|
104
|
+
)
|
105
|
+
|
106
|
+
return TogetherAIImageOutput(
|
107
|
+
images=generated_images,
|
108
|
+
model=input_data.model,
|
109
|
+
prompt=input_data.prompt,
|
110
|
+
total_time=total_time,
|
111
|
+
usage=getattr(response, "usage", {}),
|
112
|
+
)
|
113
|
+
|
114
|
+
except Exception as e:
|
115
|
+
raise ProcessingError(f"Together AI image generation failed: {str(e)}")
|
116
|
+
|
117
|
+
def save_images(
|
118
|
+
self, output: TogetherAIImageOutput, output_dir: Path
|
119
|
+
) -> List[Path]:
|
120
|
+
"""
|
121
|
+
Save generated images to disk
|
122
|
+
|
123
|
+
Args:
|
124
|
+
output (TogetherAIImageOutput): Generation output containing images
|
125
|
+
output_dir (Path): Directory to save images
|
126
|
+
|
127
|
+
Returns:
|
128
|
+
List[Path]: List of paths to saved images
|
129
|
+
"""
|
130
|
+
output_dir = Path(output_dir)
|
131
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
132
|
+
|
133
|
+
saved_paths = []
|
134
|
+
for i, img in enumerate(output.images):
|
135
|
+
output_path = output_dir / f"image_{i}.png"
|
136
|
+
image_data = base64.b64decode(img.b64_json)
|
137
|
+
|
138
|
+
with open(output_path, "wb") as f:
|
139
|
+
f.write(image_data)
|
140
|
+
|
141
|
+
saved_paths.append(output_path)
|
142
|
+
|
143
|
+
return saved_paths
|
@@ -0,0 +1,76 @@
|
|
1
|
+
from typing import Optional
|
2
|
+
import requests
|
3
|
+
from pydantic import Field
|
4
|
+
|
5
|
+
from airtrain.core.skills import Skill, ProcessingError
|
6
|
+
from airtrain.core.schemas import InputSchema, OutputSchema
|
7
|
+
from .credentials import TogetherAICredentials
|
8
|
+
from .models import TogetherModel
|
9
|
+
|
10
|
+
|
11
|
+
class TogetherListModelsInput(InputSchema):
|
12
|
+
"""Schema for Together AI list models input"""
|
13
|
+
pass
|
14
|
+
|
15
|
+
|
16
|
+
class TogetherListModelsOutput(OutputSchema):
|
17
|
+
"""Schema for Together AI list models output"""
|
18
|
+
|
19
|
+
data: list[TogetherModel] = Field(
|
20
|
+
default_factory=list,
|
21
|
+
description="List of Together AI models"
|
22
|
+
)
|
23
|
+
object: Optional[str] = Field(
|
24
|
+
default=None,
|
25
|
+
description="Object type"
|
26
|
+
)
|
27
|
+
|
28
|
+
|
29
|
+
class TogetherListModelsSkill(Skill[TogetherListModelsInput, TogetherListModelsOutput]):
|
30
|
+
"""Skill for listing Together AI models"""
|
31
|
+
|
32
|
+
input_schema = TogetherListModelsInput
|
33
|
+
output_schema = TogetherListModelsOutput
|
34
|
+
|
35
|
+
def __init__(self, credentials: Optional[TogetherAICredentials] = None):
|
36
|
+
"""Initialize the skill with optional credentials"""
|
37
|
+
super().__init__()
|
38
|
+
self.credentials = credentials or TogetherAICredentials.from_env()
|
39
|
+
self.base_url = "https://api.together.xyz/v1"
|
40
|
+
|
41
|
+
def process(
|
42
|
+
self, input_data: TogetherListModelsInput
|
43
|
+
) -> TogetherListModelsOutput:
|
44
|
+
"""Process the input and return a list of models."""
|
45
|
+
try:
|
46
|
+
# Build the URL
|
47
|
+
url = f"{self.base_url}/models"
|
48
|
+
|
49
|
+
# Make the request
|
50
|
+
headers = {
|
51
|
+
"Authorization": (
|
52
|
+
f"Bearer {self.credentials.together_api_key.get_secret_value()}"
|
53
|
+
),
|
54
|
+
"accept": "application/json"
|
55
|
+
}
|
56
|
+
|
57
|
+
response = requests.get(url, headers=headers)
|
58
|
+
response.raise_for_status()
|
59
|
+
|
60
|
+
# Parse the response
|
61
|
+
result = response.json()
|
62
|
+
|
63
|
+
# Convert the models to TogetherModel objects
|
64
|
+
models = []
|
65
|
+
for model_data in result:
|
66
|
+
models.append(TogetherModel(**model_data))
|
67
|
+
|
68
|
+
# Return the output
|
69
|
+
return TogetherListModelsOutput(
|
70
|
+
data=models,
|
71
|
+
)
|
72
|
+
|
73
|
+
except requests.RequestException as e:
|
74
|
+
raise ProcessingError(f"Failed to list Together AI models: {str(e)}")
|
75
|
+
except Exception as e:
|
76
|
+
raise ProcessingError(f"Error listing Together AI models: {str(e)}")
|
@@ -0,0 +1,95 @@
|
|
1
|
+
from typing import List, Optional, Dict, Any
|
2
|
+
from pydantic import BaseModel, Field, validator
|
3
|
+
|
4
|
+
|
5
|
+
class TogetherAIImageInput(BaseModel):
|
6
|
+
"""Schema for Together AI image generation input"""
|
7
|
+
|
8
|
+
prompt: str = Field(..., description="Text prompt for image generation")
|
9
|
+
model: str = Field(
|
10
|
+
default="black-forest-labs/FLUX.1-schnell-Free",
|
11
|
+
description="Together AI image model to use",
|
12
|
+
)
|
13
|
+
steps: int = Field(default=10, description="Number of inference steps", ge=1, le=50)
|
14
|
+
n: int = Field(default=1, description="Number of images to generate", ge=1, le=4)
|
15
|
+
size: str = Field(
|
16
|
+
default="1024x1024", description="Image size in format WIDTHxHEIGHT"
|
17
|
+
)
|
18
|
+
negative_prompt: Optional[str] = Field(
|
19
|
+
default=None, description="Things to exclude from the generation"
|
20
|
+
)
|
21
|
+
seed: Optional[int] = Field(
|
22
|
+
default=None, description="Random seed for reproducibility"
|
23
|
+
)
|
24
|
+
|
25
|
+
@validator("size")
|
26
|
+
def validate_size(cls, v):
|
27
|
+
try:
|
28
|
+
width, height = map(int, v.split("x"))
|
29
|
+
if width <= 0 or height <= 0:
|
30
|
+
raise ValueError
|
31
|
+
return v
|
32
|
+
except ValueError:
|
33
|
+
raise ValueError("Size must be in format WIDTHxHEIGHT (e.g., 1024x1024)")
|
34
|
+
|
35
|
+
|
36
|
+
class GeneratedImage(BaseModel):
|
37
|
+
"""Individual generated image data"""
|
38
|
+
|
39
|
+
b64_json: str = Field(..., description="Base64 encoded image data")
|
40
|
+
seed: Optional[int] = Field(None, description="Seed used for this image")
|
41
|
+
finish_reason: Optional[str] = Field(
|
42
|
+
None, description="Reason for finishing generation"
|
43
|
+
)
|
44
|
+
|
45
|
+
|
46
|
+
class TogetherAIImageOutput(BaseModel):
|
47
|
+
"""Schema for Together AI image generation output"""
|
48
|
+
|
49
|
+
images: List[GeneratedImage] = Field(..., description="List of generated images")
|
50
|
+
model: str = Field(..., description="Model used for generation")
|
51
|
+
prompt: str = Field(..., description="Original prompt used")
|
52
|
+
total_time: float = Field(..., description="Time taken for generation in seconds")
|
53
|
+
usage: dict = Field(
|
54
|
+
default_factory=dict, description="Usage statistics and billing information"
|
55
|
+
)
|
56
|
+
|
57
|
+
|
58
|
+
class TogetherModel(BaseModel):
|
59
|
+
"""Schema for Together AI model"""
|
60
|
+
|
61
|
+
id: str = Field(..., description="Model ID")
|
62
|
+
name: Optional[str] = Field(None, description="Model name")
|
63
|
+
object: Optional[str] = Field(None, description="Object type")
|
64
|
+
created: Optional[int] = Field(None, description="Creation timestamp")
|
65
|
+
owned_by: Optional[str] = Field(None, description="Model owner")
|
66
|
+
root: Optional[str] = Field(None, description="Root model identifier")
|
67
|
+
parent: Optional[str] = Field(None, description="Parent model identifier")
|
68
|
+
permission: Optional[List[Dict[str, Any]]] = Field(
|
69
|
+
None, description="Permission details"
|
70
|
+
)
|
71
|
+
metadata: Optional[Dict[str, Any]] = Field(
|
72
|
+
None, description="Additional metadata for the model"
|
73
|
+
)
|
74
|
+
description: Optional[str] = Field(None, description="Model description")
|
75
|
+
pricing: Optional[Dict[str, Any]] = Field(None, description="Pricing information")
|
76
|
+
context_length: Optional[int] = Field(
|
77
|
+
None, description="Maximum context length supported by the model"
|
78
|
+
)
|
79
|
+
capabilities: Optional[List[str]] = Field(
|
80
|
+
None, description="Model capabilities"
|
81
|
+
)
|
82
|
+
|
83
|
+
|
84
|
+
class TogetherListModelsInput(BaseModel):
|
85
|
+
"""Schema for listing Together AI models input"""
|
86
|
+
pass
|
87
|
+
|
88
|
+
|
89
|
+
class TogetherListModelsOutput(BaseModel):
|
90
|
+
"""Schema for listing Together AI models output"""
|
91
|
+
|
92
|
+
data: List[TogetherModel] = Field(
|
93
|
+
..., description="List of Together AI models"
|
94
|
+
)
|
95
|
+
object: Optional[str] = Field(None, description="Object type")
|