airtrain 0.1.2__py3-none-any.whl → 0.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (110) hide show
  1. airtrain/__init__.py +148 -2
  2. airtrain/__main__.py +4 -0
  3. airtrain/__pycache__/__init__.cpython-313.pyc +0 -0
  4. airtrain/agents/__init__.py +45 -0
  5. airtrain/agents/example_agent.py +348 -0
  6. airtrain/agents/groq_agent.py +289 -0
  7. airtrain/agents/memory.py +663 -0
  8. airtrain/agents/registry.py +465 -0
  9. airtrain/builder/__init__.py +3 -0
  10. airtrain/builder/agent_builder.py +122 -0
  11. airtrain/cli/__init__.py +0 -0
  12. airtrain/cli/builder.py +23 -0
  13. airtrain/cli/main.py +120 -0
  14. airtrain/contrib/__init__.py +29 -0
  15. airtrain/contrib/travel/__init__.py +35 -0
  16. airtrain/contrib/travel/agents.py +243 -0
  17. airtrain/contrib/travel/models.py +59 -0
  18. airtrain/core/__init__.py +7 -0
  19. airtrain/core/__pycache__/__init__.cpython-313.pyc +0 -0
  20. airtrain/core/__pycache__/schemas.cpython-313.pyc +0 -0
  21. airtrain/core/__pycache__/skills.cpython-313.pyc +0 -0
  22. airtrain/core/credentials.py +171 -0
  23. airtrain/core/schemas.py +237 -0
  24. airtrain/core/skills.py +269 -0
  25. airtrain/integrations/__init__.py +74 -0
  26. airtrain/integrations/anthropic/__init__.py +33 -0
  27. airtrain/integrations/anthropic/credentials.py +32 -0
  28. airtrain/integrations/anthropic/list_models.py +110 -0
  29. airtrain/integrations/anthropic/models_config.py +100 -0
  30. airtrain/integrations/anthropic/skills.py +155 -0
  31. airtrain/integrations/aws/__init__.py +6 -0
  32. airtrain/integrations/aws/credentials.py +36 -0
  33. airtrain/integrations/aws/skills.py +98 -0
  34. airtrain/integrations/cerebras/__init__.py +6 -0
  35. airtrain/integrations/cerebras/credentials.py +19 -0
  36. airtrain/integrations/cerebras/skills.py +127 -0
  37. airtrain/integrations/combined/__init__.py +21 -0
  38. airtrain/integrations/combined/groq_fireworks_skills.py +126 -0
  39. airtrain/integrations/combined/list_models_factory.py +210 -0
  40. airtrain/integrations/fireworks/__init__.py +21 -0
  41. airtrain/integrations/fireworks/completion_skills.py +147 -0
  42. airtrain/integrations/fireworks/conversation_manager.py +109 -0
  43. airtrain/integrations/fireworks/credentials.py +26 -0
  44. airtrain/integrations/fireworks/list_models.py +128 -0
  45. airtrain/integrations/fireworks/models.py +139 -0
  46. airtrain/integrations/fireworks/requests_skills.py +207 -0
  47. airtrain/integrations/fireworks/skills.py +181 -0
  48. airtrain/integrations/fireworks/structured_completion_skills.py +175 -0
  49. airtrain/integrations/fireworks/structured_requests_skills.py +291 -0
  50. airtrain/integrations/fireworks/structured_skills.py +102 -0
  51. airtrain/integrations/google/__init__.py +7 -0
  52. airtrain/integrations/google/credentials.py +58 -0
  53. airtrain/integrations/google/skills.py +122 -0
  54. airtrain/integrations/groq/__init__.py +23 -0
  55. airtrain/integrations/groq/credentials.py +24 -0
  56. airtrain/integrations/groq/models_config.py +162 -0
  57. airtrain/integrations/groq/skills.py +201 -0
  58. airtrain/integrations/ollama/__init__.py +6 -0
  59. airtrain/integrations/ollama/credentials.py +26 -0
  60. airtrain/integrations/ollama/skills.py +41 -0
  61. airtrain/integrations/openai/__init__.py +37 -0
  62. airtrain/integrations/openai/chinese_assistant.py +42 -0
  63. airtrain/integrations/openai/credentials.py +39 -0
  64. airtrain/integrations/openai/list_models.py +112 -0
  65. airtrain/integrations/openai/models_config.py +224 -0
  66. airtrain/integrations/openai/skills.py +342 -0
  67. airtrain/integrations/perplexity/__init__.py +49 -0
  68. airtrain/integrations/perplexity/credentials.py +43 -0
  69. airtrain/integrations/perplexity/list_models.py +112 -0
  70. airtrain/integrations/perplexity/models_config.py +128 -0
  71. airtrain/integrations/perplexity/skills.py +279 -0
  72. airtrain/integrations/sambanova/__init__.py +6 -0
  73. airtrain/integrations/sambanova/credentials.py +20 -0
  74. airtrain/integrations/sambanova/skills.py +129 -0
  75. airtrain/integrations/search/__init__.py +21 -0
  76. airtrain/integrations/search/exa/__init__.py +23 -0
  77. airtrain/integrations/search/exa/credentials.py +30 -0
  78. airtrain/integrations/search/exa/schemas.py +114 -0
  79. airtrain/integrations/search/exa/skills.py +115 -0
  80. airtrain/integrations/together/__init__.py +33 -0
  81. airtrain/integrations/together/audio_models_config.py +34 -0
  82. airtrain/integrations/together/credentials.py +22 -0
  83. airtrain/integrations/together/embedding_models_config.py +92 -0
  84. airtrain/integrations/together/image_models_config.py +69 -0
  85. airtrain/integrations/together/image_skill.py +143 -0
  86. airtrain/integrations/together/list_models.py +76 -0
  87. airtrain/integrations/together/models.py +95 -0
  88. airtrain/integrations/together/models_config.py +399 -0
  89. airtrain/integrations/together/rerank_models_config.py +43 -0
  90. airtrain/integrations/together/rerank_skill.py +49 -0
  91. airtrain/integrations/together/schemas.py +33 -0
  92. airtrain/integrations/together/skills.py +305 -0
  93. airtrain/integrations/together/vision_models_config.py +49 -0
  94. airtrain/telemetry/__init__.py +38 -0
  95. airtrain/telemetry/service.py +167 -0
  96. airtrain/telemetry/views.py +237 -0
  97. airtrain/tools/__init__.py +45 -0
  98. airtrain/tools/command.py +398 -0
  99. airtrain/tools/filesystem.py +166 -0
  100. airtrain/tools/network.py +111 -0
  101. airtrain/tools/registry.py +320 -0
  102. airtrain/tools/search.py +450 -0
  103. airtrain/tools/testing.py +135 -0
  104. airtrain-0.1.4.dist-info/METADATA +222 -0
  105. airtrain-0.1.4.dist-info/RECORD +108 -0
  106. {airtrain-0.1.2.dist-info → airtrain-0.1.4.dist-info}/WHEEL +1 -1
  107. airtrain-0.1.4.dist-info/entry_points.txt +2 -0
  108. airtrain-0.1.2.dist-info/METADATA +0 -106
  109. airtrain-0.1.2.dist-info/RECORD +0 -5
  110. {airtrain-0.1.2.dist-info → airtrain-0.1.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,110 @@
1
+ from typing import Optional, List, Dict, Any
2
+ from pydantic import Field
3
+
4
+ from airtrain.core.skills import Skill, ProcessingError
5
+ from airtrain.core.schemas import InputSchema, OutputSchema
6
+ from .credentials import AnthropicCredentials
7
+ from .models_config import ANTHROPIC_MODELS, AnthropicModelConfig
8
+
9
+
10
+ class AnthropicModel:
11
+ """Class to represent an Anthropic model."""
12
+
13
+ def __init__(self, model_id: str, config: AnthropicModelConfig):
14
+ """Initialize the Anthropic model."""
15
+ self.id = model_id
16
+ self.display_name = config.display_name
17
+ self.base_model = config.base_model
18
+ self.input_price = config.input_price
19
+ self.cached_write_price = config.cached_write_price
20
+ self.cached_read_price = config.cached_read_price
21
+ self.output_price = config.output_price
22
+
23
+ def dict(self, exclude_none=False):
24
+ """Convert the model to a dictionary."""
25
+ result = {
26
+ "id": self.id,
27
+ "display_name": self.display_name,
28
+ "base_model": self.base_model,
29
+ "input_price": float(self.input_price),
30
+ "output_price": float(self.output_price),
31
+ }
32
+
33
+ if self.cached_write_price is not None:
34
+ result["cached_write_price"] = float(self.cached_write_price)
35
+ elif not exclude_none:
36
+ result["cached_write_price"] = None
37
+
38
+ if self.cached_read_price is not None:
39
+ result["cached_read_price"] = float(self.cached_read_price)
40
+ elif not exclude_none:
41
+ result["cached_read_price"] = None
42
+
43
+ return result
44
+
45
+
46
+ class AnthropicListModelsInput(InputSchema):
47
+ """Schema for Anthropic list models input"""
48
+
49
+ api_models_only: bool = Field(
50
+ default=False,
51
+ description=(
52
+ "If True, fetch models from the API only. If False, use local config."
53
+ )
54
+ )
55
+
56
+
57
+ class AnthropicListModelsOutput(OutputSchema):
58
+ """Schema for Anthropic list models output"""
59
+
60
+ models: List[Dict[str, Any]] = Field(
61
+ default_factory=list,
62
+ description="List of Anthropic models"
63
+ )
64
+
65
+
66
+ class AnthropicListModelsSkill(
67
+ Skill[AnthropicListModelsInput, AnthropicListModelsOutput]
68
+ ):
69
+ """Skill for listing Anthropic models"""
70
+
71
+ input_schema = AnthropicListModelsInput
72
+ output_schema = AnthropicListModelsOutput
73
+
74
+ def __init__(self, credentials: Optional[AnthropicCredentials] = None):
75
+ """Initialize the skill with optional credentials"""
76
+ super().__init__()
77
+ self.credentials = credentials
78
+
79
+ def process(
80
+ self, input_data: AnthropicListModelsInput
81
+ ) -> AnthropicListModelsOutput:
82
+ """Process the input and return a list of models."""
83
+ try:
84
+ models = []
85
+
86
+ if input_data.api_models_only:
87
+ # Fetch models from Anthropic API
88
+ # Require credentials if using API models
89
+ if not self.credentials:
90
+ raise ProcessingError(
91
+ "Anthropic credentials required for API models"
92
+ )
93
+
94
+ # Note: Anthropic doesn't have a public models list endpoint
95
+ # We'll raise an error instead
96
+ raise ProcessingError(
97
+ "Anthropic API does not provide a models list endpoint. "
98
+ "Use api_models_only=False to list models from local config."
99
+ )
100
+ else:
101
+ # Use local model config - no credentials needed
102
+ for model_id, config in ANTHROPIC_MODELS.items():
103
+ model = AnthropicModel(model_id, config)
104
+ models.append(model.dict())
105
+
106
+ # Return the output
107
+ return AnthropicListModelsOutput(models=models)
108
+
109
+ except Exception as e:
110
+ raise ProcessingError(f"Failed to list Anthropic models: {str(e)}")
@@ -0,0 +1,100 @@
1
+ from typing import Dict, NamedTuple, Optional
2
+ from decimal import Decimal
3
+
4
+
5
+ class AnthropicModelConfig(NamedTuple):
6
+ display_name: str
7
+ base_model: str
8
+ input_price: Decimal
9
+ cached_write_price: Optional[Decimal]
10
+ cached_read_price: Optional[Decimal]
11
+ output_price: Decimal
12
+
13
+
14
+ ANTHROPIC_MODELS: Dict[str, AnthropicModelConfig] = {
15
+ "claude-3-7-sonnet": AnthropicModelConfig(
16
+ display_name="Claude 3.7 Sonnet",
17
+ base_model="claude-3-7-sonnet",
18
+ input_price=Decimal("3.00"),
19
+ cached_write_price=Decimal("3.75"),
20
+ cached_read_price=Decimal("0.30"),
21
+ output_price=Decimal("15.00"),
22
+ ),
23
+ "claude-3-5-haiku": AnthropicModelConfig(
24
+ display_name="Claude 3.5 Haiku",
25
+ base_model="claude-3-5-haiku",
26
+ input_price=Decimal("0.80"),
27
+ cached_write_price=Decimal("1.00"),
28
+ cached_read_price=Decimal("0.08"),
29
+ output_price=Decimal("4.00"),
30
+ ),
31
+ "claude-3-opus": AnthropicModelConfig(
32
+ display_name="Claude 3 Opus",
33
+ base_model="claude-3-opus",
34
+ input_price=Decimal("15.00"),
35
+ cached_write_price=Decimal("18.75"),
36
+ cached_read_price=Decimal("1.50"),
37
+ output_price=Decimal("75.00"),
38
+ ),
39
+ "claude-3-sonnet": AnthropicModelConfig(
40
+ display_name="Claude 3 Sonnet",
41
+ base_model="claude-3-sonnet",
42
+ input_price=Decimal("3.00"),
43
+ cached_write_price=Decimal("3.75"),
44
+ cached_read_price=Decimal("0.30"),
45
+ output_price=Decimal("15.00"),
46
+ ),
47
+ "claude-3-haiku": AnthropicModelConfig(
48
+ display_name="Claude 3 Haiku",
49
+ base_model="claude-3-haiku",
50
+ input_price=Decimal("0.25"),
51
+ cached_write_price=Decimal("0.31"),
52
+ cached_read_price=Decimal("0.025"),
53
+ output_price=Decimal("1.25"),
54
+ ),
55
+ }
56
+
57
+
58
+ def get_model_config(model_id: str) -> AnthropicModelConfig:
59
+ """Get model configuration by model ID"""
60
+ if model_id not in ANTHROPIC_MODELS:
61
+ raise ValueError(f"Model {model_id} not found in Anthropic models")
62
+ return ANTHROPIC_MODELS[model_id]
63
+
64
+
65
+ def get_default_model() -> str:
66
+ """Get the default model ID"""
67
+ return "claude-3-sonnet"
68
+
69
+
70
+ def calculate_cost(
71
+ model_id: str,
72
+ input_tokens: int,
73
+ output_tokens: int,
74
+ use_cached: bool = False,
75
+ cache_type: str = "read"
76
+ ) -> Decimal:
77
+ """Calculate cost for token usage
78
+
79
+ Args:
80
+ model_id: The model ID to calculate costs for
81
+ input_tokens: Number of input tokens
82
+ output_tokens: Number of output tokens
83
+ use_cached: Whether to use cached pricing
84
+ cache_type: Either "read" or "write" for cached pricing type
85
+ """
86
+ config = get_model_config(model_id)
87
+
88
+ if not use_cached:
89
+ input_cost = config.input_price * Decimal(str(input_tokens))
90
+ else:
91
+ if cache_type == "read" and config.cached_read_price is not None:
92
+ input_cost = config.cached_read_price * Decimal(str(input_tokens))
93
+ elif cache_type == "write" and config.cached_write_price is not None:
94
+ input_cost = config.cached_write_price * Decimal(str(input_tokens))
95
+ else:
96
+ input_cost = config.input_price * Decimal(str(input_tokens))
97
+
98
+ output_cost = config.output_price * Decimal(str(output_tokens))
99
+
100
+ return (input_cost + output_cost) / Decimal("1000")
@@ -0,0 +1,155 @@
1
+ from typing import List, Optional, Dict, Any, Generator
2
+ from pydantic import Field
3
+ from anthropic import Anthropic
4
+ import base64
5
+ from pathlib import Path
6
+ from loguru import logger
7
+
8
+ from airtrain.core.skills import Skill, ProcessingError
9
+ from airtrain.core.schemas import InputSchema, OutputSchema
10
+ from .credentials import AnthropicCredentials
11
+
12
+
13
+ class AnthropicInput(InputSchema):
14
+ """Schema for Anthropic chat input"""
15
+
16
+ user_input: str = Field(..., description="User's input text")
17
+ system_prompt: str = Field(
18
+ default="You are a helpful assistant.",
19
+ description="System prompt to guide the model's behavior",
20
+ )
21
+ conversation_history: List[Dict[str, str]] = Field(
22
+ default_factory=list,
23
+ description="List of previous conversation messages in [{'role': 'user|assistant', 'content': 'message'}] format",
24
+ )
25
+ model: str = Field(
26
+ default="claude-3-opus-20240229", description="Anthropic model to use"
27
+ )
28
+ max_tokens: Optional[int] = Field(
29
+ default=131072, description="Maximum tokens in response"
30
+ )
31
+ temperature: float = Field(
32
+ default=0.7, description="Temperature for response generation", ge=0, le=1
33
+ )
34
+ images: List[Path] = Field(
35
+ default_factory=list,
36
+ description="List of image paths to include in the message",
37
+ )
38
+ stream: bool = Field(
39
+ default=False, description="Whether to stream the response progressively"
40
+ )
41
+
42
+
43
+ class AnthropicOutput(OutputSchema):
44
+ """Schema for Anthropic chat output"""
45
+
46
+ response: str = Field(..., description="Model's response text")
47
+ used_model: str = Field(..., description="Model used for generation")
48
+ usage: Dict[str, Any] = Field(
49
+ default_factory=dict, description="Usage statistics from the API"
50
+ )
51
+
52
+
53
+ class AnthropicChatSkill(Skill[AnthropicInput, AnthropicOutput]):
54
+ """Skill for Anthropic chat"""
55
+
56
+ input_schema = AnthropicInput
57
+ output_schema = AnthropicOutput
58
+
59
+ def __init__(self, credentials: Optional[AnthropicCredentials] = None):
60
+ super().__init__()
61
+ self.credentials = credentials or AnthropicCredentials.from_env()
62
+ self.client = Anthropic(
63
+ api_key=self.credentials.anthropic_api_key.get_secret_value()
64
+ )
65
+
66
+ def _build_messages(self, input_data: AnthropicInput) -> List[Dict[str, Any]]:
67
+ """
68
+ Build messages list from input data including conversation history.
69
+
70
+ Args:
71
+ input_data: The input data containing system prompt, conversation history, and user input
72
+
73
+ Returns:
74
+ List[Dict[str, Any]]: List of messages in the format required by Anthropic
75
+ """
76
+ messages = []
77
+
78
+ # Add conversation history if present
79
+ if input_data.conversation_history:
80
+ messages.extend(input_data.conversation_history)
81
+
82
+ # Prepare user message content
83
+ user_message = {"type": "text", "text": input_data.user_input}
84
+
85
+ # Add images if present
86
+ if input_data.images:
87
+ content = []
88
+ for image_path in input_data.images:
89
+ with open(image_path, "rb") as img_file:
90
+ base64_image = base64.b64encode(img_file.read()).decode("utf-8")
91
+ content.append(
92
+ {
93
+ "type": "image",
94
+ "source": {
95
+ "type": "base64",
96
+ "media_type": "image/jpeg",
97
+ "data": base64_image,
98
+ },
99
+ }
100
+ )
101
+ content.append(user_message)
102
+ messages.append({"role": "user", "content": content})
103
+ else:
104
+ messages.append({"role": "user", "content": [user_message]})
105
+
106
+ return messages
107
+
108
+ def process_stream(self, input_data: AnthropicInput) -> Generator[str, None, None]:
109
+ """Process the input and stream the response token by token."""
110
+ try:
111
+ messages = self._build_messages(input_data)
112
+
113
+ with self.client.beta.messages.stream(
114
+ model=input_data.model,
115
+ system=input_data.system_prompt,
116
+ messages=messages,
117
+ max_tokens=input_data.max_tokens,
118
+ temperature=input_data.temperature,
119
+ ) as stream:
120
+ for chunk in stream.text_stream:
121
+ yield chunk
122
+
123
+ except Exception as e:
124
+ logger.exception(f"Anthropic streaming failed: {str(e)}")
125
+ raise ProcessingError(f"Anthropic streaming failed: {str(e)}")
126
+
127
+ def process(self, input_data: AnthropicInput) -> AnthropicOutput:
128
+ """Process the input and return the complete response."""
129
+ try:
130
+ if input_data.stream:
131
+ response_chunks = []
132
+ for chunk in self.process_stream(input_data):
133
+ response_chunks.append(chunk)
134
+ response = "".join(response_chunks)
135
+ usage = {} # Usage stats not available in streaming
136
+ else:
137
+ messages = self._build_messages(input_data)
138
+ response = self.client.messages.create(
139
+ model=input_data.model,
140
+ system=input_data.system_prompt,
141
+ messages=messages,
142
+ max_tokens=input_data.max_tokens,
143
+ temperature=input_data.temperature,
144
+ )
145
+ usage = response.usage.model_dump() if response.usage else {}
146
+
147
+ return AnthropicOutput(
148
+ response=response.content[0].text,
149
+ used_model=input_data.model,
150
+ usage=usage,
151
+ )
152
+
153
+ except Exception as e:
154
+ logger.exception(f"Anthropic processing failed: {str(e)}")
155
+ raise ProcessingError(f"Anthropic processing failed: {str(e)}")
@@ -0,0 +1,6 @@
1
+ """AWS integration module"""
2
+
3
+ from .credentials import AWSCredentials
4
+ from .skills import AWSBedrockSkill
5
+
6
+ __all__ = ["AWSCredentials", "AWSBedrockSkill"]
@@ -0,0 +1,36 @@
1
+ from typing import Optional
2
+ from pydantic import Field, SecretStr
3
+ from airtrain.core.credentials import BaseCredentials, CredentialValidationError
4
+ import boto3
5
+
6
+
7
+ class AWSCredentials(BaseCredentials):
8
+ """AWS credentials"""
9
+
10
+ aws_access_key_id: SecretStr = Field(..., description="AWS Access Key ID")
11
+ aws_secret_access_key: SecretStr = Field(..., description="AWS Secret Access Key")
12
+ aws_region: str = Field(default="us-east-1", description="AWS Region")
13
+ aws_session_token: Optional[SecretStr] = Field(
14
+ None, description="AWS Session Token"
15
+ )
16
+
17
+ _required_credentials = {"aws_access_key_id", "aws_secret_access_key"}
18
+
19
+ async def validate_credentials(self) -> bool:
20
+ """Validate AWS credentials by making a test API call"""
21
+ try:
22
+ session = boto3.Session(
23
+ aws_access_key_id=self.aws_access_key_id.get_secret_value(),
24
+ aws_secret_access_key=self.aws_secret_access_key.get_secret_value(),
25
+ aws_session_token=(
26
+ self.aws_session_token.get_secret_value()
27
+ if self.aws_session_token
28
+ else None
29
+ ),
30
+ region_name=self.aws_region,
31
+ )
32
+ sts = session.client("sts")
33
+ sts.get_caller_identity()
34
+ return True
35
+ except Exception as e:
36
+ raise CredentialValidationError(f"Invalid AWS credentials: {str(e)}")
@@ -0,0 +1,98 @@
1
+ from typing import List, Optional, Dict, Any
2
+ from pydantic import Field
3
+ import boto3
4
+ from pathlib import Path
5
+ from loguru import logger
6
+
7
+ from airtrain.core.skills import Skill, ProcessingError
8
+ from airtrain.core.schemas import InputSchema, OutputSchema
9
+ from .credentials import AWSCredentials
10
+
11
+
12
+ class AWSBedrockInput(InputSchema):
13
+ """Schema for AWS Bedrock chat input"""
14
+
15
+ user_input: str = Field(..., description="User's input text")
16
+ system_prompt: str = Field(
17
+ default="You are a helpful assistant.",
18
+ description="System prompt to guide the model's behavior",
19
+ )
20
+ model: str = Field(
21
+ default="anthropic.claude-3-sonnet-20240229-v1:0",
22
+ description="AWS Bedrock model to use",
23
+ )
24
+ max_tokens: int = Field(default=131072, description="Maximum tokens in response")
25
+ temperature: float = Field(
26
+ default=0.7, description="Temperature for response generation", ge=0, le=1
27
+ )
28
+ images: Optional[List[Path]] = Field(
29
+ default=None,
30
+ description="Optional list of image paths to include in the message",
31
+ )
32
+
33
+
34
+ class AWSBedrockOutput(OutputSchema):
35
+ """Schema for AWS Bedrock chat output"""
36
+
37
+ response: str = Field(..., description="Model's response text")
38
+ used_model: str = Field(..., description="Model used for generation")
39
+ usage: Dict[str, Any] = Field(
40
+ default_factory=dict, description="Usage statistics from the API"
41
+ )
42
+
43
+
44
+ class AWSBedrockSkill(Skill[AWSBedrockInput, AWSBedrockOutput]):
45
+ """Skill for interacting with AWS Bedrock models"""
46
+
47
+ input_schema = AWSBedrockInput
48
+ output_schema = AWSBedrockOutput
49
+
50
+ def __init__(self, credentials: Optional[AWSCredentials] = None):
51
+ """Initialize the skill with optional credentials"""
52
+ super().__init__()
53
+ self.credentials = credentials or AWSCredentials.from_env()
54
+ self.client = boto3.client(
55
+ "bedrock-runtime",
56
+ aws_access_key_id=self.credentials.aws_access_key_id.get_secret_value(),
57
+ aws_secret_access_key=self.credentials.aws_secret_access_key.get_secret_value(),
58
+ region_name=self.credentials.aws_region,
59
+ )
60
+
61
+ def process(self, input_data: AWSBedrockInput) -> AWSBedrockOutput:
62
+ """Process the input using AWS Bedrock API"""
63
+ try:
64
+ logger.info(f"Processing request with model {input_data.model}")
65
+
66
+ # Prepare request body based on model provider
67
+ if "anthropic" in input_data.model:
68
+ request_body = {
69
+ "anthropic_version": "bedrock-2023-05-31",
70
+ "max_tokens": input_data.max_tokens,
71
+ "temperature": input_data.temperature,
72
+ "system": input_data.system_prompt,
73
+ "messages": [{"role": "user", "content": input_data.user_input}],
74
+ }
75
+ else:
76
+ raise ProcessingError(f"Unsupported model: {input_data.model}")
77
+
78
+ response = self.client.invoke_model(
79
+ modelId=input_data.model, body=request_body
80
+ )
81
+
82
+ # Parse response based on model provider
83
+ if "anthropic" in input_data.model:
84
+ response_data = response["body"]["completion"]
85
+ usage = {
86
+ "input_tokens": response["body"]["usage"]["input_tokens"],
87
+ "output_tokens": response["body"]["usage"]["output_tokens"],
88
+ }
89
+ else:
90
+ raise ProcessingError(f"Unsupported model response: {input_data.model}")
91
+
92
+ return AWSBedrockOutput(
93
+ response=response_data, used_model=input_data.model, usage=usage
94
+ )
95
+
96
+ except Exception as e:
97
+ logger.exception(f"AWS Bedrock processing failed: {str(e)}")
98
+ raise ProcessingError(f"AWS Bedrock processing failed: {str(e)}")
@@ -0,0 +1,6 @@
1
+ """Cerebras integration module"""
2
+
3
+ from .credentials import CerebrasCredentials
4
+ from .skills import CerebrasChatSkill
5
+
6
+ __all__ = ["CerebrasCredentials", "CerebrasChatSkill"]
@@ -0,0 +1,19 @@
1
+ from pydantic import Field, SecretStr
2
+ from airtrain.core.credentials import BaseCredentials, CredentialValidationError
3
+
4
+
5
+ class CerebrasCredentials(BaseCredentials):
6
+ """Cerebras credentials"""
7
+
8
+ cerebras_api_key: SecretStr = Field(..., description="Cerebras API key")
9
+
10
+ _required_credentials = {"cerebras_api_key"}
11
+
12
+ async def validate_credentials(self) -> bool:
13
+ """Validate Cerebras credentials"""
14
+ try:
15
+ # Implement Cerebras-specific validation
16
+ # This would depend on their API client implementation
17
+ return True
18
+ except Exception as e:
19
+ raise CredentialValidationError(f"Invalid Cerebras credentials: {str(e)}")
@@ -0,0 +1,127 @@
1
+ from typing import List, Optional, Dict, Any, Generator
2
+ from pydantic import Field
3
+ from cerebras.cloud.sdk import Cerebras
4
+ from loguru import logger
5
+
6
+ from airtrain.core.skills import Skill, ProcessingError
7
+ from airtrain.core.schemas import InputSchema, OutputSchema
8
+ from .credentials import CerebrasCredentials
9
+
10
+
11
+ class CerebrasInput(InputSchema):
12
+ """Schema for Cerebras chat input"""
13
+
14
+ user_input: str = Field(..., description="User's input text")
15
+ system_prompt: str = Field(
16
+ default="You are a helpful assistant.",
17
+ description="System prompt to guide the model's behavior",
18
+ )
19
+ conversation_history: List[Dict[str, str]] = Field(
20
+ default_factory=list,
21
+ description="List of previous conversation messages in [{'role': 'user|assistant', 'content': 'message'}] format",
22
+ )
23
+ model: str = Field(default="llama3.1-8b", description="Cerebras model to use")
24
+ max_tokens: Optional[int] = Field(
25
+ default=131072, description="Maximum tokens in response"
26
+ )
27
+ temperature: float = Field(
28
+ default=0.7, description="Temperature for response generation", ge=0, le=1
29
+ )
30
+ stream: bool = Field(
31
+ default=False, description="Whether to stream the response progressively"
32
+ )
33
+
34
+
35
+ class CerebrasOutput(OutputSchema):
36
+ """Schema for Cerebras chat output"""
37
+
38
+ response: str = Field(..., description="Model's response text")
39
+ used_model: str = Field(..., description="Model used for generation")
40
+ usage: Dict[str, Any] = Field(default_factory=dict, description="Usage statistics")
41
+
42
+
43
+ class CerebrasChatSkill(Skill[CerebrasInput, CerebrasOutput]):
44
+ """Skill for Cerebras chat"""
45
+
46
+ input_schema = CerebrasInput
47
+ output_schema = CerebrasOutput
48
+
49
+ def __init__(self, credentials: Optional[CerebrasCredentials] = None):
50
+ super().__init__()
51
+ self.credentials = credentials or CerebrasCredentials.from_env()
52
+ self.client = Cerebras(
53
+ api_key=self.credentials.cerebras_api_key.get_secret_value()
54
+ )
55
+
56
+ def _build_messages(self, input_data: CerebrasInput) -> List[Dict[str, str]]:
57
+ """
58
+ Build messages list from input data including conversation history.
59
+
60
+ Args:
61
+ input_data: The input data containing system prompt, conversation history, and user input
62
+
63
+ Returns:
64
+ List[Dict[str, str]]: List of messages in the format required by Cerebras
65
+ """
66
+ messages = [{"role": "system", "content": input_data.system_prompt}]
67
+
68
+ # Add conversation history if present
69
+ if input_data.conversation_history:
70
+ messages.extend(input_data.conversation_history)
71
+
72
+ # Add current user input
73
+ messages.append({"role": "user", "content": input_data.user_input})
74
+
75
+ return messages
76
+
77
+ def process_stream(self, input_data: CerebrasInput) -> Generator[str, None, None]:
78
+ """Process the input and stream the response token by token."""
79
+ try:
80
+ messages = self._build_messages(input_data)
81
+
82
+ stream = self.client.chat.completions.create(
83
+ model=input_data.model,
84
+ messages=messages,
85
+ temperature=input_data.temperature,
86
+ max_tokens=input_data.max_tokens,
87
+ stream=True,
88
+ )
89
+
90
+ for chunk in stream:
91
+ if chunk.choices[0].delta.content is not None:
92
+ yield chunk.choices[0].delta.content
93
+
94
+ except Exception as e:
95
+ logger.exception(f"Cerebras streaming failed: {str(e)}")
96
+ raise ProcessingError(f"Cerebras streaming failed: {str(e)}")
97
+
98
+ def process(self, input_data: CerebrasInput) -> CerebrasOutput:
99
+ """Process the input and return the complete response."""
100
+ try:
101
+ if input_data.stream:
102
+ response_chunks = []
103
+ for chunk in self.process_stream(input_data):
104
+ response_chunks.append(chunk)
105
+ response = "".join(response_chunks)
106
+ usage = {} # Usage stats not available in streaming
107
+ else:
108
+ messages = self._build_messages(input_data)
109
+ response = self.client.chat.completions.create(
110
+ model=input_data.model,
111
+ messages=messages,
112
+ temperature=input_data.temperature,
113
+ max_tokens=input_data.max_tokens,
114
+ )
115
+ usage = (
116
+ response.usage.model_dump() if hasattr(response, "usage") else {}
117
+ )
118
+
119
+ return CerebrasOutput(
120
+ response=response.choices[0].message.content,
121
+ used_model=input_data.model,
122
+ usage=usage,
123
+ )
124
+
125
+ except Exception as e:
126
+ logger.exception(f"Cerebras processing failed: {str(e)}")
127
+ raise ProcessingError(f"Cerebras processing failed: {str(e)}")
@@ -0,0 +1,21 @@
1
+ """Combined integration modules for Airtrain"""
2
+
3
+ from .groq_fireworks_skills import (
4
+ GroqFireworksSkill,
5
+ GroqFireworksInput,
6
+ GroqFireworksOutput
7
+ )
8
+ from .list_models_factory import (
9
+ ListModelsSkillFactory,
10
+ GenericListModelsInput,
11
+ GenericListModelsOutput
12
+ )
13
+
14
+ __all__ = [
15
+ "GroqFireworksSkill",
16
+ "GroqFireworksInput",
17
+ "GroqFireworksOutput",
18
+ "ListModelsSkillFactory",
19
+ "GenericListModelsInput",
20
+ "GenericListModelsOutput"
21
+ ]