airtrain 0.1.3__py3-none-any.whl → 0.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. airtrain/__init__.py +146 -6
  2. airtrain/__main__.py +4 -0
  3. airtrain/__pycache__/__init__.cpython-313.pyc +0 -0
  4. airtrain/agents/__init__.py +45 -0
  5. airtrain/agents/example_agent.py +348 -0
  6. airtrain/agents/groq_agent.py +289 -0
  7. airtrain/agents/memory.py +663 -0
  8. airtrain/agents/registry.py +465 -0
  9. airtrain/builder/__init__.py +3 -0
  10. airtrain/builder/agent_builder.py +122 -0
  11. airtrain/cli/__init__.py +0 -0
  12. airtrain/cli/builder.py +23 -0
  13. airtrain/cli/main.py +120 -0
  14. airtrain/contrib/__init__.py +29 -0
  15. airtrain/contrib/travel/__init__.py +35 -0
  16. airtrain/contrib/travel/agents.py +243 -0
  17. airtrain/contrib/travel/models.py +59 -0
  18. airtrain/core/__pycache__/__init__.cpython-313.pyc +0 -0
  19. airtrain/core/__pycache__/schemas.cpython-313.pyc +0 -0
  20. airtrain/core/__pycache__/skills.cpython-313.pyc +0 -0
  21. airtrain/core/credentials.py +62 -44
  22. airtrain/core/skills.py +102 -0
  23. airtrain/integrations/__init__.py +74 -0
  24. airtrain/integrations/anthropic/__init__.py +33 -0
  25. airtrain/integrations/anthropic/credentials.py +32 -0
  26. airtrain/integrations/anthropic/list_models.py +110 -0
  27. airtrain/integrations/anthropic/models_config.py +100 -0
  28. airtrain/integrations/anthropic/skills.py +155 -0
  29. airtrain/integrations/aws/__init__.py +6 -0
  30. airtrain/integrations/aws/credentials.py +36 -0
  31. airtrain/integrations/aws/skills.py +98 -0
  32. airtrain/integrations/cerebras/__init__.py +6 -0
  33. airtrain/integrations/cerebras/credentials.py +19 -0
  34. airtrain/integrations/cerebras/skills.py +127 -0
  35. airtrain/integrations/combined/__init__.py +21 -0
  36. airtrain/integrations/combined/groq_fireworks_skills.py +126 -0
  37. airtrain/integrations/combined/list_models_factory.py +210 -0
  38. airtrain/integrations/fireworks/__init__.py +21 -0
  39. airtrain/integrations/fireworks/completion_skills.py +147 -0
  40. airtrain/integrations/fireworks/conversation_manager.py +109 -0
  41. airtrain/integrations/fireworks/credentials.py +26 -0
  42. airtrain/integrations/fireworks/list_models.py +128 -0
  43. airtrain/integrations/fireworks/models.py +139 -0
  44. airtrain/integrations/fireworks/requests_skills.py +207 -0
  45. airtrain/integrations/fireworks/skills.py +181 -0
  46. airtrain/integrations/fireworks/structured_completion_skills.py +175 -0
  47. airtrain/integrations/fireworks/structured_requests_skills.py +291 -0
  48. airtrain/integrations/fireworks/structured_skills.py +102 -0
  49. airtrain/integrations/google/__init__.py +7 -0
  50. airtrain/integrations/google/credentials.py +58 -0
  51. airtrain/integrations/google/skills.py +122 -0
  52. airtrain/integrations/groq/__init__.py +23 -0
  53. airtrain/integrations/groq/credentials.py +24 -0
  54. airtrain/integrations/groq/models_config.py +162 -0
  55. airtrain/integrations/groq/skills.py +201 -0
  56. airtrain/integrations/ollama/__init__.py +6 -0
  57. airtrain/integrations/ollama/credentials.py +26 -0
  58. airtrain/integrations/ollama/skills.py +41 -0
  59. airtrain/integrations/openai/__init__.py +37 -0
  60. airtrain/integrations/openai/chinese_assistant.py +42 -0
  61. airtrain/integrations/openai/credentials.py +39 -0
  62. airtrain/integrations/openai/list_models.py +112 -0
  63. airtrain/integrations/openai/models_config.py +224 -0
  64. airtrain/integrations/openai/skills.py +342 -0
  65. airtrain/integrations/perplexity/__init__.py +49 -0
  66. airtrain/integrations/perplexity/credentials.py +43 -0
  67. airtrain/integrations/perplexity/list_models.py +112 -0
  68. airtrain/integrations/perplexity/models_config.py +128 -0
  69. airtrain/integrations/perplexity/skills.py +279 -0
  70. airtrain/integrations/sambanova/__init__.py +6 -0
  71. airtrain/integrations/sambanova/credentials.py +20 -0
  72. airtrain/integrations/sambanova/skills.py +129 -0
  73. airtrain/integrations/search/__init__.py +21 -0
  74. airtrain/integrations/search/exa/__init__.py +23 -0
  75. airtrain/integrations/search/exa/credentials.py +30 -0
  76. airtrain/integrations/search/exa/schemas.py +114 -0
  77. airtrain/integrations/search/exa/skills.py +115 -0
  78. airtrain/integrations/together/__init__.py +33 -0
  79. airtrain/integrations/together/audio_models_config.py +34 -0
  80. airtrain/integrations/together/credentials.py +22 -0
  81. airtrain/integrations/together/embedding_models_config.py +92 -0
  82. airtrain/integrations/together/image_models_config.py +69 -0
  83. airtrain/integrations/together/image_skill.py +143 -0
  84. airtrain/integrations/together/list_models.py +76 -0
  85. airtrain/integrations/together/models.py +95 -0
  86. airtrain/integrations/together/models_config.py +399 -0
  87. airtrain/integrations/together/rerank_models_config.py +43 -0
  88. airtrain/integrations/together/rerank_skill.py +49 -0
  89. airtrain/integrations/together/schemas.py +33 -0
  90. airtrain/integrations/together/skills.py +305 -0
  91. airtrain/integrations/together/vision_models_config.py +49 -0
  92. airtrain/telemetry/__init__.py +38 -0
  93. airtrain/telemetry/service.py +167 -0
  94. airtrain/telemetry/views.py +237 -0
  95. airtrain/tools/__init__.py +45 -0
  96. airtrain/tools/command.py +398 -0
  97. airtrain/tools/filesystem.py +166 -0
  98. airtrain/tools/network.py +111 -0
  99. airtrain/tools/registry.py +320 -0
  100. airtrain/tools/search.py +450 -0
  101. airtrain/tools/testing.py +135 -0
  102. airtrain-0.1.4.dist-info/METADATA +222 -0
  103. airtrain-0.1.4.dist-info/RECORD +108 -0
  104. {airtrain-0.1.3.dist-info → airtrain-0.1.4.dist-info}/WHEEL +1 -1
  105. airtrain-0.1.4.dist-info/entry_points.txt +2 -0
  106. airtrain-0.1.3.dist-info/METADATA +0 -106
  107. airtrain-0.1.3.dist-info/RECORD +0 -9
  108. {airtrain-0.1.3.dist-info → airtrain-0.1.4.dist-info}/top_level.txt +0 -0
