airtrain 0.1.0__py3-none-any.whl → 0.1.1__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
airtrain/__init__.py ADDED
@@ -0,0 +1,9 @@
1
+ """Airtrain - A platform for building and deploying AI agents with structured skills"""
2
+
3
+ __version__ = "0.1.1"
4
+
5
+ from .core.skills import Skill
6
+ from .core.schemas import InputSchema, OutputSchema
7
+ from .core.credentials import BaseCredentials
8
+
9
+ __all__ = ["Skill", "InputSchema", "OutputSchema", "BaseCredentials"]
@@ -0,0 +1,7 @@
1
+ """Core modules for Airtrain"""
2
+
3
+ from .skills import Skill, ProcessingError
4
+ from .schemas import InputSchema, OutputSchema
5
+ from .credentials import BaseCredentials
6
+
7
+ __all__ = ["Skill", "ProcessingError", "InputSchema", "OutputSchema", "BaseCredentials"]
@@ -0,0 +1,153 @@
1
+ from typing import Dict, List, Optional, Set
2
+ import os
3
+ import json
4
+ from pathlib import Path
5
+ from abc import ABC, abstractmethod
6
+ import dotenv
7
+ from pydantic import BaseModel, Field, SecretStr
8
+ import yaml # type: ignore
9
+
10
+
11
+ class CredentialError(Exception):
12
+ """Base exception for credential-related errors"""
13
+
14
+ pass
15
+
16
+
17
+ class CredentialNotFoundError(CredentialError):
18
+ """Raised when a required credential is not found"""
19
+
20
+ pass
21
+
22
+
23
+ class CredentialValidationError(CredentialError):
24
+ """Raised when credentials fail validation"""
25
+
26
+ pass
27
+
28
+
29
+ class BaseCredentials(BaseModel):
30
+ """Base class for all credential configurations"""
31
+
32
+ _loaded: bool = False
33
+ _required_credentials: Set[str] = set()
34
+
35
+ def load_to_env(self) -> None:
36
+ """Load credentials into environment variables"""
37
+ for field_name, field_value in self:
38
+ if isinstance(field_value, SecretStr):
39
+ value = field_value.get_secret_value()
40
+ else:
41
+ value = str(field_value)
42
+ os.environ[field_name.upper()] = value
43
+ self._loaded = True
44
+
45
+ @classmethod
46
+ def from_env(cls) -> "BaseCredentials":
47
+ """Create credentials instance from environment variables"""
48
+ field_values = {}
49
+ for field_name in cls.model_fields:
50
+ env_key = field_name.upper()
51
+ if env_value := os.getenv(env_key):
52
+ field_values[field_name] = env_value
53
+ return cls(**field_values)
54
+
55
+ @classmethod
56
+ def from_file(cls, file_path: Path) -> "BaseCredentials":
57
+ """Load credentials from a file (supports .env, .json, .yaml)"""
58
+ if not file_path.exists():
59
+ raise FileNotFoundError(f"Credentials file not found: {file_path}")
60
+
61
+ if file_path.suffix == ".env":
62
+ dotenv.load_dotenv(file_path)
63
+ return cls.from_env()
64
+
65
+ elif file_path.suffix == ".json":
66
+ with open(file_path) as f:
67
+ data = json.load(f)
68
+ return cls(**data)
69
+
70
+ elif file_path.suffix in {".yaml", ".yml"}:
71
+ with open(file_path) as f:
72
+ data = yaml.safe_load(f)
73
+ return cls(**data)
74
+
75
+ else:
76
+ raise ValueError(f"Unsupported file format: {file_path.suffix}")
77
+
78
+ def save_to_file(self, file_path: Path) -> None:
79
+ """Save credentials to a file"""
80
+ data = self.model_dump(exclude={"_loaded"})
81
+
82
+ # Convert SecretStr to plain strings for saving
83
+ for key, value in data.items():
84
+ if isinstance(value, SecretStr):
85
+ data[key] = value.get_secret_value()
86
+
87
+ if file_path.suffix == ".env":
88
+ with open(file_path, "w") as f:
89
+ for key, value in data.items():
90
+ f.write(f"{key.upper()}={value}\n")
91
+
92
+ elif file_path.suffix == ".json":
93
+ with open(file_path, "w") as f:
94
+ json.dump(data, f, indent=2)
95
+
96
+ elif file_path.suffix in {".yaml", ".yml"}:
97
+ with open(file_path, "w") as f:
98
+ yaml.dump(data, f)
99
+
100
+ else:
101
+ raise ValueError(f"Unsupported file format: {file_path.suffix}")
102
+
103
+ def validate_credentials(self) -> None:
104
+ """Validate that all required credentials are present"""
105
+ missing = []
106
+ for field_name in self._required_credentials:
107
+ value = getattr(self, field_name, None)
108
+ if value is None or (
109
+ isinstance(value, SecretStr) and not value.get_secret_value()
110
+ ):
111
+ missing.append(field_name)
112
+
113
+ if missing:
114
+ raise CredentialValidationError(
115
+ f"Missing required credentials: {', '.join(missing)}"
116
+ )
117
+
118
+ def clear_from_env(self) -> None:
119
+ """Remove credentials from environment variables"""
120
+ for field_name in self.model_fields:
121
+ env_key = field_name.upper()
122
+ if env_key in os.environ:
123
+ del os.environ[env_key]
124
+ self._loaded = False
125
+
126
+
127
+ class OpenAICredentials(BaseCredentials):
128
+ """OpenAI API credentials"""
129
+
130
+ api_key: SecretStr = Field(..., description="OpenAI API key")
131
+ organization_id: Optional[str] = Field(None, description="OpenAI organization ID")
132
+
133
+ _required_credentials = {"api_key"}
134
+
135
+
136
+ class AWSCredentials(BaseCredentials):
137
+ """AWS credentials"""
138
+
139
+ aws_access_key_id: SecretStr
140
+ aws_secret_access_key: SecretStr
141
+ aws_region: str = "us-east-1"
142
+ aws_session_token: Optional[SecretStr] = None
143
+
144
+ _required_credentials = {"aws_access_key_id", "aws_secret_access_key"}
145
+
146
+
147
+ class GoogleCloudCredentials(BaseCredentials):
148
+ """Google Cloud credentials"""
149
+
150
+ project_id: str
151
+ service_account_key: SecretStr
152
+
153
+ _required_credentials = {"project_id", "service_account_key"}
@@ -0,0 +1,237 @@
1
+ from typing import Any, Dict, Optional, Type, Union, cast, get_args, get_origin
2
+ from pydantic import BaseModel, ValidationError, create_model
3
+ import json
4
+ from uuid import UUID, uuid4
5
+
6
+
7
+ class AirtrainSchema(BaseModel):
8
+ """Base schema class for all Airtrain schemas"""
9
+
10
+ _schema_id: Optional[UUID] = None
11
+ _schema_version: str = "1.0.0"
12
+
13
+ @classmethod
14
+ def _extract_field_type(cls, field_props: Dict) -> Type:
15
+ """
16
+ Extract Python type from field properties
17
+
18
+ Args:
19
+ field_props: Field properties from JSON schema
20
+
21
+ Returns:
22
+ Python type for the field
23
+ """
24
+ # Handle direct type specification
25
+ if "type" in field_props:
26
+ return cls._map_json_type_to_python(field_props["type"])
27
+
28
+ # Handle anyOf/oneOf cases
29
+ for union_key in ["anyOf", "oneOf"]:
30
+ if union_key in field_props:
31
+ types = []
32
+ for type_option in field_props[union_key]:
33
+ if "type" in type_option:
34
+ if type_option["type"] == "null":
35
+ types.append(type(None))
36
+ else:
37
+ types.append(
38
+ cls._map_json_type_to_python(type_option["type"])
39
+ )
40
+
41
+ # If we have types, create a Union
42
+ if types:
43
+ return Union[tuple(types)] if len(types) > 1 else types[0]
44
+
45
+ # Default to Any if type cannot be determined
46
+ return Any
47
+
48
+ @classmethod
49
+ def _get_field_config(cls, field_props: Dict) -> tuple:
50
+ """
51
+ Get field type and default value configuration
52
+
53
+ Args:
54
+ field_props: Field properties from JSON schema
55
+
56
+ Returns:
57
+ Tuple of (field_type, field_default)
58
+ """
59
+ field_type = cls._extract_field_type(field_props)
60
+
61
+ # Handle default values
62
+ if "default" in field_props:
63
+ return (field_type, field_props["default"])
64
+
65
+ # Handle Optional/Union types
66
+ if get_origin(field_type) is Union and type(None) in get_args(field_type):
67
+ return (field_type, None)
68
+
69
+ # No default value
70
+ return (field_type, ...)
71
+
72
+ @classmethod
73
+ def from_json_schema(cls, json_schema: str | Dict) -> "AirtrainSchema":
74
+ """
75
+ Create an AirtrainSchema from a JSON schema
76
+
77
+ Args:
78
+ json_schema: JSON schema string or dictionary
79
+
80
+ Returns:
81
+ AirtrainSchema instance
82
+
83
+ Raises:
84
+ ValidationError: If schema is invalid
85
+ """
86
+ if isinstance(json_schema, str):
87
+ json_schema = json.loads(json_schema)
88
+
89
+ # Convert JSON schema to Pydantic model
90
+ assert isinstance(json_schema, dict)
91
+ model_fields = {}
92
+ required_fields = json_schema.get("required", [])
93
+
94
+ for field_name, field_props in json_schema["properties"].items():
95
+ field_type, field_default = cls._get_field_config(field_props)
96
+
97
+ # Override default for required fields
98
+ if field_name in required_fields:
99
+ field_default = ...
100
+
101
+ model_fields[field_name] = (field_type, field_default)
102
+
103
+ # Create dynamic model using create_model
104
+ DynamicSchema = create_model("DynamicSchema", __base__=cls, **model_fields)
105
+
106
+ return cast(AirtrainSchema, DynamicSchema)
107
+
108
+ @classmethod
109
+ def from_pydantic_schema(cls, pydantic_schema: Type[BaseModel]) -> "AirtrainSchema":
110
+ """
111
+ Create an AirtrainSchema from a Pydantic model
112
+
113
+ Args:
114
+ pydantic_schema: Pydantic model class
115
+
116
+ Returns:
117
+ AirtrainSchema instance
118
+ """
119
+ # Get JSON schema from pydantic model
120
+ schema = pydantic_schema.model_json_schema()
121
+
122
+ # Create new schema using from_json_schema
123
+ return cls.from_json_schema(schema)
124
+
125
+ @staticmethod
126
+ def _map_json_type_to_python(json_type: str) -> Type:
127
+ """Map JSON schema types to Python types"""
128
+ type_mapping = {
129
+ "string": str,
130
+ "integer": int,
131
+ "number": float,
132
+ "boolean": bool,
133
+ "array": list,
134
+ "object": dict,
135
+ }
136
+ assert json_type in type_mapping, f"Unsupported JSON type: {json_type}"
137
+ return type_mapping[json_type]
138
+
139
+ def validate_custom(self) -> None:
140
+ """
141
+ Perform custom validation beyond Pydantic's built-in validation
142
+ To be implemented by subclasses
143
+
144
+ Raises:
145
+ ValidationError: If custom validation fails
146
+ """
147
+ pass
148
+
149
+ def validate_all(self) -> None:
150
+ """
151
+ Perform all validations including Pydantic and custom
152
+
153
+ Raises:
154
+ ValidationError: If any validation fails
155
+ """
156
+ # Pydantic validation happens automatically
157
+ try:
158
+ self.validate_custom()
159
+ except Exception as e:
160
+ raise ValidationError(f"Custom validation failed: {str(e)}")
161
+
162
+ def publish(self) -> UUID:
163
+ """
164
+ Publish schema to make it available for use
165
+
166
+ Returns:
167
+ UUID: Unique identifier for the published schema
168
+ """
169
+ if not self._schema_id:
170
+ self._schema_id = uuid4()
171
+ # TODO: Implement actual publishing logic
172
+ return self._schema_id
173
+
174
+ @classmethod
175
+ def get_by_id(cls, schema_id: UUID) -> "AirtrainSchema":
176
+ """
177
+ Retrieve a published schema by ID
178
+
179
+ Args:
180
+ schema_id: UUID of the published schema
181
+
182
+ Returns:
183
+ AirtrainSchema instance
184
+
185
+ Raises:
186
+ ValueError: If schema not found
187
+ """
188
+ # TODO: Implement schema retrieval logic
189
+ raise NotImplementedError("Schema retrieval not implemented yet")
190
+
191
+
192
+ class InputSchema(AirtrainSchema):
193
+ """Schema for task/skill inputs"""
194
+
195
+ def validate_input_specific(self) -> None:
196
+ """
197
+ Perform input-specific validations
198
+ To be implemented by subclasses
199
+
200
+ Raises:
201
+ ValidationError: If validation fails
202
+ """
203
+ pass
204
+
205
+ def validate_custom(self) -> None:
206
+ """
207
+ Override custom validation to include input-specific validation
208
+
209
+ Raises:
210
+ ValidationError: If validation fails
211
+ """
212
+ super().validate_custom()
213
+ self.validate_input_specific()
214
+
215
+
216
+ class OutputSchema(AirtrainSchema):
217
+ """Schema for task/skill outputs"""
218
+
219
+ def validate_output_specific(self) -> None:
220
+ """
221
+ Perform output-specific validations
222
+ To be implemented by subclasses
223
+
224
+ Raises:
225
+ ValidationError: If validation fails
226
+ """
227
+ pass
228
+
229
+ def validate_custom(self) -> None:
230
+ """
231
+ Override custom validation to include output-specific validation
232
+
233
+ Raises:
234
+ ValidationError: If validation fails
235
+ """
236
+ super().validate_custom()
237
+ self.validate_output_specific()
@@ -0,0 +1,167 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import Any, Dict, Optional, Type, Generic, TypeVar
3
+ from uuid import UUID, uuid4
4
+ from .schemas import InputSchema, OutputSchema
5
+
6
+ # Generic type variables for input and output schemas
7
+ InputT = TypeVar("InputT", bound=InputSchema)
8
+ OutputT = TypeVar("OutputT", bound=OutputSchema)
9
+
10
+
11
+ class Skill(ABC, Generic[InputT, OutputT]):
12
+ """
13
+ Abstract base class for all skills in Airtrain.
14
+ Each skill must define input/output schemas and implement core processing logic.
15
+ """
16
+
17
+ input_schema: Type[InputT]
18
+ output_schema: Type[OutputT]
19
+ _skill_id: Optional[UUID] = None
20
+
21
+ @abstractmethod
22
+ def process(self, input_data: InputT) -> OutputT:
23
+ """
24
+ Process the input and generate output according to defined schemas.
25
+
26
+ Args:
27
+ input_data: Validated input conforming to input_schema
28
+
29
+ Returns:
30
+ Output conforming to output_schema
31
+
32
+ Raises:
33
+ ProcessingError: If processing fails
34
+ """
35
+ pass
36
+
37
+ def validate_input(self, input_data: Any) -> None:
38
+ """
39
+ Validate input data before processing.
40
+
41
+ Args:
42
+ input_data: Raw input data
43
+
44
+ Raises:
45
+ InputValidationError: If validation fails
46
+ """
47
+ if not isinstance(input_data, self.input_schema):
48
+ raise InputValidationError(
49
+ f"Input must be an instance of {self.input_schema.__name__}"
50
+ )
51
+ input_data.validate_all()
52
+
53
+ def validate_output(self, output_data: Any) -> None:
54
+ """
55
+ Validate output data after processing.
56
+
57
+ Args:
58
+ output_data: Processed output data
59
+
60
+ Raises:
61
+ OutputValidationError: If validation fails
62
+ """
63
+ if not isinstance(output_data, self.output_schema):
64
+ raise OutputValidationError(
65
+ f"Output must be an instance of {self.output_schema.__name__}"
66
+ )
67
+ output_data.validate_all()
68
+
69
+ def evaluate(self, test_dataset: Optional["Dataset"] = None) -> "EvaluationResult":
70
+ """
71
+ Evaluate skill performance.
72
+
73
+ Args:
74
+ test_dataset: Optional dataset for evaluation
75
+
76
+ Returns:
77
+ EvaluationResult containing metrics
78
+ """
79
+ if not test_dataset:
80
+ test_dataset = self.get_default_test_dataset()
81
+
82
+ results = []
83
+ for test_case in test_dataset:
84
+ try:
85
+ output = self.process(test_case.input)
86
+ results.append(self.compare_output(output, test_case.expected))
87
+ except Exception as e:
88
+ results.append(EvaluationError(str(e)))
89
+
90
+ return EvaluationResult(results)
91
+
92
+ def get_default_test_dataset(self) -> "Dataset":
93
+ """Get default test dataset for evaluation"""
94
+ raise NotImplementedError("No default test dataset provided")
95
+
96
+ def compare_output(self, actual: OutputT, expected: OutputT) -> Dict:
97
+ """
98
+ Compare actual output with expected output
99
+
100
+ Args:
101
+ actual: Actual output from processing
102
+ expected: Expected output from test case
103
+
104
+ Returns:
105
+ Dictionary containing comparison metrics
106
+ """
107
+ raise NotImplementedError("Output comparison not implemented")
108
+
109
+ @property
110
+ def skill_id(self) -> UUID:
111
+ """Unique identifier for the skill"""
112
+ if not self._skill_id:
113
+ self._skill_id = uuid4()
114
+ return self._skill_id
115
+
116
+
117
+ class ProcessingError(Exception):
118
+ """Raised when skill processing fails"""
119
+
120
+ pass
121
+
122
+
123
+ class InputValidationError(Exception):
124
+ """Raised when input validation fails"""
125
+
126
+ pass
127
+
128
+
129
+ class OutputValidationError(Exception):
130
+ """Raised when output validation fails"""
131
+
132
+ pass
133
+
134
+
135
+ class EvaluationError:
136
+ """Represents an error during evaluation"""
137
+
138
+ def __init__(self, message: str):
139
+ self.message = message
140
+
141
+
142
+ class EvaluationResult:
143
+ """Contains results from skill evaluation"""
144
+
145
+ def __init__(self, results: list):
146
+ self.results = results
147
+
148
+ def get_metrics(self) -> Dict:
149
+ """Calculate evaluation metrics"""
150
+ return {
151
+ "total_cases": len(self.results),
152
+ "successful": len(
153
+ [r for r in self.results if not isinstance(r, EvaluationError)]
154
+ ),
155
+ "failed": len([r for r in self.results if isinstance(r, EvaluationError)]),
156
+ "results": self.results,
157
+ }
158
+
159
+
160
+ class Dataset:
161
+ """Represents a test dataset for skill evaluation"""
162
+
163
+ def __init__(self, test_cases: list):
164
+ self.test_cases = test_cases
165
+
166
+ def __iter__(self):
167
+ return iter(self.test_cases)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: airtrain
3
- Version: 0.1.0
3
+ Version: 0.1.1
4
4
  Summary: A platform for building and deploying AI agents with structured skills
