airtrain 0.1.3__py3-none-any.whl → 0.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. airtrain/__init__.py +146 -6
  2. airtrain/__main__.py +4 -0
  3. airtrain/__pycache__/__init__.cpython-313.pyc +0 -0
  4. airtrain/agents/__init__.py +45 -0
  5. airtrain/agents/example_agent.py +348 -0
  6. airtrain/agents/groq_agent.py +289 -0
  7. airtrain/agents/memory.py +663 -0
  8. airtrain/agents/registry.py +465 -0
  9. airtrain/builder/__init__.py +3 -0
  10. airtrain/builder/agent_builder.py +122 -0
  11. airtrain/cli/__init__.py +0 -0
  12. airtrain/cli/builder.py +23 -0
  13. airtrain/cli/main.py +120 -0
  14. airtrain/contrib/__init__.py +29 -0
  15. airtrain/contrib/travel/__init__.py +35 -0
  16. airtrain/contrib/travel/agents.py +243 -0
  17. airtrain/contrib/travel/models.py +59 -0
  18. airtrain/core/__pycache__/__init__.cpython-313.pyc +0 -0
  19. airtrain/core/__pycache__/schemas.cpython-313.pyc +0 -0
  20. airtrain/core/__pycache__/skills.cpython-313.pyc +0 -0
  21. airtrain/core/credentials.py +62 -44
  22. airtrain/core/skills.py +102 -0
  23. airtrain/integrations/__init__.py +74 -0
  24. airtrain/integrations/anthropic/__init__.py +33 -0
  25. airtrain/integrations/anthropic/credentials.py +32 -0
  26. airtrain/integrations/anthropic/list_models.py +110 -0
  27. airtrain/integrations/anthropic/models_config.py +100 -0
  28. airtrain/integrations/anthropic/skills.py +155 -0
  29. airtrain/integrations/aws/__init__.py +6 -0
  30. airtrain/integrations/aws/credentials.py +36 -0
  31. airtrain/integrations/aws/skills.py +98 -0
  32. airtrain/integrations/cerebras/__init__.py +6 -0
  33. airtrain/integrations/cerebras/credentials.py +19 -0
  34. airtrain/integrations/cerebras/skills.py +127 -0
  35. airtrain/integrations/combined/__init__.py +21 -0
  36. airtrain/integrations/combined/groq_fireworks_skills.py +126 -0
  37. airtrain/integrations/combined/list_models_factory.py +210 -0
  38. airtrain/integrations/fireworks/__init__.py +21 -0
  39. airtrain/integrations/fireworks/completion_skills.py +147 -0
  40. airtrain/integrations/fireworks/conversation_manager.py +109 -0
  41. airtrain/integrations/fireworks/credentials.py +26 -0
  42. airtrain/integrations/fireworks/list_models.py +128 -0
  43. airtrain/integrations/fireworks/models.py +139 -0
  44. airtrain/integrations/fireworks/requests_skills.py +207 -0
  45. airtrain/integrations/fireworks/skills.py +181 -0
  46. airtrain/integrations/fireworks/structured_completion_skills.py +175 -0
  47. airtrain/integrations/fireworks/structured_requests_skills.py +291 -0
  48. airtrain/integrations/fireworks/structured_skills.py +102 -0
  49. airtrain/integrations/google/__init__.py +7 -0
  50. airtrain/integrations/google/credentials.py +58 -0
  51. airtrain/integrations/google/skills.py +122 -0
  52. airtrain/integrations/groq/__init__.py +23 -0
  53. airtrain/integrations/groq/credentials.py +24 -0
  54. airtrain/integrations/groq/models_config.py +162 -0
  55. airtrain/integrations/groq/skills.py +201 -0
  56. airtrain/integrations/ollama/__init__.py +6 -0
  57. airtrain/integrations/ollama/credentials.py +26 -0
  58. airtrain/integrations/ollama/skills.py +41 -0
  59. airtrain/integrations/openai/__init__.py +37 -0
  60. airtrain/integrations/openai/chinese_assistant.py +42 -0
  61. airtrain/integrations/openai/credentials.py +39 -0
  62. airtrain/integrations/openai/list_models.py +112 -0
  63. airtrain/integrations/openai/models_config.py +224 -0
  64. airtrain/integrations/openai/skills.py +342 -0
  65. airtrain/integrations/perplexity/__init__.py +49 -0
  66. airtrain/integrations/perplexity/credentials.py +43 -0
  67. airtrain/integrations/perplexity/list_models.py +112 -0
  68. airtrain/integrations/perplexity/models_config.py +128 -0
  69. airtrain/integrations/perplexity/skills.py +279 -0
  70. airtrain/integrations/sambanova/__init__.py +6 -0
  71. airtrain/integrations/sambanova/credentials.py +20 -0
  72. airtrain/integrations/sambanova/skills.py +129 -0
  73. airtrain/integrations/search/__init__.py +21 -0
  74. airtrain/integrations/search/exa/__init__.py +23 -0
  75. airtrain/integrations/search/exa/credentials.py +30 -0
  76. airtrain/integrations/search/exa/schemas.py +114 -0
  77. airtrain/integrations/search/exa/skills.py +115 -0
  78. airtrain/integrations/together/__init__.py +33 -0
  79. airtrain/integrations/together/audio_models_config.py +34 -0
  80. airtrain/integrations/together/credentials.py +22 -0
  81. airtrain/integrations/together/embedding_models_config.py +92 -0
  82. airtrain/integrations/together/image_models_config.py +69 -0
  83. airtrain/integrations/together/image_skill.py +143 -0
  84. airtrain/integrations/together/list_models.py +76 -0
  85. airtrain/integrations/together/models.py +95 -0
  86. airtrain/integrations/together/models_config.py +399 -0
  87. airtrain/integrations/together/rerank_models_config.py +43 -0
  88. airtrain/integrations/together/rerank_skill.py +49 -0
  89. airtrain/integrations/together/schemas.py +33 -0
  90. airtrain/integrations/together/skills.py +305 -0
  91. airtrain/integrations/together/vision_models_config.py +49 -0
  92. airtrain/telemetry/__init__.py +38 -0
  93. airtrain/telemetry/service.py +167 -0
  94. airtrain/telemetry/views.py +237 -0
  95. airtrain/tools/__init__.py +45 -0
  96. airtrain/tools/command.py +398 -0
  97. airtrain/tools/filesystem.py +166 -0
  98. airtrain/tools/network.py +111 -0
  99. airtrain/tools/registry.py +320 -0
  100. airtrain/tools/search.py +450 -0
  101. airtrain/tools/testing.py +135 -0
  102. airtrain-0.1.4.dist-info/METADATA +222 -0
  103. airtrain-0.1.4.dist-info/RECORD +108 -0
  104. {airtrain-0.1.3.dist-info → airtrain-0.1.4.dist-info}/WHEEL +1 -1
  105. airtrain-0.1.4.dist-info/entry_points.txt +2 -0
  106. airtrain-0.1.3.dist-info/METADATA +0 -106
  107. airtrain-0.1.3.dist-info/RECORD +0 -9
  108. {airtrain-0.1.3.dist-info → airtrain-0.1.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,21 @@
1
+ """
2
+ Search integrations for AirTrain.
3
+
4
+ This package provides integrations with various search providers.
5
+ """
6
+
7
+ # Import specific search integrations as needed
8
+ from .exa import (
9
+ ExaCredentials,
10
+ ExaSearchInputSchema,
11
+ ExaSearchOutputSchema,
12
+ ExaSearchSkill,
13
+ )
14
+
15
+ __all__ = [
16
+ # Exa Search
17
+ "ExaCredentials",
18
+ "ExaSearchInputSchema",
19
+ "ExaSearchOutputSchema",
20
+ "ExaSearchSkill",
21
+ ]
@@ -0,0 +1,23 @@
1
+ """
2
+ Exa Search API integration.
3
+
4
+ This module provides integration with the Exa search API for web searching capabilities.
5
+ """
6
+
7
+ from .credentials import ExaCredentials
8
+ from .schemas import (
9
+ ExaSearchInputSchema,
10
+ ExaSearchOutputSchema,
11
+ ExaContentConfig,
12
+ ExaSearchResult,
13
+ )
14
+ from .skills import ExaSearchSkill
15
+
16
+ __all__ = [
17
+ "ExaCredentials",
18
+ "ExaSearchInputSchema",
19
+ "ExaSearchOutputSchema",
20
+ "ExaContentConfig",
21
+ "ExaSearchResult",
22
+ "ExaSearchSkill",
23
+ ]
@@ -0,0 +1,30 @@
1
+ """
2
+ Credentials for Exa Search API.
3
+
4
+ This module provides credential management for the Exa search API.
5
+ """
6
+
7
+ from typing import Optional
8
+ from pydantic import Field, SecretStr
9
+
10
+ from airtrain.core.credentials import BaseCredentials
11
+
12
+
13
+ class ExaCredentials(BaseCredentials):
14
+ """Credentials for accessing the Exa search API."""
15
+
16
+ exa_api_key: SecretStr = Field(
17
+ description="Exa search API key",
18
+ )
19
+
20
+ _required_credentials = {"exa_api_key"}
21
+
22
+ async def validate_credentials(self) -> bool:
23
+ """Validate that the required credentials are present and valid."""
24
+ # First check that required credentials are present
25
+ await super().validate_credentials()
26
+
27
+ # In a production environment, we might want to make a test API call here
28
+ # to verify the API key is actually valid, but for now we'll just check
29
+ # that it's present
30
+ return True
@@ -0,0 +1,114 @@
1
+ """
2
+ Schemas for Exa Search API.
3
+
4
+ This module defines the input and output schemas for the Exa search API.
5
+ """
6
+
7
+ from typing import Dict, List, Optional, Any, Union, bool
8
+ from pydantic import BaseModel, Field, HttpUrl
9
+
10
+ from airtrain.core.schemas import InputSchema, OutputSchema
11
+
12
+
13
+ class ExaContentConfig(BaseModel):
14
+ """Configuration for the content to be returned by Exa search."""
15
+
16
+ text: bool = Field(default=True, description="Whether to return text content.")
17
+ extractedText: Optional[bool] = Field(
18
+ default=None, description="Whether to return extracted text content."
19
+ )
20
+ embedded: Optional[bool] = Field(
21
+ default=None, description="Whether to return embedded content."
22
+ )
23
+ links: Optional[bool] = Field(
24
+ default=None, description="Whether to return links from the content."
25
+ )
26
+ screenshot: Optional[bool] = Field(
27
+ default=None, description="Whether to return screenshots of the content."
28
+ )
29
+ highlighted: Optional[bool] = Field(
30
+ default=None, description="Whether to return highlighted text."
31
+ )
32
+
33
+
34
+ class ExaSearchInputSchema(InputSchema):
35
+ """Input schema for Exa search API."""
36
+
37
+ query: str = Field(description="The search query to execute.")
38
+ numResults: Optional[int] = Field(
39
+ default=None, description="Number of results to return."
40
+ )
41
+ contents: Optional[ExaContentConfig] = Field(
42
+ default_factory=ExaContentConfig,
43
+ description="Configuration for the content to be returned.",
44
+ )
45
+ highlights: Optional[dict] = Field(
46
+ default=None, description="Highlighting configuration for search results."
47
+ )
48
+ useAutoprompt: Optional[bool] = Field(
49
+ default=None, description="Whether to use autoprompt for the search."
50
+ )
51
+ type: Optional[str] = Field(default=None, description="Type of search to perform.")
52
+ includeDomains: Optional[List[str]] = Field(
53
+ default=None, description="List of domains to include in the search."
54
+ )
55
+ excludeDomains: Optional[List[str]] = Field(
56
+ default=None, description="List of domains to exclude from the search."
57
+ )
58
+
59
+
60
+ class ExaModerationConfig(BaseModel):
61
+ """Moderation configuration returned in search results."""
62
+
63
+ llamaguardS1: Optional[bool] = None
64
+ llamaguardS3: Optional[bool] = None
65
+ llamaguardS4: Optional[bool] = None
66
+ llamaguardS12: Optional[bool] = None
67
+ domainBlacklisted: Optional[bool] = None
68
+ domainBlacklistedMedia: Optional[bool] = None
69
+
70
+
71
+ class ExaHighlight(BaseModel):
72
+ """Highlight information for a search result."""
73
+
74
+ text: str
75
+ score: float
76
+
77
+
78
+ class ExaSearchResult(BaseModel):
79
+ """Individual search result from Exa."""
80
+
81
+ id: str
82
+ url: str
83
+ title: Optional[str] = None
84
+ text: Optional[str] = None
85
+ extractedText: Optional[str] = None
86
+ embedded: Optional[Dict[str, Any]] = None
87
+ score: float
88
+ published: Optional[str] = None
89
+ author: Optional[str] = None
90
+ highlights: Optional[List[ExaHighlight]] = None
91
+ robotsAllowed: Optional[bool] = None
92
+ moderationConfig: Optional[ExaModerationConfig] = None
93
+ urls: Optional[List[str]] = None
94
+
95
+
96
+ class ExaCostDetails(BaseModel):
97
+ """Cost details for an Exa search request."""
98
+
99
+ total: float
100
+ search: Dict[str, float]
101
+ contents: Dict[str, float]
102
+
103
+
104
+ class ExaSearchOutputSchema(OutputSchema):
105
+ """Output schema for Exa search API."""
106
+
107
+ results: List[ExaSearchResult] = Field(description="List of search results.")
108
+ query: str = Field(description="The original search query.")
109
+ autopromptString: Optional[str] = Field(
110
+ default=None, description="Autoprompt string used for the search if enabled."
111
+ )
112
+ costDollars: Optional[ExaCostDetails] = Field(
113
+ default=None, description="Cost details for the search request."
114
+ )
@@ -0,0 +1,115 @@
1
+ """
2
+ Skills for Exa Search API.
3
+
4
+ This module provides skills for using the Exa search API.
5
+ """
6
+
7
+ import json
8
+ import logging
9
+ import httpx
10
+ from typing import Optional, Dict, Any, List, cast
11
+
12
+ from pydantic import ValidationError
13
+
14
+ from airtrain.core.skills import Skill
15
+ from airtrain.core.errors import ProcessingError
16
+ from .credentials import ExaCredentials
17
+ from .schemas import ExaSearchInputSchema, ExaSearchOutputSchema, ExaSearchResult
18
+
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ class ExaSearchSkill(Skill[ExaSearchInputSchema, ExaSearchOutputSchema]):
24
+ """Skill for searching the web using the Exa search API."""
25
+
26
+ input_schema = ExaSearchInputSchema
27
+ output_schema = ExaSearchOutputSchema
28
+
29
+ EXA_API_ENDPOINT = "https://api.exa.ai/search"
30
+
31
+ def __init__(
32
+ self,
33
+ credentials: ExaCredentials,
34
+ timeout: float = 60.0,
35
+ max_retries: int = 3,
36
+ **kwargs,
37
+ ):
38
+ """
39
+ Initialize the Exa search skill.
40
+
41
+ Args:
42
+ credentials: Credentials for accessing the Exa API
43
+ timeout: Timeout for API requests in seconds
44
+ max_retries: Maximum number of retries for failed requests
45
+ """
46
+ super().__init__(**kwargs)
47
+ self.credentials = credentials
48
+ self.timeout = timeout
49
+ self.max_retries = max_retries
50
+
51
+ async def process(self, input_data: ExaSearchInputSchema) -> ExaSearchOutputSchema:
52
+ """
53
+ Process a search request using the Exa API.
54
+
55
+ Args:
56
+ input_data: Search input parameters
57
+
58
+ Returns:
59
+ Search results from Exa
60
+
61
+ Raises:
62
+ ProcessingError: If there's an issue with the API request
63
+ """
64
+ try:
65
+ # Prepare request payload
66
+ payload = input_data.model_dump(exclude_none=True)
67
+
68
+ # Build request headers
69
+ headers = {
70
+ "content-type": "application/json",
71
+ "Authorization": f"Bearer {self.credentials.api_key.get_secret_value()}",
72
+ }
73
+
74
+ # Make the API request
75
+ async with httpx.AsyncClient() as client:
76
+ response = await client.post(
77
+ self.EXA_API_ENDPOINT,
78
+ headers=headers,
79
+ json=payload,
80
+ timeout=self.timeout,
81
+ )
82
+
83
+ # Check for successful response
84
+ if response.status_code == 200:
85
+ result_data = response.json()
86
+
87
+ # Construct the output schema
88
+ output = ExaSearchOutputSchema(
89
+ results=result_data.get("results", []),
90
+ query=input_data.query,
91
+ autopromptString=result_data.get("autopromptString"),
92
+ costDollars=result_data.get("costDollars"),
93
+ )
94
+
95
+ return output
96
+ else:
97
+ # Handle error responses
98
+ error_message = f"Exa API returned status code {response.status_code}: {response.text}"
99
+ logger.error(error_message)
100
+ raise ProcessingError(error_message)
101
+
102
+ except httpx.TimeoutException:
103
+ error_message = f"Timeout while querying Exa API (timeout={self.timeout}s)"
104
+ logger.error(error_message)
105
+ raise ProcessingError(error_message)
106
+
107
+ except ValidationError as e:
108
+ error_message = f"Schema validation error: {str(e)}"
109
+ logger.error(error_message)
110
+ raise ProcessingError(error_message)
111
+
112
+ except Exception as e:
113
+ error_message = f"Unexpected error while querying Exa API: {str(e)}"
114
+ logger.error(error_message)
115
+ raise ProcessingError(error_message)
@@ -0,0 +1,33 @@
1
+ """Together AI integration module"""
2
+
3
+ from .credentials import TogetherAICredentials
4
+ from .skills import TogetherAIChatSkill, TogetherAIInput, TogetherAIOutput
5
+ from .models_config import (
6
+ get_model_config_with_capabilities,
7
+ get_max_completion_tokens,
8
+ supports_tool_use,
9
+ supports_json_mode,
10
+ TOGETHER_MODELS_CONFIG,
11
+ )
12
+ from .list_models import (
13
+ TogetherListModelsSkill,
14
+ TogetherListModelsInput,
15
+ TogetherListModelsOutput,
16
+ )
17
+ from .models import TogetherModel
18
+
19
+ __all__ = [
20
+ "TogetherAICredentials",
21
+ "TogetherAIChatSkill",
22
+ "TogetherAIInput",
23
+ "TogetherAIOutput",
24
+ "TogetherListModelsSkill",
25
+ "TogetherListModelsInput",
26
+ "TogetherListModelsOutput",
27
+ "TogetherModel",
28
+ "get_model_config_with_capabilities",
29
+ "get_max_completion_tokens",
30
+ "supports_tool_use",
31
+ "supports_json_mode",
32
+ "TOGETHER_MODELS_CONFIG",
33
+ ]
@@ -0,0 +1,34 @@
1
+ from typing import Dict, NamedTuple
2
+
3
+
4
+ class AudioModelConfig(NamedTuple):
5
+ organization: str
6
+ display_name: str
7
+
8
+
9
+ TOGETHER_AUDIO_MODELS: Dict[str, AudioModelConfig] = {
10
+ "Cartesia/Sonic": AudioModelConfig(
11
+ organization="Cartesia", display_name="Cartesia/Sonic"
12
+ )
13
+ }
14
+
15
+
16
+ def get_audio_model_config(model_id: str) -> AudioModelConfig:
17
+ """Get audio model configuration by model ID"""
18
+ if model_id not in TOGETHER_AUDIO_MODELS:
19
+ raise ValueError(f"Model {model_id} not found in Together AI audio models")
20
+ return TOGETHER_AUDIO_MODELS[model_id]
21
+
22
+
23
+ def list_audio_models_by_organization(organization: str) -> Dict[str, AudioModelConfig]:
24
+ """Get all audio models for a specific organization"""
25
+ return {
26
+ model_id: config
27
+ for model_id, config in TOGETHER_AUDIO_MODELS.items()
28
+ if config.organization.lower() == organization.lower()
29
+ }
30
+
31
+
32
+ def get_default_audio_model() -> str:
33
+ """Get the default audio model ID"""
34
+ return "Cartesia/Sonic"
@@ -0,0 +1,22 @@
1
+ from pydantic import Field, SecretStr
2
+ from airtrain.core.credentials import BaseCredentials, CredentialValidationError
3
+ import together
4
+
5
+
6
+ class TogetherAICredentials(BaseCredentials):
7
+ """Together AI credentials"""
8
+
9
+ together_api_key: SecretStr = Field(..., description="Together AI API key")
10
+
11
+ _required_credentials = {"together_api_key"}
12
+
13
+ async def validate_credentials(self) -> bool:
14
+ """Validate Together AI credentials"""
15
+ try:
16
+ together.api_key = self.together_api_key.get_secret_value()
17
+ await together.Models.list()
18
+ return True
19
+ except Exception as e:
20
+ raise CredentialValidationError(
21
+ f"Invalid Together AI credentials: {str(e)}"
22
+ )
@@ -0,0 +1,92 @@
1
+ from typing import Dict, NamedTuple
2
+
3
+
4
+ class EmbeddingModelConfig(NamedTuple):
5
+ organization: str
6
+ display_name: str
7
+ model_size: str
8
+ embedding_dimension: int
9
+ context_window: int
10
+
11
+
12
+ TOGETHER_EMBEDDING_MODELS: Dict[str, EmbeddingModelConfig] = {
13
+ "togethercomputer/m2-bert-80M-2k-retrieval": EmbeddingModelConfig(
14
+ organization="Together",
15
+ display_name="M2-BERT-80M-2K-Retrieval",
16
+ model_size="80M",
17
+ embedding_dimension=768,
18
+ context_window=2048,
19
+ ),
20
+ "togethercomputer/m2-bert-80M-8k-retrieval": EmbeddingModelConfig(
21
+ organization="Together",
22
+ display_name="M2-BERT-80M-8K-Retrieval",
23
+ model_size="80M",
24
+ embedding_dimension=768,
25
+ context_window=8192,
26
+ ),
27
+ "togethercomputer/m2-bert-80M-32k-retrieval": EmbeddingModelConfig(
28
+ organization="Together",
29
+ display_name="M2-BERT-80M-32K-Retrieval",
30
+ model_size="80M",
31
+ embedding_dimension=768,
32
+ context_window=32768,
33
+ ),
34
+ "WhereIsAI/UAE-Large-V1": EmbeddingModelConfig(
35
+ organization="WhereIsAI",
36
+ display_name="UAE-Large-v1",
37
+ model_size="326M",
38
+ embedding_dimension=1024,
39
+ context_window=512,
40
+ ),
41
+ "BAAI/bge-large-en-v1.5": EmbeddingModelConfig(
42
+ organization="BAAI",
43
+ display_name="BGE-Large-EN-v1.5",
44
+ model_size="326M",
45
+ embedding_dimension=1024,
46
+ context_window=512,
47
+ ),
48
+ "BAAI/bge-base-en-v1.5": EmbeddingModelConfig(
49
+ organization="BAAI",
50
+ display_name="BGE-Base-EN-v1.5",
51
+ model_size="102M",
52
+ embedding_dimension=768,
53
+ context_window=512,
54
+ ),
55
+ "sentence-transformers/msmarco-bert-base-dot-v5": EmbeddingModelConfig(
56
+ organization="sentence-transformers",
57
+ display_name="Sentence-BERT",
58
+ model_size="110M",
59
+ embedding_dimension=768,
60
+ context_window=512,
61
+ ),
62
+ "bert-base-uncased": EmbeddingModelConfig(
63
+ organization="Hugging Face",
64
+ display_name="BERT",
65
+ model_size="110M",
66
+ embedding_dimension=768,
67
+ context_window=512,
68
+ ),
69
+ }
70
+
71
+
72
+ def get_embedding_model_config(model_id: str) -> EmbeddingModelConfig:
73
+ """Get embedding model configuration by model ID"""
74
+ if model_id not in TOGETHER_EMBEDDING_MODELS:
75
+ raise ValueError(f"Model {model_id} not found in Together AI embedding models")
76
+ return TOGETHER_EMBEDDING_MODELS[model_id]
77
+
78
+
79
+ def list_embedding_models_by_organization(
80
+ organization: str,
81
+ ) -> Dict[str, EmbeddingModelConfig]:
82
+ """Get all embedding models for a specific organization"""
83
+ return {
84
+ model_id: config
85
+ for model_id, config in TOGETHER_EMBEDDING_MODELS.items()
86
+ if config.organization.lower() == organization.lower()
87
+ }
88
+
89
+
90
+ def get_default_embedding_model() -> str:
91
+ """Get the default embedding model ID"""
92
+ return "togethercomputer/m2-bert-80M-32k-retrieval"
@@ -0,0 +1,69 @@
1
+ from typing import Dict, NamedTuple, Optional
2
+
3
+
4
+ class ImageModelConfig(NamedTuple):
5
+ organization: str
6
+ display_name: str
7
+ default_steps: Optional[int]
8
+
9
+
10
+ TOGETHER_IMAGE_MODELS: Dict[str, ImageModelConfig] = {
11
+ "black-forest-labs/FLUX.1-schnell-Free": ImageModelConfig(
12
+ organization="Black Forest Labs",
13
+ display_name="Flux.1 [schnell] (free)",
14
+ default_steps=None,
15
+ ),
16
+ "black-forest-labs/FLUX.1-schnell": ImageModelConfig(
17
+ organization="Black Forest Labs",
18
+ display_name="Flux.1 [schnell] (Turbo)",
19
+ default_steps=4,
20
+ ),
21
+ "black-forest-labs/FLUX.1-dev": ImageModelConfig(
22
+ organization="Black Forest Labs", display_name="Flux.1 Dev", default_steps=28
23
+ ),
24
+ "black-forest-labs/FLUX.1-canny": ImageModelConfig(
25
+ organization="Black Forest Labs", display_name="Flux.1 Canny", default_steps=28
26
+ ),
27
+ "black-forest-labs/FLUX.1-depth": ImageModelConfig(
28
+ organization="Black Forest Labs", display_name="Flux.1 Depth", default_steps=28
29
+ ),
30
+ "black-forest-labs/FLUX.1-redux": ImageModelConfig(
31
+ organization="Black Forest Labs", display_name="Flux.1 Redux", default_steps=28
32
+ ),
33
+ "black-forest-labs/FLUX.1.1-pro": ImageModelConfig(
34
+ organization="Black Forest Labs",
35
+ display_name="Flux1.1 [pro]",
36
+ default_steps=None,
37
+ ),
38
+ "black-forest-labs/FLUX.1-pro": ImageModelConfig(
39
+ organization="Black Forest Labs",
40
+ display_name="Flux.1 [pro]",
41
+ default_steps=None,
42
+ ),
43
+ "stabilityai/stable-diffusion-xl-base-1.0": ImageModelConfig(
44
+ organization="Stability AI",
45
+ display_name="Stable Diffusion XL 1.0",
46
+ default_steps=None,
47
+ ),
48
+ }
49
+
50
+
51
+ def get_image_model_config(model_id: str) -> ImageModelConfig:
52
+ """Get image model configuration by model ID"""
53
+ if model_id not in TOGETHER_IMAGE_MODELS:
54
+ raise ValueError(f"Model {model_id} not found in Together AI image models")
55
+ return TOGETHER_IMAGE_MODELS[model_id]
56
+
57
+
58
+ def list_image_models_by_organization(organization: str) -> Dict[str, ImageModelConfig]:
59
+ """Get all image models for a specific organization"""
60
+ return {
61
+ model_id: config
62
+ for model_id, config in TOGETHER_IMAGE_MODELS.items()
63
+ if config.organization.lower() == organization.lower()
64
+ }
65
+
66
+
67
+ def get_default_image_model() -> str:
68
+ """Get the default image model ID"""
69
+ return "black-forest-labs/FLUX.1-schnell-Free"