airtrain/core/skills.py CHANGED
@@ -1,8 +1,17 @@
1
1
  from abc import ABC, abstractmethod
2
2
  from typing import Any, Dict, Optional, Type, Generic, TypeVar
3
3
  from uuid import UUID, uuid4
4
+ import time
5
+ import functools
4
6
  from .schemas import InputSchema, OutputSchema
5
7
 
8
+ # Import telemetry
9
+ from airtrain.telemetry import (
10
+ telemetry,
11
+ SkillInitTelemetryEvent,
12
+ SkillProcessTelemetryEvent,
13
+ )
14
+
6
15
  # Generic type variables for input and output schemas
7
16
  InputT = TypeVar("InputT", bound=InputSchema)
8
17
  OutputT = TypeVar("OutputT", bound=OutputSchema)
@@ -17,6 +26,92 @@ class Skill(ABC, Generic[InputT, OutputT]):
17
26
  input_schema: Type[InputT]
18
27
  output_schema: Type[OutputT]
19
28
  _skill_id: Optional[UUID] = None
29
+ _original_process = None
30
+
31
+ def __init__(self):
32
+ """Initialize the skill and capture telemetry."""
33
+ # Initialize skill_id if not already set
34
+ if not self._skill_id:
35
+ self._skill_id = uuid4()
36
+
37
+ # Monkey patch the process method if it hasn't been patched yet
38
+ # This allows us to add telemetry without changing the API
39
+ if not hasattr(self.__class__, '_patched_process'):
40
+ # Store the original process method implementation from this instance
41
+ # This is crucial for proper behavior with inheritance
42
+ self.__class__._original_process = self.__class__.process
43
+
44
+ # Create a wrapper function that will capture telemetry
45
+ def _create_wrapper(original_method):
46
+ @functools.wraps(original_method)
47
+ def wrapped_process(instance, input_data):
48
+ start_time = time.time()
49
+ error = None
50
+
51
+ try:
52
+ # Call the original process method
53
+ result = original_method(instance, input_data)
54
+ return result
55
+ except Exception as e:
56
+ error = str(e)
57
+ raise
58
+ finally:
59
+ duration = time.time() - start_time
60
+
61
+ try:
62
+ # Serialize input data for telemetry
63
+ serialized_input = None
64
+ try:
65
+ # Convert input_data to dict if it's a Pydantic model
66
+ if hasattr(input_data, "dict"):
67
+ serialized_input = input_data.dict()
68
+ # If it's a dataclass
69
+ elif hasattr(input_data, "__dataclass_fields__"):
70
+ from dataclasses import asdict
71
+ serialized_input = asdict(input_data)
72
+ # Fallback
73
+ else:
74
+ serialized_input = {
75
+ "__str__": str(input_data)
76
+ }
77
+ except Exception:
78
+ # If serialization fails, provide simple info
79
+ serialized_input = {"error": "Failed to serialize input data"}
80
+
81
+ telemetry.capture(
82
+ SkillProcessTelemetryEvent(
83
+ skill_id=str(instance.skill_id),
84
+ skill_class=instance.__class__.__name__,
85
+ input_schema=instance.input_schema.__name__,
86
+ output_schema=instance.output_schema.__name__,
87
+ input_data=serialized_input,
88
+ duration_seconds=duration,
89
+ error=error,
90
+ )
91
+ )
92
+ except Exception:
93
+ # Silently continue if telemetry fails
94
+ pass
95
+
96
+ return wrapped_process
97
+
98
+ # Replace the process method with our wrapped version at the class level
99
+ self.__class__.process = _create_wrapper(self.__class__._original_process)
100
+
101
+ # Mark this class as patched to prevent double-patching
102
+ self.__class__._patched_process = True
103
+
104
+ # Capture telemetry for initialization
105
+ try:
106
+ telemetry.capture(
107
+ SkillInitTelemetryEvent(
108
+ skill_id=str(self.skill_id),
109
+ skill_class=self.__class__.__name__,
110
+ )
111
+ )
112
+ except Exception:
113
+ # Silently continue if telemetry fails
114
+ pass
20
115
 
