airtrain 0.1.2__py3-none-any.whl → 0.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (110) hide show
  1. airtrain/__init__.py +148 -2
  2. airtrain/__main__.py +4 -0
  3. airtrain/__pycache__/__init__.cpython-313.pyc +0 -0
  4. airtrain/agents/__init__.py +45 -0
  5. airtrain/agents/example_agent.py +348 -0
  6. airtrain/agents/groq_agent.py +289 -0
  7. airtrain/agents/memory.py +663 -0
  8. airtrain/agents/registry.py +465 -0
  9. airtrain/builder/__init__.py +3 -0
  10. airtrain/builder/agent_builder.py +122 -0
  11. airtrain/cli/__init__.py +0 -0
  12. airtrain/cli/builder.py +23 -0
  13. airtrain/cli/main.py +120 -0
  14. airtrain/contrib/__init__.py +29 -0
  15. airtrain/contrib/travel/__init__.py +35 -0
  16. airtrain/contrib/travel/agents.py +243 -0
  17. airtrain/contrib/travel/models.py +59 -0
  18. airtrain/core/__init__.py +7 -0
  19. airtrain/core/__pycache__/__init__.cpython-313.pyc +0 -0
  20. airtrain/core/__pycache__/schemas.cpython-313.pyc +0 -0
  21. airtrain/core/__pycache__/skills.cpython-313.pyc +0 -0
  22. airtrain/core/credentials.py +171 -0
  23. airtrain/core/schemas.py +237 -0
  24. airtrain/core/skills.py +269 -0
  25. airtrain/integrations/__init__.py +74 -0
  26. airtrain/integrations/anthropic/__init__.py +33 -0
  27. airtrain/integrations/anthropic/credentials.py +32 -0
  28. airtrain/integrations/anthropic/list_models.py +110 -0
  29. airtrain/integrations/anthropic/models_config.py +100 -0
  30. airtrain/integrations/anthropic/skills.py +155 -0
  31. airtrain/integrations/aws/__init__.py +6 -0
  32. airtrain/integrations/aws/credentials.py +36 -0
  33. airtrain/integrations/aws/skills.py +98 -0
  34. airtrain/integrations/cerebras/__init__.py +6 -0
  35. airtrain/integrations/cerebras/credentials.py +19 -0
  36. airtrain/integrations/cerebras/skills.py +127 -0
  37. airtrain/integrations/combined/__init__.py +21 -0
  38. airtrain/integrations/combined/groq_fireworks_skills.py +126 -0
  39. airtrain/integrations/combined/list_models_factory.py +210 -0
  40. airtrain/integrations/fireworks/__init__.py +21 -0
  41. airtrain/integrations/fireworks/completion_skills.py +147 -0
  42. airtrain/integrations/fireworks/conversation_manager.py +109 -0
  43. airtrain/integrations/fireworks/credentials.py +26 -0
  44. airtrain/integrations/fireworks/list_models.py +128 -0
  45. airtrain/integrations/fireworks/models.py +139 -0
  46. airtrain/integrations/fireworks/requests_skills.py +207 -0
  47. airtrain/integrations/fireworks/skills.py +181 -0
  48. airtrain/integrations/fireworks/structured_completion_skills.py +175 -0
  49. airtrain/integrations/fireworks/structured_requests_skills.py +291 -0
  50. airtrain/integrations/fireworks/structured_skills.py +102 -0
  51. airtrain/integrations/google/__init__.py +7 -0
  52. airtrain/integrations/google/credentials.py +58 -0
  53. airtrain/integrations/google/skills.py +122 -0
  54. airtrain/integrations/groq/__init__.py +23 -0
  55. airtrain/integrations/groq/credentials.py +24 -0
  56. airtrain/integrations/groq/models_config.py +162 -0
  57. airtrain/integrations/groq/skills.py +201 -0
  58. airtrain/integrations/ollama/__init__.py +6 -0
  59. airtrain/integrations/ollama/credentials.py +26 -0
  60. airtrain/integrations/ollama/skills.py +41 -0
  61. airtrain/integrations/openai/__init__.py +37 -0
  62. airtrain/integrations/openai/chinese_assistant.py +42 -0
  63. airtrain/integrations/openai/credentials.py +39 -0
  64. airtrain/integrations/openai/list_models.py +112 -0
  65. airtrain/integrations/openai/models_config.py +224 -0
  66. airtrain/integrations/openai/skills.py +342 -0
  67. airtrain/integrations/perplexity/__init__.py +49 -0
  68. airtrain/integrations/perplexity/credentials.py +43 -0
  69. airtrain/integrations/perplexity/list_models.py +112 -0
  70. airtrain/integrations/perplexity/models_config.py +128 -0
  71. airtrain/integrations/perplexity/skills.py +279 -0
  72. airtrain/integrations/sambanova/__init__.py +6 -0
  73. airtrain/integrations/sambanova/credentials.py +20 -0
  74. airtrain/integrations/sambanova/skills.py +129 -0
  75. airtrain/integrations/search/__init__.py +21 -0
  76. airtrain/integrations/search/exa/__init__.py +23 -0
  77. airtrain/integrations/search/exa/credentials.py +30 -0
  78. airtrain/integrations/search/exa/schemas.py +114 -0
  79. airtrain/integrations/search/exa/skills.py +115 -0
  80. airtrain/integrations/together/__init__.py +33 -0
  81. airtrain/integrations/together/audio_models_config.py +34 -0
  82. airtrain/integrations/together/credentials.py +22 -0
  83. airtrain/integrations/together/embedding_models_config.py +92 -0
  84. airtrain/integrations/together/image_models_config.py +69 -0
  85. airtrain/integrations/together/image_skill.py +143 -0
  86. airtrain/integrations/together/list_models.py +76 -0
  87. airtrain/integrations/together/models.py +95 -0
  88. airtrain/integrations/together/models_config.py +399 -0
  89. airtrain/integrations/together/rerank_models_config.py +43 -0
  90. airtrain/integrations/together/rerank_skill.py +49 -0
  91. airtrain/integrations/together/schemas.py +33 -0
  92. airtrain/integrations/together/skills.py +305 -0
  93. airtrain/integrations/together/vision_models_config.py +49 -0
  94. airtrain/telemetry/__init__.py +38 -0
  95. airtrain/telemetry/service.py +167 -0
  96. airtrain/telemetry/views.py +237 -0
  97. airtrain/tools/__init__.py +45 -0
  98. airtrain/tools/command.py +398 -0
  99. airtrain/tools/filesystem.py +166 -0
  100. airtrain/tools/network.py +111 -0
  101. airtrain/tools/registry.py +320 -0
  102. airtrain/tools/search.py +450 -0
  103. airtrain/tools/testing.py +135 -0
  104. airtrain-0.1.4.dist-info/METADATA +222 -0
  105. airtrain-0.1.4.dist-info/RECORD +108 -0
  106. {airtrain-0.1.2.dist-info → airtrain-0.1.4.dist-info}/WHEEL +1 -1
  107. airtrain-0.1.4.dist-info/entry_points.txt +2 -0
  108. airtrain-0.1.2.dist-info/METADATA +0 -106
  109. airtrain-0.1.2.dist-info/RECORD +0 -5
  110. {airtrain-0.1.2.dist-info → airtrain-0.1.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,237 @@
1
+ from typing import Any, Dict, Optional, Type, Union, cast, get_args, get_origin
2
+ from pydantic import BaseModel, ValidationError, create_model
3
+ import json
4
+ from uuid import UUID, uuid4
5
+
6
+
7
+ class AirtrainSchema(BaseModel):
8
+ """Base schema class for all Airtrain schemas"""
9
+
10
+ _schema_id: Optional[UUID] = None
11
+ _schema_version: str = "1.0.0"
12
+
13
+ @classmethod
14
+ def _extract_field_type(cls, field_props: Dict) -> Type:
15
+ """
16
+ Extract Python type from field properties
17
+
18
+ Args:
19
+ field_props: Field properties from JSON schema
20
+
21
+ Returns:
22
+ Python type for the field
23
+ """
24
+ # Handle direct type specification
25
+ if "type" in field_props:
26
+ return cls._map_json_type_to_python(field_props["type"])
27
+
28
+ # Handle anyOf/oneOf cases
29
+ for union_key in ["anyOf", "oneOf"]:
30
+ if union_key in field_props:
31
+ types = []
32
+ for type_option in field_props[union_key]:
33
+ if "type" in type_option:
34
+ if type_option["type"] == "null":
35
+ types.append(type(None))
36
+ else:
37
+ types.append(
38
+ cls._map_json_type_to_python(type_option["type"])
39
+ )
40
+
41
+ # If we have types, create a Union
42
+ if types:
43
+ return Union[tuple(types)] if len(types) > 1 else types[0]
44
+
45
+ # Default to Any if type cannot be determined
46
+ return Any
47
+
48
+ @classmethod
49
+ def _get_field_config(cls, field_props: Dict) -> tuple:
50
+ """
51
+ Get field type and default value configuration
52
+
53
+ Args:
54
+ field_props: Field properties from JSON schema
55
+
56
+ Returns:
57
+ Tuple of (field_type, field_default)
58
+ """
59
+ field_type = cls._extract_field_type(field_props)
60
+
61
+ # Handle default values
62
+ if "default" in field_props:
63
+ return (field_type, field_props["default"])
64
+
65
+ # Handle Optional/Union types
66
+ if get_origin(field_type) is Union and type(None) in get_args(field_type):
67
+ return (field_type, None)
68
+
69
+ # No default value
70
+ return (field_type, ...)
71
+
72
+ @classmethod
73
+ def from_json_schema(cls, json_schema: str | Dict) -> "AirtrainSchema":
74
+ """
75
+ Create an AirtrainSchema from a JSON schema
76
+
77
+ Args:
78
+ json_schema: JSON schema string or dictionary
79
+
80
+ Returns:
81
+ AirtrainSchema instance
82
+
83
+ Raises:
84
+ ValidationError: If schema is invalid
85
+ """
86
+ if isinstance(json_schema, str):
87
+ json_schema = json.loads(json_schema)
88
+
89
+ # Convert JSON schema to Pydantic model
90
+ assert isinstance(json_schema, dict)
91
+ model_fields = {}
92
+ required_fields = json_schema.get("required", [])
93
+
94
+ for field_name, field_props in json_schema["properties"].items():
95
+ field_type, field_default = cls._get_field_config(field_props)
96
+
97
+ # Override default for required fields
98
+ if field_name in required_fields:
99
+ field_default = ...
100
+
101
+ model_fields[field_name] = (field_type, field_default)
102
+
103
+ # Create dynamic model using create_model
104
+ DynamicSchema = create_model("DynamicSchema", __base__=cls, **model_fields)
105
+
106
+ return cast(AirtrainSchema, DynamicSchema)
107
+
108
+ @classmethod
109
+ def from_pydantic_schema(cls, pydantic_schema: Type[BaseModel]) -> "AirtrainSchema":
110
+ """
111
+ Create an AirtrainSchema from a Pydantic model
112
+
113
+ Args:
114
+ pydantic_schema: Pydantic model class
115
+
116
+ Returns:
117
+ AirtrainSchema instance
118
+ """
119
+ # Get JSON schema from pydantic model
120
+ schema = pydantic_schema.model_json_schema()
121
+
122
+ # Create new schema using from_json_schema
123
+ return cls.from_json_schema(schema)
124
+
125
+ @staticmethod
126
+ def _map_json_type_to_python(json_type: str) -> Type:
127
+ """Map JSON schema types to Python types"""
128
+ type_mapping = {
129
+ "string": str,
130
+ "integer": int,
131
+ "number": float,
132
+ "boolean": bool,
133
+ "array": list,
134
+ "object": dict,
135
+ }
136
+ assert json_type in type_mapping, f"Unsupported JSON type: {json_type}"
137
+ return type_mapping[json_type]
138
+
139
+ def validate_custom(self) -> None:
140
+ """
141
+ Perform custom validation beyond Pydantic's built-in validation
142
+ To be implemented by subclasses
143
+
144
+ Raises:
145
+ ValidationError: If custom validation fails
146
+ """
147
+ pass
148
+
149
+ def validate_all(self) -> None:
150
+ """
151
+ Perform all validations including Pydantic and custom
152
+
153
+ Raises:
154
+ ValidationError: If any validation fails
155
+ """
156
+ # Pydantic validation happens automatically
157
+ try:
158
+ self.validate_custom()
159
+ except Exception as e:
160
+ raise ValidationError(f"Custom validation failed: {str(e)}")
161
+
162
+ def publish(self) -> UUID:
163
+ """
164
+ Publish schema to make it available for use
165
+
166
+ Returns:
167
+ UUID: Unique identifier for the published schema
168
+ """
169
+ if not self._schema_id:
170
+ self._schema_id = uuid4()
171
+ # TODO: Implement actual publishing logic
172
+ return self._schema_id
173
+
174
+ @classmethod
175
+ def get_by_id(cls, schema_id: UUID) -> "AirtrainSchema":
176
+ """
177
+ Retrieve a published schema by ID
178
+
179
+ Args:
180
+ schema_id: UUID of the published schema
181
+
182
+ Returns:
183
+ AirtrainSchema instance
184
+
185
+ Raises:
186
+ ValueError: If schema not found
187
+ """
188
+ # TODO: Implement schema retrieval logic
189
+ raise NotImplementedError("Schema retrieval not implemented yet")
190
+
191
+
192
+ class InputSchema(AirtrainSchema):
193
+ """Schema for task/skill inputs"""
194
+
195
+ def validate_input_specific(self) -> None:
196
+ """
197
+ Perform input-specific validations
198
+ To be implemented by subclasses
199
+
200
+ Raises:
201
+ ValidationError: If validation fails
202
+ """
203
+ pass
204
+
205
+ def validate_custom(self) -> None:
206
+ """
207
+ Override custom validation to include input-specific validation
208
+
209
+ Raises:
210
+ ValidationError: If validation fails
211
+ """
212
+ super().validate_custom()
213
+ self.validate_input_specific()
214
+
215
+
216
+ class OutputSchema(AirtrainSchema):
217
+ """Schema for task/skill outputs"""
218
+
219
+ def validate_output_specific(self) -> None:
220
+ """
221
+ Perform output-specific validations
222
+ To be implemented by subclasses
223
+
224
+ Raises:
225
+ ValidationError: If validation fails
226
+ """
227
+ pass
228
+
229
+ def validate_custom(self) -> None:
230
+ """
231
+ Override custom validation to include output-specific validation
232
+
233
+ Raises:
234
+ ValidationError: If validation fails
235
+ """
236
+ super().validate_custom()
237
+ self.validate_output_specific()
@@ -0,0 +1,269 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import Any, Dict, Optional, Type, Generic, TypeVar
3
+ from uuid import UUID, uuid4
4
+ import time
5
+ import functools
6
+ from .schemas import InputSchema, OutputSchema
7
+
8
+ # Import telemetry
9
+ from airtrain.telemetry import (
10
+ telemetry,
11
+ SkillInitTelemetryEvent,
12
+ SkillProcessTelemetryEvent,
13
+ )
14
+
15
+ # Generic type variables for input and output schemas
16
+ InputT = TypeVar("InputT", bound=InputSchema)
17
+ OutputT = TypeVar("OutputT", bound=OutputSchema)
18
+
19
+
20
+ class Skill(ABC, Generic[InputT, OutputT]):
21
+ """
22
+ Abstract base class for all skills in Airtrain.
23
+ Each skill must define input/output schemas and implement core processing logic.
24
+ """
25
+
26
+ input_schema: Type[InputT]
27
+ output_schema: Type[OutputT]
28
+ _skill_id: Optional[UUID] = None
29
+ _original_process = None
30
+
31
+ def __init__(self):
32
+ """Initialize the skill and capture telemetry."""
33
+ # Initialize skill_id if not already set
34
+ if not self._skill_id:
35
+ self._skill_id = uuid4()
36
+
37
+ # Monkey patch the process method if it hasn't been patched yet
38
+ # This allows us to add telemetry without changing the API
39
+ if not hasattr(self.__class__, '_patched_process'):
40
+ # Store the original process method implementation from this instance
41
+ # This is crucial for proper behavior with inheritance
42
+ self.__class__._original_process = self.__class__.process
43
+
44
+ # Create a wrapper function that will capture telemetry
45
+ def _create_wrapper(original_method):
46
+ @functools.wraps(original_method)
47
+ def wrapped_process(instance, input_data):
48
+ start_time = time.time()
49
+ error = None
50
+
51
+ try:
52
+ # Call the original process method
53
+ result = original_method(instance, input_data)
54
+ return result
55
+ except Exception as e:
56
+ error = str(e)
57
+ raise
58
+ finally:
59
+ duration = time.time() - start_time
60
+
61
+ try:
62
+ # Serialize input data for telemetry
63
+ serialized_input = None
64
+ try:
65
+ # Convert input_data to dict if it's a Pydantic model
66
+ if hasattr(input_data, "dict"):
67
+ serialized_input = input_data.dict()
68
+ # If it's a dataclass
69
+ elif hasattr(input_data, "__dataclass_fields__"):
70
+ from dataclasses import asdict
71
+ serialized_input = asdict(input_data)
72
+ # Fallback
73
+ else:
74
+ serialized_input = {
75
+ "__str__": str(input_data)
76
+ }
77
+ except Exception:
78
+ # If serialization fails, provide simple info
79
+ serialized_input = {"error": "Failed to serialize input data"}
80
+
81
+ telemetry.capture(
82
+ SkillProcessTelemetryEvent(
83
+ skill_id=str(instance.skill_id),
84
+ skill_class=instance.__class__.__name__,
85
+ input_schema=instance.input_schema.__name__,
86
+ output_schema=instance.output_schema.__name__,
87
+ input_data=serialized_input,
88
+ duration_seconds=duration,
89
+ error=error,
90
+ )
91
+ )
92
+ except Exception:
93
+ # Silently continue if telemetry fails
94
+ pass
95
+
96
+ return wrapped_process
97
+
98
+ # Replace the process method with our wrapped version at the class level
99
+ self.__class__.process = _create_wrapper(self.__class__._original_process)
100
+
101
+ # Mark this class as patched to prevent double-patching
102
+ self.__class__._patched_process = True
103
+
104
+ # Capture telemetry for initialization
105
+ try:
106
+ telemetry.capture(
107
+ SkillInitTelemetryEvent(
108
+ skill_id=str(self.skill_id),
109
+ skill_class=self.__class__.__name__,
110
+ )
111
+ )
112
+ except Exception:
113
+ # Silently continue if telemetry fails
114
+ pass
115
+
116
+ @abstractmethod
117
+ def process(self, input_data: InputT) -> OutputT:
118
+ """
119
+ Process the input and generate output according to defined schemas.
120
+
121
+ Args:
122
+ input_data: Validated input conforming to input_schema
123
+
124
+ Returns:
125
+ Output conforming to output_schema
126
+
127
+ Raises:
128
+ ProcessingError: If processing fails
129
+ """
130
+ pass
131
+
132
+ def __call__(self, input_data: InputT) -> OutputT:
133
+ """Make the skill callable, with input/output validation."""
134
+ self.validate_input(input_data)
135
+ result = self.process(input_data)
136
+ self.validate_output(result)
137
+ return result
138
+
139
+ def validate_input(self, input_data: Any) -> None:
140
+ """
141
+ Validate input data before processing.
142
+
143
+ Args:
144
+ input_data: Raw input data
145
+
146
+ Raises:
147
+ InputValidationError: If validation fails
148
+ """
149
+ if not isinstance(input_data, self.input_schema):
150
+ raise InputValidationError(
151
+ f"Input must be an instance of {self.input_schema.__name__}"
152
+ )
153
+ input_data.validate_all()
154
+
155
+ def validate_output(self, output_data: Any) -> None:
156
+ """
157
+ Validate output data after processing.
158
+
159
+ Args:
160
+ output_data: Processed output data
161
+
162
+ Raises:
163
+ OutputValidationError: If validation fails
164
+ """
165
+ if not isinstance(output_data, self.output_schema):
166
+ raise OutputValidationError(
167
+ f"Output must be an instance of {self.output_schema.__name__}"
168
+ )
169
+ output_data.validate_all()
170
+
171
+ def evaluate(self, test_dataset: Optional["Dataset"] = None) -> "EvaluationResult":
172
+ """
173
+ Evaluate skill performance.
174
+
175
+ Args:
176
+ test_dataset: Optional dataset for evaluation
177
+
178
+ Returns:
179
+ EvaluationResult containing metrics
180
+ """
181
+ if not test_dataset:
182
+ test_dataset = self.get_default_test_dataset()
183
+
184
+ results = []
185
+ for test_case in test_dataset:
186
+ try:
187
+ output = self.process(test_case.input)
188
+ results.append(self.compare_output(output, test_case.expected))
189
+ except Exception as e:
190
+ results.append(EvaluationError(str(e)))
191
+
192
+ return EvaluationResult(results)
193
+
194
+ def get_default_test_dataset(self) -> "Dataset":
195
+ """Get default test dataset for evaluation"""
196
+ raise NotImplementedError("No default test dataset provided")
197
+
198
+ def compare_output(self, actual: OutputT, expected: OutputT) -> Dict:
199
+ """
200
+ Compare actual output with expected output
201
+
202
+ Args:
203
+ actual: Actual output from processing
204
+ expected: Expected output from test case
205
+
206
+ Returns:
207
+ Dictionary containing comparison metrics
208
+ """
209
+ raise NotImplementedError("Output comparison not implemented")
210
+
211
+ @property
212
+ def skill_id(self) -> UUID:
213
+ """Unique identifier for the skill"""
214
+ if not self._skill_id:
215
+ self._skill_id = uuid4()
216
+ return self._skill_id
217
+
218
+
219
+ class ProcessingError(Exception):
220
+ """Raised when skill processing fails"""
221
+
222
+ pass
223
+
224
+
225
+ class InputValidationError(Exception):
226
+ """Raised when input validation fails"""
227
+
228
+ pass
229
+
230
+
231
+ class OutputValidationError(Exception):
232
+ """Raised when output validation fails"""
233
+
234
+ pass
235
+
236
+
237
+ class EvaluationError:
238
+ """Represents an error during evaluation"""
239
+
240
+ def __init__(self, message: str):
241
+ self.message = message
242
+
243
+
244
+ class EvaluationResult:
245
+ """Contains results from skill evaluation"""
246
+
247
+ def __init__(self, results: list):
248
+ self.results = results
249
+
250
+ def get_metrics(self) -> Dict:
251
+ """Calculate evaluation metrics"""
252
+ return {
253
+ "total_cases": len(self.results),
254
+ "successful": len(
255
+ [r for r in self.results if not isinstance(r, EvaluationError)]
256
+ ),
257
+ "failed": len([r for r in self.results if isinstance(r, EvaluationError)]),
258
+ "results": self.results,
259
+ }
260
+
261
+
262
+ class Dataset:
263
+ """Represents a test dataset for skill evaluation"""
264
+
265
+ def __init__(self, test_cases: list):
266
+ self.test_cases = test_cases
267
+
268
+ def __iter__(self):
269
+ return iter(self.test_cases)
@@ -0,0 +1,74 @@
1
+ """Airtrain integrations package"""
2
+
3
+ # Credentials imports
4
+ from .openai.credentials import OpenAICredentials
5
+ from .aws.credentials import AWSCredentials
6
+ from .google.credentials import GoogleCloudCredentials
7
+ from .anthropic.credentials import AnthropicCredentials
8
+ from .groq.credentials import GroqCredentials
9
+ from .together.credentials import TogetherAICredentials
10
+ from .ollama.credentials import OllamaCredentials
11
+ from .sambanova.credentials import SambanovaCredentials
12
+ from .cerebras.credentials import CerebrasCredentials
13
+ from .perplexity.credentials import PerplexityCredentials
14
+
15
+ # Skills imports
16
+ from .openai.skills import OpenAIChatSkill, OpenAIParserSkill
17
+ from .anthropic.skills import AnthropicChatSkill
18
+ from .aws.skills import AWSBedrockSkill
19
+ from .google.skills import GoogleChatSkill
20
+ from .groq.skills import GroqChatSkill
21
+ from .together.skills import TogetherAIChatSkill
22
+ from .ollama.skills import OllamaChatSkill
23
+ from .sambanova.skills import SambanovaChatSkill
24
+ from .cerebras.skills import CerebrasChatSkill
25
+ from .perplexity.skills import PerplexityChatSkill, PerplexityStreamingChatSkill
26
+
27
+ # Model configurations
28
+ from .openai.models_config import OPENAI_MODELS, OpenAIModelConfig
29
+ from .anthropic.models_config import ANTHROPIC_MODELS, AnthropicModelConfig
30
+ from .perplexity.models_config import PERPLEXITY_MODELS_CONFIG
31
+
32
+ # Combined modules
33
+ from .combined.list_models_factory import (
34
+ ListModelsSkillFactory,
35
+ GenericListModelsInput,
36
+ GenericListModelsOutput,
37
+ )
38
+
39
+ __all__ = [
40
+ # Credentials
41
+ "OpenAICredentials",
42
+ "AWSCredentials",
43
+ "GoogleCloudCredentials",
44
+ "AnthropicCredentials",
45
+ "GroqCredentials",
46
+ "TogetherAICredentials",
47
+ "OllamaCredentials",
48
+ "SambanovaCredentials",
49
+ "CerebrasCredentials",
50
+ "PerplexityCredentials",
51
+ # Skills
52
+ "OpenAIChatSkill",
53
+ "OpenAIParserSkill",
54
+ "AnthropicChatSkill",
55
+ "AWSBedrockSkill",
56
+ "GoogleChatSkill",
57
+ "GroqChatSkill",
58
+ "TogetherAIChatSkill",
59
+ "OllamaChatSkill",
60
+ "SambanovaChatSkill",
61
+ "CerebrasChatSkill",
62
+ "PerplexityChatSkill",
63
+ "PerplexityStreamingChatSkill",
64
+ # Model configurations
65
+ "OPENAI_MODELS",
66
+ "OpenAIModelConfig",
67
+ "ANTHROPIC_MODELS",
68
+ "AnthropicModelConfig",
69
+ "PERPLEXITY_MODELS_CONFIG",
70
+ # Combined modules
71
+ "ListModelsSkillFactory",
72
+ "GenericListModelsInput",
73
+ "GenericListModelsOutput",
74
+ ]
@@ -0,0 +1,33 @@
1
+ """Anthropic integration for Airtrain"""
2
+
3
+ from .credentials import AnthropicCredentials
4
+ from .skills import AnthropicChatSkill, AnthropicInput, AnthropicOutput
5
+ from .models_config import (
6
+ ANTHROPIC_MODELS,
7
+ AnthropicModelConfig,
8
+ get_model_config,
9
+ get_default_model,
10
+ calculate_cost,
11
+ )
12
+ from .list_models import (
13
+ AnthropicListModelsSkill,
14
+ AnthropicListModelsInput,
15
+ AnthropicListModelsOutput,
16
+ AnthropicModel,
17
+ )
18
+
19
+ __all__ = [
20
+ "AnthropicCredentials",
21
+ "AnthropicChatSkill",
22
+ "AnthropicInput",
23
+ "AnthropicOutput",
24
+ "ANTHROPIC_MODELS",
25
+ "AnthropicModelConfig",
26
+ "get_model_config",
27
+ "get_default_model",
28
+ "calculate_cost",
29
+ "AnthropicListModelsSkill",
30
+ "AnthropicListModelsInput",
31
+ "AnthropicListModelsOutput",
32
+ "AnthropicModel",
33
+ ]
@@ -0,0 +1,32 @@
1
+ from pydantic import Field, SecretStr, validator
2
+ from airtrain.core.credentials import BaseCredentials, CredentialValidationError
3
+ from anthropic import Anthropic
4
+
5
+
6
+ class AnthropicCredentials(BaseCredentials):
7
+ """Anthropic API credentials"""
8
+
9
+ anthropic_api_key: SecretStr = Field(..., description="Anthropic API key")
10
+ version: str = Field(default="2023-06-01", description="API Version")
11
+
12
+ _required_credentials = {"anthropic_api_key"}
13
+
14
+ @validator("anthropic_api_key")
15
+ def validate_api_key_format(cls, v: SecretStr) -> SecretStr:
16
+ key = v.get_secret_value()
17
+ if not key.startswith("sk-ant-"):
18
+ raise ValueError("Anthropic API key must start with 'sk-ant-'")
19
+ return v
20
+
21
+ async def validate_credentials(self) -> bool:
22
+ """Validate Anthropic credentials"""
23
+ try:
24
+ client = Anthropic(api_key=self.anthropic_api_key.get_secret_value())
25
+ client.messages.create(
26
+ model="claude-3-opus-20240229",
27
+ max_tokens=1,
28
+ messages=[{"role": "user", "content": "Hi"}],
29
+ )
30
+ return True
31
+ except Exception as e:
32
+ raise CredentialValidationError(f"Invalid Anthropic credentials: {str(e)}")