pyrmute 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pyrmute/__init__.py ADDED
@@ -0,0 +1,20 @@
1
+ """pyrmute - versioned Pydantic models and schemas with migrations.
2
+
3
+ A package for managing versioned Pydantic models with automatic migrations
4
+ and schema management.
5
+ """
6
+
7
+ from ._version import __version__
8
+ from .model_manager import ModelManager
9
+ from .model_version import ModelVersion
10
+ from .types import JsonSchema, MigrationData, MigrationFunc, ModelMetadata
11
+
12
+ __all__ = [
13
+ "JsonSchema",
14
+ "MigrationData",
15
+ "MigrationFunc",
16
+ "ModelManager",
17
+ "ModelMetadata",
18
+ "ModelVersion",
19
+ "__version__",
20
+ ]
@@ -0,0 +1,306 @@
1
+ """Migrations manager."""
2
+
3
+ from collections.abc import Callable
4
+ from typing import Any, Self, get_args, get_origin
5
+
6
+ from pydantic import BaseModel
7
+ from pydantic.fields import FieldInfo
8
+
9
+ from ._registry import Registry
10
+ from .model_version import ModelVersion
11
+ from .types import MigrationData, MigrationFunc, ModelName
12
+
13
+
14
+ class MigrationManager:
15
+ """Manager for data migrations between model versions.
16
+
17
+ Handles registration and execution of migration functions, including
18
+ support for nested Pydantic models.
19
+
20
+ Attributes:
21
+ registry: Reference to the VersionedModelRegistry.
22
+ """
23
+
24
+ def __init__(self: Self, registry: Registry) -> None:
25
+ """Initialize the migration manager.
26
+
27
+ Args:
28
+ registry: Registry instance to use.
29
+ """
30
+ self.registry = registry
31
+
32
+ def register_migration(
33
+ self: Self,
34
+ name: ModelName,
35
+ from_version: str | ModelVersion,
36
+ to_version: str | ModelVersion,
37
+ ) -> Callable[[MigrationFunc], MigrationFunc]:
38
+ """Register a migration function between two versions.
39
+
40
+ Args:
41
+ name: Name of the model.
42
+ from_version: Source version for migration.
43
+ to_version: Target version for migration.
44
+
45
+ Returns:
46
+ Decorator function for migration function.
47
+
48
+ Example:
49
+ >>> manager = MigrationManager(registry)
50
+ >>> @manager.register_migration("User", "1.0.0", "2.0.0")
51
+ ... def migrate_v1_to_v2(data: dict[str, Any]) -> dict[str, Any]:
52
+ ... return {**data, "email": "unknown@example.com"}
53
+ """
54
+ from_ver = (
55
+ ModelVersion.parse(from_version)
56
+ if isinstance(from_version, str)
57
+ else from_version
58
+ )
59
+ to_ver = (
60
+ ModelVersion.parse(to_version)
61
+ if isinstance(to_version, str)
62
+ else to_version
63
+ )
64
+
65
+ def decorator(func: MigrationFunc) -> MigrationFunc:
66
+ self.registry._migrations[name][(from_ver, to_ver)] = func
67
+ return func
68
+
69
+ return decorator
70
+
71
+ def migrate(
72
+ self: Self,
73
+ data: MigrationData,
74
+ name: ModelName,
75
+ from_version: str | ModelVersion,
76
+ to_version: str | ModelVersion,
77
+ ) -> MigrationData:
78
+ """Migrate data from one version to another.
79
+
80
+ Args:
81
+ data: Data dictionary to migrate.
82
+ name: Name of the model.
83
+ from_version: Source version.
84
+ to_version: Target version.
85
+
86
+ Returns:
87
+ Migrated data dictionary.
88
+
89
+ Raises:
90
+ ValueError: If migration path cannot be found.
91
+ """
92
+ from_ver = (
93
+ ModelVersion.parse(from_version)
94
+ if isinstance(from_version, str)
95
+ else from_version
96
+ )
97
+ to_ver = (
98
+ ModelVersion.parse(to_version)
99
+ if isinstance(to_version, str)
100
+ else to_version
101
+ )
102
+
103
+ if from_ver == to_ver:
104
+ return data
105
+
106
+ path = self._find_migration_path(name, from_ver, to_ver)
107
+
108
+ current_data = data
109
+ for i in range(len(path) - 1):
110
+ migration_key = (path[i], path[i + 1])
111
+
112
+ if migration_key in self.registry._migrations[name]:
113
+ migration_func = self.registry._migrations[name][migration_key]
114
+ current_data = migration_func(current_data)
115
+ else:
116
+ current_data = self._auto_migrate(
117
+ current_data, name, path[i], path[i + 1]
118
+ )
119
+
120
+ return current_data
121
+
122
+ def _find_migration_path(
123
+ self: Self,
124
+ name: ModelName,
125
+ from_ver: ModelVersion,
126
+ to_ver: ModelVersion,
127
+ ) -> list[ModelVersion]:
128
+ """Find migration path between versions.
129
+
130
+ Args:
131
+ name: Name of the model.
132
+ from_ver: Source version.
133
+ to_ver: Target version.
134
+
135
+ Returns:
136
+ List of versions forming the migration path.
137
+ """
138
+ versions = sorted(self.registry.get_versions(name))
139
+
140
+ if from_ver not in versions or to_ver not in versions:
141
+ raise ValueError(f"Invalid version range for {name}")
142
+
143
+ from_idx = versions.index(from_ver)
144
+ to_idx = versions.index(to_ver)
145
+
146
+ if from_idx < to_idx:
147
+ return versions[from_idx : to_idx + 1]
148
+ return versions[to_idx : from_idx + 1][::-1]
149
+
150
+ def _auto_migrate(
151
+ self: Self,
152
+ data: MigrationData,
153
+ name: ModelName,
154
+ from_ver: ModelVersion,
155
+ to_ver: ModelVersion,
156
+ ) -> MigrationData:
157
+ """Automatically migrate data when no explicit migration exists.
158
+
159
+ This method handles nested Pydantic models recursively, migrating
160
+ them to their corresponding versions.
161
+
162
+ Args:
163
+ data: Data dictionary to migrate.
164
+ name: Name of the model.
165
+ from_ver: Source version.
166
+ to_ver: Target version.
167
+
168
+ Returns:
169
+ Migrated data dictionary.
170
+ """
171
+ from_model = self.registry.get_model(name, from_ver)
172
+ to_model = self.registry.get_model(name, to_ver)
173
+
174
+ from_fields = from_model.model_fields
175
+ to_fields = to_model.model_fields
176
+
177
+ result: MigrationData = {}
178
+
179
+ for field_name, to_field_info in to_fields.items():
180
+ if field_name not in data:
181
+ continue
182
+
183
+ value = data[field_name]
184
+
185
+ # Get corresponding from_field if it exists
186
+ from_field_info = from_fields.get(field_name)
187
+
188
+ # Migrate the field value (handles nested models)
189
+ result[field_name] = self._migrate_field_value(
190
+ value, from_field_info, to_field_info
191
+ )
192
+
193
+ return result
194
+
195
+ def _migrate_field_value(
196
+ self: Self,
197
+ value: Any,
198
+ from_field: FieldInfo | None,
199
+ to_field: FieldInfo,
200
+ ) -> Any:
201
+ """Migrate a single field value, handling nested models.
202
+
203
+ Args:
204
+ value: The field value to migrate.
205
+ from_field: Source field info (None if field is new).
206
+ to_field: Target field info.
207
+
208
+ Returns:
209
+ Migrated field value.
210
+ """
211
+ if value is None:
212
+ return None
213
+
214
+ # Check if this is a nested Pydantic model
215
+ if isinstance(value, dict):
216
+ nested_info = self._extract_nested_model_info(value, from_field, to_field)
217
+ if nested_info:
218
+ nested_name, nested_from_ver, nested_to_ver = nested_info
219
+ return self.migrate(value, nested_name, nested_from_ver, nested_to_ver)
220
+
221
+ # Try to recursively migrate dict values
222
+ return {
223
+ k: self._migrate_field_value(v, from_field, to_field)
224
+ for k, v in value.items()
225
+ }
226
+
227
+ # Handle lists
228
+ if isinstance(value, list):
229
+ return [
230
+ self._migrate_field_value(item, from_field, to_field) for item in value
231
+ ]
232
+
233
+ return value
234
+
235
+ def _extract_nested_model_info(
236
+ self: Self,
237
+ value: MigrationData,
238
+ from_field: FieldInfo | None,
239
+ to_field: FieldInfo,
240
+ ) -> tuple[ModelName, ModelVersion, ModelVersion] | None:
241
+ """Extract nested model migration information.
242
+
243
+ Args:
244
+ value: The nested model data.
245
+ from_field: Source field info.
246
+ to_field: Target field info.
247
+
248
+ Returns:
249
+ Tuple of (model_name, from_version, to_version) if this is a
250
+ versioned nested model, None otherwise.
251
+ """
252
+ # Get the target model type
253
+ to_model_type = self._get_model_type_from_field(to_field)
254
+ if not to_model_type or not issubclass(to_model_type, BaseModel):
255
+ return None
256
+
257
+ # Check if target model is registered
258
+ to_info = self.registry.get_model_info(to_model_type)
259
+ if not to_info:
260
+ return None
261
+
262
+ model_name, to_version = to_info
263
+
264
+ # Get the source version
265
+ if from_field:
266
+ from_model_type = self._get_model_type_from_field(from_field)
267
+ if from_model_type and issubclass(from_model_type, BaseModel):
268
+ from_info = self.registry.get_model_info(from_model_type)
269
+ if from_info and from_info[0] == model_name:
270
+ from_version = from_info[1]
271
+ return (model_name, from_version, to_version)
272
+
273
+ # If we can't determine the source version, assume it's the same as target
274
+ return (model_name, to_version, to_version)
275
+
276
+ def _get_model_type_from_field(
277
+ self: Self, field: FieldInfo
278
+ ) -> type[BaseModel] | None:
279
+ """Extract the Pydantic model type from a field.
280
+
281
+ Handles Optional, List, and other generic types.
282
+
283
+ Args:
284
+ field: The field info to extract from.
285
+
286
+ Returns:
287
+ The model type if found, None otherwise.
288
+ """
289
+ annotation = field.annotation
290
+
291
+ if annotation is None:
292
+ return None
293
+
294
+ # Handle direct model types
295
+ if isinstance(annotation, type) and issubclass(annotation, BaseModel):
296
+ return annotation
297
+
298
+ # Handle Optional, List, etc.
299
+ origin = get_origin(annotation)
300
+ if origin is not None:
301
+ args = get_args(annotation)
302
+ for arg in args:
303
+ if isinstance(arg, type) and issubclass(arg, BaseModel):
304
+ return arg
305
+
306
+ return None
pyrmute/_registry.py ADDED
@@ -0,0 +1,172 @@
1
+ """Model registry."""
2
+
3
+ from collections import defaultdict
4
+ from collections.abc import Callable
5
+ from typing import Self
6
+
7
+ from pydantic import BaseModel
8
+
9
+ from .model_version import ModelVersion
10
+ from .types import (
11
+ DecoratedBaseModel,
12
+ JsonSchemaGenerator,
13
+ MigrationMap,
14
+ ModelMetadata,
15
+ ModelName,
16
+ SchemaGenerators,
17
+ VersionedModels,
18
+ )
19
+
20
+
21
+ class Registry:
22
+ """Registry for versioned Pydantic models.
23
+
24
+ Manages the registration and retrieval of versioned models and their
25
+ associated metadata.
26
+
27
+ Attributes:
28
+ _models: Dictionary mapping model names to version-model mappings.
29
+ _migrations: Dictionary storing migration functions between versions.
30
+ _schema_generators: Dictionary storing custom schema generators.
31
+ _model_metadata: Dictionary mapping model classes to (name, version).
32
+ _ref_enabled: Dictionary tracking which models have enable_ref=True.
33
+ """
34
+
35
+ def __init__(self: Self) -> None:
36
+ """Initialize the model registry."""
37
+ self._models: dict[ModelName, VersionedModels] = defaultdict(dict)
38
+ self._migrations: dict[ModelName, MigrationMap] = defaultdict(dict)
39
+ self._schema_generators: dict[ModelName, SchemaGenerators] = defaultdict(dict)
40
+ self._model_metadata: dict[type[BaseModel], ModelMetadata] = {}
41
+ self._ref_enabled: dict[ModelName, set[ModelVersion]] = defaultdict(set)
42
+
43
+ def register(
44
+ self: Self,
45
+ name: ModelName,
46
+ version: str | ModelVersion,
47
+ schema_generator: JsonSchemaGenerator | None = None,
48
+ enable_ref: bool = False,
49
+ ) -> Callable[[type[DecoratedBaseModel]], type[DecoratedBaseModel]]:
50
+ """Register a versioned model.
51
+
52
+ Args:
53
+ name: Name of the model.
54
+ version: Semantic version string or ModelVersion instance.
55
+ schema_generator: Optional custom schema generator function.
56
+ enable_ref: If True, this model can be referenced via $ref in
57
+ separate schema files. If False, it will always be inlined.
58
+
59
+ Returns:
60
+ Decorator function for model class.
61
+
62
+ Example:
63
+ >>> registry = Registry()
64
+ >>> @registry.register("User", "1.0.0", enable_ref=True)
65
+ ... class UserV1(BaseModel):
66
+ ... name: str
67
+ """
68
+ ver = ModelVersion.parse(version) if isinstance(version, str) else version
69
+
70
+ def decorator(cls: type[DecoratedBaseModel]) -> type[DecoratedBaseModel]:
71
+ self._models[name][ver] = cls
72
+ self._model_metadata[cls] = (name, ver)
73
+ if schema_generator:
74
+ self._schema_generators[name][ver] = schema_generator
75
+ if enable_ref:
76
+ self._ref_enabled[name].add(ver)
77
+ return cls
78
+
79
+ return decorator
80
+
81
+ def get_model(
82
+ self: Self, name: ModelName, version: str | ModelVersion
83
+ ) -> type[BaseModel]:
84
+ """Get a specific version of a model.
85
+
86
+ Args:
87
+ name: Name of the model.
88
+ version: Semantic version string or ModelVersion instance.
89
+
90
+ Returns:
91
+ Model class for the specified version.
92
+
93
+ Raises:
94
+ ValueError: If model or version not found.
95
+ """
96
+ ver = ModelVersion.parse(version) if isinstance(version, str) else version
97
+
98
+ if name not in self._models or ver not in self._models[name]:
99
+ raise ValueError(f"Model {name} v{ver} not found")
100
+
101
+ return self._models[name][ver]
102
+
103
+ def get_latest(self: Self, name: ModelName) -> type[BaseModel]:
104
+ """Get the latest version of a model.
105
+
106
+ Args:
107
+ name: Name of the model.
108
+
109
+ Returns:
110
+ Latest version of the model class.
111
+
112
+ Raises:
113
+ ValueError: If model not found.
114
+ """
115
+ if name not in self._models:
116
+ raise ValueError(f"Model {name} not found")
117
+
118
+ latest_version = max(self._models[name].keys())
119
+ return self._models[name][latest_version]
120
+
121
+ def get_versions(self: Self, name: ModelName) -> list[ModelVersion]:
122
+ """Get all versions available for a model.
123
+
124
+ Args:
125
+ name: Name of the model.
126
+
127
+ Returns:
128
+ Sorted list of available versions.
129
+
130
+ Raises:
131
+ ValueError: If model not found.
132
+ """
133
+ if name not in self._models:
134
+ raise ValueError(f"Model {name} not found")
135
+
136
+ return sorted(self._models[name].keys())
137
+
138
+ def list_models(self: Self) -> list[ModelName]:
139
+ """Get list of all registered model names.
140
+
141
+ Returns:
142
+ List of model names.
143
+ """
144
+ return list(self._models.keys())
145
+
146
+ def get_model_info(
147
+ self: Self, model_class: type[BaseModel]
148
+ ) -> ModelMetadata | None:
149
+ """Get the name and version for a registered model class.
150
+
151
+ Args:
152
+ model_class: The model class to look up.
153
+
154
+ Returns:
155
+ Tuple of (name, version) if found, None otherwise.
156
+ """
157
+ return self._model_metadata.get(model_class)
158
+
159
+ def is_ref_enabled(
160
+ self: Self, name: ModelName, version: str | ModelVersion
161
+ ) -> bool:
162
+ """Check if a model version is enabled for $ref usage.
163
+
164
+ Args:
165
+ name: Name of the model.
166
+ version: Semantic version.
167
+
168
+ Returns:
169
+ True if this model can be referenced via $ref, False otherwise.
170
+ """
171
+ ver = ModelVersion.parse(version) if isinstance(version, str) else version
172
+ return ver in self._ref_enabled.get(name, set())