21
116
  @abstractmethod
22
117
  def process(self, input_data: InputT) -> OutputT:
@@ -34,6 +129,13 @@ class Skill(ABC, Generic[InputT, OutputT]):
34
129
  """
35
130
  pass
36
131
 
132
+ def __call__(self, input_data: InputT) -> OutputT:
133
+ """Make the skill callable, with input/output validation."""
134
+ self.validate_input(input_data)
135
+ result = self.process(input_data)
136
+ self.validate_output(result)
137
+ return result
138
+
37
139
  def validate_input(self, input_data: Any) -> None:
38
140
  """
39
141
  Validate input data before processing.
@@ -0,0 +1,74 @@
1
+ """Airtrain integrations package"""
2
+
3
+ # Credentials imports
4
+ from .openai.credentials import OpenAICredentials
5
+ from .aws.credentials import AWSCredentials
6
+ from .google.credentials import GoogleCloudCredentials
7
+ from .anthropic.credentials import AnthropicCredentials
8
+ from .groq.credentials import GroqCredentials
9
+ from .together.credentials import TogetherAICredentials
10
+ from .ollama.credentials import OllamaCredentials
11
+ from .sambanova.credentials import SambanovaCredentials
12
+ from .cerebras.credentials import CerebrasCredentials
13
+ from .perplexity.credentials import PerplexityCredentials
14
+
15
+ # Skills imports
16
+ from .openai.skills import OpenAIChatSkill, OpenAIParserSkill
17
+ from .anthropic.skills import AnthropicChatSkill
18
+ from .aws.skills import AWSBedrockSkill
19
+ from .google.skills import GoogleChatSkill
20
+ from .groq.skills import GroqChatSkill
21
+ from .together.skills import TogetherAIChatSkill
22
+ from .ollama.skills import OllamaChatSkill
23
+ from .sambanova.skills import SambanovaChatSkill
24
+ from .cerebras.skills import CerebrasChatSkill
25
+ from .perplexity.skills import PerplexityChatSkill, PerplexityStreamingChatSkill
26
+
27
+ # Model configurations
28
+ from .openai.models_config import OPENAI_MODELS, OpenAIModelConfig
29
+ from .anthropic.models_config import ANTHROPIC_MODELS, AnthropicModelConfig
30
+ from .perplexity.models_config import PERPLEXITY_MODELS_CONFIG
31
+
32
+ # Combined modules
33
+ from .combined.list_models_factory import (
34
+ ListModelsSkillFactory,
35
+ GenericListModelsInput,
36
+ GenericListModelsOutput,
37
+ )
38
+
39
+ __all__ = [
40
+ # Credentials
41
+ "OpenAICredentials",
42
+ "AWSCredentials",
43
+ "GoogleCloudCredentials",
44
+ "AnthropicCredentials",
45
+ "GroqCredentials",
46
+ "TogetherAICredentials",
47
+ "OllamaCredentials",
48
+ "SambanovaCredentials",
49
+ "CerebrasCredentials",
50
+ "PerplexityCredentials",
51
+ # Skills
52
+ "OpenAIChatSkill",
53
+ "OpenAIParserSkill",
54
+ "AnthropicChatSkill",
55
+ "AWSBedrockSkill",
56
+ "GoogleChatSkill",
57
+ "GroqChatSkill",
58
+ "TogetherAIChatSkill",
59
+ "OllamaChatSkill",
60
+ "SambanovaChatSkill",
61
+ "CerebrasChatSkill",
62
+ "PerplexityChatSkill",
63
+ "PerplexityStreamingChatSkill",
64
+ # Model configurations
65
+ "OPENAI_MODELS",
66
+ "OpenAIModelConfig",
67
+ "ANTHROPIC_MODELS",
68
+ "AnthropicModelConfig",
69
+ "PERPLEXITY_MODELS_CONFIG",
70
+ # Combined modules
71
+ "ListModelsSkillFactory",
72
+ "GenericListModelsInput",
73
+ "GenericListModelsOutput",
74
+ ]
@@ -0,0 +1,33 @@
1
+ """Anthropic integration for Airtrain"""
2
+
3
+ from .credentials import AnthropicCredentials
4
+ from .skills import AnthropicChatSkill, AnthropicInput, AnthropicOutput
5
+ from .models_config import (
6
+ ANTHROPIC_MODELS,
7
+ AnthropicModelConfig,
8
+ get_model_config,
9
+ get_default_model,
10
+ calculate_cost,
11
+ )
12
+ from .list_models import (
13
+ AnthropicListModelsSkill,
14
+ AnthropicListModelsInput,
15
+ AnthropicListModelsOutput,
16
+ AnthropicModel,
17
+ )
18
+
19
+ __all__ = [
20
+ "AnthropicCredentials",
21
+ "AnthropicChatSkill",
22
+ "AnthropicInput",
23
+ "AnthropicOutput",
24
+ "ANTHROPIC_MODELS",
25
+ "AnthropicModelConfig",
26
+ "get_model_config",
27
+ "get_default_model",
28
+ "calculate_cost",
29
+ "AnthropicListModelsSkill",
30
+ "AnthropicListModelsInput",
31
+ "AnthropicListModelsOutput",
32
+ "AnthropicModel",
33
+ ]
@@ -0,0 +1,32 @@
1
+ from pydantic import Field, SecretStr, validator
2
+ from airtrain.core.credentials import BaseCredentials, CredentialValidationError
3
+ from anthropic import Anthropic
4
+
5
+
6
+ class AnthropicCredentials(BaseCredentials):
7
+ """Anthropic API credentials"""
8
+
9
+ anthropic_api_key: SecretStr = Field(..., description="Anthropic API key")
10
+ version: str = Field(default="2023-06-01", description="API Version")
11
+
12
+ _required_credentials = {"anthropic_api_key"}
13
+
14
+ @validator("anthropic_api_key")
15
+ def validate_api_key_format(cls, v: SecretStr) -> SecretStr:
16
+ key = v.get_secret_value()
17
+ if not key.startswith("sk-ant-"):
18
+ raise ValueError("Anthropic API key must start with 'sk-ant-'")
19
+ return v
20
+
21
+ async def validate_credentials(self) -> bool:
22
+ """Validate Anthropic credentials"""
23
+ try:
24
+ client = Anthropic(api_key=self.anthropic_api_key.get_secret_value())
25
+ client.messages.create(
26
+ model="claude-3-opus-20240229",
27
+ max_tokens=1,
28
+ messages=[{"role": "user", "content": "Hi"}],
29
+ )
30
+ return True
31
+ except Exception as e:
32
+ raise CredentialValidationError(f"Invalid Anthropic credentials: {str(e)}")
@@ -0,0 +1,110 @@
1
+ from typing import Optional, List, Dict, Any
2
+ from pydantic import Field
3
+
4
+ from airtrain.core.skills import Skill, ProcessingError
5
+ from airtrain.core.schemas import InputSchema, OutputSchema
6
+ from .credentials import AnthropicCredentials
7
+ from .models_config import ANTHROPIC_MODELS, AnthropicModelConfig
8
+
9
+
10
+ class AnthropicModel:
11
+ """Class to represent an Anthropic model."""
12
+
13
+ def __init__(self, model_id: str, config: AnthropicModelConfig):
14
+ """Initialize the Anthropic model."""
15
+ self.id = model_id
16
+ self.display_name = config.display_name
17
+ self.base_model = config.base_model
18
+ self.input_price = config.input_price
19
+ self.cached_write_price = config.cached_write_price
20
+ self.cached_read_price = config.cached_read_price
21
+ self.output_price = config.output_price
22
+
23
+ def dict(self, exclude_none=False):
24
+ """Convert the model to a dictionary."""
25
+ result = {
26
+ "id": self.id,
27
+ "display_name": self.display_name,
28
+ "base_model": self.base_model,
29
+ "input_price": float(self.input_price),
30
+ "output_price": float(self.output_price),
31
+ }
32
+
33
+ if self.cached_write_price is not None:
34
+ result["cached_write_price"] = float(self.cached_write_price)
35
+ elif not exclude_none:
36
+ result["cached_write_price"] = None
37
+
38
+ if self.cached_read_price is not None:
39
+ result["cached_read_price"] = float(self.cached_read_price)
40
+ elif not exclude_none:
41
+ result["cached_read_price"] = None
42
+
43
+ return result
44
+
45
+
46
+ class AnthropicListModelsInput(InputSchema):
47
+ """Schema for Anthropic list models input"""
48
+
49
+ api_models_only: bool = Field(
50
+ default=False,
51
+ description=(
52
+ "If True, fetch models from the API only. If False, use local config."
53
+ )
54
+ )
55
+
56
+
57
+ class AnthropicListModelsOutput(OutputSchema):
58
+ """Schema for Anthropic list models output"""
59
+
60
+ models: List[Dict[str, Any]] = Field(
61
+ default_factory=list,
62
+ description="List of Anthropic models"
63
+ )
64
+
65
+
66
+ class AnthropicListModelsSkill(
67
+ Skill[AnthropicListModelsInput, AnthropicListModelsOutput]
68
+ ):
69
+ """Skill for listing Anthropic models"""
70
+
71
+ input_schema = AnthropicListModelsInput
72
+ output_schema = AnthropicListModelsOutput
73
+
74
+ def __init__(self, credentials: Optional[AnthropicCredentials] = None):
75
+ """Initialize the skill with optional credentials"""
76
+ super().__init__()
77
+ self.credentials = credentials
78
+
79
+ def process(
80
+ self, input_data: AnthropicListModelsInput
81
+ ) -> AnthropicListModelsOutput:
82
+ """Process the input and return a list of models."""
83
+ try:
84
+ models = []
85
+
86
+ if input_data.api_models_only:
87
+ # Fetch models from Anthropic API
88
+ # Require credentials if using API models
89
+ if not self.credentials:
90
+ raise ProcessingError(
91
+ "Anthropic credentials required for API models"
92
+ )
93
+
94
+ # Note: Anthropic doesn't have a public models list endpoint
95
+ # We'll raise an error instead
96
+ raise ProcessingError(
97
+ "Anthropic API does not provide a models list endpoint. "
98
+ "Use api_models_only=False to list models from local config."
99
+ )
100
+ else:
101
+ # Use local model config - no credentials needed
102
+ for model_id, config in ANTHROPIC_MODELS.items():
103
+ model = AnthropicModel(model_id, config)
104
+ models.append(model.dict())
105
+
106
+ # Return the output
107
+ return AnthropicListModelsOutput(models=models)
108
+
109
+ except Exception as e:
110
+ raise ProcessingError(f"Failed to list Anthropic models: {str(e)}")
@@ -0,0 +1,100 @@
1
+ from typing import Dict, NamedTuple, Optional
2
+ from decimal import Decimal
3
+
4
+
5
+ class AnthropicModelConfig(NamedTuple):
6
+ display_name: str
7
+ base_model: str
8
+ input_price: Decimal
9
+ cached_write_price: Optional[Decimal]
10
+ cached_read_price: Optional[Decimal]
11
+ output_price: Decimal
12
+
13
+
14
+ ANTHROPIC_MODELS: Dict[str, AnthropicModelConfig] = {
15
+ "claude-3-7-sonnet": AnthropicModelConfig(
16
+ display_name="Claude 3.7 Sonnet",
17
+ base_model="claude-3-7-sonnet",
18
+ input_price=Decimal("3.00"),
19
+ cached_write_price=Decimal("3.75"),
20
+ cached_read_price=Decimal("0.30"),
21
+ output_price=Decimal("15.00"),
22
+ ),
23
+ "claude-3-5-haiku": AnthropicModelConfig(
24
+ display_name="Claude 3.5 Haiku",
25
+ base_model="claude-3-5-haiku",
26
+ input_price=Decimal("0.80"),
27
+ cached_write_price=Decimal("1.00"),
28
+ cached_read_price=Decimal("0.08"),
29
+ output_price=Decimal("4.00"),
30
+ ),
31
+ "claude-3-opus": AnthropicModelConfig(
32
+ display_name="Claude 3 Opus",
33
+ base_model="claude-3-opus",
34
+ input_price=Decimal("15.00"),
35
+ cached_write_price=Decimal("18.75"),
36
+ cached_read_price=Decimal("1.50"),
37
+ output_price=Decimal("75.00"),
38
+ ),
39
+ "claude-3-sonnet": AnthropicModelConfig(
40
+ display_name="Claude 3 Sonnet",
41
+ base_model="claude-3-sonnet",
42
+ input_price=Decimal("3.00"),
43
+ cached_write_price=Decimal("3.75"),
44
+ cached_read_price=Decimal("0.30"),
45
+ output_price=Decimal("15.00"),
46
+ ),
47
+ "claude-3-haiku": AnthropicModelConfig(
48
+ display_name="Claude 3 Haiku",
49
+ base_model="claude-3-haiku",
50
+ input_price=Decimal("0.25"),
51
+ cached_write_price=Decimal("0.31"),
52
+ cached_read_price=Decimal("0.025"),
53
+ output_price=Decimal("1.25"),
54
+ ),
55
+ }
56
+
57
+
58
+ def get_model_config(model_id: str) -> AnthropicModelConfig:
59
+ """Get model configuration by model ID"""
60
+ if model_id not in ANTHROPIC_MODELS:
61
+ raise ValueError(f"Model {model_id} not found in Anthropic models")
62
+ return ANTHROPIC_MODELS[model_id]
63
+
64
+
65
+ def get_default_model() -> str:
66
+ """Get the default model ID"""
67
+ return "claude-3-sonnet"
68
+
69
+
70
+ def calculate_cost(
71
+ model_id: str,
72
+ input_tokens: int,
73
+ output_tokens: int,
74
+ use_cached: bool = False,
75
+ cache_type: str = "read"
76
+ ) -> Decimal:
77
+ """Calculate cost for token usage
78
+
79
+ Args:
80
+ model_id: The model ID to calculate costs for
81
+ input_tokens: Number of input tokens
82
+ output_tokens: Number of output tokens
83
+ use_cached: Whether to use cached pricing
84
+ cache_type: Either "read" or "write" for cached pricing type
85
+ """
86
+ config = get_model_config(model_id)
87
+
88
+ if not use_cached:
89
+ input_cost = config.input_price * Decimal(str(input_tokens))
90
+ else:
91
+ if cache_type == "read" and config.cached_read_price is not None:
92
+ input_cost = config.cached_read_price * Decimal(str(input_tokens))
93
+ elif cache_type == "write" and config.cached_write_price is not None:
94
+ input_cost = config.cached_write_price * Decimal(str(input_tokens))
95
+ else:
96
+ input_cost = config.input_price * Decimal(str(input_tokens))
97
+
98
+ output_cost = config.output_price * Decimal(str(output_tokens))
99
+
100
+ return (input_cost + output_cost) / Decimal("1000")
@@ -0,0 +1,155 @@
1
+ from typing import List, Optional, Dict, Any, Generator
2
+ from pydantic import Field
3
+ from anthropic import Anthropic
4
+ import base64
5
+ from pathlib import Path
6
+ from loguru import logger
7
+
8
+ from airtrain.core.skills import Skill, ProcessingError
9
+ from airtrain.core.schemas import InputSchema, OutputSchema
10
+ from .credentials import AnthropicCredentials
11
+
12
+
13
+ class AnthropicInput(InputSchema):
14
+ """Schema for Anthropic chat input"""
15
+
16
+ user_input: str = Field(..., description="User's input text")
17
+ system_prompt: str = Field(
18
+ default="You are a helpful assistant.",
19
+ description="System prompt to guide the model's behavior",
20
+ )
21
+ conversation_history: List[Dict[str, str]] = Field(
22
+ default_factory=list,
23
+ description="List of previous conversation messages in [{'role': 'user|assistant', 'content': 'message'}] format",
24
+ )
25
+ model: str = Field(
26
+ default="claude-3-opus-20240229", description="Anthropic model to use"
27
+ )
28
+ max_tokens: Optional[int] = Field(
29
+ default=131072, description="Maximum tokens in response"
30
+ )
31
+ temperature: float = Field(
32
+ default=0.7, description="Temperature for response generation", ge=0, le=1
33
+ )
34
+ images: List[Path] = Field(
35
+ default_factory=list,
36
+ description="List of image paths to include in the message",
37
+ )
38
+ stream: bool = Field(
39
+ default=False, description="Whether to stream the response progressively"
40
+ )
41
+
42
+
43
+ class AnthropicOutput(OutputSchema):
44
+ """Schema for Anthropic chat output"""
45
+
46
+ response: str = Field(..., description="Model's response text")
47
+ used_model: str = Field(..., description="Model used for generation")
48
+ usage: Dict[str, Any] = Field(
49
+ default_factory=dict, description="Usage statistics from the API"
50
+ )
51
+
52
+
53
+ class AnthropicChatSkill(Skill[AnthropicInput, AnthropicOutput]):
54
+ """Skill for Anthropic chat"""
55
+
56
+ input_schema = AnthropicInput
57
+ output_schema = AnthropicOutput
58
+
59
+ def __init__(self, credentials: Optional[AnthropicCredentials] = None):
60
+ super().__init__()
61
+ self.credentials = credentials or AnthropicCredentials.from_env()
62
+ self.client = Anthropic(
63
+ api_key=self.credentials.anthropic_api_key.get_secret_value()
64
+ )
65
+
66
+ def _build_messages(self, input_data: AnthropicInput) -> List[Dict[str, Any]]:
67
+ """
68
+ Build messages list from input data including conversation history.
69
+
70
+ Args:
71
+ input_data: The input data containing system prompt, conversation history, and user input
72
+
73
+ Returns:
74
+ List[Dict[str, Any]]: List of messages in the format required by Anthropic
75
+ """
76
+ messages = []
77
+
78
+ # Add conversation history if present
79
+ if input_data.conversation_history:
80
+ messages.extend(input_data.conversation_history)
81
+
82
+ # Prepare user message content
83
+ user_message = {"type": "text", "text": input_data.user_input}
84
+
85
+ # Add images if present
86
+ if input_data.images:
87
+ content = []
88
+ for image_path in input_data.images:
89
+ with open(image_path, "rb") as img_file:
90
+ base64_image = base64.b64encode(img_file.read()).decode("utf-8")
91
+ content.append(
92
+ {
93
+ "type": "image",
94
+ "source": {
95
+ "type": "base64",
96
+ "media_type": "image/jpeg",
97
+ "data": base64_image,
98
+ },
99
+ }
100
+ )
101
+ content.append(user_message)
102
+ messages.append({"role": "user", "content": content})
103
+ else:
104
+ messages.append({"role": "user", "content": [user_message]})
105
+
106
+ return messages
107
+
108
+ def process_stream(self, input_data: AnthropicInput) -> Generator[str, None, None]:
109
+ """Process the input and stream the response token by token."""
110
+ try:
111
+ messages = self._build_messages(input_data)
112
+
113
+ with self.client.beta.messages.stream(
114
+ model=input_data.model,
115
+ system=input_data.system_prompt,
116
+ messages=messages,
117
+ max_tokens=input_data.max_tokens,
118
+ temperature=input_data.temperature,
119
+ ) as stream:
120
+ for chunk in stream.text_stream:
121
+ yield chunk
122
+
123
+ except Exception as e:
124
+ logger.exception(f"Anthropic streaming failed: {str(e)}")
125
+ raise ProcessingError(f"Anthropic streaming failed: {str(e)}")
126
+
127
+ def process(self, input_data: AnthropicInput) -> AnthropicOutput:
128
+ """Process the input and return the complete response."""
129
+ try:
130
+ if input_data.stream:
131
+ response_chunks = []
132
+ for chunk in self.process_stream(input_data):
133
+ response_chunks.append(chunk)
134
+ response = "".join(response_chunks)
135
+ usage = {} # Usage stats not available in streaming
136
+ else:
137
+ messages = self._build_messages(input_data)
138
+ response = self.client.messages.create(
139
+ model=input_data.model,
140
+ system=input_data.system_prompt,
141
+ messages=messages,
142
+ max_tokens=input_data.max_tokens,
143
+ temperature=input_data.temperature,
144
+ )
145
+ usage = response.usage.model_dump() if response.usage else {}
146
+
147
+ return AnthropicOutput(
148
+ response=response.content[0].text,
149
+ used_model=input_data.model,
150
+ usage=usage,
151
+ )
152
+
153
+ except Exception as e:
154
+ logger.exception(f"Anthropic processing failed: {str(e)}")
155
+ raise ProcessingError(f"Anthropic processing failed: {str(e)}")
@@ -0,0 +1,6 @@
1
+ """AWS integration module"""
2
+
3
+ from .credentials import AWSCredentials
4
+ from .skills import AWSBedrockSkill
5
+
6
+ __all__ = ["AWSCredentials", "AWSBedrockSkill"]
@@ -0,0 +1,36 @@
1
+ from typing import Optional
2
+ from pydantic import Field, SecretStr
3
+ from airtrain.core.credentials import BaseCredentials, CredentialValidationError
4
+ import boto3
5
+
6
+
7
+ class AWSCredentials(BaseCredentials):
8
+ """AWS credentials"""
9
+
10
+ aws_access_key_id: SecretStr = Field(..., description="AWS Access Key ID")
11
+ aws_secret_access_key: SecretStr = Field(..., description="AWS Secret Access Key")
12
+ aws_region: str = Field(default="us-east-1", description="AWS Region")
13
+ aws_session_token: Optional[SecretStr] = Field(
14
+ None, description="AWS Session Token"
15
+ )
16
+
17
+ _required_credentials = {"aws_access_key_id", "aws_secret_access_key"}
18
+
19
+ async def validate_credentials(self) -> bool:
20
+ """Validate AWS credentials by making a test API call"""
21
+ try:
22
+ session = boto3.Session(
23
+ aws_access_key_id=self.aws_access_key_id.get_secret_value(),
24
+ aws_secret_access_key=self.aws_secret_access_key.get_secret_value(),
25
+ aws_session_token=(
26
+ self.aws_session_token.get_secret_value()
27
+ if self.aws_session_token
28
+ else None
29
+ ),
30
+ region_name=self.aws_region,
31
+ )
32
+ sts = session.client("sts")
33
+ sts.get_caller_identity()
34
+ return True
35
+ except Exception as e:
36
+ raise CredentialValidationError(f"Invalid AWS credentials: {str(e)}")