pyrmute 0.2.0__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyrmute/__init__.py +13 -11
- pyrmute/_migration_manager.py +517 -132
- pyrmute/_schema_manager.py +17 -7
- pyrmute/_version.py +2 -2
- pyrmute/migration_testing.py +8 -5
- pyrmute/model_manager.py +46 -66
- pyrmute/types.py +11 -4
- {pyrmute-0.2.0.dist-info → pyrmute-0.4.0.dist-info}/METADATA +229 -70
- pyrmute-0.4.0.dist-info/RECORD +17 -0
- pyrmute-0.2.0.dist-info/RECORD +0 -17
- {pyrmute-0.2.0.dist-info → pyrmute-0.4.0.dist-info}/WHEEL +0 -0
- {pyrmute-0.2.0.dist-info → pyrmute-0.4.0.dist-info}/licenses/LICENSE +0 -0
- {pyrmute-0.2.0.dist-info → pyrmute-0.4.0.dist-info}/top_level.txt +0 -0
pyrmute/_migration_manager.py
CHANGED
@@ -1,8 +1,9 @@
|
|
1
1
|
"""Migrations manager."""
|
2
2
|
|
3
3
|
import contextlib
|
4
|
+
import types
|
4
5
|
from collections.abc import Callable
|
5
|
-
from typing import Any, Self, get_args, get_origin
|
6
|
+
from typing import Annotated, Any, Literal, Self, Union, get_args, get_origin
|
6
7
|
|
7
8
|
from pydantic import BaseModel
|
8
9
|
from pydantic.fields import FieldInfo
|
@@ -11,7 +12,7 @@ from pydantic_core import PydanticUndefined
|
|
11
12
|
from ._registry import Registry
|
12
13
|
from .exceptions import MigrationError, ModelNotFoundError
|
13
14
|
from .model_version import ModelVersion
|
14
|
-
from .types import
|
15
|
+
from .types import MigrationFunc, ModelData, ModelName
|
15
16
|
|
16
17
|
|
17
18
|
class MigrationManager:
|
@@ -54,16 +55,8 @@ class MigrationManager:
|
|
54
55
|
... def migrate_v1_to_v2(data: dict[str, Any]) -> dict[str, Any]:
|
55
56
|
... return {**data, "email": "unknown@example.com"}
|
56
57
|
"""
|
57
|
-
from_ver = (
|
58
|
-
|
59
|
-
if isinstance(from_version, str)
|
60
|
-
else from_version
|
61
|
-
)
|
62
|
-
to_ver = (
|
63
|
-
ModelVersion.parse(to_version)
|
64
|
-
if isinstance(to_version, str)
|
65
|
-
else to_version
|
66
|
-
)
|
58
|
+
from_ver = self._parse_version(from_version)
|
59
|
+
to_ver = self._parse_version(to_version)
|
67
60
|
|
68
61
|
def decorator(func: MigrationFunc) -> MigrationFunc:
|
69
62
|
self.registry._migrations[name][(from_ver, to_ver)] = func
|
@@ -73,11 +66,11 @@ class MigrationManager:
|
|
73
66
|
|
74
67
|
def migrate(
|
75
68
|
self: Self,
|
76
|
-
data:
|
69
|
+
data: ModelData,
|
77
70
|
name: ModelName,
|
78
71
|
from_version: str | ModelVersion,
|
79
72
|
to_version: str | ModelVersion,
|
80
|
-
) ->
|
73
|
+
) -> ModelData:
|
81
74
|
"""Migrate data from one version to another.
|
82
75
|
|
83
76
|
Args:
|
@@ -93,61 +86,14 @@ class MigrationManager:
|
|
93
86
|
ModelNotFoundError: If model or versions don't exist.
|
94
87
|
MigrationError: If migration path cannot be found.
|
95
88
|
"""
|
96
|
-
from_ver = (
|
97
|
-
|
98
|
-
if isinstance(from_version, str)
|
99
|
-
else from_version
|
100
|
-
)
|
101
|
-
to_ver = (
|
102
|
-
ModelVersion.parse(to_version)
|
103
|
-
if isinstance(to_version, str)
|
104
|
-
else to_version
|
105
|
-
)
|
89
|
+
from_ver = self._parse_version(from_version)
|
90
|
+
to_ver = self._parse_version(to_version)
|
106
91
|
|
107
92
|
if from_ver == to_ver:
|
108
93
|
return data
|
109
94
|
|
110
95
|
path = self.find_migration_path(name, from_ver, to_ver)
|
111
|
-
|
112
|
-
current_data = data
|
113
|
-
for i in range(len(path) - 1):
|
114
|
-
migration_key = (path[i], path[i + 1])
|
115
|
-
|
116
|
-
if migration_key in self.registry._migrations[name]:
|
117
|
-
migration_func = self.registry._migrations[name][migration_key]
|
118
|
-
try:
|
119
|
-
current_data = migration_func(current_data)
|
120
|
-
except Exception as e:
|
121
|
-
raise MigrationError(
|
122
|
-
name,
|
123
|
-
str(path[i]),
|
124
|
-
str(path[i + 1]),
|
125
|
-
f"Migration function raised: {type(e).__name__}: {e}",
|
126
|
-
) from e
|
127
|
-
elif path[i + 1] in self.registry._backward_compatible_enabled[name]:
|
128
|
-
try:
|
129
|
-
current_data = self._auto_migrate(
|
130
|
-
current_data, name, path[i], path[i + 1]
|
131
|
-
)
|
132
|
-
except Exception as e:
|
133
|
-
raise MigrationError(
|
134
|
-
name,
|
135
|
-
str(path[i]),
|
136
|
-
str(path[i + 1]),
|
137
|
-
f"Auto-migration failed: {type(e).__name__}: {e}",
|
138
|
-
) from e
|
139
|
-
else:
|
140
|
-
raise MigrationError(
|
141
|
-
name,
|
142
|
-
str(path[i]),
|
143
|
-
str(path[i + 1]),
|
144
|
-
(
|
145
|
-
"No migration path found. Define a migration function or mark "
|
146
|
-
"the target version as backward_compatible."
|
147
|
-
),
|
148
|
-
)
|
149
|
-
|
150
|
-
return current_data
|
96
|
+
return self._execute_migration_path(data, name, path)
|
151
97
|
|
152
98
|
def find_migration_path(
|
153
99
|
self: Self,
|
@@ -168,6 +114,9 @@ class MigrationManager:
|
|
168
114
|
Raises:
|
169
115
|
ModelNotFoundError: If the model or versions don't exist.
|
170
116
|
"""
|
117
|
+
if (from_ver, to_ver) in self.registry._migrations.get(name, {}):
|
118
|
+
return [from_ver, to_ver]
|
119
|
+
|
171
120
|
versions = sorted(self.registry.get_versions(name))
|
172
121
|
|
173
122
|
if from_ver not in versions:
|
@@ -204,14 +153,8 @@ class MigrationManager:
|
|
204
153
|
for i in range(len(path) - 1):
|
205
154
|
current_ver = path[i]
|
206
155
|
next_ver = path[i + 1]
|
207
|
-
migration_key = (current_ver, next_ver)
|
208
156
|
|
209
|
-
|
210
|
-
has_auto = next_ver in self.registry._backward_compatible_enabled.get(
|
211
|
-
name, set()
|
212
|
-
)
|
213
|
-
|
214
|
-
if not has_explicit and not has_auto:
|
157
|
+
if not self._has_migration_step(name, current_ver, next_ver):
|
215
158
|
raise MigrationError(
|
216
159
|
name,
|
217
160
|
str(current_ver),
|
@@ -222,17 +165,96 @@ class MigrationManager:
|
|
222
165
|
),
|
223
166
|
)
|
224
167
|
|
168
|
+
def _parse_version(self: Self, version: str | ModelVersion) -> ModelVersion:
|
169
|
+
"""Parse version string or return ModelVersion as-is."""
|
170
|
+
return ModelVersion.parse(version) if isinstance(version, str) else version
|
171
|
+
|
172
|
+
def _has_migration_step(
|
173
|
+
self: Self, name: ModelName, from_ver: ModelVersion, to_ver: ModelVersion
|
174
|
+
) -> bool:
|
175
|
+
"""Check if a migration step exists (explicit or auto)."""
|
176
|
+
migration_key = (from_ver, to_ver)
|
177
|
+
has_explicit = migration_key in self.registry._migrations.get(name, {})
|
178
|
+
has_auto = to_ver in self.registry._backward_compatible_enabled.get(name, set())
|
179
|
+
return has_explicit or has_auto
|
180
|
+
|
181
|
+
def _execute_migration_path(
|
182
|
+
self: Self, data: ModelData, name: ModelName, path: list[ModelVersion]
|
183
|
+
) -> ModelData:
|
184
|
+
"""Execute migration through a path of versions."""
|
185
|
+
current_data = data
|
186
|
+
|
187
|
+
for i in range(len(path) - 1):
|
188
|
+
try:
|
189
|
+
current_data = self._execute_single_step(
|
190
|
+
current_data, name, path[i], path[i + 1]
|
191
|
+
)
|
192
|
+
except Exception as e:
|
193
|
+
if isinstance(e, MigrationError):
|
194
|
+
raise
|
195
|
+
raise MigrationError(
|
196
|
+
name,
|
197
|
+
str(path[i]),
|
198
|
+
str(path[i + 1]),
|
199
|
+
f"Migration failed: {type(e).__name__}: {e}",
|
200
|
+
) from e
|
201
|
+
|
202
|
+
return current_data
|
203
|
+
|
204
|
+
def _execute_single_step(
|
205
|
+
self: Self,
|
206
|
+
data: ModelData,
|
207
|
+
name: ModelName,
|
208
|
+
from_ver: ModelVersion,
|
209
|
+
to_ver: ModelVersion,
|
210
|
+
) -> ModelData:
|
211
|
+
"""Execute a single migration step."""
|
212
|
+
migration_key = (from_ver, to_ver)
|
213
|
+
|
214
|
+
if migration_key in self.registry._migrations[name]:
|
215
|
+
migration_func = self.registry._migrations[name][migration_key]
|
216
|
+
try:
|
217
|
+
return migration_func(data)
|
218
|
+
except Exception as e:
|
219
|
+
raise MigrationError(
|
220
|
+
name,
|
221
|
+
str(from_ver),
|
222
|
+
str(to_ver),
|
223
|
+
f"Migration function raised: {type(e).__name__}: {e}",
|
224
|
+
) from e
|
225
|
+
|
226
|
+
if to_ver in self.registry._backward_compatible_enabled[name]:
|
227
|
+
try:
|
228
|
+
return self._auto_migrate(data, name, from_ver, to_ver)
|
229
|
+
except Exception as e:
|
230
|
+
raise MigrationError(
|
231
|
+
name,
|
232
|
+
str(from_ver),
|
233
|
+
str(to_ver),
|
234
|
+
f"Auto-migration failed: {type(e).__name__}: {e}",
|
235
|
+
) from e
|
236
|
+
|
237
|
+
raise MigrationError(
|
238
|
+
name,
|
239
|
+
str(from_ver),
|
240
|
+
str(to_ver),
|
241
|
+
(
|
242
|
+
"No migration path found. Define a migration function or mark "
|
243
|
+
"the target version as backward_compatible."
|
244
|
+
),
|
245
|
+
)
|
246
|
+
|
225
247
|
def _auto_migrate(
|
226
248
|
self: Self,
|
227
|
-
data:
|
249
|
+
data: ModelData,
|
228
250
|
name: ModelName,
|
229
251
|
from_ver: ModelVersion,
|
230
252
|
to_ver: ModelVersion,
|
231
|
-
) ->
|
253
|
+
) -> ModelData:
|
232
254
|
"""Automatically migrate data when no explicit migration exists.
|
233
255
|
|
234
256
|
This method handles nested Pydantic models recursively, migrating them to their
|
235
|
-
corresponding versions.
|
257
|
+
corresponding versions. Handles field aliases by building a lookup map.
|
236
258
|
|
237
259
|
Args:
|
238
260
|
data: Data dictionary to migrate.
|
@@ -249,36 +271,142 @@ class MigrationManager:
|
|
249
271
|
from_fields = from_model.model_fields
|
250
272
|
to_fields = to_model.model_fields
|
251
273
|
|
252
|
-
|
274
|
+
key_to_field_name = self._build_alias_map(from_fields)
|
275
|
+
|
276
|
+
result: ModelData = {}
|
277
|
+
processed_keys: set[str] = set()
|
253
278
|
|
254
279
|
for field_name, to_field_info in to_fields.items():
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
from_field_info = from_fields.get(field_name)
|
259
|
-
result[field_name] = self._migrate_field_value(
|
260
|
-
value, from_field_info, to_field_info
|
261
|
-
)
|
280
|
+
field_result = self._migrate_single_field(
|
281
|
+
data, field_name, to_field_info, from_fields, key_to_field_name
|
282
|
+
)
|
262
283
|
|
263
|
-
|
264
|
-
|
265
|
-
result[
|
266
|
-
|
267
|
-
with contextlib.suppress(Exception):
|
268
|
-
result[field_name] = to_field_info.default_factory() # type: ignore
|
284
|
+
if field_result is not None:
|
285
|
+
output_key, migrated_value, source_keys = field_result
|
286
|
+
result[output_key] = migrated_value
|
287
|
+
processed_keys.update(source_keys)
|
269
288
|
|
270
|
-
#
|
271
|
-
for
|
272
|
-
if
|
273
|
-
result[
|
289
|
+
# Preserve unprocessed extra fields
|
290
|
+
for data_key, value in data.items():
|
291
|
+
if data_key not in processed_keys:
|
292
|
+
result[data_key] = value
|
274
293
|
|
275
294
|
return result
|
276
295
|
|
277
|
-
def
|
296
|
+
def _build_alias_map(self: Self, fields: dict[str, FieldInfo]) -> dict[str, str]:
|
297
|
+
"""Build a mapping from all possible input keys to canonical field names.
|
298
|
+
|
299
|
+
Args:
|
300
|
+
fields: Model fields to extract aliases from.
|
301
|
+
|
302
|
+
Returns:
|
303
|
+
Dictionary mapping data keys (field names and aliases) to field names.
|
304
|
+
"""
|
305
|
+
key_to_field_name: dict[str, str] = {}
|
306
|
+
|
307
|
+
for field_name, field_info in fields.items():
|
308
|
+
key_to_field_name[field_name] = field_name
|
309
|
+
|
310
|
+
if field_info.alias:
|
311
|
+
key_to_field_name[field_info.alias] = field_name
|
312
|
+
if field_info.serialization_alias:
|
313
|
+
key_to_field_name[field_info.serialization_alias] = field_name
|
314
|
+
if isinstance(field_info.validation_alias, str):
|
315
|
+
key_to_field_name[field_info.validation_alias] = field_name
|
316
|
+
|
317
|
+
return key_to_field_name
|
318
|
+
|
319
|
+
def _migrate_single_field(
|
278
320
|
self: Self,
|
279
|
-
|
280
|
-
|
281
|
-
|
321
|
+
data: ModelData,
|
322
|
+
field_name: str,
|
323
|
+
to_field_info: FieldInfo,
|
324
|
+
from_fields: dict[str, FieldInfo],
|
325
|
+
key_to_field_name: dict[str, str],
|
326
|
+
) -> tuple[str, Any, set[str]] | None:
|
327
|
+
"""Migrate a single field, handling aliases and defaults.
|
328
|
+
|
329
|
+
Args:
|
330
|
+
data: Source data dictionary.
|
331
|
+
field_name: Target field name.
|
332
|
+
to_field_info: Target field info.
|
333
|
+
from_fields: Source model fields.
|
334
|
+
key_to_field_name: Mapping from data keys to field names.
|
335
|
+
|
336
|
+
Returns:
|
337
|
+
Tuple of (output_key, migrated_value, source_keys_to_mark_processed) if
|
338
|
+
field should be included, None if field should be skipped.
|
339
|
+
"""
|
340
|
+
value, data_key_used = self._find_field_value(
|
341
|
+
data, field_name, key_to_field_name
|
342
|
+
)
|
343
|
+
|
344
|
+
if data_key_used is not None:
|
345
|
+
from_field_info = from_fields.get(field_name)
|
346
|
+
migrated_value = self._migrate_field_value(
|
347
|
+
value, from_field_info, to_field_info
|
348
|
+
)
|
349
|
+
keys_to_process = {
|
350
|
+
k for k, v in key_to_field_name.items() if v == field_name
|
351
|
+
}
|
352
|
+
return (data_key_used, migrated_value, keys_to_process)
|
353
|
+
|
354
|
+
_NO_DEFAULT = object()
|
355
|
+
default_value = self._get_field_default(to_field_info, _NO_DEFAULT)
|
356
|
+
|
357
|
+
if default_value is not _NO_DEFAULT:
|
358
|
+
output_key = (
|
359
|
+
to_field_info.serialization_alias or to_field_info.alias or field_name
|
360
|
+
)
|
361
|
+
return (output_key, default_value, set())
|
362
|
+
|
363
|
+
return None
|
364
|
+
|
365
|
+
def _find_field_value(
|
366
|
+
self: Self, data: ModelData, field_name: str, key_to_field_name: dict[str, str]
|
367
|
+
) -> tuple[Any, str | None]:
|
368
|
+
"""Find a field's value in data, checking both field name and aliases.
|
369
|
+
|
370
|
+
Args:
|
371
|
+
data: Source data dictionary.
|
372
|
+
field_name: Target field name to find.
|
373
|
+
key_to_field_name: Mapping from data keys to field names.
|
374
|
+
|
375
|
+
Returns:
|
376
|
+
Tuple of (value, data_key) if found, (None, None) otherwise.
|
377
|
+
"""
|
378
|
+
if field_name in data:
|
379
|
+
return (data[field_name], field_name)
|
380
|
+
|
381
|
+
for data_key in data:
|
382
|
+
if key_to_field_name.get(data_key) == field_name:
|
383
|
+
return (data[data_key], data_key)
|
384
|
+
|
385
|
+
return (None, None)
|
386
|
+
|
387
|
+
def _get_field_default(
|
388
|
+
self: Self, field_info: FieldInfo, sentinel: Any = None
|
389
|
+
) -> Any:
|
390
|
+
"""Get the default value for a field.
|
391
|
+
|
392
|
+
Args:
|
393
|
+
field_info: Field info to extract default from.
|
394
|
+
sentinel: Sentinel value to return if no default is available.
|
395
|
+
|
396
|
+
Returns:
|
397
|
+
Default value if available, sentinel otherwise.
|
398
|
+
"""
|
399
|
+
if field_info.default is not PydanticUndefined:
|
400
|
+
return field_info.default
|
401
|
+
|
402
|
+
if field_info.default_factory is not None:
|
403
|
+
with contextlib.suppress(Exception):
|
404
|
+
return field_info.default_factory() # type: ignore
|
405
|
+
|
406
|
+
return sentinel
|
407
|
+
|
408
|
+
def _migrate_field_value(
|
409
|
+
self: Self, value: Any, from_field: FieldInfo | None, to_field: FieldInfo
|
282
410
|
) -> Any:
|
283
411
|
"""Migrate a single field value, handling nested models.
|
284
412
|
|
@@ -294,31 +422,104 @@ class MigrationManager:
|
|
294
422
|
return None
|
295
423
|
|
296
424
|
if isinstance(value, dict):
|
297
|
-
|
298
|
-
if nested_info:
|
299
|
-
nested_name, nested_from_ver, nested_to_ver = nested_info
|
300
|
-
return self.migrate(value, nested_name, nested_from_ver, nested_to_ver)
|
301
|
-
|
302
|
-
return {
|
303
|
-
k: self._migrate_field_value(v, from_field, to_field)
|
304
|
-
for k, v in value.items()
|
305
|
-
}
|
425
|
+
return self._migrate_dict_value(value, from_field, to_field)
|
306
426
|
|
307
427
|
if isinstance(value, list):
|
308
|
-
return
|
309
|
-
self._migrate_field_value(item, from_field, to_field) for item in value
|
310
|
-
]
|
428
|
+
return self._migrate_list_value(value, from_field, to_field)
|
311
429
|
|
312
430
|
return value
|
313
431
|
|
432
|
+
def _migrate_dict_value(
|
433
|
+
self, value: dict[str, Any], from_field: FieldInfo | None, to_field: FieldInfo
|
434
|
+
) -> dict[str, Any]:
|
435
|
+
"""Migrate a dictionary value (might be a nested model)."""
|
436
|
+
nested_info = self._extract_nested_model_info(value, from_field, to_field)
|
437
|
+
if nested_info:
|
438
|
+
nested_name, nested_from_ver, nested_to_ver = nested_info
|
439
|
+
return self.migrate(value, nested_name, nested_from_ver, nested_to_ver)
|
440
|
+
|
441
|
+
return {
|
442
|
+
k: self._migrate_field_value(v, from_field, to_field)
|
443
|
+
for k, v in value.items()
|
444
|
+
}
|
445
|
+
|
446
|
+
def _migrate_list_value(
|
447
|
+
self, value: list[Any], from_field: FieldInfo | None, to_field: FieldInfo
|
448
|
+
) -> list[Any]:
|
449
|
+
"""Migrate a list value (might contain nested models)."""
|
450
|
+
from_item_field, to_item_field = self._get_list_item_fields(
|
451
|
+
from_field, to_field
|
452
|
+
)
|
453
|
+
|
454
|
+
return [
|
455
|
+
self._migrate_list_item(item, from_item_field, to_item_field)
|
456
|
+
for item in value
|
457
|
+
]
|
458
|
+
|
459
|
+
def _migrate_list_item(
|
460
|
+
self, item: Any, from_field: FieldInfo | None, to_field: FieldInfo | None
|
461
|
+
) -> Any:
|
462
|
+
"""Migrate a single item from a list."""
|
463
|
+
if not isinstance(item, dict) or to_field is None:
|
464
|
+
return item
|
465
|
+
|
466
|
+
nested_info = self._extract_nested_model_info(item, from_field, to_field)
|
467
|
+
if nested_info:
|
468
|
+
nested_name, nested_from_ver, nested_to_ver = nested_info
|
469
|
+
return self.migrate(item, nested_name, nested_from_ver, nested_to_ver)
|
470
|
+
|
471
|
+
return item
|
472
|
+
|
473
|
+
def _get_list_item_fields(
|
474
|
+
self, from_field: FieldInfo | None, to_field: FieldInfo
|
475
|
+
) -> tuple[FieldInfo | None, FieldInfo | None]:
|
476
|
+
"""Extract field info for items in a list field."""
|
477
|
+
to_item_field = self._extract_list_item_field(to_field)
|
478
|
+
from_item_field = (
|
479
|
+
self._extract_list_item_field(from_field) if from_field else None
|
480
|
+
)
|
481
|
+
return from_item_field, to_item_field
|
482
|
+
|
483
|
+
def _extract_list_item_field(self, field: FieldInfo | None) -> FieldInfo | None:
|
484
|
+
"""Extract field info for the items of a list field."""
|
485
|
+
if field is None or field.annotation is None:
|
486
|
+
return None
|
487
|
+
|
488
|
+
origin = get_origin(field.annotation)
|
489
|
+
if origin is not list:
|
490
|
+
return None
|
491
|
+
|
492
|
+
args = get_args(field.annotation)
|
493
|
+
if not args:
|
494
|
+
return None
|
495
|
+
|
496
|
+
item_annotation = args[0]
|
497
|
+
|
498
|
+
discriminator = None
|
499
|
+
if get_origin(item_annotation) is Annotated:
|
500
|
+
annotated_args = get_args(item_annotation)
|
501
|
+
item_annotation = annotated_args[0]
|
502
|
+
for metadata in annotated_args[1:]:
|
503
|
+
if hasattr(metadata, "discriminator"):
|
504
|
+
discriminator = metadata.discriminator
|
505
|
+
|
506
|
+
synthetic_field = FieldInfo(
|
507
|
+
annotation=item_annotation, default=PydanticUndefined
|
508
|
+
)
|
509
|
+
if discriminator:
|
510
|
+
synthetic_field.discriminator = discriminator
|
511
|
+
|
512
|
+
return synthetic_field
|
513
|
+
|
314
514
|
def _extract_nested_model_info(
|
315
|
-
self:
|
316
|
-
value: MigrationData,
|
317
|
-
from_field: FieldInfo | None,
|
318
|
-
to_field: FieldInfo,
|
515
|
+
self, value: ModelData, from_field: FieldInfo | None, to_field: FieldInfo
|
319
516
|
) -> tuple[ModelName, ModelVersion, ModelVersion] | None:
|
320
517
|
"""Extract nested model migration information.
|
321
518
|
|
519
|
+
Handles discriminated unions by using the discriminator field to determine which
|
520
|
+
model type to migrate.
|
521
|
+
|
522
|
+
|
322
523
|
Args:
|
323
524
|
value: The nested model data.
|
324
525
|
from_field: Source field info.
|
@@ -328,8 +529,51 @@ class MigrationManager:
|
|
328
529
|
Tuple of (model_name, from_version, to_version) if this is a
|
329
530
|
versioned nested model, None otherwise.
|
330
531
|
"""
|
532
|
+
discriminated_info = self._try_extract_discriminated_model(
|
533
|
+
value, from_field, to_field
|
534
|
+
)
|
535
|
+
if discriminated_info:
|
536
|
+
return discriminated_info
|
537
|
+
|
538
|
+
return self._try_extract_simple_nested_model(from_field, to_field)
|
539
|
+
|
540
|
+
def _try_extract_discriminated_model(
|
541
|
+
self, value: ModelData, from_field: FieldInfo | None, to_field: FieldInfo
|
542
|
+
) -> tuple[ModelName, ModelVersion, ModelVersion] | None:
|
543
|
+
"""Try to extract model info from a discriminated union field."""
|
544
|
+
discriminator_key = self._get_discriminator_key(to_field)
|
545
|
+
if not discriminator_key:
|
546
|
+
return None
|
547
|
+
|
548
|
+
discriminator_value = self._find_discriminator_value(
|
549
|
+
value, discriminator_key, to_field
|
550
|
+
)
|
551
|
+
if discriminator_value is None:
|
552
|
+
return None
|
553
|
+
|
554
|
+
to_model_type = self._find_discriminated_type(
|
555
|
+
to_field, discriminator_key, discriminator_value
|
556
|
+
)
|
557
|
+
if not to_model_type:
|
558
|
+
return None
|
559
|
+
|
560
|
+
to_info = self.registry.get_model_info(to_model_type)
|
561
|
+
if not to_info:
|
562
|
+
return None
|
563
|
+
|
564
|
+
model_name, to_version = to_info
|
565
|
+
from_version = self._get_discriminated_source_version(
|
566
|
+
from_field, discriminator_key, discriminator_value, model_name, to_version
|
567
|
+
)
|
568
|
+
|
569
|
+
return (model_name, from_version, to_version)
|
570
|
+
|
571
|
+
def _try_extract_simple_nested_model(
|
572
|
+
self, from_field: FieldInfo | None, to_field: FieldInfo
|
573
|
+
) -> tuple[ModelName, ModelVersion, ModelVersion] | None:
|
574
|
+
"""Try to extract model info from a simple (non-discriminated) nested field."""
|
331
575
|
to_model_type = self._get_model_type_from_field(to_field)
|
332
|
-
if not to_model_type
|
576
|
+
if not to_model_type:
|
333
577
|
return None
|
334
578
|
|
335
579
|
to_info = self.registry.get_model_info(to_model_type)
|
@@ -337,34 +581,175 @@ class MigrationManager:
|
|
337
581
|
return None
|
338
582
|
|
339
583
|
model_name, to_version = to_info
|
584
|
+
from_version = self._get_simple_source_version(from_field, model_name)
|
585
|
+
|
586
|
+
return (model_name, from_version or to_version, to_version)
|
587
|
+
|
588
|
+
def _get_simple_source_version(
|
589
|
+
self, from_field: FieldInfo | None, model_name: str
|
590
|
+
) -> ModelVersion | None:
|
591
|
+
"""Get the source version for a simple nested field."""
|
592
|
+
if not from_field:
|
593
|
+
return None
|
594
|
+
|
595
|
+
from_model_type = self._get_model_type_from_field(from_field)
|
596
|
+
if not from_model_type:
|
597
|
+
return None
|
598
|
+
|
599
|
+
from_info = self.registry.get_model_info(from_model_type)
|
600
|
+
if not from_info or from_info[0] != model_name:
|
601
|
+
return None
|
602
|
+
|
603
|
+
return from_info[1]
|
604
|
+
|
605
|
+
def _get_discriminator_key(self, field: FieldInfo) -> str | None:
|
606
|
+
"""Extract the discriminator key from a field."""
|
607
|
+
discriminator = field.discriminator
|
608
|
+
if discriminator is None:
|
609
|
+
return None
|
610
|
+
|
611
|
+
if isinstance(discriminator, str):
|
612
|
+
return discriminator
|
613
|
+
|
614
|
+
if hasattr(discriminator, "discriminator") and isinstance(
|
615
|
+
discriminator.discriminator, str
|
616
|
+
):
|
617
|
+
return discriminator.discriminator
|
340
618
|
|
341
|
-
|
342
|
-
|
343
|
-
|
344
|
-
|
345
|
-
|
346
|
-
|
347
|
-
|
348
|
-
|
619
|
+
return None
|
620
|
+
|
621
|
+
def _find_discriminator_value(
|
622
|
+
self, value: ModelData, discriminator_key: str, field: FieldInfo
|
623
|
+
) -> Any:
|
624
|
+
"""Find the discriminator value in data, checking field name and aliases."""
|
625
|
+
if discriminator_key in value:
|
626
|
+
return value[discriminator_key]
|
349
627
|
|
350
|
-
|
351
|
-
|
628
|
+
for model_type in self._get_union_members(field):
|
629
|
+
if discriminator_key not in model_type.model_fields:
|
630
|
+
continue
|
631
|
+
|
632
|
+
disc_field = model_type.model_fields[discriminator_key]
|
633
|
+
|
634
|
+
for alias_attr in ["alias", "serialization_alias"]:
|
635
|
+
alias = getattr(disc_field, alias_attr, None)
|
636
|
+
if alias and alias in value:
|
637
|
+
return value[alias]
|
638
|
+
|
639
|
+
val_alias = disc_field.validation_alias
|
640
|
+
if isinstance(val_alias, str) and val_alias in value:
|
641
|
+
return value[val_alias]
|
642
|
+
|
643
|
+
return None
|
352
644
|
|
353
|
-
def
|
354
|
-
self
|
645
|
+
def _get_discriminated_source_version(
|
646
|
+
self,
|
647
|
+
from_field: FieldInfo | None,
|
648
|
+
discriminator_key: str,
|
649
|
+
discriminator_value: Any,
|
650
|
+
model_name: str,
|
651
|
+
default_version: ModelVersion,
|
652
|
+
) -> ModelVersion:
|
653
|
+
"""Get the source version for a discriminated union field."""
|
654
|
+
if not from_field:
|
655
|
+
return default_version
|
656
|
+
|
657
|
+
from_model_type = self._find_discriminated_type(
|
658
|
+
from_field, discriminator_key, discriminator_value
|
659
|
+
)
|
660
|
+
if not from_model_type:
|
661
|
+
return default_version
|
662
|
+
|
663
|
+
from_info = self.registry.get_model_info(from_model_type)
|
664
|
+
if not from_info or from_info[0] != model_name:
|
665
|
+
return default_version
|
666
|
+
|
667
|
+
return from_info[1]
|
668
|
+
|
669
|
+
def _find_discriminated_type(
|
670
|
+
self, field: FieldInfo, discriminator_key: str, discriminator_value: Any
|
355
671
|
) -> type[BaseModel] | None:
|
356
|
-
"""
|
672
|
+
"""Find the right type in a discriminated union based on discriminator value."""
|
673
|
+
for model_type in self._get_union_members(field):
|
674
|
+
if self._model_matches_discriminator(
|
675
|
+
model_type, discriminator_key, discriminator_value
|
676
|
+
):
|
677
|
+
return model_type
|
678
|
+
return None
|
357
679
|
|
358
|
-
|
680
|
+
def _model_matches_discriminator(
|
681
|
+
self,
|
682
|
+
model_type: type[BaseModel],
|
683
|
+
discriminator_key: str,
|
684
|
+
discriminator_value: Any,
|
685
|
+
) -> bool:
|
686
|
+
"""Check if a model type matches a discriminator value."""
|
687
|
+
if discriminator_key not in model_type.model_fields:
|
688
|
+
return False
|
359
689
|
|
360
|
-
|
361
|
-
field: The field info to extract from.
|
690
|
+
field_info = model_type.model_fields[discriminator_key]
|
362
691
|
|
363
|
-
|
364
|
-
|
365
|
-
|
692
|
+
if self._literal_matches_value(field_info.annotation, discriminator_value):
|
693
|
+
return True
|
694
|
+
|
695
|
+
return (
|
696
|
+
field_info.default is not PydanticUndefined
|
697
|
+
and field_info.default == discriminator_value
|
698
|
+
)
|
699
|
+
|
700
|
+
def _literal_matches_value(self, annotation: Any, value: Any) -> bool:
|
701
|
+
"""Check if a Literal type annotation contains a specific value."""
|
702
|
+
if annotation is None:
|
703
|
+
return False
|
704
|
+
|
705
|
+
origin = get_origin(annotation)
|
706
|
+
if origin is not Literal:
|
707
|
+
return False
|
708
|
+
|
709
|
+
literal_values = get_args(annotation)
|
710
|
+
return value in literal_values
|
711
|
+
|
712
|
+
def _get_union_members(self, field: FieldInfo) -> list[type[BaseModel]]:
|
713
|
+
"""Extract all BaseModel types from a union field."""
|
366
714
|
annotation = field.annotation
|
715
|
+
if annotation is None:
|
716
|
+
return []
|
367
717
|
|
718
|
+
origin = get_origin(annotation)
|
719
|
+
|
720
|
+
if origin is None:
|
721
|
+
if isinstance(annotation, type) and issubclass(annotation, BaseModel):
|
722
|
+
return [annotation]
|
723
|
+
return []
|
724
|
+
|
725
|
+
if not self._is_union_type(origin):
|
726
|
+
return []
|
727
|
+
|
728
|
+
args = get_args(annotation)
|
729
|
+
return [
|
730
|
+
arg
|
731
|
+
for arg in args
|
732
|
+
if arg is not type(None)
|
733
|
+
and isinstance(arg, type)
|
734
|
+
and issubclass(arg, BaseModel)
|
735
|
+
]
|
736
|
+
|
737
|
+
def _is_union_type(self, origin: Any) -> bool:
|
738
|
+
"""Check if an origin type represents a Union."""
|
739
|
+
if origin is Union:
|
740
|
+
return True
|
741
|
+
|
742
|
+
if hasattr(types, "UnionType"):
|
743
|
+
try:
|
744
|
+
return origin is types.UnionType
|
745
|
+
except (ImportError, AttributeError):
|
746
|
+
pass
|
747
|
+
|
748
|
+
return False
|
749
|
+
|
750
|
+
def _get_model_type_from_field(self, field: FieldInfo) -> type[BaseModel] | None:
|
751
|
+
"""Extract the Pydantic model type from a field."""
|
752
|
+
annotation = field.annotation
|
368
753
|
if annotation is None:
|
369
754
|
return None
|
370
755
|
|