airtrain 0.1.3__py3-none-any.whl → 0.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airtrain/__init__.py +146 -6
- airtrain/__main__.py +4 -0
- airtrain/__pycache__/__init__.cpython-313.pyc +0 -0
- airtrain/agents/__init__.py +45 -0
- airtrain/agents/example_agent.py +348 -0
- airtrain/agents/groq_agent.py +289 -0
- airtrain/agents/memory.py +663 -0
- airtrain/agents/registry.py +465 -0
- airtrain/builder/__init__.py +3 -0
- airtrain/builder/agent_builder.py +122 -0
- airtrain/cli/__init__.py +0 -0
- airtrain/cli/builder.py +23 -0
- airtrain/cli/main.py +120 -0
- airtrain/contrib/__init__.py +29 -0
- airtrain/contrib/travel/__init__.py +35 -0
- airtrain/contrib/travel/agents.py +243 -0
- airtrain/contrib/travel/models.py +59 -0
- airtrain/core/__pycache__/__init__.cpython-313.pyc +0 -0
- airtrain/core/__pycache__/schemas.cpython-313.pyc +0 -0
- airtrain/core/__pycache__/skills.cpython-313.pyc +0 -0
- airtrain/core/credentials.py +62 -44
- airtrain/core/skills.py +102 -0
- airtrain/integrations/__init__.py +74 -0
- airtrain/integrations/anthropic/__init__.py +33 -0
- airtrain/integrations/anthropic/credentials.py +32 -0
- airtrain/integrations/anthropic/list_models.py +110 -0
- airtrain/integrations/anthropic/models_config.py +100 -0
- airtrain/integrations/anthropic/skills.py +155 -0
- airtrain/integrations/aws/__init__.py +6 -0
- airtrain/integrations/aws/credentials.py +36 -0
- airtrain/integrations/aws/skills.py +98 -0
- airtrain/integrations/cerebras/__init__.py +6 -0
- airtrain/integrations/cerebras/credentials.py +19 -0
- airtrain/integrations/cerebras/skills.py +127 -0
- airtrain/integrations/combined/__init__.py +21 -0
- airtrain/integrations/combined/groq_fireworks_skills.py +126 -0
- airtrain/integrations/combined/list_models_factory.py +210 -0
- airtrain/integrations/fireworks/__init__.py +21 -0
- airtrain/integrations/fireworks/completion_skills.py +147 -0
- airtrain/integrations/fireworks/conversation_manager.py +109 -0
- airtrain/integrations/fireworks/credentials.py +26 -0
- airtrain/integrations/fireworks/list_models.py +128 -0
- airtrain/integrations/fireworks/models.py +139 -0
- airtrain/integrations/fireworks/requests_skills.py +207 -0
- airtrain/integrations/fireworks/skills.py +181 -0
- airtrain/integrations/fireworks/structured_completion_skills.py +175 -0
- airtrain/integrations/fireworks/structured_requests_skills.py +291 -0
- airtrain/integrations/fireworks/structured_skills.py +102 -0
- airtrain/integrations/google/__init__.py +7 -0
- airtrain/integrations/google/credentials.py +58 -0
- airtrain/integrations/google/skills.py +122 -0
- airtrain/integrations/groq/__init__.py +23 -0
- airtrain/integrations/groq/credentials.py +24 -0
- airtrain/integrations/groq/models_config.py +162 -0
- airtrain/integrations/groq/skills.py +201 -0
- airtrain/integrations/ollama/__init__.py +6 -0
- airtrain/integrations/ollama/credentials.py +26 -0
- airtrain/integrations/ollama/skills.py +41 -0
- airtrain/integrations/openai/__init__.py +37 -0
- airtrain/integrations/openai/chinese_assistant.py +42 -0
- airtrain/integrations/openai/credentials.py +39 -0
- airtrain/integrations/openai/list_models.py +112 -0
- airtrain/integrations/openai/models_config.py +224 -0
- airtrain/integrations/openai/skills.py +342 -0
- airtrain/integrations/perplexity/__init__.py +49 -0
- airtrain/integrations/perplexity/credentials.py +43 -0
- airtrain/integrations/perplexity/list_models.py +112 -0
- airtrain/integrations/perplexity/models_config.py +128 -0
- airtrain/integrations/perplexity/skills.py +279 -0
- airtrain/integrations/sambanova/__init__.py +6 -0
- airtrain/integrations/sambanova/credentials.py +20 -0
- airtrain/integrations/sambanova/skills.py +129 -0
- airtrain/integrations/search/__init__.py +21 -0
- airtrain/integrations/search/exa/__init__.py +23 -0
- airtrain/integrations/search/exa/credentials.py +30 -0
- airtrain/integrations/search/exa/schemas.py +114 -0
- airtrain/integrations/search/exa/skills.py +115 -0
- airtrain/integrations/together/__init__.py +33 -0
- airtrain/integrations/together/audio_models_config.py +34 -0
- airtrain/integrations/together/credentials.py +22 -0
- airtrain/integrations/together/embedding_models_config.py +92 -0
- airtrain/integrations/together/image_models_config.py +69 -0
- airtrain/integrations/together/image_skill.py +143 -0
- airtrain/integrations/together/list_models.py +76 -0
- airtrain/integrations/together/models.py +95 -0
- airtrain/integrations/together/models_config.py +399 -0
- airtrain/integrations/together/rerank_models_config.py +43 -0
- airtrain/integrations/together/rerank_skill.py +49 -0
- airtrain/integrations/together/schemas.py +33 -0
- airtrain/integrations/together/skills.py +305 -0
- airtrain/integrations/together/vision_models_config.py +49 -0
- airtrain/telemetry/__init__.py +38 -0
- airtrain/telemetry/service.py +167 -0
- airtrain/telemetry/views.py +237 -0
- airtrain/tools/__init__.py +45 -0
- airtrain/tools/command.py +398 -0
- airtrain/tools/filesystem.py +166 -0
- airtrain/tools/network.py +111 -0
- airtrain/tools/registry.py +320 -0
- airtrain/tools/search.py +450 -0
- airtrain/tools/testing.py +135 -0
- airtrain-0.1.4.dist-info/METADATA +222 -0
- airtrain-0.1.4.dist-info/RECORD +108 -0
- {airtrain-0.1.3.dist-info → airtrain-0.1.4.dist-info}/WHEEL +1 -1
- airtrain-0.1.4.dist-info/entry_points.txt +2 -0
- airtrain-0.1.3.dist-info/METADATA +0 -106
- airtrain-0.1.3.dist-info/RECORD +0 -9
- {airtrain-0.1.3.dist-info → airtrain-0.1.4.dist-info}/top_level.txt +0 -0
airtrain/core/skills.py
CHANGED
@@ -1,8 +1,17 @@
|
|
1
1
|
from abc import ABC, abstractmethod
|
2
2
|
from typing import Any, Dict, Optional, Type, Generic, TypeVar
|
3
3
|
from uuid import UUID, uuid4
|
4
|
+
import time
|
5
|
+
import functools
|
4
6
|
from .schemas import InputSchema, OutputSchema
|
5
7
|
|
8
|
+
# Import telemetry
|
9
|
+
from airtrain.telemetry import (
|
10
|
+
telemetry,
|
11
|
+
SkillInitTelemetryEvent,
|
12
|
+
SkillProcessTelemetryEvent,
|
13
|
+
)
|
14
|
+
|
6
15
|
# Generic type variables for input and output schemas
|
7
16
|
InputT = TypeVar("InputT", bound=InputSchema)
|
8
17
|
OutputT = TypeVar("OutputT", bound=OutputSchema)
|
@@ -17,6 +26,92 @@ class Skill(ABC, Generic[InputT, OutputT]):
|
|
17
26
|
input_schema: Type[InputT]
|
18
27
|
output_schema: Type[OutputT]
|
19
28
|
_skill_id: Optional[UUID] = None
|
29
|
+
_original_process = None
|
30
|
+
|
31
|
+
def __init__(self):
|
32
|
+
"""Initialize the skill and capture telemetry."""
|
33
|
+
# Initialize skill_id if not already set
|
34
|
+
if not self._skill_id:
|
35
|
+
self._skill_id = uuid4()
|
36
|
+
|
37
|
+
# Monkey patch the process method if it hasn't been patched yet
|
38
|
+
# This allows us to add telemetry without changing the API
|
39
|
+
if not hasattr(self.__class__, '_patched_process'):
|
40
|
+
# Store the original process method implementation from this instance
|
41
|
+
# This is crucial for proper behavior with inheritance
|
42
|
+
self.__class__._original_process = self.__class__.process
|
43
|
+
|
44
|
+
# Create a wrapper function that will capture telemetry
|
45
|
+
def _create_wrapper(original_method):
|
46
|
+
@functools.wraps(original_method)
|
47
|
+
def wrapped_process(instance, input_data):
|
48
|
+
start_time = time.time()
|
49
|
+
error = None
|
50
|
+
|
51
|
+
try:
|
52
|
+
# Call the original process method
|
53
|
+
result = original_method(instance, input_data)
|
54
|
+
return result
|
55
|
+
except Exception as e:
|
56
|
+
error = str(e)
|
57
|
+
raise
|
58
|
+
finally:
|
59
|
+
duration = time.time() - start_time
|
60
|
+
|
61
|
+
try:
|
62
|
+
# Serialize input data for telemetry
|
63
|
+
serialized_input = None
|
64
|
+
try:
|
65
|
+
# Convert input_data to dict if it's a Pydantic model
|
66
|
+
if hasattr(input_data, "dict"):
|
67
|
+
serialized_input = input_data.dict()
|
68
|
+
# If it's a dataclass
|
69
|
+
elif hasattr(input_data, "__dataclass_fields__"):
|
70
|
+
from dataclasses import asdict
|
71
|
+
serialized_input = asdict(input_data)
|
72
|
+
# Fallback
|
73
|
+
else:
|
74
|
+
serialized_input = {
|
75
|
+
"__str__": str(input_data)
|
76
|
+
}
|
77
|
+
except Exception:
|
78
|
+
# If serialization fails, provide simple info
|
79
|
+
serialized_input = {"error": "Failed to serialize input data"}
|
80
|
+
|
81
|
+
telemetry.capture(
|
82
|
+
SkillProcessTelemetryEvent(
|
83
|
+
skill_id=str(instance.skill_id),
|
84
|
+
skill_class=instance.__class__.__name__,
|
85
|
+
input_schema=instance.input_schema.__name__,
|
86
|
+
output_schema=instance.output_schema.__name__,
|
87
|
+
input_data=serialized_input,
|
88
|
+
duration_seconds=duration,
|
89
|
+
error=error,
|
90
|
+
)
|
91
|
+
)
|
92
|
+
except Exception:
|
93
|
+
# Silently continue if telemetry fails
|
94
|
+
pass
|
95
|
+
|
96
|
+
return wrapped_process
|
97
|
+
|
98
|
+
# Replace the process method with our wrapped version at the class level
|
99
|
+
self.__class__.process = _create_wrapper(self.__class__._original_process)
|
100
|
+
|
101
|
+
# Mark this class as patched to prevent double-patching
|
102
|
+
self.__class__._patched_process = True
|
103
|
+
|
104
|
+
# Capture telemetry for initialization
|
105
|
+
try:
|
106
|
+
telemetry.capture(
|
107
|
+
SkillInitTelemetryEvent(
|
108
|
+
skill_id=str(self.skill_id),
|
109
|
+
skill_class=self.__class__.__name__,
|
110
|
+
)
|
111
|
+
)
|
112
|
+
except Exception:
|
113
|
+
# Silently continue if telemetry fails
|
114
|
+
pass
|
20
115
|
|
21
116
|
@abstractmethod
|
22
117
|
def process(self, input_data: InputT) -> OutputT:
|
@@ -34,6 +129,13 @@ class Skill(ABC, Generic[InputT, OutputT]):
|
|
34
129
|
"""
|
35
130
|
pass
|
36
131
|
|
132
|
+
def __call__(self, input_data: InputT) -> OutputT:
|
133
|
+
"""Make the skill callable, with input/output validation."""
|
134
|
+
self.validate_input(input_data)
|
135
|
+
result = self.process(input_data)
|
136
|
+
self.validate_output(result)
|
137
|
+
return result
|
138
|
+
|
37
139
|
def validate_input(self, input_data: Any) -> None:
|
38
140
|
"""
|
39
141
|
Validate input data before processing.
|
@@ -0,0 +1,74 @@
|
|
1
|
+
"""Airtrain integrations package"""
|
2
|
+
|
3
|
+
# Credentials imports
|
4
|
+
from .openai.credentials import OpenAICredentials
|
5
|
+
from .aws.credentials import AWSCredentials
|
6
|
+
from .google.credentials import GoogleCloudCredentials
|
7
|
+
from .anthropic.credentials import AnthropicCredentials
|
8
|
+
from .groq.credentials import GroqCredentials
|
9
|
+
from .together.credentials import TogetherAICredentials
|
10
|
+
from .ollama.credentials import OllamaCredentials
|
11
|
+
from .sambanova.credentials import SambanovaCredentials
|
12
|
+
from .cerebras.credentials import CerebrasCredentials
|
13
|
+
from .perplexity.credentials import PerplexityCredentials
|
14
|
+
|
15
|
+
# Skills imports
|
16
|
+
from .openai.skills import OpenAIChatSkill, OpenAIParserSkill
|
17
|
+
from .anthropic.skills import AnthropicChatSkill
|
18
|
+
from .aws.skills import AWSBedrockSkill
|
19
|
+
from .google.skills import GoogleChatSkill
|
20
|
+
from .groq.skills import GroqChatSkill
|
21
|
+
from .together.skills import TogetherAIChatSkill
|
22
|
+
from .ollama.skills import OllamaChatSkill
|
23
|
+
from .sambanova.skills import SambanovaChatSkill
|
24
|
+
from .cerebras.skills import CerebrasChatSkill
|
25
|
+
from .perplexity.skills import PerplexityChatSkill, PerplexityStreamingChatSkill
|
26
|
+
|
27
|
+
# Model configurations
|
28
|
+
from .openai.models_config import OPENAI_MODELS, OpenAIModelConfig
|
29
|
+
from .anthropic.models_config import ANTHROPIC_MODELS, AnthropicModelConfig
|
30
|
+
from .perplexity.models_config import PERPLEXITY_MODELS_CONFIG
|
31
|
+
|
32
|
+
# Combined modules
|
33
|
+
from .combined.list_models_factory import (
|
34
|
+
ListModelsSkillFactory,
|
35
|
+
GenericListModelsInput,
|
36
|
+
GenericListModelsOutput,
|
37
|
+
)
|
38
|
+
|
39
|
+
__all__ = [
|
40
|
+
# Credentials
|
41
|
+
"OpenAICredentials",
|
42
|
+
"AWSCredentials",
|
43
|
+
"GoogleCloudCredentials",
|
44
|
+
"AnthropicCredentials",
|
45
|
+
"GroqCredentials",
|
46
|
+
"TogetherAICredentials",
|
47
|
+
"OllamaCredentials",
|
48
|
+
"SambanovaCredentials",
|
49
|
+
"CerebrasCredentials",
|
50
|
+
"PerplexityCredentials",
|
51
|
+
# Skills
|
52
|
+
"OpenAIChatSkill",
|
53
|
+
"OpenAIParserSkill",
|
54
|
+
"AnthropicChatSkill",
|
55
|
+
"AWSBedrockSkill",
|
56
|
+
"GoogleChatSkill",
|
57
|
+
"GroqChatSkill",
|
58
|
+
"TogetherAIChatSkill",
|
59
|
+
"OllamaChatSkill",
|
60
|
+
"SambanovaChatSkill",
|
61
|
+
"CerebrasChatSkill",
|
62
|
+
"PerplexityChatSkill",
|
63
|
+
"PerplexityStreamingChatSkill",
|
64
|
+
# Model configurations
|
65
|
+
"OPENAI_MODELS",
|
66
|
+
"OpenAIModelConfig",
|
67
|
+
"ANTHROPIC_MODELS",
|
68
|
+
"AnthropicModelConfig",
|
69
|
+
"PERPLEXITY_MODELS_CONFIG",
|
70
|
+
# Combined modules
|
71
|
+
"ListModelsSkillFactory",
|
72
|
+
"GenericListModelsInput",
|
73
|
+
"GenericListModelsOutput",
|
74
|
+
]
|
@@ -0,0 +1,33 @@
|
|
1
|
+
"""Anthropic integration for Airtrain"""
|
2
|
+
|
3
|
+
from .credentials import AnthropicCredentials
|
4
|
+
from .skills import AnthropicChatSkill, AnthropicInput, AnthropicOutput
|
5
|
+
from .models_config import (
|
6
|
+
ANTHROPIC_MODELS,
|
7
|
+
AnthropicModelConfig,
|
8
|
+
get_model_config,
|
9
|
+
get_default_model,
|
10
|
+
calculate_cost,
|
11
|
+
)
|
12
|
+
from .list_models import (
|
13
|
+
AnthropicListModelsSkill,
|
14
|
+
AnthropicListModelsInput,
|
15
|
+
AnthropicListModelsOutput,
|
16
|
+
AnthropicModel,
|
17
|
+
)
|
18
|
+
|
19
|
+
__all__ = [
|
20
|
+
"AnthropicCredentials",
|
21
|
+
"AnthropicChatSkill",
|
22
|
+
"AnthropicInput",
|
23
|
+
"AnthropicOutput",
|
24
|
+
"ANTHROPIC_MODELS",
|
25
|
+
"AnthropicModelConfig",
|
26
|
+
"get_model_config",
|
27
|
+
"get_default_model",
|
28
|
+
"calculate_cost",
|
29
|
+
"AnthropicListModelsSkill",
|
30
|
+
"AnthropicListModelsInput",
|
31
|
+
"AnthropicListModelsOutput",
|
32
|
+
"AnthropicModel",
|
33
|
+
]
|
@@ -0,0 +1,32 @@
|
|
1
|
+
from pydantic import Field, SecretStr, validator
|
2
|
+
from airtrain.core.credentials import BaseCredentials, CredentialValidationError
|
3
|
+
from anthropic import Anthropic
|
4
|
+
|
5
|
+
|
6
|
+
class AnthropicCredentials(BaseCredentials):
|
7
|
+
"""Anthropic API credentials"""
|
8
|
+
|
9
|
+
anthropic_api_key: SecretStr = Field(..., description="Anthropic API key")
|
10
|
+
version: str = Field(default="2023-06-01", description="API Version")
|
11
|
+
|
12
|
+
_required_credentials = {"anthropic_api_key"}
|
13
|
+
|
14
|
+
@validator("anthropic_api_key")
|
15
|
+
def validate_api_key_format(cls, v: SecretStr) -> SecretStr:
|
16
|
+
key = v.get_secret_value()
|
17
|
+
if not key.startswith("sk-ant-"):
|
18
|
+
raise ValueError("Anthropic API key must start with 'sk-ant-'")
|
19
|
+
return v
|
20
|
+
|
21
|
+
async def validate_credentials(self) -> bool:
|
22
|
+
"""Validate Anthropic credentials"""
|
23
|
+
try:
|
24
|
+
client = Anthropic(api_key=self.anthropic_api_key.get_secret_value())
|
25
|
+
client.messages.create(
|
26
|
+
model="claude-3-opus-20240229",
|
27
|
+
max_tokens=1,
|
28
|
+
messages=[{"role": "user", "content": "Hi"}],
|
29
|
+
)
|
30
|
+
return True
|
31
|
+
except Exception as e:
|
32
|
+
raise CredentialValidationError(f"Invalid Anthropic credentials: {str(e)}")
|
@@ -0,0 +1,110 @@
|
|
1
|
+
from typing import Optional, List, Dict, Any
|
2
|
+
from pydantic import Field
|
3
|
+
|
4
|
+
from airtrain.core.skills import Skill, ProcessingError
|
5
|
+
from airtrain.core.schemas import InputSchema, OutputSchema
|
6
|
+
from .credentials import AnthropicCredentials
|
7
|
+
from .models_config import ANTHROPIC_MODELS, AnthropicModelConfig
|
8
|
+
|
9
|
+
|
10
|
+
class AnthropicModel:
|
11
|
+
"""Class to represent an Anthropic model."""
|
12
|
+
|
13
|
+
def __init__(self, model_id: str, config: AnthropicModelConfig):
|
14
|
+
"""Initialize the Anthropic model."""
|
15
|
+
self.id = model_id
|
16
|
+
self.display_name = config.display_name
|
17
|
+
self.base_model = config.base_model
|
18
|
+
self.input_price = config.input_price
|
19
|
+
self.cached_write_price = config.cached_write_price
|
20
|
+
self.cached_read_price = config.cached_read_price
|
21
|
+
self.output_price = config.output_price
|
22
|
+
|
23
|
+
def dict(self, exclude_none=False):
|
24
|
+
"""Convert the model to a dictionary."""
|
25
|
+
result = {
|
26
|
+
"id": self.id,
|
27
|
+
"display_name": self.display_name,
|
28
|
+
"base_model": self.base_model,
|
29
|
+
"input_price": float(self.input_price),
|
30
|
+
"output_price": float(self.output_price),
|
31
|
+
}
|
32
|
+
|
33
|
+
if self.cached_write_price is not None:
|
34
|
+
result["cached_write_price"] = float(self.cached_write_price)
|
35
|
+
elif not exclude_none:
|
36
|
+
result["cached_write_price"] = None
|
37
|
+
|
38
|
+
if self.cached_read_price is not None:
|
39
|
+
result["cached_read_price"] = float(self.cached_read_price)
|
40
|
+
elif not exclude_none:
|
41
|
+
result["cached_read_price"] = None
|
42
|
+
|
43
|
+
return result
|
44
|
+
|
45
|
+
|
46
|
+
class AnthropicListModelsInput(InputSchema):
|
47
|
+
"""Schema for Anthropic list models input"""
|
48
|
+
|
49
|
+
api_models_only: bool = Field(
|
50
|
+
default=False,
|
51
|
+
description=(
|
52
|
+
"If True, fetch models from the API only. If False, use local config."
|
53
|
+
)
|
54
|
+
)
|
55
|
+
|
56
|
+
|
57
|
+
class AnthropicListModelsOutput(OutputSchema):
|
58
|
+
"""Schema for Anthropic list models output"""
|
59
|
+
|
60
|
+
models: List[Dict[str, Any]] = Field(
|
61
|
+
default_factory=list,
|
62
|
+
description="List of Anthropic models"
|
63
|
+
)
|
64
|
+
|
65
|
+
|
66
|
+
class AnthropicListModelsSkill(
|
67
|
+
Skill[AnthropicListModelsInput, AnthropicListModelsOutput]
|
68
|
+
):
|
69
|
+
"""Skill for listing Anthropic models"""
|
70
|
+
|
71
|
+
input_schema = AnthropicListModelsInput
|
72
|
+
output_schema = AnthropicListModelsOutput
|
73
|
+
|
74
|
+
def __init__(self, credentials: Optional[AnthropicCredentials] = None):
|
75
|
+
"""Initialize the skill with optional credentials"""
|
76
|
+
super().__init__()
|
77
|
+
self.credentials = credentials
|
78
|
+
|
79
|
+
def process(
|
80
|
+
self, input_data: AnthropicListModelsInput
|
81
|
+
) -> AnthropicListModelsOutput:
|
82
|
+
"""Process the input and return a list of models."""
|
83
|
+
try:
|
84
|
+
models = []
|
85
|
+
|
86
|
+
if input_data.api_models_only:
|
87
|
+
# Fetch models from Anthropic API
|
88
|
+
# Require credentials if using API models
|
89
|
+
if not self.credentials:
|
90
|
+
raise ProcessingError(
|
91
|
+
"Anthropic credentials required for API models"
|
92
|
+
)
|
93
|
+
|
94
|
+
# Note: Anthropic doesn't have a public models list endpoint
|
95
|
+
# We'll raise an error instead
|
96
|
+
raise ProcessingError(
|
97
|
+
"Anthropic API does not provide a models list endpoint. "
|
98
|
+
"Use api_models_only=False to list models from local config."
|
99
|
+
)
|
100
|
+
else:
|
101
|
+
# Use local model config - no credentials needed
|
102
|
+
for model_id, config in ANTHROPIC_MODELS.items():
|
103
|
+
model = AnthropicModel(model_id, config)
|
104
|
+
models.append(model.dict())
|
105
|
+
|
106
|
+
# Return the output
|
107
|
+
return AnthropicListModelsOutput(models=models)
|
108
|
+
|
109
|
+
except Exception as e:
|
110
|
+
raise ProcessingError(f"Failed to list Anthropic models: {str(e)}")
|
@@ -0,0 +1,100 @@
|
|
1
|
+
from typing import Dict, NamedTuple, Optional
|
2
|
+
from decimal import Decimal
|
3
|
+
|
4
|
+
|
5
|
+
class AnthropicModelConfig(NamedTuple):
|
6
|
+
display_name: str
|
7
|
+
base_model: str
|
8
|
+
input_price: Decimal
|
9
|
+
cached_write_price: Optional[Decimal]
|
10
|
+
cached_read_price: Optional[Decimal]
|
11
|
+
output_price: Decimal
|
12
|
+
|
13
|
+
|
14
|
+
ANTHROPIC_MODELS: Dict[str, AnthropicModelConfig] = {
|
15
|
+
"claude-3-7-sonnet": AnthropicModelConfig(
|
16
|
+
display_name="Claude 3.7 Sonnet",
|
17
|
+
base_model="claude-3-7-sonnet",
|
18
|
+
input_price=Decimal("3.00"),
|
19
|
+
cached_write_price=Decimal("3.75"),
|
20
|
+
cached_read_price=Decimal("0.30"),
|
21
|
+
output_price=Decimal("15.00"),
|
22
|
+
),
|
23
|
+
"claude-3-5-haiku": AnthropicModelConfig(
|
24
|
+
display_name="Claude 3.5 Haiku",
|
25
|
+
base_model="claude-3-5-haiku",
|
26
|
+
input_price=Decimal("0.80"),
|
27
|
+
cached_write_price=Decimal("1.00"),
|
28
|
+
cached_read_price=Decimal("0.08"),
|
29
|
+
output_price=Decimal("4.00"),
|
30
|
+
),
|
31
|
+
"claude-3-opus": AnthropicModelConfig(
|
32
|
+
display_name="Claude 3 Opus",
|
33
|
+
base_model="claude-3-opus",
|
34
|
+
input_price=Decimal("15.00"),
|
35
|
+
cached_write_price=Decimal("18.75"),
|
36
|
+
cached_read_price=Decimal("1.50"),
|
37
|
+
output_price=Decimal("75.00"),
|
38
|
+
),
|
39
|
+
"claude-3-sonnet": AnthropicModelConfig(
|
40
|
+
display_name="Claude 3 Sonnet",
|
41
|
+
base_model="claude-3-sonnet",
|
42
|
+
input_price=Decimal("3.00"),
|
43
|
+
cached_write_price=Decimal("3.75"),
|
44
|
+
cached_read_price=Decimal("0.30"),
|
45
|
+
output_price=Decimal("15.00"),
|
46
|
+
),
|
47
|
+
"claude-3-haiku": AnthropicModelConfig(
|
48
|
+
display_name="Claude 3 Haiku",
|
49
|
+
base_model="claude-3-haiku",
|
50
|
+
input_price=Decimal("0.25"),
|
51
|
+
cached_write_price=Decimal("0.31"),
|
52
|
+
cached_read_price=Decimal("0.025"),
|
53
|
+
output_price=Decimal("1.25"),
|
54
|
+
),
|
55
|
+
}
|
56
|
+
|
57
|
+
|
58
|
+
def get_model_config(model_id: str) -> AnthropicModelConfig:
|
59
|
+
"""Get model configuration by model ID"""
|
60
|
+
if model_id not in ANTHROPIC_MODELS:
|
61
|
+
raise ValueError(f"Model {model_id} not found in Anthropic models")
|
62
|
+
return ANTHROPIC_MODELS[model_id]
|
63
|
+
|
64
|
+
|
65
|
+
def get_default_model() -> str:
|
66
|
+
"""Get the default model ID"""
|
67
|
+
return "claude-3-sonnet"
|
68
|
+
|
69
|
+
|
70
|
+
def calculate_cost(
|
71
|
+
model_id: str,
|
72
|
+
input_tokens: int,
|
73
|
+
output_tokens: int,
|
74
|
+
use_cached: bool = False,
|
75
|
+
cache_type: str = "read"
|
76
|
+
) -> Decimal:
|
77
|
+
"""Calculate cost for token usage
|
78
|
+
|
79
|
+
Args:
|
80
|
+
model_id: The model ID to calculate costs for
|
81
|
+
input_tokens: Number of input tokens
|
82
|
+
output_tokens: Number of output tokens
|
83
|
+
use_cached: Whether to use cached pricing
|
84
|
+
cache_type: Either "read" or "write" for cached pricing type
|
85
|
+
"""
|
86
|
+
config = get_model_config(model_id)
|
87
|
+
|
88
|
+
if not use_cached:
|
89
|
+
input_cost = config.input_price * Decimal(str(input_tokens))
|
90
|
+
else:
|
91
|
+
if cache_type == "read" and config.cached_read_price is not None:
|
92
|
+
input_cost = config.cached_read_price * Decimal(str(input_tokens))
|
93
|
+
elif cache_type == "write" and config.cached_write_price is not None:
|
94
|
+
input_cost = config.cached_write_price * Decimal(str(input_tokens))
|
95
|
+
else:
|
96
|
+
input_cost = config.input_price * Decimal(str(input_tokens))
|
97
|
+
|
98
|
+
output_cost = config.output_price * Decimal(str(output_tokens))
|
99
|
+
|
100
|
+
return (input_cost + output_cost) / Decimal("1000")
|
@@ -0,0 +1,155 @@
|
|
1
|
+
from typing import List, Optional, Dict, Any, Generator
|
2
|
+
from pydantic import Field
|
3
|
+
from anthropic import Anthropic
|
4
|
+
import base64
|
5
|
+
from pathlib import Path
|
6
|
+
from loguru import logger
|
7
|
+
|
8
|
+
from airtrain.core.skills import Skill, ProcessingError
|
9
|
+
from airtrain.core.schemas import InputSchema, OutputSchema
|
10
|
+
from .credentials import AnthropicCredentials
|
11
|
+
|
12
|
+
|
13
|
+
class AnthropicInput(InputSchema):
|
14
|
+
"""Schema for Anthropic chat input"""
|
15
|
+
|
16
|
+
user_input: str = Field(..., description="User's input text")
|
17
|
+
system_prompt: str = Field(
|
18
|
+
default="You are a helpful assistant.",
|
19
|
+
description="System prompt to guide the model's behavior",
|
20
|
+
)
|
21
|
+
conversation_history: List[Dict[str, str]] = Field(
|
22
|
+
default_factory=list,
|
23
|
+
description="List of previous conversation messages in [{'role': 'user|assistant', 'content': 'message'}] format",
|
24
|
+
)
|
25
|
+
model: str = Field(
|
26
|
+
default="claude-3-opus-20240229", description="Anthropic model to use"
|
27
|
+
)
|
28
|
+
max_tokens: Optional[int] = Field(
|
29
|
+
default=131072, description="Maximum tokens in response"
|
30
|
+
)
|
31
|
+
temperature: float = Field(
|
32
|
+
default=0.7, description="Temperature for response generation", ge=0, le=1
|
33
|
+
)
|
34
|
+
images: List[Path] = Field(
|
35
|
+
default_factory=list,
|
36
|
+
description="List of image paths to include in the message",
|
37
|
+
)
|
38
|
+
stream: bool = Field(
|
39
|
+
default=False, description="Whether to stream the response progressively"
|
40
|
+
)
|
41
|
+
|
42
|
+
|
43
|
+
class AnthropicOutput(OutputSchema):
|
44
|
+
"""Schema for Anthropic chat output"""
|
45
|
+
|
46
|
+
response: str = Field(..., description="Model's response text")
|
47
|
+
used_model: str = Field(..., description="Model used for generation")
|
48
|
+
usage: Dict[str, Any] = Field(
|
49
|
+
default_factory=dict, description="Usage statistics from the API"
|
50
|
+
)
|
51
|
+
|
52
|
+
|
53
|
+
class AnthropicChatSkill(Skill[AnthropicInput, AnthropicOutput]):
|
54
|
+
"""Skill for Anthropic chat"""
|
55
|
+
|
56
|
+
input_schema = AnthropicInput
|
57
|
+
output_schema = AnthropicOutput
|
58
|
+
|
59
|
+
def __init__(self, credentials: Optional[AnthropicCredentials] = None):
|
60
|
+
super().__init__()
|
61
|
+
self.credentials = credentials or AnthropicCredentials.from_env()
|
62
|
+
self.client = Anthropic(
|
63
|
+
api_key=self.credentials.anthropic_api_key.get_secret_value()
|
64
|
+
)
|
65
|
+
|
66
|
+
def _build_messages(self, input_data: AnthropicInput) -> List[Dict[str, Any]]:
|
67
|
+
"""
|
68
|
+
Build messages list from input data including conversation history.
|
69
|
+
|
70
|
+
Args:
|
71
|
+
input_data: The input data containing system prompt, conversation history, and user input
|
72
|
+
|
73
|
+
Returns:
|
74
|
+
List[Dict[str, Any]]: List of messages in the format required by Anthropic
|
75
|
+
"""
|
76
|
+
messages = []
|
77
|
+
|
78
|
+
# Add conversation history if present
|
79
|
+
if input_data.conversation_history:
|
80
|
+
messages.extend(input_data.conversation_history)
|
81
|
+
|
82
|
+
# Prepare user message content
|
83
|
+
user_message = {"type": "text", "text": input_data.user_input}
|
84
|
+
|
85
|
+
# Add images if present
|
86
|
+
if input_data.images:
|
87
|
+
content = []
|
88
|
+
for image_path in input_data.images:
|
89
|
+
with open(image_path, "rb") as img_file:
|
90
|
+
base64_image = base64.b64encode(img_file.read()).decode("utf-8")
|
91
|
+
content.append(
|
92
|
+
{
|
93
|
+
"type": "image",
|
94
|
+
"source": {
|
95
|
+
"type": "base64",
|
96
|
+
"media_type": "image/jpeg",
|
97
|
+
"data": base64_image,
|
98
|
+
},
|
99
|
+
}
|
100
|
+
)
|
101
|
+
content.append(user_message)
|
102
|
+
messages.append({"role": "user", "content": content})
|
103
|
+
else:
|
104
|
+
messages.append({"role": "user", "content": [user_message]})
|
105
|
+
|
106
|
+
return messages
|
107
|
+
|
108
|
+
def process_stream(self, input_data: AnthropicInput) -> Generator[str, None, None]:
|
109
|
+
"""Process the input and stream the response token by token."""
|
110
|
+
try:
|
111
|
+
messages = self._build_messages(input_data)
|
112
|
+
|
113
|
+
with self.client.beta.messages.stream(
|
114
|
+
model=input_data.model,
|
115
|
+
system=input_data.system_prompt,
|
116
|
+
messages=messages,
|
117
|
+
max_tokens=input_data.max_tokens,
|
118
|
+
temperature=input_data.temperature,
|
119
|
+
) as stream:
|
120
|
+
for chunk in stream.text_stream:
|
121
|
+
yield chunk
|
122
|
+
|
123
|
+
except Exception as e:
|
124
|
+
logger.exception(f"Anthropic streaming failed: {str(e)}")
|
125
|
+
raise ProcessingError(f"Anthropic streaming failed: {str(e)}")
|
126
|
+
|
127
|
+
def process(self, input_data: AnthropicInput) -> AnthropicOutput:
|
128
|
+
"""Process the input and return the complete response."""
|
129
|
+
try:
|
130
|
+
if input_data.stream:
|
131
|
+
response_chunks = []
|
132
|
+
for chunk in self.process_stream(input_data):
|
133
|
+
response_chunks.append(chunk)
|
134
|
+
response = "".join(response_chunks)
|
135
|
+
usage = {} # Usage stats not available in streaming
|
136
|
+
else:
|
137
|
+
messages = self._build_messages(input_data)
|
138
|
+
response = self.client.messages.create(
|
139
|
+
model=input_data.model,
|
140
|
+
system=input_data.system_prompt,
|
141
|
+
messages=messages,
|
142
|
+
max_tokens=input_data.max_tokens,
|
143
|
+
temperature=input_data.temperature,
|
144
|
+
)
|
145
|
+
usage = response.usage.model_dump() if response.usage else {}
|
146
|
+
|
147
|
+
return AnthropicOutput(
|
148
|
+
response=response.content[0].text,
|
149
|
+
used_model=input_data.model,
|
150
|
+
usage=usage,
|
151
|
+
)
|
152
|
+
|
153
|
+
except Exception as e:
|
154
|
+
logger.exception(f"Anthropic processing failed: {str(e)}")
|
155
|
+
raise ProcessingError(f"Anthropic processing failed: {str(e)}")
|
@@ -0,0 +1,36 @@
|
|
1
|
+
from typing import Optional
|
2
|
+
from pydantic import Field, SecretStr
|
3
|
+
from airtrain.core.credentials import BaseCredentials, CredentialValidationError
|
4
|
+
import boto3
|
5
|
+
|
6
|
+
|
7
|
+
class AWSCredentials(BaseCredentials):
|
8
|
+
"""AWS credentials"""
|
9
|
+
|
10
|
+
aws_access_key_id: SecretStr = Field(..., description="AWS Access Key ID")
|
11
|
+
aws_secret_access_key: SecretStr = Field(..., description="AWS Secret Access Key")
|
12
|
+
aws_region: str = Field(default="us-east-1", description="AWS Region")
|
13
|
+
aws_session_token: Optional[SecretStr] = Field(
|
14
|
+
None, description="AWS Session Token"
|
15
|
+
)
|
16
|
+
|
17
|
+
_required_credentials = {"aws_access_key_id", "aws_secret_access_key"}
|
18
|
+
|
19
|
+
async def validate_credentials(self) -> bool:
|
20
|
+
"""Validate AWS credentials by making a test API call"""
|
21
|
+
try:
|
22
|
+
session = boto3.Session(
|
23
|
+
aws_access_key_id=self.aws_access_key_id.get_secret_value(),
|
24
|
+
aws_secret_access_key=self.aws_secret_access_key.get_secret_value(),
|
25
|
+
aws_session_token=(
|
26
|
+
self.aws_session_token.get_secret_value()
|
27
|
+
if self.aws_session_token
|
28
|
+
else None
|
29
|
+
),
|
30
|
+
region_name=self.aws_region,
|
31
|
+
)
|
32
|
+
sts = session.client("sts")
|
33
|
+
sts.get_caller_identity()
|
34
|
+
return True
|
35
|
+
except Exception as e:
|
36
|
+
raise CredentialValidationError(f"Invalid AWS credentials: {str(e)}")
|