airtrain 0.1.2__py3-none-any.whl → 0.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (110) hide show
  1. airtrain/__init__.py +148 -2
  2. airtrain/__main__.py +4 -0
  3. airtrain/__pycache__/__init__.cpython-313.pyc +0 -0
  4. airtrain/agents/__init__.py +45 -0
  5. airtrain/agents/example_agent.py +348 -0
  6. airtrain/agents/groq_agent.py +289 -0
  7. airtrain/agents/memory.py +663 -0
  8. airtrain/agents/registry.py +465 -0
  9. airtrain/builder/__init__.py +3 -0
  10. airtrain/builder/agent_builder.py +122 -0
  11. airtrain/cli/__init__.py +0 -0
  12. airtrain/cli/builder.py +23 -0
  13. airtrain/cli/main.py +120 -0
  14. airtrain/contrib/__init__.py +29 -0
  15. airtrain/contrib/travel/__init__.py +35 -0
  16. airtrain/contrib/travel/agents.py +243 -0
  17. airtrain/contrib/travel/models.py +59 -0
  18. airtrain/core/__init__.py +7 -0
  19. airtrain/core/__pycache__/__init__.cpython-313.pyc +0 -0
  20. airtrain/core/__pycache__/schemas.cpython-313.pyc +0 -0
  21. airtrain/core/__pycache__/skills.cpython-313.pyc +0 -0
  22. airtrain/core/credentials.py +171 -0
  23. airtrain/core/schemas.py +237 -0
  24. airtrain/core/skills.py +269 -0
  25. airtrain/integrations/__init__.py +74 -0
  26. airtrain/integrations/anthropic/__init__.py +33 -0
  27. airtrain/integrations/anthropic/credentials.py +32 -0
  28. airtrain/integrations/anthropic/list_models.py +110 -0
  29. airtrain/integrations/anthropic/models_config.py +100 -0
  30. airtrain/integrations/anthropic/skills.py +155 -0
  31. airtrain/integrations/aws/__init__.py +6 -0
  32. airtrain/integrations/aws/credentials.py +36 -0
  33. airtrain/integrations/aws/skills.py +98 -0
  34. airtrain/integrations/cerebras/__init__.py +6 -0
  35. airtrain/integrations/cerebras/credentials.py +19 -0
  36. airtrain/integrations/cerebras/skills.py +127 -0
  37. airtrain/integrations/combined/__init__.py +21 -0
  38. airtrain/integrations/combined/groq_fireworks_skills.py +126 -0
  39. airtrain/integrations/combined/list_models_factory.py +210 -0
  40. airtrain/integrations/fireworks/__init__.py +21 -0
  41. airtrain/integrations/fireworks/completion_skills.py +147 -0
  42. airtrain/integrations/fireworks/conversation_manager.py +109 -0
  43. airtrain/integrations/fireworks/credentials.py +26 -0
  44. airtrain/integrations/fireworks/list_models.py +128 -0
  45. airtrain/integrations/fireworks/models.py +139 -0
  46. airtrain/integrations/fireworks/requests_skills.py +207 -0
  47. airtrain/integrations/fireworks/skills.py +181 -0
  48. airtrain/integrations/fireworks/structured_completion_skills.py +175 -0
  49. airtrain/integrations/fireworks/structured_requests_skills.py +291 -0
  50. airtrain/integrations/fireworks/structured_skills.py +102 -0
  51. airtrain/integrations/google/__init__.py +7 -0
  52. airtrain/integrations/google/credentials.py +58 -0
  53. airtrain/integrations/google/skills.py +122 -0
  54. airtrain/integrations/groq/__init__.py +23 -0
  55. airtrain/integrations/groq/credentials.py +24 -0
  56. airtrain/integrations/groq/models_config.py +162 -0
  57. airtrain/integrations/groq/skills.py +201 -0
  58. airtrain/integrations/ollama/__init__.py +6 -0
  59. airtrain/integrations/ollama/credentials.py +26 -0
  60. airtrain/integrations/ollama/skills.py +41 -0
  61. airtrain/integrations/openai/__init__.py +37 -0
  62. airtrain/integrations/openai/chinese_assistant.py +42 -0
  63. airtrain/integrations/openai/credentials.py +39 -0
  64. airtrain/integrations/openai/list_models.py +112 -0
  65. airtrain/integrations/openai/models_config.py +224 -0
  66. airtrain/integrations/openai/skills.py +342 -0
  67. airtrain/integrations/perplexity/__init__.py +49 -0
  68. airtrain/integrations/perplexity/credentials.py +43 -0
  69. airtrain/integrations/perplexity/list_models.py +112 -0
  70. airtrain/integrations/perplexity/models_config.py +128 -0
  71. airtrain/integrations/perplexity/skills.py +279 -0
  72. airtrain/integrations/sambanova/__init__.py +6 -0
  73. airtrain/integrations/sambanova/credentials.py +20 -0
  74. airtrain/integrations/sambanova/skills.py +129 -0
  75. airtrain/integrations/search/__init__.py +21 -0
  76. airtrain/integrations/search/exa/__init__.py +23 -0
  77. airtrain/integrations/search/exa/credentials.py +30 -0
  78. airtrain/integrations/search/exa/schemas.py +114 -0
  79. airtrain/integrations/search/exa/skills.py +115 -0
  80. airtrain/integrations/together/__init__.py +33 -0
  81. airtrain/integrations/together/audio_models_config.py +34 -0
  82. airtrain/integrations/together/credentials.py +22 -0
  83. airtrain/integrations/together/embedding_models_config.py +92 -0
  84. airtrain/integrations/together/image_models_config.py +69 -0
  85. airtrain/integrations/together/image_skill.py +143 -0
  86. airtrain/integrations/together/list_models.py +76 -0
  87. airtrain/integrations/together/models.py +95 -0
  88. airtrain/integrations/together/models_config.py +399 -0
  89. airtrain/integrations/together/rerank_models_config.py +43 -0
  90. airtrain/integrations/together/rerank_skill.py +49 -0
  91. airtrain/integrations/together/schemas.py +33 -0
  92. airtrain/integrations/together/skills.py +305 -0
  93. airtrain/integrations/together/vision_models_config.py +49 -0
  94. airtrain/telemetry/__init__.py +38 -0
  95. airtrain/telemetry/service.py +167 -0
  96. airtrain/telemetry/views.py +237 -0
  97. airtrain/tools/__init__.py +45 -0
  98. airtrain/tools/command.py +398 -0
  99. airtrain/tools/filesystem.py +166 -0
  100. airtrain/tools/network.py +111 -0
  101. airtrain/tools/registry.py +320 -0
  102. airtrain/tools/search.py +450 -0
  103. airtrain/tools/testing.py +135 -0
  104. airtrain-0.1.4.dist-info/METADATA +222 -0
  105. airtrain-0.1.4.dist-info/RECORD +108 -0
  106. {airtrain-0.1.2.dist-info → airtrain-0.1.4.dist-info}/WHEEL +1 -1
  107. airtrain-0.1.4.dist-info/entry_points.txt +2 -0
  108. airtrain-0.1.2.dist-info/METADATA +0 -106
  109. airtrain-0.1.2.dist-info/RECORD +0 -5
  110. {airtrain-0.1.2.dist-info → airtrain-0.1.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,115 @@
1
+ """
2
+ Skills for Exa Search API.
3
+
4
+ This module provides skills for using the Exa search API.
5
+ """
6
+
7
+ import json
8
+ import logging
9
+ import httpx
10
+ from typing import Optional, Dict, Any, List, cast
11
+
12
+ from pydantic import ValidationError
13
+
14
+ from airtrain.core.skills import Skill
15
+ from airtrain.core.errors import ProcessingError
16
+ from .credentials import ExaCredentials
17
+ from .schemas import ExaSearchInputSchema, ExaSearchOutputSchema, ExaSearchResult
18
+
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ class ExaSearchSkill(Skill[ExaSearchInputSchema, ExaSearchOutputSchema]):
24
+ """Skill for searching the web using the Exa search API."""
25
+
26
+ input_schema = ExaSearchInputSchema
27
+ output_schema = ExaSearchOutputSchema
28
+
29
+ EXA_API_ENDPOINT = "https://api.exa.ai/search"
30
+
31
+ def __init__(
32
+ self,
33
+ credentials: ExaCredentials,
34
+ timeout: float = 60.0,
35
+ max_retries: int = 3,
36
+ **kwargs,
37
+ ):
38
+ """
39
+ Initialize the Exa search skill.
40
+
41
+ Args:
42
+ credentials: Credentials for accessing the Exa API
43
+ timeout: Timeout for API requests in seconds
44
+ max_retries: Maximum number of retries for failed requests
45
+ """
46
+ super().__init__(**kwargs)
47
+ self.credentials = credentials
48
+ self.timeout = timeout
49
+ self.max_retries = max_retries
50
+
51
+ async def process(self, input_data: ExaSearchInputSchema) -> ExaSearchOutputSchema:
52
+ """
53
+ Process a search request using the Exa API.
54
+
55
+ Args:
56
+ input_data: Search input parameters
57
+
58
+ Returns:
59
+ Search results from Exa
60
+
61
+ Raises:
62
+ ProcessingError: If there's an issue with the API request
63
+ """
64
+ try:
65
+ # Prepare request payload
66
+ payload = input_data.model_dump(exclude_none=True)
67
+
68
+ # Build request headers
69
+ headers = {
70
+ "content-type": "application/json",
71
+ "Authorization": f"Bearer {self.credentials.api_key.get_secret_value()}",
72
+ }
73
+
74
+ # Make the API request
75
+ async with httpx.AsyncClient() as client:
76
+ response = await client.post(
77
+ self.EXA_API_ENDPOINT,
78
+ headers=headers,
79
+ json=payload,
80
+ timeout=self.timeout,
81
+ )
82
+
83
+ # Check for successful response
84
+ if response.status_code == 200:
85
+ result_data = response.json()
86
+
87
+ # Construct the output schema
88
+ output = ExaSearchOutputSchema(
89
+ results=result_data.get("results", []),
90
+ query=input_data.query,
91
+ autopromptString=result_data.get("autopromptString"),
92
+ costDollars=result_data.get("costDollars"),
93
+ )
94
+
95
+ return output
96
+ else:
97
+ # Handle error responses
98
+ error_message = f"Exa API returned status code {response.status_code}: {response.text}"
99
+ logger.error(error_message)
100
+ raise ProcessingError(error_message)
101
+
102
+ except httpx.TimeoutException:
103
+ error_message = f"Timeout while querying Exa API (timeout={self.timeout}s)"
104
+ logger.error(error_message)
105
+ raise ProcessingError(error_message)
106
+
107
+ except ValidationError as e:
108
+ error_message = f"Schema validation error: {str(e)}"
109
+ logger.error(error_message)
110
+ raise ProcessingError(error_message)
111
+
112
+ except Exception as e:
113
+ error_message = f"Unexpected error while querying Exa API: {str(e)}"
114
+ logger.error(error_message)
115
+ raise ProcessingError(error_message)
@@ -0,0 +1,33 @@
1
+ """Together AI integration module"""
2
+
3
+ from .credentials import TogetherAICredentials
4
+ from .skills import TogetherAIChatSkill, TogetherAIInput, TogetherAIOutput
5
+ from .models_config import (
6
+ get_model_config_with_capabilities,
7
+ get_max_completion_tokens,
8
+ supports_tool_use,
9
+ supports_json_mode,
10
+ TOGETHER_MODELS_CONFIG,
11
+ )
12
+ from .list_models import (
13
+ TogetherListModelsSkill,
14
+ TogetherListModelsInput,
15
+ TogetherListModelsOutput,
16
+ )
17
+ from .models import TogetherModel
18
+
19
+ __all__ = [
20
+ "TogetherAICredentials",
21
+ "TogetherAIChatSkill",
22
+ "TogetherAIInput",
23
+ "TogetherAIOutput",
24
+ "TogetherListModelsSkill",
25
+ "TogetherListModelsInput",
26
+ "TogetherListModelsOutput",
27
+ "TogetherModel",
28
+ "get_model_config_with_capabilities",
29
+ "get_max_completion_tokens",
30
+ "supports_tool_use",
31
+ "supports_json_mode",
32
+ "TOGETHER_MODELS_CONFIG",
33
+ ]
@@ -0,0 +1,34 @@
1
+ from typing import Dict, NamedTuple
2
+
3
+
4
+ class AudioModelConfig(NamedTuple):
5
+ organization: str
6
+ display_name: str
7
+
8
+
9
+ TOGETHER_AUDIO_MODELS: Dict[str, AudioModelConfig] = {
10
+ "Cartesia/Sonic": AudioModelConfig(
11
+ organization="Cartesia", display_name="Cartesia/Sonic"
12
+ )
13
+ }
14
+
15
+
16
+ def get_audio_model_config(model_id: str) -> AudioModelConfig:
17
+ """Get audio model configuration by model ID"""
18
+ if model_id not in TOGETHER_AUDIO_MODELS:
19
+ raise ValueError(f"Model {model_id} not found in Together AI audio models")
20
+ return TOGETHER_AUDIO_MODELS[model_id]
21
+
22
+
23
+ def list_audio_models_by_organization(organization: str) -> Dict[str, AudioModelConfig]:
24
+ """Get all audio models for a specific organization"""
25
+ return {
26
+ model_id: config
27
+ for model_id, config in TOGETHER_AUDIO_MODELS.items()
28
+ if config.organization.lower() == organization.lower()
29
+ }
30
+
31
+
32
+ def get_default_audio_model() -> str:
33
+ """Get the default audio model ID"""
34
+ return "Cartesia/Sonic"
@@ -0,0 +1,22 @@
1
+ from pydantic import Field, SecretStr
2
+ from airtrain.core.credentials import BaseCredentials, CredentialValidationError
3
+ import together
4
+
5
+
6
+ class TogetherAICredentials(BaseCredentials):
7
+ """Together AI credentials"""
8
+
9
+ together_api_key: SecretStr = Field(..., description="Together AI API key")
10
+
11
+ _required_credentials = {"together_api_key"}
12
+
13
+ async def validate_credentials(self) -> bool:
14
+ """Validate Together AI credentials"""
15
+ try:
16
+ together.api_key = self.together_api_key.get_secret_value()
17
+ await together.Models.list()
18
+ return True
19
+ except Exception as e:
20
+ raise CredentialValidationError(
21
+ f"Invalid Together AI credentials: {str(e)}"
22
+ )
@@ -0,0 +1,92 @@
1
+ from typing import Dict, NamedTuple
2
+
3
+
4
+ class EmbeddingModelConfig(NamedTuple):
5
+ organization: str
6
+ display_name: str
7
+ model_size: str
8
+ embedding_dimension: int
9
+ context_window: int
10
+
11
+
12
+ TOGETHER_EMBEDDING_MODELS: Dict[str, EmbeddingModelConfig] = {
13
+ "togethercomputer/m2-bert-80M-2k-retrieval": EmbeddingModelConfig(
14
+ organization="Together",
15
+ display_name="M2-BERT-80M-2K-Retrieval",
16
+ model_size="80M",
17
+ embedding_dimension=768,
18
+ context_window=2048,
19
+ ),
20
+ "togethercomputer/m2-bert-80M-8k-retrieval": EmbeddingModelConfig(
21
+ organization="Together",
22
+ display_name="M2-BERT-80M-8K-Retrieval",
23
+ model_size="80M",
24
+ embedding_dimension=768,
25
+ context_window=8192,
26
+ ),
27
+ "togethercomputer/m2-bert-80M-32k-retrieval": EmbeddingModelConfig(
28
+ organization="Together",
29
+ display_name="M2-BERT-80M-32K-Retrieval",
30
+ model_size="80M",
31
+ embedding_dimension=768,
32
+ context_window=32768,
33
+ ),
34
+ "WhereIsAI/UAE-Large-V1": EmbeddingModelConfig(
35
+ organization="WhereIsAI",
36
+ display_name="UAE-Large-v1",
37
+ model_size="326M",
38
+ embedding_dimension=1024,
39
+ context_window=512,
40
+ ),
41
+ "BAAI/bge-large-en-v1.5": EmbeddingModelConfig(
42
+ organization="BAAI",
43
+ display_name="BGE-Large-EN-v1.5",
44
+ model_size="326M",
45
+ embedding_dimension=1024,
46
+ context_window=512,
47
+ ),
48
+ "BAAI/bge-base-en-v1.5": EmbeddingModelConfig(
49
+ organization="BAAI",
50
+ display_name="BGE-Base-EN-v1.5",
51
+ model_size="102M",
52
+ embedding_dimension=768,
53
+ context_window=512,
54
+ ),
55
+ "sentence-transformers/msmarco-bert-base-dot-v5": EmbeddingModelConfig(
56
+ organization="sentence-transformers",
57
+ display_name="Sentence-BERT",
58
+ model_size="110M",
59
+ embedding_dimension=768,
60
+ context_window=512,
61
+ ),
62
+ "bert-base-uncased": EmbeddingModelConfig(
63
+ organization="Hugging Face",
64
+ display_name="BERT",
65
+ model_size="110M",
66
+ embedding_dimension=768,
67
+ context_window=512,
68
+ ),
69
+ }
70
+
71
+
72
+ def get_embedding_model_config(model_id: str) -> EmbeddingModelConfig:
73
+ """Get embedding model configuration by model ID"""
74
+ if model_id not in TOGETHER_EMBEDDING_MODELS:
75
+ raise ValueError(f"Model {model_id} not found in Together AI embedding models")
76
+ return TOGETHER_EMBEDDING_MODELS[model_id]
77
+
78
+
79
+ def list_embedding_models_by_organization(
80
+ organization: str,
81
+ ) -> Dict[str, EmbeddingModelConfig]:
82
+ """Get all embedding models for a specific organization"""
83
+ return {
84
+ model_id: config
85
+ for model_id, config in TOGETHER_EMBEDDING_MODELS.items()
86
+ if config.organization.lower() == organization.lower()
87
+ }
88
+
89
+
90
+ def get_default_embedding_model() -> str:
91
+ """Get the default embedding model ID"""
92
+ return "togethercomputer/m2-bert-80M-32k-retrieval"
@@ -0,0 +1,69 @@
1
+ from typing import Dict, NamedTuple, Optional
2
+
3
+
4
+ class ImageModelConfig(NamedTuple):
5
+ organization: str
6
+ display_name: str
7
+ default_steps: Optional[int]
8
+
9
+
10
+ TOGETHER_IMAGE_MODELS: Dict[str, ImageModelConfig] = {
11
+ "black-forest-labs/FLUX.1-schnell-Free": ImageModelConfig(
12
+ organization="Black Forest Labs",
13
+ display_name="Flux.1 [schnell] (free)",
14
+ default_steps=None,
15
+ ),
16
+ "black-forest-labs/FLUX.1-schnell": ImageModelConfig(
17
+ organization="Black Forest Labs",
18
+ display_name="Flux.1 [schnell] (Turbo)",
19
+ default_steps=4,
20
+ ),
21
+ "black-forest-labs/FLUX.1-dev": ImageModelConfig(
22
+ organization="Black Forest Labs", display_name="Flux.1 Dev", default_steps=28
23
+ ),
24
+ "black-forest-labs/FLUX.1-canny": ImageModelConfig(
25
+ organization="Black Forest Labs", display_name="Flux.1 Canny", default_steps=28
26
+ ),
27
+ "black-forest-labs/FLUX.1-depth": ImageModelConfig(
28
+ organization="Black Forest Labs", display_name="Flux.1 Depth", default_steps=28
29
+ ),
30
+ "black-forest-labs/FLUX.1-redux": ImageModelConfig(
31
+ organization="Black Forest Labs", display_name="Flux.1 Redux", default_steps=28
32
+ ),
33
+ "black-forest-labs/FLUX.1.1-pro": ImageModelConfig(
34
+ organization="Black Forest Labs",
35
+ display_name="Flux1.1 [pro]",
36
+ default_steps=None,
37
+ ),
38
+ "black-forest-labs/FLUX.1-pro": ImageModelConfig(
39
+ organization="Black Forest Labs",
40
+ display_name="Flux.1 [pro]",
41
+ default_steps=None,
42
+ ),
43
+ "stabilityai/stable-diffusion-xl-base-1.0": ImageModelConfig(
44
+ organization="Stability AI",
45
+ display_name="Stable Diffusion XL 1.0",
46
+ default_steps=None,
47
+ ),
48
+ }
49
+
50
+
51
+ def get_image_model_config(model_id: str) -> ImageModelConfig:
52
+ """Get image model configuration by model ID"""
53
+ if model_id not in TOGETHER_IMAGE_MODELS:
54
+ raise ValueError(f"Model {model_id} not found in Together AI image models")
55
+ return TOGETHER_IMAGE_MODELS[model_id]
56
+
57
+
58
+ def list_image_models_by_organization(organization: str) -> Dict[str, ImageModelConfig]:
59
+ """Get all image models for a specific organization"""
60
+ return {
61
+ model_id: config
62
+ for model_id, config in TOGETHER_IMAGE_MODELS.items()
63
+ if config.organization.lower() == organization.lower()
64
+ }
65
+
66
+
67
+ def get_default_image_model() -> str:
68
+ """Get the default image model ID"""
69
+ return "black-forest-labs/FLUX.1-schnell-Free"
@@ -0,0 +1,143 @@
1
+ from typing import Optional, List
2
+ from pathlib import Path
3
+ from pydantic import Field
4
+ from together import Together
5
+ import base64
6
+ import time
7
+
8
+ from airtrain.core.skills import Skill, ProcessingError
9
+ from airtrain.core.schemas import InputSchema, OutputSchema
10
+ from .credentials import TogetherAICredentials
11
+ from .image_models_config import get_image_model_config, get_default_image_model
12
+
13
+
14
+ class TogetherAIImageInput(InputSchema):
15
+ """Schema for Together AI image generation input"""
16
+
17
+ prompt: str = Field(..., description="Text prompt for image generation")
18
+ model: str = Field(
19
+ default=get_default_image_model(), description="Together AI image model to use"
20
+ )
21
+ steps: int = Field(default=10, description="Number of inference steps", ge=1, le=50)
22
+ n: int = Field(default=1, description="Number of images to generate", ge=1, le=4)
23
+ size: str = Field(
24
+ default="1024x1024", description="Image size in format WIDTHxHEIGHT"
25
+ )
26
+ negative_prompt: Optional[str] = Field(
27
+ default=None, description="Things to exclude from the generation"
28
+ )
29
+ seed: Optional[int] = Field(
30
+ default=None, description="Random seed for reproducibility"
31
+ )
32
+
33
+
34
+ class GeneratedImage(OutputSchema):
35
+ """Individual generated image data"""
36
+
37
+ b64_json: Optional[str] = Field(None, description="Base64 encoded image data")
38
+ url: str = Field(..., description="URL of the generated image")
39
+ seed: Optional[int] = Field(None, description="Seed used for this image")
40
+ finish_reason: Optional[str] = Field(
41
+ None, description="Reason for finishing generation"
42
+ )
43
+
44
+
45
+ class TogetherAIImageOutput(OutputSchema):
46
+ """Schema for Together AI image generation output"""
47
+
48
+ images: List[GeneratedImage] = Field(..., description="List of generated images")
49
+ model: str = Field(..., description="Model used for generation")
50
+ prompt: str = Field(..., description="Original prompt used")
51
+ total_time: float = Field(..., description="Time taken for generation in seconds")
52
+ usage: dict = Field(default_factory=dict, description="Usage statistics")
53
+
54
+
55
+ class TogetherAIImageSkill(Skill[TogetherAIImageInput, TogetherAIImageOutput]):
56
+ """Skill for generating images using Together AI"""
57
+
58
+ input_schema = TogetherAIImageInput
59
+ output_schema = TogetherAIImageOutput
60
+
61
+ def __init__(self, credentials: Optional[TogetherAICredentials] = None):
62
+ """Initialize the skill with optional credentials"""
63
+ super().__init__()
64
+ self.credentials = credentials or TogetherAICredentials.from_env()
65
+ self.client = Together(
66
+ api_key=self.credentials.together_api_key.get_secret_value()
67
+ )
68
+
69
+ def process(self, input_data: TogetherAIImageInput) -> TogetherAIImageOutput:
70
+ try:
71
+ # Validate the model exists in our config
72
+ get_image_model_config(input_data.model)
73
+
74
+ start_time = time.time()
75
+
76
+ # Generate images
77
+ response = self.client.images.generate(
78
+ prompt=input_data.prompt,
79
+ model=input_data.model,
80
+ steps=input_data.steps,
81
+ n=input_data.n,
82
+ size=input_data.size,
83
+ negative_prompt=input_data.negative_prompt,
84
+ seed=input_data.seed,
85
+ )
86
+
87
+ # Calculate total time
88
+ total_time = time.time() - start_time
89
+
90
+ # Convert response to our output format
91
+ generated_images = []
92
+ for img in response.data:
93
+ if not hasattr(img, "url"):
94
+ raise ProcessingError(
95
+ f"No URL found in API response. Response structure: {dir(img)}"
96
+ )
97
+
98
+ generated_images.append(
99
+ GeneratedImage(
100
+ url=img.url,
101
+ seed=getattr(img, "seed", None),
102
+ finish_reason=getattr(img, "finish_reason", None),
103
+ )
104
+ )
105
+
106
+ return TogetherAIImageOutput(
107
+ images=generated_images,
108
+ model=input_data.model,
109
+ prompt=input_data.prompt,
110
+ total_time=total_time,
111
+ usage=getattr(response, "usage", {}),
112
+ )
113
+
114
+ except Exception as e:
115
+ raise ProcessingError(f"Together AI image generation failed: {str(e)}")
116
+
117
+ def save_images(
118
+ self, output: TogetherAIImageOutput, output_dir: Path
119
+ ) -> List[Path]:
120
+ """
121
+ Save generated images to disk
122
+
123
+ Args:
124
+ output (TogetherAIImageOutput): Generation output containing images
125
+ output_dir (Path): Directory to save images
126
+
127
+ Returns:
128
+ List[Path]: List of paths to saved images
129
+ """
130
+ output_dir = Path(output_dir)
131
+ output_dir.mkdir(parents=True, exist_ok=True)
132
+
133
+ saved_paths = []
134
+ for i, img in enumerate(output.images):
135
+ output_path = output_dir / f"image_{i}.png"
136
+ image_data = base64.b64decode(img.b64_json)
137
+
138
+ with open(output_path, "wb") as f:
139
+ f.write(image_data)
140
+
141
+ saved_paths.append(output_path)
142
+
143
+ return saved_paths
@@ -0,0 +1,76 @@
1
+ from typing import Optional
2
+ import requests
3
+ from pydantic import Field
4
+
5
+ from airtrain.core.skills import Skill, ProcessingError
6
+ from airtrain.core.schemas import InputSchema, OutputSchema
7
+ from .credentials import TogetherAICredentials
8
+ from .models import TogetherModel
9
+
10
+
11
+ class TogetherListModelsInput(InputSchema):
12
+ """Schema for Together AI list models input"""
13
+ pass
14
+
15
+
16
+ class TogetherListModelsOutput(OutputSchema):
17
+ """Schema for Together AI list models output"""
18
+
19
+ data: list[TogetherModel] = Field(
20
+ default_factory=list,
21
+ description="List of Together AI models"
22
+ )
23
+ object: Optional[str] = Field(
24
+ default=None,
25
+ description="Object type"
26
+ )
27
+
28
+
29
+ class TogetherListModelsSkill(Skill[TogetherListModelsInput, TogetherListModelsOutput]):
30
+ """Skill for listing Together AI models"""
31
+
32
+ input_schema = TogetherListModelsInput
33
+ output_schema = TogetherListModelsOutput
34
+
35
+ def __init__(self, credentials: Optional[TogetherAICredentials] = None):
36
+ """Initialize the skill with optional credentials"""
37
+ super().__init__()
38
+ self.credentials = credentials or TogetherAICredentials.from_env()
39
+ self.base_url = "https://api.together.xyz/v1"
40
+
41
+ def process(
42
+ self, input_data: TogetherListModelsInput
43
+ ) -> TogetherListModelsOutput:
44
+ """Process the input and return a list of models."""
45
+ try:
46
+ # Build the URL
47
+ url = f"{self.base_url}/models"
48
+
49
+ # Make the request
50
+ headers = {
51
+ "Authorization": (
52
+ f"Bearer {self.credentials.together_api_key.get_secret_value()}"
53
+ ),
54
+ "accept": "application/json"
55
+ }
56
+
57
+ response = requests.get(url, headers=headers)
58
+ response.raise_for_status()
59
+
60
+ # Parse the response
61
+ result = response.json()
62
+
63
+ # Convert the models to TogetherModel objects
64
+ models = []
65
+ for model_data in result:
66
+ models.append(TogetherModel(**model_data))
67
+
68
+ # Return the output
69
+ return TogetherListModelsOutput(
70
+ data=models,
71
+ )
72
+
73
+ except requests.RequestException as e:
74
+ raise ProcessingError(f"Failed to list Together AI models: {str(e)}")
75
+ except Exception as e:
76
+ raise ProcessingError(f"Error listing Together AI models: {str(e)}")
@@ -0,0 +1,95 @@
1
+ from typing import List, Optional, Dict, Any
2
+ from pydantic import BaseModel, Field, validator
3
+
4
+
5
+ class TogetherAIImageInput(BaseModel):
6
+ """Schema for Together AI image generation input"""
7
+
8
+ prompt: str = Field(..., description="Text prompt for image generation")
9
+ model: str = Field(
10
+ default="black-forest-labs/FLUX.1-schnell-Free",
11
+ description="Together AI image model to use",
12
+ )
13
+ steps: int = Field(default=10, description="Number of inference steps", ge=1, le=50)
14
+ n: int = Field(default=1, description="Number of images to generate", ge=1, le=4)
15
+ size: str = Field(
16
+ default="1024x1024", description="Image size in format WIDTHxHEIGHT"
17
+ )
18
+ negative_prompt: Optional[str] = Field(
19
+ default=None, description="Things to exclude from the generation"
20
+ )
21
+ seed: Optional[int] = Field(
22
+ default=None, description="Random seed for reproducibility"
23
+ )
24
+
25
+ @validator("size")
26
+ def validate_size(cls, v):
27
+ try:
28
+ width, height = map(int, v.split("x"))
29
+ if width <= 0 or height <= 0:
30
+ raise ValueError
31
+ return v
32
+ except ValueError:
33
+ raise ValueError("Size must be in format WIDTHxHEIGHT (e.g., 1024x1024)")
34
+
35
+
36
+ class GeneratedImage(BaseModel):
37
+ """Individual generated image data"""
38
+
39
+ b64_json: str = Field(..., description="Base64 encoded image data")
40
+ seed: Optional[int] = Field(None, description="Seed used for this image")
41
+ finish_reason: Optional[str] = Field(
42
+ None, description="Reason for finishing generation"
43
+ )
44
+
45
+
46
+ class TogetherAIImageOutput(BaseModel):
47
+ """Schema for Together AI image generation output"""
48
+
49
+ images: List[GeneratedImage] = Field(..., description="List of generated images")
50
+ model: str = Field(..., description="Model used for generation")
51
+ prompt: str = Field(..., description="Original prompt used")
52
+ total_time: float = Field(..., description="Time taken for generation in seconds")
53
+ usage: dict = Field(
54
+ default_factory=dict, description="Usage statistics and billing information"
55
+ )
56
+
57
+
58
+ class TogetherModel(BaseModel):
59
+ """Schema for Together AI model"""
60
+
61
+ id: str = Field(..., description="Model ID")
62
+ name: Optional[str] = Field(None, description="Model name")
63
+ object: Optional[str] = Field(None, description="Object type")
64
+ created: Optional[int] = Field(None, description="Creation timestamp")
65
+ owned_by: Optional[str] = Field(None, description="Model owner")
66
+ root: Optional[str] = Field(None, description="Root model identifier")
67
+ parent: Optional[str] = Field(None, description="Parent model identifier")
68
+ permission: Optional[List[Dict[str, Any]]] = Field(
69
+ None, description="Permission details"
70
+ )
71
+ metadata: Optional[Dict[str, Any]] = Field(
72
+ None, description="Additional metadata for the model"
73
+ )
74
+ description: Optional[str] = Field(None, description="Model description")
75
+ pricing: Optional[Dict[str, Any]] = Field(None, description="Pricing information")
76
+ context_length: Optional[int] = Field(
77
+ None, description="Maximum context length supported by the model"
78
+ )
79
+ capabilities: Optional[List[str]] = Field(
80
+ None, description="Model capabilities"
81
+ )
82
+
83
+
84
+ class TogetherListModelsInput(BaseModel):
85
+ """Schema for listing Together AI models input"""
86
+ pass
87
+
88
+
89
+ class TogetherListModelsOutput(BaseModel):
90
+ """Schema for listing Together AI models output"""
91
+
92
+ data: List[TogetherModel] = Field(
93
+ ..., description="List of Together AI models"
94
+ )
95
+ object: Optional[str] = Field(None, description="Object type")