airtrain 0.1.3__py3-none-any.whl → 0.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airtrain/__init__.py +146 -6
- airtrain/__main__.py +4 -0
- airtrain/__pycache__/__init__.cpython-313.pyc +0 -0
- airtrain/agents/__init__.py +45 -0
- airtrain/agents/example_agent.py +348 -0
- airtrain/agents/groq_agent.py +289 -0
- airtrain/agents/memory.py +663 -0
- airtrain/agents/registry.py +465 -0
- airtrain/builder/__init__.py +3 -0
- airtrain/builder/agent_builder.py +122 -0
- airtrain/cli/__init__.py +0 -0
- airtrain/cli/builder.py +23 -0
- airtrain/cli/main.py +120 -0
- airtrain/contrib/__init__.py +29 -0
- airtrain/contrib/travel/__init__.py +35 -0
- airtrain/contrib/travel/agents.py +243 -0
- airtrain/contrib/travel/models.py +59 -0
- airtrain/core/__pycache__/__init__.cpython-313.pyc +0 -0
- airtrain/core/__pycache__/schemas.cpython-313.pyc +0 -0
- airtrain/core/__pycache__/skills.cpython-313.pyc +0 -0
- airtrain/core/credentials.py +62 -44
- airtrain/core/skills.py +102 -0
- airtrain/integrations/__init__.py +74 -0
- airtrain/integrations/anthropic/__init__.py +33 -0
- airtrain/integrations/anthropic/credentials.py +32 -0
- airtrain/integrations/anthropic/list_models.py +110 -0
- airtrain/integrations/anthropic/models_config.py +100 -0
- airtrain/integrations/anthropic/skills.py +155 -0
- airtrain/integrations/aws/__init__.py +6 -0
- airtrain/integrations/aws/credentials.py +36 -0
- airtrain/integrations/aws/skills.py +98 -0
- airtrain/integrations/cerebras/__init__.py +6 -0
- airtrain/integrations/cerebras/credentials.py +19 -0
- airtrain/integrations/cerebras/skills.py +127 -0
- airtrain/integrations/combined/__init__.py +21 -0
- airtrain/integrations/combined/groq_fireworks_skills.py +126 -0
- airtrain/integrations/combined/list_models_factory.py +210 -0
- airtrain/integrations/fireworks/__init__.py +21 -0
- airtrain/integrations/fireworks/completion_skills.py +147 -0
- airtrain/integrations/fireworks/conversation_manager.py +109 -0
- airtrain/integrations/fireworks/credentials.py +26 -0
- airtrain/integrations/fireworks/list_models.py +128 -0
- airtrain/integrations/fireworks/models.py +139 -0
- airtrain/integrations/fireworks/requests_skills.py +207 -0
- airtrain/integrations/fireworks/skills.py +181 -0
- airtrain/integrations/fireworks/structured_completion_skills.py +175 -0
- airtrain/integrations/fireworks/structured_requests_skills.py +291 -0
- airtrain/integrations/fireworks/structured_skills.py +102 -0
- airtrain/integrations/google/__init__.py +7 -0
- airtrain/integrations/google/credentials.py +58 -0
- airtrain/integrations/google/skills.py +122 -0
- airtrain/integrations/groq/__init__.py +23 -0
- airtrain/integrations/groq/credentials.py +24 -0
- airtrain/integrations/groq/models_config.py +162 -0
- airtrain/integrations/groq/skills.py +201 -0
- airtrain/integrations/ollama/__init__.py +6 -0
- airtrain/integrations/ollama/credentials.py +26 -0
- airtrain/integrations/ollama/skills.py +41 -0
- airtrain/integrations/openai/__init__.py +37 -0
- airtrain/integrations/openai/chinese_assistant.py +42 -0
- airtrain/integrations/openai/credentials.py +39 -0
- airtrain/integrations/openai/list_models.py +112 -0
- airtrain/integrations/openai/models_config.py +224 -0
- airtrain/integrations/openai/skills.py +342 -0
- airtrain/integrations/perplexity/__init__.py +49 -0
- airtrain/integrations/perplexity/credentials.py +43 -0
- airtrain/integrations/perplexity/list_models.py +112 -0
- airtrain/integrations/perplexity/models_config.py +128 -0
- airtrain/integrations/perplexity/skills.py +279 -0
- airtrain/integrations/sambanova/__init__.py +6 -0
- airtrain/integrations/sambanova/credentials.py +20 -0
- airtrain/integrations/sambanova/skills.py +129 -0
- airtrain/integrations/search/__init__.py +21 -0
- airtrain/integrations/search/exa/__init__.py +23 -0
- airtrain/integrations/search/exa/credentials.py +30 -0
- airtrain/integrations/search/exa/schemas.py +114 -0
- airtrain/integrations/search/exa/skills.py +115 -0
- airtrain/integrations/together/__init__.py +33 -0
- airtrain/integrations/together/audio_models_config.py +34 -0
- airtrain/integrations/together/credentials.py +22 -0
- airtrain/integrations/together/embedding_models_config.py +92 -0
- airtrain/integrations/together/image_models_config.py +69 -0
- airtrain/integrations/together/image_skill.py +143 -0
- airtrain/integrations/together/list_models.py +76 -0
- airtrain/integrations/together/models.py +95 -0
- airtrain/integrations/together/models_config.py +399 -0
- airtrain/integrations/together/rerank_models_config.py +43 -0
- airtrain/integrations/together/rerank_skill.py +49 -0
- airtrain/integrations/together/schemas.py +33 -0
- airtrain/integrations/together/skills.py +305 -0
- airtrain/integrations/together/vision_models_config.py +49 -0
- airtrain/telemetry/__init__.py +38 -0
- airtrain/telemetry/service.py +167 -0
- airtrain/telemetry/views.py +237 -0
- airtrain/tools/__init__.py +45 -0
- airtrain/tools/command.py +398 -0
- airtrain/tools/filesystem.py +166 -0
- airtrain/tools/network.py +111 -0
- airtrain/tools/registry.py +320 -0
- airtrain/tools/search.py +450 -0
- airtrain/tools/testing.py +135 -0
- airtrain-0.1.4.dist-info/METADATA +222 -0
- airtrain-0.1.4.dist-info/RECORD +108 -0
- {airtrain-0.1.3.dist-info → airtrain-0.1.4.dist-info}/WHEEL +1 -1
- airtrain-0.1.4.dist-info/entry_points.txt +2 -0
- airtrain-0.1.3.dist-info/METADATA +0 -106
- airtrain-0.1.3.dist-info/RECORD +0 -9
- {airtrain-0.1.3.dist-info → airtrain-0.1.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,21 @@
|
|
1
|
+
"""
|
2
|
+
Search integrations for AirTrain.
|
3
|
+
|
4
|
+
This package provides integrations with various search providers.
|
5
|
+
"""
|
6
|
+
|
7
|
+
# Import specific search integrations as needed
|
8
|
+
from .exa import (
|
9
|
+
ExaCredentials,
|
10
|
+
ExaSearchInputSchema,
|
11
|
+
ExaSearchOutputSchema,
|
12
|
+
ExaSearchSkill,
|
13
|
+
)
|
14
|
+
|
15
|
+
__all__ = [
|
16
|
+
# Exa Search
|
17
|
+
"ExaCredentials",
|
18
|
+
"ExaSearchInputSchema",
|
19
|
+
"ExaSearchOutputSchema",
|
20
|
+
"ExaSearchSkill",
|
21
|
+
]
|
@@ -0,0 +1,23 @@
|
|
1
|
+
"""
|
2
|
+
Exa Search API integration.
|
3
|
+
|
4
|
+
This module provides integration with the Exa search API for web searching capabilities.
|
5
|
+
"""
|
6
|
+
|
7
|
+
from .credentials import ExaCredentials
|
8
|
+
from .schemas import (
|
9
|
+
ExaSearchInputSchema,
|
10
|
+
ExaSearchOutputSchema,
|
11
|
+
ExaContentConfig,
|
12
|
+
ExaSearchResult,
|
13
|
+
)
|
14
|
+
from .skills import ExaSearchSkill
|
15
|
+
|
16
|
+
__all__ = [
|
17
|
+
"ExaCredentials",
|
18
|
+
"ExaSearchInputSchema",
|
19
|
+
"ExaSearchOutputSchema",
|
20
|
+
"ExaContentConfig",
|
21
|
+
"ExaSearchResult",
|
22
|
+
"ExaSearchSkill",
|
23
|
+
]
|
@@ -0,0 +1,30 @@
|
|
1
|
+
"""
|
2
|
+
Credentials for Exa Search API.
|
3
|
+
|
4
|
+
This module provides credential management for the Exa search API.
|
5
|
+
"""
|
6
|
+
|
7
|
+
from typing import Optional
|
8
|
+
from pydantic import Field, SecretStr
|
9
|
+
|
10
|
+
from airtrain.core.credentials import BaseCredentials
|
11
|
+
|
12
|
+
|
13
|
+
class ExaCredentials(BaseCredentials):
|
14
|
+
"""Credentials for accessing the Exa search API."""
|
15
|
+
|
16
|
+
exa_api_key: SecretStr = Field(
|
17
|
+
description="Exa search API key",
|
18
|
+
)
|
19
|
+
|
20
|
+
_required_credentials = {"exa_api_key"}
|
21
|
+
|
22
|
+
async def validate_credentials(self) -> bool:
|
23
|
+
"""Validate that the required credentials are present and valid."""
|
24
|
+
# First check that required credentials are present
|
25
|
+
await super().validate_credentials()
|
26
|
+
|
27
|
+
# In a production environment, we might want to make a test API call here
|
28
|
+
# to verify the API key is actually valid, but for now we'll just check
|
29
|
+
# that it's present
|
30
|
+
return True
|
@@ -0,0 +1,114 @@
|
|
1
|
+
"""
|
2
|
+
Schemas for Exa Search API.
|
3
|
+
|
4
|
+
This module defines the input and output schemas for the Exa search API.
|
5
|
+
"""
|
6
|
+
|
7
|
+
from typing import Dict, List, Optional, Any, Union, bool
|
8
|
+
from pydantic import BaseModel, Field, HttpUrl
|
9
|
+
|
10
|
+
from airtrain.core.schemas import InputSchema, OutputSchema
|
11
|
+
|
12
|
+
|
13
|
+
class ExaContentConfig(BaseModel):
|
14
|
+
"""Configuration for the content to be returned by Exa search."""
|
15
|
+
|
16
|
+
text: bool = Field(default=True, description="Whether to return text content.")
|
17
|
+
extractedText: Optional[bool] = Field(
|
18
|
+
default=None, description="Whether to return extracted text content."
|
19
|
+
)
|
20
|
+
embedded: Optional[bool] = Field(
|
21
|
+
default=None, description="Whether to return embedded content."
|
22
|
+
)
|
23
|
+
links: Optional[bool] = Field(
|
24
|
+
default=None, description="Whether to return links from the content."
|
25
|
+
)
|
26
|
+
screenshot: Optional[bool] = Field(
|
27
|
+
default=None, description="Whether to return screenshots of the content."
|
28
|
+
)
|
29
|
+
highlighted: Optional[bool] = Field(
|
30
|
+
default=None, description="Whether to return highlighted text."
|
31
|
+
)
|
32
|
+
|
33
|
+
|
34
|
+
class ExaSearchInputSchema(InputSchema):
|
35
|
+
"""Input schema for Exa search API."""
|
36
|
+
|
37
|
+
query: str = Field(description="The search query to execute.")
|
38
|
+
numResults: Optional[int] = Field(
|
39
|
+
default=None, description="Number of results to return."
|
40
|
+
)
|
41
|
+
contents: Optional[ExaContentConfig] = Field(
|
42
|
+
default_factory=ExaContentConfig,
|
43
|
+
description="Configuration for the content to be returned.",
|
44
|
+
)
|
45
|
+
highlights: Optional[dict] = Field(
|
46
|
+
default=None, description="Highlighting configuration for search results."
|
47
|
+
)
|
48
|
+
useAutoprompt: Optional[bool] = Field(
|
49
|
+
default=None, description="Whether to use autoprompt for the search."
|
50
|
+
)
|
51
|
+
type: Optional[str] = Field(default=None, description="Type of search to perform.")
|
52
|
+
includeDomains: Optional[List[str]] = Field(
|
53
|
+
default=None, description="List of domains to include in the search."
|
54
|
+
)
|
55
|
+
excludeDomains: Optional[List[str]] = Field(
|
56
|
+
default=None, description="List of domains to exclude from the search."
|
57
|
+
)
|
58
|
+
|
59
|
+
|
60
|
+
class ExaModerationConfig(BaseModel):
|
61
|
+
"""Moderation configuration returned in search results."""
|
62
|
+
|
63
|
+
llamaguardS1: Optional[bool] = None
|
64
|
+
llamaguardS3: Optional[bool] = None
|
65
|
+
llamaguardS4: Optional[bool] = None
|
66
|
+
llamaguardS12: Optional[bool] = None
|
67
|
+
domainBlacklisted: Optional[bool] = None
|
68
|
+
domainBlacklistedMedia: Optional[bool] = None
|
69
|
+
|
70
|
+
|
71
|
+
class ExaHighlight(BaseModel):
|
72
|
+
"""Highlight information for a search result."""
|
73
|
+
|
74
|
+
text: str
|
75
|
+
score: float
|
76
|
+
|
77
|
+
|
78
|
+
class ExaSearchResult(BaseModel):
|
79
|
+
"""Individual search result from Exa."""
|
80
|
+
|
81
|
+
id: str
|
82
|
+
url: str
|
83
|
+
title: Optional[str] = None
|
84
|
+
text: Optional[str] = None
|
85
|
+
extractedText: Optional[str] = None
|
86
|
+
embedded: Optional[Dict[str, Any]] = None
|
87
|
+
score: float
|
88
|
+
published: Optional[str] = None
|
89
|
+
author: Optional[str] = None
|
90
|
+
highlights: Optional[List[ExaHighlight]] = None
|
91
|
+
robotsAllowed: Optional[bool] = None
|
92
|
+
moderationConfig: Optional[ExaModerationConfig] = None
|
93
|
+
urls: Optional[List[str]] = None
|
94
|
+
|
95
|
+
|
96
|
+
class ExaCostDetails(BaseModel):
|
97
|
+
"""Cost details for an Exa search request."""
|
98
|
+
|
99
|
+
total: float
|
100
|
+
search: Dict[str, float]
|
101
|
+
contents: Dict[str, float]
|
102
|
+
|
103
|
+
|
104
|
+
class ExaSearchOutputSchema(OutputSchema):
|
105
|
+
"""Output schema for Exa search API."""
|
106
|
+
|
107
|
+
results: List[ExaSearchResult] = Field(description="List of search results.")
|
108
|
+
query: str = Field(description="The original search query.")
|
109
|
+
autopromptString: Optional[str] = Field(
|
110
|
+
default=None, description="Autoprompt string used for the search if enabled."
|
111
|
+
)
|
112
|
+
costDollars: Optional[ExaCostDetails] = Field(
|
113
|
+
default=None, description="Cost details for the search request."
|
114
|
+
)
|
@@ -0,0 +1,115 @@
|
|
1
|
+
"""
|
2
|
+
Skills for Exa Search API.
|
3
|
+
|
4
|
+
This module provides skills for using the Exa search API.
|
5
|
+
"""
|
6
|
+
|
7
|
+
import json
|
8
|
+
import logging
|
9
|
+
import httpx
|
10
|
+
from typing import Optional, Dict, Any, List, cast
|
11
|
+
|
12
|
+
from pydantic import ValidationError
|
13
|
+
|
14
|
+
from airtrain.core.skills import Skill
|
15
|
+
from airtrain.core.errors import ProcessingError
|
16
|
+
from .credentials import ExaCredentials
|
17
|
+
from .schemas import ExaSearchInputSchema, ExaSearchOutputSchema, ExaSearchResult
|
18
|
+
|
19
|
+
|
20
|
+
logger = logging.getLogger(__name__)
|
21
|
+
|
22
|
+
|
23
|
+
class ExaSearchSkill(Skill[ExaSearchInputSchema, ExaSearchOutputSchema]):
|
24
|
+
"""Skill for searching the web using the Exa search API."""
|
25
|
+
|
26
|
+
input_schema = ExaSearchInputSchema
|
27
|
+
output_schema = ExaSearchOutputSchema
|
28
|
+
|
29
|
+
EXA_API_ENDPOINT = "https://api.exa.ai/search"
|
30
|
+
|
31
|
+
def __init__(
|
32
|
+
self,
|
33
|
+
credentials: ExaCredentials,
|
34
|
+
timeout: float = 60.0,
|
35
|
+
max_retries: int = 3,
|
36
|
+
**kwargs,
|
37
|
+
):
|
38
|
+
"""
|
39
|
+
Initialize the Exa search skill.
|
40
|
+
|
41
|
+
Args:
|
42
|
+
credentials: Credentials for accessing the Exa API
|
43
|
+
timeout: Timeout for API requests in seconds
|
44
|
+
max_retries: Maximum number of retries for failed requests
|
45
|
+
"""
|
46
|
+
super().__init__(**kwargs)
|
47
|
+
self.credentials = credentials
|
48
|
+
self.timeout = timeout
|
49
|
+
self.max_retries = max_retries
|
50
|
+
|
51
|
+
async def process(self, input_data: ExaSearchInputSchema) -> ExaSearchOutputSchema:
|
52
|
+
"""
|
53
|
+
Process a search request using the Exa API.
|
54
|
+
|
55
|
+
Args:
|
56
|
+
input_data: Search input parameters
|
57
|
+
|
58
|
+
Returns:
|
59
|
+
Search results from Exa
|
60
|
+
|
61
|
+
Raises:
|
62
|
+
ProcessingError: If there's an issue with the API request
|
63
|
+
"""
|
64
|
+
try:
|
65
|
+
# Prepare request payload
|
66
|
+
payload = input_data.model_dump(exclude_none=True)
|
67
|
+
|
68
|
+
# Build request headers
|
69
|
+
headers = {
|
70
|
+
"content-type": "application/json",
|
71
|
+
"Authorization": f"Bearer {self.credentials.api_key.get_secret_value()}",
|
72
|
+
}
|
73
|
+
|
74
|
+
# Make the API request
|
75
|
+
async with httpx.AsyncClient() as client:
|
76
|
+
response = await client.post(
|
77
|
+
self.EXA_API_ENDPOINT,
|
78
|
+
headers=headers,
|
79
|
+
json=payload,
|
80
|
+
timeout=self.timeout,
|
81
|
+
)
|
82
|
+
|
83
|
+
# Check for successful response
|
84
|
+
if response.status_code == 200:
|
85
|
+
result_data = response.json()
|
86
|
+
|
87
|
+
# Construct the output schema
|
88
|
+
output = ExaSearchOutputSchema(
|
89
|
+
results=result_data.get("results", []),
|
90
|
+
query=input_data.query,
|
91
|
+
autopromptString=result_data.get("autopromptString"),
|
92
|
+
costDollars=result_data.get("costDollars"),
|
93
|
+
)
|
94
|
+
|
95
|
+
return output
|
96
|
+
else:
|
97
|
+
# Handle error responses
|
98
|
+
error_message = f"Exa API returned status code {response.status_code}: {response.text}"
|
99
|
+
logger.error(error_message)
|
100
|
+
raise ProcessingError(error_message)
|
101
|
+
|
102
|
+
except httpx.TimeoutException:
|
103
|
+
error_message = f"Timeout while querying Exa API (timeout={self.timeout}s)"
|
104
|
+
logger.error(error_message)
|
105
|
+
raise ProcessingError(error_message)
|
106
|
+
|
107
|
+
except ValidationError as e:
|
108
|
+
error_message = f"Schema validation error: {str(e)}"
|
109
|
+
logger.error(error_message)
|
110
|
+
raise ProcessingError(error_message)
|
111
|
+
|
112
|
+
except Exception as e:
|
113
|
+
error_message = f"Unexpected error while querying Exa API: {str(e)}"
|
114
|
+
logger.error(error_message)
|
115
|
+
raise ProcessingError(error_message)
|
@@ -0,0 +1,33 @@
|
|
1
|
+
"""Together AI integration module"""
|
2
|
+
|
3
|
+
from .credentials import TogetherAICredentials
|
4
|
+
from .skills import TogetherAIChatSkill, TogetherAIInput, TogetherAIOutput
|
5
|
+
from .models_config import (
|
6
|
+
get_model_config_with_capabilities,
|
7
|
+
get_max_completion_tokens,
|
8
|
+
supports_tool_use,
|
9
|
+
supports_json_mode,
|
10
|
+
TOGETHER_MODELS_CONFIG,
|
11
|
+
)
|
12
|
+
from .list_models import (
|
13
|
+
TogetherListModelsSkill,
|
14
|
+
TogetherListModelsInput,
|
15
|
+
TogetherListModelsOutput,
|
16
|
+
)
|
17
|
+
from .models import TogetherModel
|
18
|
+
|
19
|
+
__all__ = [
|
20
|
+
"TogetherAICredentials",
|
21
|
+
"TogetherAIChatSkill",
|
22
|
+
"TogetherAIInput",
|
23
|
+
"TogetherAIOutput",
|
24
|
+
"TogetherListModelsSkill",
|
25
|
+
"TogetherListModelsInput",
|
26
|
+
"TogetherListModelsOutput",
|
27
|
+
"TogetherModel",
|
28
|
+
"get_model_config_with_capabilities",
|
29
|
+
"get_max_completion_tokens",
|
30
|
+
"supports_tool_use",
|
31
|
+
"supports_json_mode",
|
32
|
+
"TOGETHER_MODELS_CONFIG",
|
33
|
+
]
|
@@ -0,0 +1,34 @@
|
|
1
|
+
from typing import Dict, NamedTuple
|
2
|
+
|
3
|
+
|
4
|
+
class AudioModelConfig(NamedTuple):
|
5
|
+
organization: str
|
6
|
+
display_name: str
|
7
|
+
|
8
|
+
|
9
|
+
TOGETHER_AUDIO_MODELS: Dict[str, AudioModelConfig] = {
|
10
|
+
"Cartesia/Sonic": AudioModelConfig(
|
11
|
+
organization="Cartesia", display_name="Cartesia/Sonic"
|
12
|
+
)
|
13
|
+
}
|
14
|
+
|
15
|
+
|
16
|
+
def get_audio_model_config(model_id: str) -> AudioModelConfig:
|
17
|
+
"""Get audio model configuration by model ID"""
|
18
|
+
if model_id not in TOGETHER_AUDIO_MODELS:
|
19
|
+
raise ValueError(f"Model {model_id} not found in Together AI audio models")
|
20
|
+
return TOGETHER_AUDIO_MODELS[model_id]
|
21
|
+
|
22
|
+
|
23
|
+
def list_audio_models_by_organization(organization: str) -> Dict[str, AudioModelConfig]:
|
24
|
+
"""Get all audio models for a specific organization"""
|
25
|
+
return {
|
26
|
+
model_id: config
|
27
|
+
for model_id, config in TOGETHER_AUDIO_MODELS.items()
|
28
|
+
if config.organization.lower() == organization.lower()
|
29
|
+
}
|
30
|
+
|
31
|
+
|
32
|
+
def get_default_audio_model() -> str:
|
33
|
+
"""Get the default audio model ID"""
|
34
|
+
return "Cartesia/Sonic"
|
@@ -0,0 +1,22 @@
|
|
1
|
+
from pydantic import Field, SecretStr
|
2
|
+
from airtrain.core.credentials import BaseCredentials, CredentialValidationError
|
3
|
+
import together
|
4
|
+
|
5
|
+
|
6
|
+
class TogetherAICredentials(BaseCredentials):
|
7
|
+
"""Together AI credentials"""
|
8
|
+
|
9
|
+
together_api_key: SecretStr = Field(..., description="Together AI API key")
|
10
|
+
|
11
|
+
_required_credentials = {"together_api_key"}
|
12
|
+
|
13
|
+
async def validate_credentials(self) -> bool:
|
14
|
+
"""Validate Together AI credentials"""
|
15
|
+
try:
|
16
|
+
together.api_key = self.together_api_key.get_secret_value()
|
17
|
+
await together.Models.list()
|
18
|
+
return True
|
19
|
+
except Exception as e:
|
20
|
+
raise CredentialValidationError(
|
21
|
+
f"Invalid Together AI credentials: {str(e)}"
|
22
|
+
)
|
@@ -0,0 +1,92 @@
|
|
1
|
+
from typing import Dict, NamedTuple
|
2
|
+
|
3
|
+
|
4
|
+
class EmbeddingModelConfig(NamedTuple):
|
5
|
+
organization: str
|
6
|
+
display_name: str
|
7
|
+
model_size: str
|
8
|
+
embedding_dimension: int
|
9
|
+
context_window: int
|
10
|
+
|
11
|
+
|
12
|
+
TOGETHER_EMBEDDING_MODELS: Dict[str, EmbeddingModelConfig] = {
|
13
|
+
"togethercomputer/m2-bert-80M-2k-retrieval": EmbeddingModelConfig(
|
14
|
+
organization="Together",
|
15
|
+
display_name="M2-BERT-80M-2K-Retrieval",
|
16
|
+
model_size="80M",
|
17
|
+
embedding_dimension=768,
|
18
|
+
context_window=2048,
|
19
|
+
),
|
20
|
+
"togethercomputer/m2-bert-80M-8k-retrieval": EmbeddingModelConfig(
|
21
|
+
organization="Together",
|
22
|
+
display_name="M2-BERT-80M-8K-Retrieval",
|
23
|
+
model_size="80M",
|
24
|
+
embedding_dimension=768,
|
25
|
+
context_window=8192,
|
26
|
+
),
|
27
|
+
"togethercomputer/m2-bert-80M-32k-retrieval": EmbeddingModelConfig(
|
28
|
+
organization="Together",
|
29
|
+
display_name="M2-BERT-80M-32K-Retrieval",
|
30
|
+
model_size="80M",
|
31
|
+
embedding_dimension=768,
|
32
|
+
context_window=32768,
|
33
|
+
),
|
34
|
+
"WhereIsAI/UAE-Large-V1": EmbeddingModelConfig(
|
35
|
+
organization="WhereIsAI",
|
36
|
+
display_name="UAE-Large-v1",
|
37
|
+
model_size="326M",
|
38
|
+
embedding_dimension=1024,
|
39
|
+
context_window=512,
|
40
|
+
),
|
41
|
+
"BAAI/bge-large-en-v1.5": EmbeddingModelConfig(
|
42
|
+
organization="BAAI",
|
43
|
+
display_name="BGE-Large-EN-v1.5",
|
44
|
+
model_size="326M",
|
45
|
+
embedding_dimension=1024,
|
46
|
+
context_window=512,
|
47
|
+
),
|
48
|
+
"BAAI/bge-base-en-v1.5": EmbeddingModelConfig(
|
49
|
+
organization="BAAI",
|
50
|
+
display_name="BGE-Base-EN-v1.5",
|
51
|
+
model_size="102M",
|
52
|
+
embedding_dimension=768,
|
53
|
+
context_window=512,
|
54
|
+
),
|
55
|
+
"sentence-transformers/msmarco-bert-base-dot-v5": EmbeddingModelConfig(
|
56
|
+
organization="sentence-transformers",
|
57
|
+
display_name="Sentence-BERT",
|
58
|
+
model_size="110M",
|
59
|
+
embedding_dimension=768,
|
60
|
+
context_window=512,
|
61
|
+
),
|
62
|
+
"bert-base-uncased": EmbeddingModelConfig(
|
63
|
+
organization="Hugging Face",
|
64
|
+
display_name="BERT",
|
65
|
+
model_size="110M",
|
66
|
+
embedding_dimension=768,
|
67
|
+
context_window=512,
|
68
|
+
),
|
69
|
+
}
|
70
|
+
|
71
|
+
|
72
|
+
def get_embedding_model_config(model_id: str) -> EmbeddingModelConfig:
|
73
|
+
"""Get embedding model configuration by model ID"""
|
74
|
+
if model_id not in TOGETHER_EMBEDDING_MODELS:
|
75
|
+
raise ValueError(f"Model {model_id} not found in Together AI embedding models")
|
76
|
+
return TOGETHER_EMBEDDING_MODELS[model_id]
|
77
|
+
|
78
|
+
|
79
|
+
def list_embedding_models_by_organization(
|
80
|
+
organization: str,
|
81
|
+
) -> Dict[str, EmbeddingModelConfig]:
|
82
|
+
"""Get all embedding models for a specific organization"""
|
83
|
+
return {
|
84
|
+
model_id: config
|
85
|
+
for model_id, config in TOGETHER_EMBEDDING_MODELS.items()
|
86
|
+
if config.organization.lower() == organization.lower()
|
87
|
+
}
|
88
|
+
|
89
|
+
|
90
|
+
def get_default_embedding_model() -> str:
|
91
|
+
"""Get the default embedding model ID"""
|
92
|
+
return "togethercomputer/m2-bert-80M-32k-retrieval"
|
@@ -0,0 +1,69 @@
|
|
1
|
+
from typing import Dict, NamedTuple, Optional
|
2
|
+
|
3
|
+
|
4
|
+
class ImageModelConfig(NamedTuple):
|
5
|
+
organization: str
|
6
|
+
display_name: str
|
7
|
+
default_steps: Optional[int]
|
8
|
+
|
9
|
+
|
10
|
+
TOGETHER_IMAGE_MODELS: Dict[str, ImageModelConfig] = {
|
11
|
+
"black-forest-labs/FLUX.1-schnell-Free": ImageModelConfig(
|
12
|
+
organization="Black Forest Labs",
|
13
|
+
display_name="Flux.1 [schnell] (free)",
|
14
|
+
default_steps=None,
|
15
|
+
),
|
16
|
+
"black-forest-labs/FLUX.1-schnell": ImageModelConfig(
|
17
|
+
organization="Black Forest Labs",
|
18
|
+
display_name="Flux.1 [schnell] (Turbo)",
|
19
|
+
default_steps=4,
|
20
|
+
),
|
21
|
+
"black-forest-labs/FLUX.1-dev": ImageModelConfig(
|
22
|
+
organization="Black Forest Labs", display_name="Flux.1 Dev", default_steps=28
|
23
|
+
),
|
24
|
+
"black-forest-labs/FLUX.1-canny": ImageModelConfig(
|
25
|
+
organization="Black Forest Labs", display_name="Flux.1 Canny", default_steps=28
|
26
|
+
),
|
27
|
+
"black-forest-labs/FLUX.1-depth": ImageModelConfig(
|
28
|
+
organization="Black Forest Labs", display_name="Flux.1 Depth", default_steps=28
|
29
|
+
),
|
30
|
+
"black-forest-labs/FLUX.1-redux": ImageModelConfig(
|
31
|
+
organization="Black Forest Labs", display_name="Flux.1 Redux", default_steps=28
|
32
|
+
),
|
33
|
+
"black-forest-labs/FLUX.1.1-pro": ImageModelConfig(
|
34
|
+
organization="Black Forest Labs",
|
35
|
+
display_name="Flux1.1 [pro]",
|
36
|
+
default_steps=None,
|
37
|
+
),
|
38
|
+
"black-forest-labs/FLUX.1-pro": ImageModelConfig(
|
39
|
+
organization="Black Forest Labs",
|
40
|
+
display_name="Flux.1 [pro]",
|
41
|
+
default_steps=None,
|
42
|
+
),
|
43
|
+
"stabilityai/stable-diffusion-xl-base-1.0": ImageModelConfig(
|
44
|
+
organization="Stability AI",
|
45
|
+
display_name="Stable Diffusion XL 1.0",
|
46
|
+
default_steps=None,
|
47
|
+
),
|
48
|
+
}
|
49
|
+
|
50
|
+
|
51
|
+
def get_image_model_config(model_id: str) -> ImageModelConfig:
|
52
|
+
"""Get image model configuration by model ID"""
|
53
|
+
if model_id not in TOGETHER_IMAGE_MODELS:
|
54
|
+
raise ValueError(f"Model {model_id} not found in Together AI image models")
|
55
|
+
return TOGETHER_IMAGE_MODELS[model_id]
|
56
|
+
|
57
|
+
|
58
|
+
def list_image_models_by_organization(organization: str) -> Dict[str, ImageModelConfig]:
|
59
|
+
"""Get all image models for a specific organization"""
|
60
|
+
return {
|
61
|
+
model_id: config
|
62
|
+
for model_id, config in TOGETHER_IMAGE_MODELS.items()
|
63
|
+
if config.organization.lower() == organization.lower()
|
64
|
+
}
|
65
|
+
|
66
|
+
|
67
|
+
def get_default_image_model() -> str:
|
68
|
+
"""Get the default image model ID"""
|
69
|
+
return "black-forest-labs/FLUX.1-schnell-Free"
|