5
5
  Home-page: https://github.com/rosaboyle/airtrain.dev
6
6
  Author: Dheeraj Pai
@@ -95,7 +95,7 @@ creds.load_to_env()
95
95
 
96
96
  ## Documentation
97
97
 
98
- For detailed documentation, visit [our documentation site](https://airtrain.readthedocs.io/).
98
+ For detailed documentation, visit [our documentation site](https://docs.airtrain.dev/).
99
99
 
100
100
  ## Contributing
101
101
 
@@ -0,0 +1,9 @@
1
+ airtrain/__init__.py,sha256=dqQKBcKKk6Xis8BNi-BygiK1W51cppG3Sh5rdefqBys,312
2
+ airtrain/core/__init__.py,sha256=9h7iKwTzZocCPc9bU6j8bA02BokteWIOcO1uaqGMcrk,254
3
+ airtrain/core/credentials.py,sha256=CzUZkAFxrSMC0nq70zybkkJmeIZDYiNBuzfivOTEgH0,4773
4
+ airtrain/core/schemas.py,sha256=MMXrDviC4gRea_QaPpbjgO--B_UKxnD7YrxqZOLJZZU,7003
5
+ airtrain/core/skills.py,sha256=LljalzeSHK5eQPTAOEAYc5D8Qn1kVSfiz9WgziTD5UM,4688
6
+ airtrain-0.1.1.dist-info/METADATA,sha256=3qlD2n866n3emRI9yEnfYC4dzLhwXGyeb6lqqFnwCyM,2786
7
+ airtrain-0.1.1.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
8
+ airtrain-0.1.1.dist-info/top_level.txt,sha256=cFWW1vY6VMCb3AGVdz6jBDpZ65xxBRSqlsPyySxTkxY,9
9
+ airtrain-0.1.1.dist-info/RECORD,,
@@ -0,0 +1 @@
1
+ airtrain
@@ -1,4 +0,0 @@
1
- airtrain-0.1.0.dist-info/METADATA,sha256=2wLV8r1XTMiq6YUJpSCANE5A2KbXlfDUjSAxHOwq8iw,2792
2
- airtrain-0.1.0.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
3
- airtrain-0.1.0.dist-info/top_level.txt,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
4
- airtrain-0.1.0.dist-info/RECORD,,
@@ -1 +0,0 @@
1
-