airtrain 0.1.2__py3-none-any.whl → 0.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airtrain/__init__.py +148 -2
- airtrain/__main__.py +4 -0
- airtrain/__pycache__/__init__.cpython-313.pyc +0 -0
- airtrain/agents/__init__.py +45 -0
- airtrain/agents/example_agent.py +348 -0
- airtrain/agents/groq_agent.py +289 -0
- airtrain/agents/memory.py +663 -0
- airtrain/agents/registry.py +465 -0
- airtrain/builder/__init__.py +3 -0
- airtrain/builder/agent_builder.py +122 -0
- airtrain/cli/__init__.py +0 -0
- airtrain/cli/builder.py +23 -0
- airtrain/cli/main.py +120 -0
- airtrain/contrib/__init__.py +29 -0
- airtrain/contrib/travel/__init__.py +35 -0
- airtrain/contrib/travel/agents.py +243 -0
- airtrain/contrib/travel/models.py +59 -0
- airtrain/core/__init__.py +7 -0
- airtrain/core/__pycache__/__init__.cpython-313.pyc +0 -0
- airtrain/core/__pycache__/schemas.cpython-313.pyc +0 -0
- airtrain/core/__pycache__/skills.cpython-313.pyc +0 -0
- airtrain/core/credentials.py +171 -0
- airtrain/core/schemas.py +237 -0
- airtrain/core/skills.py +269 -0
- airtrain/integrations/__init__.py +74 -0
- airtrain/integrations/anthropic/__init__.py +33 -0
- airtrain/integrations/anthropic/credentials.py +32 -0
- airtrain/integrations/anthropic/list_models.py +110 -0
- airtrain/integrations/anthropic/models_config.py +100 -0
- airtrain/integrations/anthropic/skills.py +155 -0
- airtrain/integrations/aws/__init__.py +6 -0
- airtrain/integrations/aws/credentials.py +36 -0
- airtrain/integrations/aws/skills.py +98 -0
- airtrain/integrations/cerebras/__init__.py +6 -0
- airtrain/integrations/cerebras/credentials.py +19 -0
- airtrain/integrations/cerebras/skills.py +127 -0
- airtrain/integrations/combined/__init__.py +21 -0
- airtrain/integrations/combined/groq_fireworks_skills.py +126 -0
- airtrain/integrations/combined/list_models_factory.py +210 -0
- airtrain/integrations/fireworks/__init__.py +21 -0
- airtrain/integrations/fireworks/completion_skills.py +147 -0
- airtrain/integrations/fireworks/conversation_manager.py +109 -0
- airtrain/integrations/fireworks/credentials.py +26 -0
- airtrain/integrations/fireworks/list_models.py +128 -0
- airtrain/integrations/fireworks/models.py +139 -0
- airtrain/integrations/fireworks/requests_skills.py +207 -0
- airtrain/integrations/fireworks/skills.py +181 -0
- airtrain/integrations/fireworks/structured_completion_skills.py +175 -0
- airtrain/integrations/fireworks/structured_requests_skills.py +291 -0
- airtrain/integrations/fireworks/structured_skills.py +102 -0
- airtrain/integrations/google/__init__.py +7 -0
- airtrain/integrations/google/credentials.py +58 -0
- airtrain/integrations/google/skills.py +122 -0
- airtrain/integrations/groq/__init__.py +23 -0
- airtrain/integrations/groq/credentials.py +24 -0
- airtrain/integrations/groq/models_config.py +162 -0
- airtrain/integrations/groq/skills.py +201 -0
- airtrain/integrations/ollama/__init__.py +6 -0
- airtrain/integrations/ollama/credentials.py +26 -0
- airtrain/integrations/ollama/skills.py +41 -0
- airtrain/integrations/openai/__init__.py +37 -0
- airtrain/integrations/openai/chinese_assistant.py +42 -0
- airtrain/integrations/openai/credentials.py +39 -0
- airtrain/integrations/openai/list_models.py +112 -0
- airtrain/integrations/openai/models_config.py +224 -0
- airtrain/integrations/openai/skills.py +342 -0
- airtrain/integrations/perplexity/__init__.py +49 -0
- airtrain/integrations/perplexity/credentials.py +43 -0
- airtrain/integrations/perplexity/list_models.py +112 -0
- airtrain/integrations/perplexity/models_config.py +128 -0
- airtrain/integrations/perplexity/skills.py +279 -0
- airtrain/integrations/sambanova/__init__.py +6 -0
- airtrain/integrations/sambanova/credentials.py +20 -0
- airtrain/integrations/sambanova/skills.py +129 -0
- airtrain/integrations/search/__init__.py +21 -0
- airtrain/integrations/search/exa/__init__.py +23 -0
- airtrain/integrations/search/exa/credentials.py +30 -0
- airtrain/integrations/search/exa/schemas.py +114 -0
- airtrain/integrations/search/exa/skills.py +115 -0
- airtrain/integrations/together/__init__.py +33 -0
- airtrain/integrations/together/audio_models_config.py +34 -0
- airtrain/integrations/together/credentials.py +22 -0
- airtrain/integrations/together/embedding_models_config.py +92 -0
- airtrain/integrations/together/image_models_config.py +69 -0
- airtrain/integrations/together/image_skill.py +143 -0
- airtrain/integrations/together/list_models.py +76 -0
- airtrain/integrations/together/models.py +95 -0
- airtrain/integrations/together/models_config.py +399 -0
- airtrain/integrations/together/rerank_models_config.py +43 -0
- airtrain/integrations/together/rerank_skill.py +49 -0
- airtrain/integrations/together/schemas.py +33 -0
- airtrain/integrations/together/skills.py +305 -0
- airtrain/integrations/together/vision_models_config.py +49 -0
- airtrain/telemetry/__init__.py +38 -0
- airtrain/telemetry/service.py +167 -0
- airtrain/telemetry/views.py +237 -0
- airtrain/tools/__init__.py +45 -0
- airtrain/tools/command.py +398 -0
- airtrain/tools/filesystem.py +166 -0
- airtrain/tools/network.py +111 -0
- airtrain/tools/registry.py +320 -0
- airtrain/tools/search.py +450 -0
- airtrain/tools/testing.py +135 -0
- airtrain-0.1.4.dist-info/METADATA +222 -0
- airtrain-0.1.4.dist-info/RECORD +108 -0
- {airtrain-0.1.2.dist-info → airtrain-0.1.4.dist-info}/WHEEL +1 -1
- airtrain-0.1.4.dist-info/entry_points.txt +2 -0
- airtrain-0.1.2.dist-info/METADATA +0 -106
- airtrain-0.1.2.dist-info/RECORD +0 -5
- {airtrain-0.1.2.dist-info → airtrain-0.1.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,110 @@
|
|
1
|
+
from typing import Optional, List, Dict, Any
|
2
|
+
from pydantic import Field
|
3
|
+
|
4
|
+
from airtrain.core.skills import Skill, ProcessingError
|
5
|
+
from airtrain.core.schemas import InputSchema, OutputSchema
|
6
|
+
from .credentials import AnthropicCredentials
|
7
|
+
from .models_config import ANTHROPIC_MODELS, AnthropicModelConfig
|
8
|
+
|
9
|
+
|
10
|
+
class AnthropicModel:
|
11
|
+
"""Class to represent an Anthropic model."""
|
12
|
+
|
13
|
+
def __init__(self, model_id: str, config: AnthropicModelConfig):
|
14
|
+
"""Initialize the Anthropic model."""
|
15
|
+
self.id = model_id
|
16
|
+
self.display_name = config.display_name
|
17
|
+
self.base_model = config.base_model
|
18
|
+
self.input_price = config.input_price
|
19
|
+
self.cached_write_price = config.cached_write_price
|
20
|
+
self.cached_read_price = config.cached_read_price
|
21
|
+
self.output_price = config.output_price
|
22
|
+
|
23
|
+
def dict(self, exclude_none=False):
|
24
|
+
"""Convert the model to a dictionary."""
|
25
|
+
result = {
|
26
|
+
"id": self.id,
|
27
|
+
"display_name": self.display_name,
|
28
|
+
"base_model": self.base_model,
|
29
|
+
"input_price": float(self.input_price),
|
30
|
+
"output_price": float(self.output_price),
|
31
|
+
}
|
32
|
+
|
33
|
+
if self.cached_write_price is not None:
|
34
|
+
result["cached_write_price"] = float(self.cached_write_price)
|
35
|
+
elif not exclude_none:
|
36
|
+
result["cached_write_price"] = None
|
37
|
+
|
38
|
+
if self.cached_read_price is not None:
|
39
|
+
result["cached_read_price"] = float(self.cached_read_price)
|
40
|
+
elif not exclude_none:
|
41
|
+
result["cached_read_price"] = None
|
42
|
+
|
43
|
+
return result
|
44
|
+
|
45
|
+
|
46
|
+
class AnthropicListModelsInput(InputSchema):
|
47
|
+
"""Schema for Anthropic list models input"""
|
48
|
+
|
49
|
+
api_models_only: bool = Field(
|
50
|
+
default=False,
|
51
|
+
description=(
|
52
|
+
"If True, fetch models from the API only. If False, use local config."
|
53
|
+
)
|
54
|
+
)
|
55
|
+
|
56
|
+
|
57
|
+
class AnthropicListModelsOutput(OutputSchema):
|
58
|
+
"""Schema for Anthropic list models output"""
|
59
|
+
|
60
|
+
models: List[Dict[str, Any]] = Field(
|
61
|
+
default_factory=list,
|
62
|
+
description="List of Anthropic models"
|
63
|
+
)
|
64
|
+
|
65
|
+
|
66
|
+
class AnthropicListModelsSkill(
|
67
|
+
Skill[AnthropicListModelsInput, AnthropicListModelsOutput]
|
68
|
+
):
|
69
|
+
"""Skill for listing Anthropic models"""
|
70
|
+
|
71
|
+
input_schema = AnthropicListModelsInput
|
72
|
+
output_schema = AnthropicListModelsOutput
|
73
|
+
|
74
|
+
def __init__(self, credentials: Optional[AnthropicCredentials] = None):
|
75
|
+
"""Initialize the skill with optional credentials"""
|
76
|
+
super().__init__()
|
77
|
+
self.credentials = credentials
|
78
|
+
|
79
|
+
def process(
|
80
|
+
self, input_data: AnthropicListModelsInput
|
81
|
+
) -> AnthropicListModelsOutput:
|
82
|
+
"""Process the input and return a list of models."""
|
83
|
+
try:
|
84
|
+
models = []
|
85
|
+
|
86
|
+
if input_data.api_models_only:
|
87
|
+
# Fetch models from Anthropic API
|
88
|
+
# Require credentials if using API models
|
89
|
+
if not self.credentials:
|
90
|
+
raise ProcessingError(
|
91
|
+
"Anthropic credentials required for API models"
|
92
|
+
)
|
93
|
+
|
94
|
+
# Note: Anthropic doesn't have a public models list endpoint
|
95
|
+
# We'll raise an error instead
|
96
|
+
raise ProcessingError(
|
97
|
+
"Anthropic API does not provide a models list endpoint. "
|
98
|
+
"Use api_models_only=False to list models from local config."
|
99
|
+
)
|
100
|
+
else:
|
101
|
+
# Use local model config - no credentials needed
|
102
|
+
for model_id, config in ANTHROPIC_MODELS.items():
|
103
|
+
model = AnthropicModel(model_id, config)
|
104
|
+
models.append(model.dict())
|
105
|
+
|
106
|
+
# Return the output
|
107
|
+
return AnthropicListModelsOutput(models=models)
|
108
|
+
|
109
|
+
except Exception as e:
|
110
|
+
raise ProcessingError(f"Failed to list Anthropic models: {str(e)}")
|
@@ -0,0 +1,100 @@
|
|
1
|
+
from typing import Dict, NamedTuple, Optional
|
2
|
+
from decimal import Decimal
|
3
|
+
|
4
|
+
|
5
|
+
class AnthropicModelConfig(NamedTuple):
|
6
|
+
display_name: str
|
7
|
+
base_model: str
|
8
|
+
input_price: Decimal
|
9
|
+
cached_write_price: Optional[Decimal]
|
10
|
+
cached_read_price: Optional[Decimal]
|
11
|
+
output_price: Decimal
|
12
|
+
|
13
|
+
|
14
|
+
ANTHROPIC_MODELS: Dict[str, AnthropicModelConfig] = {
|
15
|
+
"claude-3-7-sonnet": AnthropicModelConfig(
|
16
|
+
display_name="Claude 3.7 Sonnet",
|
17
|
+
base_model="claude-3-7-sonnet",
|
18
|
+
input_price=Decimal("3.00"),
|
19
|
+
cached_write_price=Decimal("3.75"),
|
20
|
+
cached_read_price=Decimal("0.30"),
|
21
|
+
output_price=Decimal("15.00"),
|
22
|
+
),
|
23
|
+
"claude-3-5-haiku": AnthropicModelConfig(
|
24
|
+
display_name="Claude 3.5 Haiku",
|
25
|
+
base_model="claude-3-5-haiku",
|
26
|
+
input_price=Decimal("0.80"),
|
27
|
+
cached_write_price=Decimal("1.00"),
|
28
|
+
cached_read_price=Decimal("0.08"),
|
29
|
+
output_price=Decimal("4.00"),
|
30
|
+
),
|
31
|
+
"claude-3-opus": AnthropicModelConfig(
|
32
|
+
display_name="Claude 3 Opus",
|
33
|
+
base_model="claude-3-opus",
|
34
|
+
input_price=Decimal("15.00"),
|
35
|
+
cached_write_price=Decimal("18.75"),
|
36
|
+
cached_read_price=Decimal("1.50"),
|
37
|
+
output_price=Decimal("75.00"),
|
38
|
+
),
|
39
|
+
"claude-3-sonnet": AnthropicModelConfig(
|
40
|
+
display_name="Claude 3 Sonnet",
|
41
|
+
base_model="claude-3-sonnet",
|
42
|
+
input_price=Decimal("3.00"),
|
43
|
+
cached_write_price=Decimal("3.75"),
|
44
|
+
cached_read_price=Decimal("0.30"),
|
45
|
+
output_price=Decimal("15.00"),
|
46
|
+
),
|
47
|
+
"claude-3-haiku": AnthropicModelConfig(
|
48
|
+
display_name="Claude 3 Haiku",
|
49
|
+
base_model="claude-3-haiku",
|
50
|
+
input_price=Decimal("0.25"),
|
51
|
+
cached_write_price=Decimal("0.31"),
|
52
|
+
cached_read_price=Decimal("0.025"),
|
53
|
+
output_price=Decimal("1.25"),
|
54
|
+
),
|
55
|
+
}
|
56
|
+
|
57
|
+
|
58
|
+
def get_model_config(model_id: str) -> AnthropicModelConfig:
|
59
|
+
"""Get model configuration by model ID"""
|
60
|
+
if model_id not in ANTHROPIC_MODELS:
|
61
|
+
raise ValueError(f"Model {model_id} not found in Anthropic models")
|
62
|
+
return ANTHROPIC_MODELS[model_id]
|
63
|
+
|
64
|
+
|
65
|
+
def get_default_model() -> str:
|
66
|
+
"""Get the default model ID"""
|
67
|
+
return "claude-3-sonnet"
|
68
|
+
|
69
|
+
|
70
|
+
def calculate_cost(
|
71
|
+
model_id: str,
|
72
|
+
input_tokens: int,
|
73
|
+
output_tokens: int,
|
74
|
+
use_cached: bool = False,
|
75
|
+
cache_type: str = "read"
|
76
|
+
) -> Decimal:
|
77
|
+
"""Calculate cost for token usage
|
78
|
+
|
79
|
+
Args:
|
80
|
+
model_id: The model ID to calculate costs for
|
81
|
+
input_tokens: Number of input tokens
|
82
|
+
output_tokens: Number of output tokens
|
83
|
+
use_cached: Whether to use cached pricing
|
84
|
+
cache_type: Either "read" or "write" for cached pricing type
|
85
|
+
"""
|
86
|
+
config = get_model_config(model_id)
|
87
|
+
|
88
|
+
if not use_cached:
|
89
|
+
input_cost = config.input_price * Decimal(str(input_tokens))
|
90
|
+
else:
|
91
|
+
if cache_type == "read" and config.cached_read_price is not None:
|
92
|
+
input_cost = config.cached_read_price * Decimal(str(input_tokens))
|
93
|
+
elif cache_type == "write" and config.cached_write_price is not None:
|
94
|
+
input_cost = config.cached_write_price * Decimal(str(input_tokens))
|
95
|
+
else:
|
96
|
+
input_cost = config.input_price * Decimal(str(input_tokens))
|
97
|
+
|
98
|
+
output_cost = config.output_price * Decimal(str(output_tokens))
|
99
|
+
|
100
|
+
return (input_cost + output_cost) / Decimal("1000")
|
@@ -0,0 +1,155 @@
|
|
1
|
+
from typing import List, Optional, Dict, Any, Generator
|
2
|
+
from pydantic import Field
|
3
|
+
from anthropic import Anthropic
|
4
|
+
import base64
|
5
|
+
from pathlib import Path
|
6
|
+
from loguru import logger
|
7
|
+
|
8
|
+
from airtrain.core.skills import Skill, ProcessingError
|
9
|
+
from airtrain.core.schemas import InputSchema, OutputSchema
|
10
|
+
from .credentials import AnthropicCredentials
|
11
|
+
|
12
|
+
|
13
|
+
class AnthropicInput(InputSchema):
|
14
|
+
"""Schema for Anthropic chat input"""
|
15
|
+
|
16
|
+
user_input: str = Field(..., description="User's input text")
|
17
|
+
system_prompt: str = Field(
|
18
|
+
default="You are a helpful assistant.",
|
19
|
+
description="System prompt to guide the model's behavior",
|
20
|
+
)
|
21
|
+
conversation_history: List[Dict[str, str]] = Field(
|
22
|
+
default_factory=list,
|
23
|
+
description="List of previous conversation messages in [{'role': 'user|assistant', 'content': 'message'}] format",
|
24
|
+
)
|
25
|
+
model: str = Field(
|
26
|
+
default="claude-3-opus-20240229", description="Anthropic model to use"
|
27
|
+
)
|
28
|
+
max_tokens: Optional[int] = Field(
|
29
|
+
default=131072, description="Maximum tokens in response"
|
30
|
+
)
|
31
|
+
temperature: float = Field(
|
32
|
+
default=0.7, description="Temperature for response generation", ge=0, le=1
|
33
|
+
)
|
34
|
+
images: List[Path] = Field(
|
35
|
+
default_factory=list,
|
36
|
+
description="List of image paths to include in the message",
|
37
|
+
)
|
38
|
+
stream: bool = Field(
|
39
|
+
default=False, description="Whether to stream the response progressively"
|
40
|
+
)
|
41
|
+
|
42
|
+
|
43
|
+
class AnthropicOutput(OutputSchema):
|
44
|
+
"""Schema for Anthropic chat output"""
|
45
|
+
|
46
|
+
response: str = Field(..., description="Model's response text")
|
47
|
+
used_model: str = Field(..., description="Model used for generation")
|
48
|
+
usage: Dict[str, Any] = Field(
|
49
|
+
default_factory=dict, description="Usage statistics from the API"
|
50
|
+
)
|
51
|
+
|
52
|
+
|
53
|
+
class AnthropicChatSkill(Skill[AnthropicInput, AnthropicOutput]):
|
54
|
+
"""Skill for Anthropic chat"""
|
55
|
+
|
56
|
+
input_schema = AnthropicInput
|
57
|
+
output_schema = AnthropicOutput
|
58
|
+
|
59
|
+
def __init__(self, credentials: Optional[AnthropicCredentials] = None):
|
60
|
+
super().__init__()
|
61
|
+
self.credentials = credentials or AnthropicCredentials.from_env()
|
62
|
+
self.client = Anthropic(
|
63
|
+
api_key=self.credentials.anthropic_api_key.get_secret_value()
|
64
|
+
)
|
65
|
+
|
66
|
+
def _build_messages(self, input_data: AnthropicInput) -> List[Dict[str, Any]]:
|
67
|
+
"""
|
68
|
+
Build messages list from input data including conversation history.
|
69
|
+
|
70
|
+
Args:
|
71
|
+
input_data: The input data containing system prompt, conversation history, and user input
|
72
|
+
|
73
|
+
Returns:
|
74
|
+
List[Dict[str, Any]]: List of messages in the format required by Anthropic
|
75
|
+
"""
|
76
|
+
messages = []
|
77
|
+
|
78
|
+
# Add conversation history if present
|
79
|
+
if input_data.conversation_history:
|
80
|
+
messages.extend(input_data.conversation_history)
|
81
|
+
|
82
|
+
# Prepare user message content
|
83
|
+
user_message = {"type": "text", "text": input_data.user_input}
|
84
|
+
|
85
|
+
# Add images if present
|
86
|
+
if input_data.images:
|
87
|
+
content = []
|
88
|
+
for image_path in input_data.images:
|
89
|
+
with open(image_path, "rb") as img_file:
|
90
|
+
base64_image = base64.b64encode(img_file.read()).decode("utf-8")
|
91
|
+
content.append(
|
92
|
+
{
|
93
|
+
"type": "image",
|
94
|
+
"source": {
|
95
|
+
"type": "base64",
|
96
|
+
"media_type": "image/jpeg",
|
97
|
+
"data": base64_image,
|
98
|
+
},
|
99
|
+
}
|
100
|
+
)
|
101
|
+
content.append(user_message)
|
102
|
+
messages.append({"role": "user", "content": content})
|
103
|
+
else:
|
104
|
+
messages.append({"role": "user", "content": [user_message]})
|
105
|
+
|
106
|
+
return messages
|
107
|
+
|
108
|
+
def process_stream(self, input_data: AnthropicInput) -> Generator[str, None, None]:
|
109
|
+
"""Process the input and stream the response token by token."""
|
110
|
+
try:
|
111
|
+
messages = self._build_messages(input_data)
|
112
|
+
|
113
|
+
with self.client.beta.messages.stream(
|
114
|
+
model=input_data.model,
|
115
|
+
system=input_data.system_prompt,
|
116
|
+
messages=messages,
|
117
|
+
max_tokens=input_data.max_tokens,
|
118
|
+
temperature=input_data.temperature,
|
119
|
+
) as stream:
|
120
|
+
for chunk in stream.text_stream:
|
121
|
+
yield chunk
|
122
|
+
|
123
|
+
except Exception as e:
|
124
|
+
logger.exception(f"Anthropic streaming failed: {str(e)}")
|
125
|
+
raise ProcessingError(f"Anthropic streaming failed: {str(e)}")
|
126
|
+
|
127
|
+
def process(self, input_data: AnthropicInput) -> AnthropicOutput:
|
128
|
+
"""Process the input and return the complete response."""
|
129
|
+
try:
|
130
|
+
if input_data.stream:
|
131
|
+
response_chunks = []
|
132
|
+
for chunk in self.process_stream(input_data):
|
133
|
+
response_chunks.append(chunk)
|
134
|
+
response = "".join(response_chunks)
|
135
|
+
usage = {} # Usage stats not available in streaming
|
136
|
+
else:
|
137
|
+
messages = self._build_messages(input_data)
|
138
|
+
response = self.client.messages.create(
|
139
|
+
model=input_data.model,
|
140
|
+
system=input_data.system_prompt,
|
141
|
+
messages=messages,
|
142
|
+
max_tokens=input_data.max_tokens,
|
143
|
+
temperature=input_data.temperature,
|
144
|
+
)
|
145
|
+
usage = response.usage.model_dump() if response.usage else {}
|
146
|
+
|
147
|
+
return AnthropicOutput(
|
148
|
+
response=response.content[0].text,
|
149
|
+
used_model=input_data.model,
|
150
|
+
usage=usage,
|
151
|
+
)
|
152
|
+
|
153
|
+
except Exception as e:
|
154
|
+
logger.exception(f"Anthropic processing failed: {str(e)}")
|
155
|
+
raise ProcessingError(f"Anthropic processing failed: {str(e)}")
|
@@ -0,0 +1,36 @@
|
|
1
|
+
from typing import Optional
|
2
|
+
from pydantic import Field, SecretStr
|
3
|
+
from airtrain.core.credentials import BaseCredentials, CredentialValidationError
|
4
|
+
import boto3
|
5
|
+
|
6
|
+
|
7
|
+
class AWSCredentials(BaseCredentials):
|
8
|
+
"""AWS credentials"""
|
9
|
+
|
10
|
+
aws_access_key_id: SecretStr = Field(..., description="AWS Access Key ID")
|
11
|
+
aws_secret_access_key: SecretStr = Field(..., description="AWS Secret Access Key")
|
12
|
+
aws_region: str = Field(default="us-east-1", description="AWS Region")
|
13
|
+
aws_session_token: Optional[SecretStr] = Field(
|
14
|
+
None, description="AWS Session Token"
|
15
|
+
)
|
16
|
+
|
17
|
+
_required_credentials = {"aws_access_key_id", "aws_secret_access_key"}
|
18
|
+
|
19
|
+
async def validate_credentials(self) -> bool:
|
20
|
+
"""Validate AWS credentials by making a test API call"""
|
21
|
+
try:
|
22
|
+
session = boto3.Session(
|
23
|
+
aws_access_key_id=self.aws_access_key_id.get_secret_value(),
|
24
|
+
aws_secret_access_key=self.aws_secret_access_key.get_secret_value(),
|
25
|
+
aws_session_token=(
|
26
|
+
self.aws_session_token.get_secret_value()
|
27
|
+
if self.aws_session_token
|
28
|
+
else None
|
29
|
+
),
|
30
|
+
region_name=self.aws_region,
|
31
|
+
)
|
32
|
+
sts = session.client("sts")
|
33
|
+
sts.get_caller_identity()
|
34
|
+
return True
|
35
|
+
except Exception as e:
|
36
|
+
raise CredentialValidationError(f"Invalid AWS credentials: {str(e)}")
|
@@ -0,0 +1,98 @@
|
|
1
|
+
from typing import List, Optional, Dict, Any
|
2
|
+
from pydantic import Field
|
3
|
+
import boto3
|
4
|
+
from pathlib import Path
|
5
|
+
from loguru import logger
|
6
|
+
|
7
|
+
from airtrain.core.skills import Skill, ProcessingError
|
8
|
+
from airtrain.core.schemas import InputSchema, OutputSchema
|
9
|
+
from .credentials import AWSCredentials
|
10
|
+
|
11
|
+
|
12
|
+
class AWSBedrockInput(InputSchema):
|
13
|
+
"""Schema for AWS Bedrock chat input"""
|
14
|
+
|
15
|
+
user_input: str = Field(..., description="User's input text")
|
16
|
+
system_prompt: str = Field(
|
17
|
+
default="You are a helpful assistant.",
|
18
|
+
description="System prompt to guide the model's behavior",
|
19
|
+
)
|
20
|
+
model: str = Field(
|
21
|
+
default="anthropic.claude-3-sonnet-20240229-v1:0",
|
22
|
+
description="AWS Bedrock model to use",
|
23
|
+
)
|
24
|
+
max_tokens: int = Field(default=131072, description="Maximum tokens in response")
|
25
|
+
temperature: float = Field(
|
26
|
+
default=0.7, description="Temperature for response generation", ge=0, le=1
|
27
|
+
)
|
28
|
+
images: Optional[List[Path]] = Field(
|
29
|
+
default=None,
|
30
|
+
description="Optional list of image paths to include in the message",
|
31
|
+
)
|
32
|
+
|
33
|
+
|
34
|
+
class AWSBedrockOutput(OutputSchema):
|
35
|
+
"""Schema for AWS Bedrock chat output"""
|
36
|
+
|
37
|
+
response: str = Field(..., description="Model's response text")
|
38
|
+
used_model: str = Field(..., description="Model used for generation")
|
39
|
+
usage: Dict[str, Any] = Field(
|
40
|
+
default_factory=dict, description="Usage statistics from the API"
|
41
|
+
)
|
42
|
+
|
43
|
+
|
44
|
+
class AWSBedrockSkill(Skill[AWSBedrockInput, AWSBedrockOutput]):
|
45
|
+
"""Skill for interacting with AWS Bedrock models"""
|
46
|
+
|
47
|
+
input_schema = AWSBedrockInput
|
48
|
+
output_schema = AWSBedrockOutput
|
49
|
+
|
50
|
+
def __init__(self, credentials: Optional[AWSCredentials] = None):
|
51
|
+
"""Initialize the skill with optional credentials"""
|
52
|
+
super().__init__()
|
53
|
+
self.credentials = credentials or AWSCredentials.from_env()
|
54
|
+
self.client = boto3.client(
|
55
|
+
"bedrock-runtime",
|
56
|
+
aws_access_key_id=self.credentials.aws_access_key_id.get_secret_value(),
|
57
|
+
aws_secret_access_key=self.credentials.aws_secret_access_key.get_secret_value(),
|
58
|
+
region_name=self.credentials.aws_region,
|
59
|
+
)
|
60
|
+
|
61
|
+
def process(self, input_data: AWSBedrockInput) -> AWSBedrockOutput:
|
62
|
+
"""Process the input using AWS Bedrock API"""
|
63
|
+
try:
|
64
|
+
logger.info(f"Processing request with model {input_data.model}")
|
65
|
+
|
66
|
+
# Prepare request body based on model provider
|
67
|
+
if "anthropic" in input_data.model:
|
68
|
+
request_body = {
|
69
|
+
"anthropic_version": "bedrock-2023-05-31",
|
70
|
+
"max_tokens": input_data.max_tokens,
|
71
|
+
"temperature": input_data.temperature,
|
72
|
+
"system": input_data.system_prompt,
|
73
|
+
"messages": [{"role": "user", "content": input_data.user_input}],
|
74
|
+
}
|
75
|
+
else:
|
76
|
+
raise ProcessingError(f"Unsupported model: {input_data.model}")
|
77
|
+
|
78
|
+
response = self.client.invoke_model(
|
79
|
+
modelId=input_data.model, body=request_body
|
80
|
+
)
|
81
|
+
|
82
|
+
# Parse response based on model provider
|
83
|
+
if "anthropic" in input_data.model:
|
84
|
+
response_data = response["body"]["completion"]
|
85
|
+
usage = {
|
86
|
+
"input_tokens": response["body"]["usage"]["input_tokens"],
|
87
|
+
"output_tokens": response["body"]["usage"]["output_tokens"],
|
88
|
+
}
|
89
|
+
else:
|
90
|
+
raise ProcessingError(f"Unsupported model response: {input_data.model}")
|
91
|
+
|
92
|
+
return AWSBedrockOutput(
|
93
|
+
response=response_data, used_model=input_data.model, usage=usage
|
94
|
+
)
|
95
|
+
|
96
|
+
except Exception as e:
|
97
|
+
logger.exception(f"AWS Bedrock processing failed: {str(e)}")
|
98
|
+
raise ProcessingError(f"AWS Bedrock processing failed: {str(e)}")
|
@@ -0,0 +1,19 @@
|
|
1
|
+
from pydantic import Field, SecretStr
|
2
|
+
from airtrain.core.credentials import BaseCredentials, CredentialValidationError
|
3
|
+
|
4
|
+
|
5
|
+
class CerebrasCredentials(BaseCredentials):
|
6
|
+
"""Cerebras credentials"""
|
7
|
+
|
8
|
+
cerebras_api_key: SecretStr = Field(..., description="Cerebras API key")
|
9
|
+
|
10
|
+
_required_credentials = {"cerebras_api_key"}
|
11
|
+
|
12
|
+
async def validate_credentials(self) -> bool:
|
13
|
+
"""Validate Cerebras credentials"""
|
14
|
+
try:
|
15
|
+
# Implement Cerebras-specific validation
|
16
|
+
# This would depend on their API client implementation
|
17
|
+
return True
|
18
|
+
except Exception as e:
|
19
|
+
raise CredentialValidationError(f"Invalid Cerebras credentials: {str(e)}")
|
@@ -0,0 +1,127 @@
|
|
1
|
+
from typing import List, Optional, Dict, Any, Generator
|
2
|
+
from pydantic import Field
|
3
|
+
from cerebras.cloud.sdk import Cerebras
|
4
|
+
from loguru import logger
|
5
|
+
|
6
|
+
from airtrain.core.skills import Skill, ProcessingError
|
7
|
+
from airtrain.core.schemas import InputSchema, OutputSchema
|
8
|
+
from .credentials import CerebrasCredentials
|
9
|
+
|
10
|
+
|
11
|
+
class CerebrasInput(InputSchema):
|
12
|
+
"""Schema for Cerebras chat input"""
|
13
|
+
|
14
|
+
user_input: str = Field(..., description="User's input text")
|
15
|
+
system_prompt: str = Field(
|
16
|
+
default="You are a helpful assistant.",
|
17
|
+
description="System prompt to guide the model's behavior",
|
18
|
+
)
|
19
|
+
conversation_history: List[Dict[str, str]] = Field(
|
20
|
+
default_factory=list,
|
21
|
+
description="List of previous conversation messages in [{'role': 'user|assistant', 'content': 'message'}] format",
|
22
|
+
)
|
23
|
+
model: str = Field(default="llama3.1-8b", description="Cerebras model to use")
|
24
|
+
max_tokens: Optional[int] = Field(
|
25
|
+
default=131072, description="Maximum tokens in response"
|
26
|
+
)
|
27
|
+
temperature: float = Field(
|
28
|
+
default=0.7, description="Temperature for response generation", ge=0, le=1
|
29
|
+
)
|
30
|
+
stream: bool = Field(
|
31
|
+
default=False, description="Whether to stream the response progressively"
|
32
|
+
)
|
33
|
+
|
34
|
+
|
35
|
+
class CerebrasOutput(OutputSchema):
|
36
|
+
"""Schema for Cerebras chat output"""
|
37
|
+
|
38
|
+
response: str = Field(..., description="Model's response text")
|
39
|
+
used_model: str = Field(..., description="Model used for generation")
|
40
|
+
usage: Dict[str, Any] = Field(default_factory=dict, description="Usage statistics")
|
41
|
+
|
42
|
+
|
43
|
+
class CerebrasChatSkill(Skill[CerebrasInput, CerebrasOutput]):
|
44
|
+
"""Skill for Cerebras chat"""
|
45
|
+
|
46
|
+
input_schema = CerebrasInput
|
47
|
+
output_schema = CerebrasOutput
|
48
|
+
|
49
|
+
def __init__(self, credentials: Optional[CerebrasCredentials] = None):
|
50
|
+
super().__init__()
|
51
|
+
self.credentials = credentials or CerebrasCredentials.from_env()
|
52
|
+
self.client = Cerebras(
|
53
|
+
api_key=self.credentials.cerebras_api_key.get_secret_value()
|
54
|
+
)
|
55
|
+
|
56
|
+
def _build_messages(self, input_data: CerebrasInput) -> List[Dict[str, str]]:
|
57
|
+
"""
|
58
|
+
Build messages list from input data including conversation history.
|
59
|
+
|
60
|
+
Args:
|
61
|
+
input_data: The input data containing system prompt, conversation history, and user input
|
62
|
+
|
63
|
+
Returns:
|
64
|
+
List[Dict[str, str]]: List of messages in the format required by Cerebras
|
65
|
+
"""
|
66
|
+
messages = [{"role": "system", "content": input_data.system_prompt}]
|
67
|
+
|
68
|
+
# Add conversation history if present
|
69
|
+
if input_data.conversation_history:
|
70
|
+
messages.extend(input_data.conversation_history)
|
71
|
+
|
72
|
+
# Add current user input
|
73
|
+
messages.append({"role": "user", "content": input_data.user_input})
|
74
|
+
|
75
|
+
return messages
|
76
|
+
|
77
|
+
def process_stream(self, input_data: CerebrasInput) -> Generator[str, None, None]:
|
78
|
+
"""Process the input and stream the response token by token."""
|
79
|
+
try:
|
80
|
+
messages = self._build_messages(input_data)
|
81
|
+
|
82
|
+
stream = self.client.chat.completions.create(
|
83
|
+
model=input_data.model,
|
84
|
+
messages=messages,
|
85
|
+
temperature=input_data.temperature,
|
86
|
+
max_tokens=input_data.max_tokens,
|
87
|
+
stream=True,
|
88
|
+
)
|
89
|
+
|
90
|
+
for chunk in stream:
|
91
|
+
if chunk.choices[0].delta.content is not None:
|
92
|
+
yield chunk.choices[0].delta.content
|
93
|
+
|
94
|
+
except Exception as e:
|
95
|
+
logger.exception(f"Cerebras streaming failed: {str(e)}")
|
96
|
+
raise ProcessingError(f"Cerebras streaming failed: {str(e)}")
|
97
|
+
|
98
|
+
def process(self, input_data: CerebrasInput) -> CerebrasOutput:
|
99
|
+
"""Process the input and return the complete response."""
|
100
|
+
try:
|
101
|
+
if input_data.stream:
|
102
|
+
response_chunks = []
|
103
|
+
for chunk in self.process_stream(input_data):
|
104
|
+
response_chunks.append(chunk)
|
105
|
+
response = "".join(response_chunks)
|
106
|
+
usage = {} # Usage stats not available in streaming
|
107
|
+
else:
|
108
|
+
messages = self._build_messages(input_data)
|
109
|
+
response = self.client.chat.completions.create(
|
110
|
+
model=input_data.model,
|
111
|
+
messages=messages,
|
112
|
+
temperature=input_data.temperature,
|
113
|
+
max_tokens=input_data.max_tokens,
|
114
|
+
)
|
115
|
+
usage = (
|
116
|
+
response.usage.model_dump() if hasattr(response, "usage") else {}
|
117
|
+
)
|
118
|
+
|
119
|
+
return CerebrasOutput(
|
120
|
+
response=response.choices[0].message.content,
|
121
|
+
used_model=input_data.model,
|
122
|
+
usage=usage,
|
123
|
+
)
|
124
|
+
|
125
|
+
except Exception as e:
|
126
|
+
logger.exception(f"Cerebras processing failed: {str(e)}")
|
127
|
+
raise ProcessingError(f"Cerebras processing failed: {str(e)}")
|
@@ -0,0 +1,21 @@
|
|
1
|
+
"""Combined integration modules for Airtrain"""
|
2
|
+
|
3
|
+
from .groq_fireworks_skills import (
|
4
|
+
GroqFireworksSkill,
|
5
|
+
GroqFireworksInput,
|
6
|
+
GroqFireworksOutput
|
7
|
+
)
|
8
|
+
from .list_models_factory import (
|
9
|
+
ListModelsSkillFactory,
|
10
|
+
GenericListModelsInput,
|
11
|
+
GenericListModelsOutput
|
12
|
+
)
|
13
|
+
|
14
|
+
__all__ = [
|
15
|
+
"GroqFireworksSkill",
|
16
|
+
"GroqFireworksInput",
|
17
|
+
"GroqFireworksOutput",
|
18
|
+
"ListModelsSkillFactory",
|
19
|
+
"GenericListModelsInput",
|
20
|
+
"GenericListModelsOutput"
|
21
|
+
]
|