pyrmute 0.3.0__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pyrmute/__init__.py CHANGED
@@ -12,6 +12,7 @@ from .exceptions import (
12
12
  )
13
13
  from .migration_testing import (
14
14
  MigrationTestCase,
15
+ MigrationTestCases,
15
16
  MigrationTestResult,
16
17
  MigrationTestResults,
17
18
  )
@@ -32,6 +33,7 @@ __all__ = [
32
33
  "MigrationFunc",
33
34
  "MigrationManager",
34
35
  "MigrationTestCase",
36
+ "MigrationTestCases",
35
37
  "MigrationTestResult",
36
38
  "MigrationTestResults",
37
39
  "ModelData",
@@ -1,8 +1,9 @@
1
1
  """Migrations manager."""
2
2
 
3
3
  import contextlib
4
+ import types
4
5
  from collections.abc import Callable
5
- from typing import Any, Self, get_args, get_origin
6
+ from typing import Annotated, Any, Literal, Self, Union, get_args, get_origin
6
7
 
7
8
  from pydantic import BaseModel
8
9
  from pydantic.fields import FieldInfo
@@ -54,16 +55,8 @@ class MigrationManager:
54
55
  ... def migrate_v1_to_v2(data: dict[str, Any]) -> dict[str, Any]:
55
56
  ... return {**data, "email": "unknown@example.com"}
56
57
  """
57
- from_ver = (
58
- ModelVersion.parse(from_version)
59
- if isinstance(from_version, str)
60
- else from_version
61
- )
62
- to_ver = (
63
- ModelVersion.parse(to_version)
64
- if isinstance(to_version, str)
65
- else to_version
66
- )
58
+ from_ver = self._parse_version(from_version)
59
+ to_ver = self._parse_version(to_version)
67
60
 
68
61
  def decorator(func: MigrationFunc) -> MigrationFunc:
69
62
  self.registry._migrations[name][(from_ver, to_ver)] = func
@@ -93,61 +86,14 @@ class MigrationManager:
93
86
  ModelNotFoundError: If model or versions don't exist.
94
87
  MigrationError: If migration path cannot be found.
95
88
  """
96
- from_ver = (
97
- ModelVersion.parse(from_version)
98
- if isinstance(from_version, str)
99
- else from_version
100
- )
101
- to_ver = (
102
- ModelVersion.parse(to_version)
103
- if isinstance(to_version, str)
104
- else to_version
105
- )
89
+ from_ver = self._parse_version(from_version)
90
+ to_ver = self._parse_version(to_version)
106
91
 
107
92
  if from_ver == to_ver:
108
93
  return data
109
94
 
110
95
  path = self.find_migration_path(name, from_ver, to_ver)
111
-
112
- current_data = data
113
- for i in range(len(path) - 1):
114
- migration_key = (path[i], path[i + 1])
115
-
116
- if migration_key in self.registry._migrations[name]:
117
- migration_func = self.registry._migrations[name][migration_key]
118
- try:
119
- current_data = migration_func(current_data)
120
- except Exception as e:
121
- raise MigrationError(
122
- name,
123
- str(path[i]),
124
- str(path[i + 1]),
125
- f"Migration function raised: {type(e).__name__}: {e}",
126
- ) from e
127
- elif path[i + 1] in self.registry._backward_compatible_enabled[name]:
128
- try:
129
- current_data = self._auto_migrate(
130
- current_data, name, path[i], path[i + 1]
131
- )
132
- except Exception as e:
133
- raise MigrationError(
134
- name,
135
- str(path[i]),
136
- str(path[i + 1]),
137
- f"Auto-migration failed: {type(e).__name__}: {e}",
138
- ) from e
139
- else:
140
- raise MigrationError(
141
- name,
142
- str(path[i]),
143
- str(path[i + 1]),
144
- (
145
- "No migration path found. Define a migration function or mark "
146
- "the target version as backward_compatible."
147
- ),
148
- )
149
-
150
- return current_data
96
+ return self._execute_migration_path(data, name, path)
151
97
 
152
98
  def find_migration_path(
153
99
  self: Self,
@@ -168,6 +114,9 @@ class MigrationManager:
168
114
  Raises:
169
115
  ModelNotFoundError: If the model or versions don't exist.
170
116
  """
117
+ if (from_ver, to_ver) in self.registry._migrations.get(name, {}):
118
+ return [from_ver, to_ver]
119
+
171
120
  versions = sorted(self.registry.get_versions(name))
172
121
 
173
122
  if from_ver not in versions:
@@ -204,14 +153,8 @@ class MigrationManager:
204
153
  for i in range(len(path) - 1):
205
154
  current_ver = path[i]
206
155
  next_ver = path[i + 1]
207
- migration_key = (current_ver, next_ver)
208
-
209
- has_explicit = migration_key in self.registry._migrations.get(name, {})
210
- has_auto = next_ver in self.registry._backward_compatible_enabled.get(
211
- name, set()
212
- )
213
156
 
214
- if not has_explicit and not has_auto:
157
+ if not self._has_migration_step(name, current_ver, next_ver):
215
158
  raise MigrationError(
216
159
  name,
217
160
  str(current_ver),
@@ -222,6 +165,85 @@ class MigrationManager:
222
165
  ),
223
166
  )
224
167
 
168
+ def _parse_version(self: Self, version: str | ModelVersion) -> ModelVersion:
169
+ """Parse version string or return ModelVersion as-is."""
170
+ return ModelVersion.parse(version) if isinstance(version, str) else version
171
+
172
+ def _has_migration_step(
173
+ self: Self, name: ModelName, from_ver: ModelVersion, to_ver: ModelVersion
174
+ ) -> bool:
175
+ """Check if a migration step exists (explicit or auto)."""
176
+ migration_key = (from_ver, to_ver)
177
+ has_explicit = migration_key in self.registry._migrations.get(name, {})
178
+ has_auto = to_ver in self.registry._backward_compatible_enabled.get(name, set())
179
+ return has_explicit or has_auto
180
+
181
+ def _execute_migration_path(
182
+ self: Self, data: ModelData, name: ModelName, path: list[ModelVersion]
183
+ ) -> ModelData:
184
+ """Execute migration through a path of versions."""
185
+ current_data = data
186
+
187
+ for i in range(len(path) - 1):
188
+ try:
189
+ current_data = self._execute_single_step(
190
+ current_data, name, path[i], path[i + 1]
191
+ )
192
+ except Exception as e:
193
+ if isinstance(e, MigrationError):
194
+ raise
195
+ raise MigrationError(
196
+ name,
197
+ str(path[i]),
198
+ str(path[i + 1]),
199
+ f"Migration failed: {type(e).__name__}: {e}",
200
+ ) from e
201
+
202
+ return current_data
203
+
204
+ def _execute_single_step(
205
+ self: Self,
206
+ data: ModelData,
207
+ name: ModelName,
208
+ from_ver: ModelVersion,
209
+ to_ver: ModelVersion,
210
+ ) -> ModelData:
211
+ """Execute a single migration step."""
212
+ migration_key = (from_ver, to_ver)
213
+
214
+ if migration_key in self.registry._migrations[name]:
215
+ migration_func = self.registry._migrations[name][migration_key]
216
+ try:
217
+ return migration_func(data)
218
+ except Exception as e:
219
+ raise MigrationError(
220
+ name,
221
+ str(from_ver),
222
+ str(to_ver),
223
+ f"Migration function raised: {type(e).__name__}: {e}",
224
+ ) from e
225
+
226
+ if to_ver in self.registry._backward_compatible_enabled[name]:
227
+ try:
228
+ return self._auto_migrate(data, name, from_ver, to_ver)
229
+ except Exception as e:
230
+ raise MigrationError(
231
+ name,
232
+ str(from_ver),
233
+ str(to_ver),
234
+ f"Auto-migration failed: {type(e).__name__}: {e}",
235
+ ) from e
236
+
237
+ raise MigrationError(
238
+ name,
239
+ str(from_ver),
240
+ str(to_ver),
241
+ (
242
+ "No migration path found. Define a migration function or mark "
243
+ "the target version as backward_compatible."
244
+ ),
245
+ )
246
+
225
247
  def _auto_migrate(
226
248
  self: Self,
227
249
  data: ModelData,
@@ -232,7 +254,7 @@ class MigrationManager:
232
254
  """Automatically migrate data when no explicit migration exists.
233
255
 
234
256
  This method handles nested Pydantic models recursively, migrating them to their
235
- corresponding versions.
257
+ corresponding versions. Handles field aliases by building a lookup map.
236
258
 
237
259
  Args:
238
260
  data: Data dictionary to migrate.
@@ -249,36 +271,142 @@ class MigrationManager:
249
271
  from_fields = from_model.model_fields
250
272
  to_fields = to_model.model_fields
251
273
 
274
+ key_to_field_name = self._build_alias_map(from_fields)
275
+
252
276
  result: ModelData = {}
277
+ processed_keys: set[str] = set()
253
278
 
254
279
  for field_name, to_field_info in to_fields.items():
255
- # Field exists in data, migrate it
256
- if field_name in data:
257
- value = data[field_name]
258
- from_field_info = from_fields.get(field_name)
259
- result[field_name] = self._migrate_field_value(
260
- value, from_field_info, to_field_info
261
- )
280
+ field_result = self._migrate_single_field(
281
+ data, field_name, to_field_info, from_fields, key_to_field_name
282
+ )
262
283
 
263
- # Field missing from data, use default if available
264
- elif to_field_info.default is not PydanticUndefined:
265
- result[field_name] = to_field_info.default
266
- elif to_field_info.default_factory is not None:
267
- with contextlib.suppress(Exception):
268
- result[field_name] = to_field_info.default_factory() # type: ignore
284
+ if field_result is not None:
285
+ output_key, migrated_value, source_keys = field_result
286
+ result[output_key] = migrated_value
287
+ processed_keys.update(source_keys)
269
288
 
270
- # Migrate all extra data not in the field, too
271
- for field_name, value in data.items():
272
- if field_name not in to_fields:
273
- result[field_name] = value
289
+ # Preserve unprocessed extra fields
290
+ for data_key, value in data.items():
291
+ if data_key not in processed_keys:
292
+ result[data_key] = value
274
293
 
275
294
  return result
276
295
 
277
- def _migrate_field_value(
296
+ def _build_alias_map(self: Self, fields: dict[str, FieldInfo]) -> dict[str, str]:
297
+ """Build a mapping from all possible input keys to canonical field names.
298
+
299
+ Args:
300
+ fields: Model fields to extract aliases from.
301
+
302
+ Returns:
303
+ Dictionary mapping data keys (field names and aliases) to field names.
304
+ """
305
+ key_to_field_name: dict[str, str] = {}
306
+
307
+ for field_name, field_info in fields.items():
308
+ key_to_field_name[field_name] = field_name
309
+
310
+ if field_info.alias:
311
+ key_to_field_name[field_info.alias] = field_name
312
+ if field_info.serialization_alias:
313
+ key_to_field_name[field_info.serialization_alias] = field_name
314
+ if isinstance(field_info.validation_alias, str):
315
+ key_to_field_name[field_info.validation_alias] = field_name
316
+
317
+ return key_to_field_name
318
+
319
+ def _migrate_single_field(
278
320
  self: Self,
279
- value: Any,
280
- from_field: FieldInfo | None,
281
- to_field: FieldInfo,
321
+ data: ModelData,
322
+ field_name: str,
323
+ to_field_info: FieldInfo,
324
+ from_fields: dict[str, FieldInfo],
325
+ key_to_field_name: dict[str, str],
326
+ ) -> tuple[str, Any, set[str]] | None:
327
+ """Migrate a single field, handling aliases and defaults.
328
+
329
+ Args:
330
+ data: Source data dictionary.
331
+ field_name: Target field name.
332
+ to_field_info: Target field info.
333
+ from_fields: Source model fields.
334
+ key_to_field_name: Mapping from data keys to field names.
335
+
336
+ Returns:
337
+ Tuple of (output_key, migrated_value, source_keys_to_mark_processed) if
338
+ field should be included, None if field should be skipped.
339
+ """
340
+ value, data_key_used = self._find_field_value(
341
+ data, field_name, key_to_field_name
342
+ )
343
+
344
+ if data_key_used is not None:
345
+ from_field_info = from_fields.get(field_name)
346
+ migrated_value = self._migrate_field_value(
347
+ value, from_field_info, to_field_info
348
+ )
349
+ keys_to_process = {
350
+ k for k, v in key_to_field_name.items() if v == field_name
351
+ }
352
+ return (data_key_used, migrated_value, keys_to_process)
353
+
354
+ _NO_DEFAULT = object()
355
+ default_value = self._get_field_default(to_field_info, _NO_DEFAULT)
356
+
357
+ if default_value is not _NO_DEFAULT:
358
+ output_key = (
359
+ to_field_info.serialization_alias or to_field_info.alias or field_name
360
+ )
361
+ return (output_key, default_value, set())
362
+
363
+ return None
364
+
365
+ def _find_field_value(
366
+ self: Self, data: ModelData, field_name: str, key_to_field_name: dict[str, str]
367
+ ) -> tuple[Any, str | None]:
368
+ """Find a field's value in data, checking both field name and aliases.
369
+
370
+ Args:
371
+ data: Source data dictionary.
372
+ field_name: Target field name to find.
373
+ key_to_field_name: Mapping from data keys to field names.
374
+
375
+ Returns:
376
+ Tuple of (value, data_key) if found, (None, None) otherwise.
377
+ """
378
+ if field_name in data:
379
+ return (data[field_name], field_name)
380
+
381
+ for data_key in data:
382
+ if key_to_field_name.get(data_key) == field_name:
383
+ return (data[data_key], data_key)
384
+
385
+ return (None, None)
386
+
387
+ def _get_field_default(
388
+ self: Self, field_info: FieldInfo, sentinel: Any = None
389
+ ) -> Any:
390
+ """Get the default value for a field.
391
+
392
+ Args:
393
+ field_info: Field info to extract default from.
394
+ sentinel: Sentinel value to return if no default is available.
395
+
396
+ Returns:
397
+ Default value if available, sentinel otherwise.
398
+ """
399
+ if field_info.default is not PydanticUndefined:
400
+ return field_info.default
401
+
402
+ if field_info.default_factory is not None:
403
+ with contextlib.suppress(Exception):
404
+ return field_info.default_factory() # type: ignore
405
+
406
+ return sentinel
407
+
408
+ def _migrate_field_value(
409
+ self: Self, value: Any, from_field: FieldInfo | None, to_field: FieldInfo
282
410
  ) -> Any:
283
411
  """Migrate a single field value, handling nested models.
284
412
 
@@ -294,31 +422,104 @@ class MigrationManager:
294
422
  return None
295
423
 
296
424
  if isinstance(value, dict):
297
- nested_info = self._extract_nested_model_info(value, from_field, to_field)
298
- if nested_info:
299
- nested_name, nested_from_ver, nested_to_ver = nested_info
300
- return self.migrate(value, nested_name, nested_from_ver, nested_to_ver)
301
-
302
- return {
303
- k: self._migrate_field_value(v, from_field, to_field)
304
- for k, v in value.items()
305
- }
425
+ return self._migrate_dict_value(value, from_field, to_field)
306
426
 
307
427
  if isinstance(value, list):
308
- return [
309
- self._migrate_field_value(item, from_field, to_field) for item in value
310
- ]
428
+ return self._migrate_list_value(value, from_field, to_field)
311
429
 
312
430
  return value
313
431
 
432
+ def _migrate_dict_value(
433
+ self, value: dict[str, Any], from_field: FieldInfo | None, to_field: FieldInfo
434
+ ) -> dict[str, Any]:
435
+ """Migrate a dictionary value (might be a nested model)."""
436
+ nested_info = self._extract_nested_model_info(value, from_field, to_field)
437
+ if nested_info:
438
+ nested_name, nested_from_ver, nested_to_ver = nested_info
439
+ return self.migrate(value, nested_name, nested_from_ver, nested_to_ver)
440
+
441
+ return {
442
+ k: self._migrate_field_value(v, from_field, to_field)
443
+ for k, v in value.items()
444
+ }
445
+
446
+ def _migrate_list_value(
447
+ self, value: list[Any], from_field: FieldInfo | None, to_field: FieldInfo
448
+ ) -> list[Any]:
449
+ """Migrate a list value (might contain nested models)."""
450
+ from_item_field, to_item_field = self._get_list_item_fields(
451
+ from_field, to_field
452
+ )
453
+
454
+ return [
455
+ self._migrate_list_item(item, from_item_field, to_item_field)
456
+ for item in value
457
+ ]
458
+
459
+ def _migrate_list_item(
460
+ self, item: Any, from_field: FieldInfo | None, to_field: FieldInfo | None
461
+ ) -> Any:
462
+ """Migrate a single item from a list."""
463
+ if not isinstance(item, dict) or to_field is None:
464
+ return item
465
+
466
+ nested_info = self._extract_nested_model_info(item, from_field, to_field)
467
+ if nested_info:
468
+ nested_name, nested_from_ver, nested_to_ver = nested_info
469
+ return self.migrate(item, nested_name, nested_from_ver, nested_to_ver)
470
+
471
+ return item
472
+
473
+ def _get_list_item_fields(
474
+ self, from_field: FieldInfo | None, to_field: FieldInfo
475
+ ) -> tuple[FieldInfo | None, FieldInfo | None]:
476
+ """Extract field info for items in a list field."""
477
+ to_item_field = self._extract_list_item_field(to_field)
478
+ from_item_field = (
479
+ self._extract_list_item_field(from_field) if from_field else None
480
+ )
481
+ return from_item_field, to_item_field
482
+
483
+ def _extract_list_item_field(self, field: FieldInfo | None) -> FieldInfo | None:
484
+ """Extract field info for the items of a list field."""
485
+ if field is None or field.annotation is None:
486
+ return None
487
+
488
+ origin = get_origin(field.annotation)
489
+ if origin is not list:
490
+ return None
491
+
492
+ args = get_args(field.annotation)
493
+ if not args:
494
+ return None
495
+
496
+ item_annotation = args[0]
497
+
498
+ discriminator = None
499
+ if get_origin(item_annotation) is Annotated:
500
+ annotated_args = get_args(item_annotation)
501
+ item_annotation = annotated_args[0]
502
+ for metadata in annotated_args[1:]:
503
+ if hasattr(metadata, "discriminator"):
504
+ discriminator = metadata.discriminator
505
+
506
+ synthetic_field = FieldInfo(
507
+ annotation=item_annotation, default=PydanticUndefined
508
+ )
509
+ if discriminator:
510
+ synthetic_field.discriminator = discriminator
511
+
512
+ return synthetic_field
513
+
314
514
  def _extract_nested_model_info(
315
- self: Self,
316
- value: ModelData,
317
- from_field: FieldInfo | None,
318
- to_field: FieldInfo,
515
+ self, value: ModelData, from_field: FieldInfo | None, to_field: FieldInfo
319
516
  ) -> tuple[ModelName, ModelVersion, ModelVersion] | None:
320
517
  """Extract nested model migration information.
321
518
 
519
+ Handles discriminated unions by using the discriminator field to determine which
520
+ model type to migrate.
521
+
522
+
322
523
  Args:
323
524
  value: The nested model data.
324
525
  from_field: Source field info.
@@ -328,8 +529,51 @@ class MigrationManager:
328
529
  Tuple of (model_name, from_version, to_version) if this is a
329
530
  versioned nested model, None otherwise.
330
531
  """
532
+ discriminated_info = self._try_extract_discriminated_model(
533
+ value, from_field, to_field
534
+ )
535
+ if discriminated_info:
536
+ return discriminated_info
537
+
538
+ return self._try_extract_simple_nested_model(from_field, to_field)
539
+
540
+ def _try_extract_discriminated_model(
541
+ self, value: ModelData, from_field: FieldInfo | None, to_field: FieldInfo
542
+ ) -> tuple[ModelName, ModelVersion, ModelVersion] | None:
543
+ """Try to extract model info from a discriminated union field."""
544
+ discriminator_key = self._get_discriminator_key(to_field)
545
+ if not discriminator_key:
546
+ return None
547
+
548
+ discriminator_value = self._find_discriminator_value(
549
+ value, discriminator_key, to_field
550
+ )
551
+ if discriminator_value is None:
552
+ return None
553
+
554
+ to_model_type = self._find_discriminated_type(
555
+ to_field, discriminator_key, discriminator_value
556
+ )
557
+ if not to_model_type:
558
+ return None
559
+
560
+ to_info = self.registry.get_model_info(to_model_type)
561
+ if not to_info:
562
+ return None
563
+
564
+ model_name, to_version = to_info
565
+ from_version = self._get_discriminated_source_version(
566
+ from_field, discriminator_key, discriminator_value, model_name, to_version
567
+ )
568
+
569
+ return (model_name, from_version, to_version)
570
+
571
+ def _try_extract_simple_nested_model(
572
+ self, from_field: FieldInfo | None, to_field: FieldInfo
573
+ ) -> tuple[ModelName, ModelVersion, ModelVersion] | None:
574
+ """Try to extract model info from a simple (non-discriminated) nested field."""
331
575
  to_model_type = self._get_model_type_from_field(to_field)
332
- if not to_model_type or not issubclass(to_model_type, BaseModel):
576
+ if not to_model_type:
333
577
  return None
334
578
 
335
579
  to_info = self.registry.get_model_info(to_model_type)
@@ -337,34 +581,175 @@ class MigrationManager:
337
581
  return None
338
582
 
339
583
  model_name, to_version = to_info
584
+ from_version = self._get_simple_source_version(from_field, model_name)
585
+
586
+ return (model_name, from_version or to_version, to_version)
587
+
588
+ def _get_simple_source_version(
589
+ self, from_field: FieldInfo | None, model_name: str
590
+ ) -> ModelVersion | None:
591
+ """Get the source version for a simple nested field."""
592
+ if not from_field:
593
+ return None
594
+
595
+ from_model_type = self._get_model_type_from_field(from_field)
596
+ if not from_model_type:
597
+ return None
598
+
599
+ from_info = self.registry.get_model_info(from_model_type)
600
+ if not from_info or from_info[0] != model_name:
601
+ return None
602
+
603
+ return from_info[1]
604
+
605
+ def _get_discriminator_key(self, field: FieldInfo) -> str | None:
606
+ """Extract the discriminator key from a field."""
607
+ discriminator = field.discriminator
608
+ if discriminator is None:
609
+ return None
610
+
611
+ if isinstance(discriminator, str):
612
+ return discriminator
613
+
614
+ if hasattr(discriminator, "discriminator") and isinstance(
615
+ discriminator.discriminator, str
616
+ ):
617
+ return discriminator.discriminator
618
+
619
+ return None
620
+
621
+ def _find_discriminator_value(
622
+ self, value: ModelData, discriminator_key: str, field: FieldInfo
623
+ ) -> Any:
624
+ """Find the discriminator value in data, checking field name and aliases."""
625
+ if discriminator_key in value:
626
+ return value[discriminator_key]
627
+
628
+ for model_type in self._get_union_members(field):
629
+ if discriminator_key not in model_type.model_fields:
630
+ continue
631
+
632
+ disc_field = model_type.model_fields[discriminator_key]
633
+
634
+ for alias_attr in ["alias", "serialization_alias"]:
635
+ alias = getattr(disc_field, alias_attr, None)
636
+ if alias and alias in value:
637
+ return value[alias]
340
638
 
341
- # Get the source version
342
- if from_field:
343
- from_model_type = self._get_model_type_from_field(from_field)
344
- if from_model_type and issubclass(from_model_type, BaseModel):
345
- from_info = self.registry.get_model_info(from_model_type)
346
- if from_info and from_info[0] == model_name:
347
- from_version = from_info[1]
348
- return (model_name, from_version, to_version)
639
+ val_alias = disc_field.validation_alias
640
+ if isinstance(val_alias, str) and val_alias in value:
641
+ return value[val_alias]
349
642
 
350
- # If we can't determine the source version, assume it's the same as target
351
- return (model_name, to_version, to_version)
643
+ return None
644
+
645
+ def _get_discriminated_source_version(
646
+ self,
647
+ from_field: FieldInfo | None,
648
+ discriminator_key: str,
649
+ discriminator_value: Any,
650
+ model_name: str,
651
+ default_version: ModelVersion,
652
+ ) -> ModelVersion:
653
+ """Get the source version for a discriminated union field."""
654
+ if not from_field:
655
+ return default_version
656
+
657
+ from_model_type = self._find_discriminated_type(
658
+ from_field, discriminator_key, discriminator_value
659
+ )
660
+ if not from_model_type:
661
+ return default_version
662
+
663
+ from_info = self.registry.get_model_info(from_model_type)
664
+ if not from_info or from_info[0] != model_name:
665
+ return default_version
666
+
667
+ return from_info[1]
352
668
 
353
- def _get_model_type_from_field(
354
- self: Self, field: FieldInfo
669
+ def _find_discriminated_type(
670
+ self, field: FieldInfo, discriminator_key: str, discriminator_value: Any
355
671
  ) -> type[BaseModel] | None:
356
- """Extract the Pydantic model type from a field.
672
+ """Find the right type in a discriminated union based on discriminator value."""
673
+ for model_type in self._get_union_members(field):
674
+ if self._model_matches_discriminator(
675
+ model_type, discriminator_key, discriminator_value
676
+ ):
677
+ return model_type
678
+ return None
357
679
 
358
- Handles Optional, List, and other generic types.
680
+ def _model_matches_discriminator(
681
+ self,
682
+ model_type: type[BaseModel],
683
+ discriminator_key: str,
684
+ discriminator_value: Any,
685
+ ) -> bool:
686
+ """Check if a model type matches a discriminator value."""
687
+ if discriminator_key not in model_type.model_fields:
688
+ return False
359
689
 
360
- Args:
361
- field: The field info to extract from.
690
+ field_info = model_type.model_fields[discriminator_key]
362
691
 
363
- Returns:
364
- The model type if found, None otherwise.
365
- """
692
+ if self._literal_matches_value(field_info.annotation, discriminator_value):
693
+ return True
694
+
695
+ return (
696
+ field_info.default is not PydanticUndefined
697
+ and field_info.default == discriminator_value
698
+ )
699
+
700
+ def _literal_matches_value(self, annotation: Any, value: Any) -> bool:
701
+ """Check if a Literal type annotation contains a specific value."""
702
+ if annotation is None:
703
+ return False
704
+
705
+ origin = get_origin(annotation)
706
+ if origin is not Literal:
707
+ return False
708
+
709
+ literal_values = get_args(annotation)
710
+ return value in literal_values
711
+
712
+ def _get_union_members(self, field: FieldInfo) -> list[type[BaseModel]]:
713
+ """Extract all BaseModel types from a union field."""
366
714
  annotation = field.annotation
715
+ if annotation is None:
716
+ return []
367
717
 
718
+ origin = get_origin(annotation)
719
+
720
+ if origin is None:
721
+ if isinstance(annotation, type) and issubclass(annotation, BaseModel):
722
+ return [annotation]
723
+ return []
724
+
725
+ if not self._is_union_type(origin):
726
+ return []
727
+
728
+ args = get_args(annotation)
729
+ return [
730
+ arg
731
+ for arg in args
732
+ if arg is not type(None)
733
+ and isinstance(arg, type)
734
+ and issubclass(arg, BaseModel)
735
+ ]
736
+
737
+ def _is_union_type(self, origin: Any) -> bool:
738
+ """Check if an origin type represents a Union."""
739
+ if origin is Union:
740
+ return True
741
+
742
+ if hasattr(types, "UnionType"):
743
+ try:
744
+ return origin is types.UnionType
745
+ except (ImportError, AttributeError):
746
+ pass
747
+
748
+ return False
749
+
750
+ def _get_model_type_from_field(self, field: FieldInfo) -> type[BaseModel] | None:
751
+ """Extract the Pydantic model type from a field."""
752
+ annotation = field.annotation
368
753
  if annotation is None:
369
754
  return None
370
755
 
pyrmute/_version.py CHANGED
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '0.3.0'
32
- __version_tuple__ = version_tuple = (0, 3, 0)
31
+ __version__ = version = '0.4.0'
32
+ __version_tuple__ = version_tuple = (0, 4, 0)
33
33
 
34
34
  __commit_id__ = commit_id = None
@@ -2,7 +2,7 @@
2
2
 
3
3
  from collections.abc import Iterator
4
4
  from dataclasses import dataclass
5
- from typing import Self
5
+ from typing import Self, TypeAlias
6
6
 
7
7
  from .types import ModelData
8
8
 
@@ -159,3 +159,6 @@ class MigrationTestResults:
159
159
  f"✗ {len(self.failures)} of {total_count} test(s) failed "
160
160
  f"({passed_count} passed)"
161
161
  )
162
+
163
+
164
+ MigrationTestCases: TypeAlias = list[tuple[ModelData, ModelData] | MigrationTestCase]
pyrmute/model_manager.py CHANGED
@@ -13,6 +13,7 @@ from ._schema_manager import SchemaManager
13
13
  from .exceptions import MigrationError, ModelNotFoundError
14
14
  from .migration_testing import (
15
15
  MigrationTestCase,
16
+ MigrationTestCases,
16
17
  MigrationTestResult,
17
18
  MigrationTestResults,
18
19
  )
@@ -635,7 +636,7 @@ class ModelManager:
635
636
  name: str,
636
637
  from_version: str | ModelVersion,
637
638
  to_version: str | ModelVersion,
638
- test_cases: list[tuple[ModelData, ModelData] | MigrationTestCase],
639
+ test_cases: MigrationTestCases,
639
640
  ) -> MigrationTestResults:
640
641
  """Test a migration with multiple test cases.
641
642
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pyrmute
3
- Version: 0.3.0
3
+ Version: 0.4.0
4
4
  Summary: Pydantic model migrations and schemas
5
5
  Author-email: Matt Ferrera <mattferrera@gmail.com>
6
6
  License: MIT
@@ -56,6 +56,28 @@ through multiple versions.
56
56
  support for large datasets
57
57
  - **Only one dependency** - Pydantic
58
58
 
59
+ ## When to Use pyrmute
60
+
61
+ pyrmute excels at handling schema evolution in production systems:
62
+
63
+ - **Configuration files** - Upgrade user config files as your CLI/desktop app
64
+ evolves (`.apprc`, `config.json`, `settings.yaml`)
65
+ - **Message queues & event streams** - Handle messages from multiple service
66
+ versions publishing different schemas (Kafka, RabbitMQ, SQS)
67
+ - **ETL & data imports** - Import CSV/JSON/Excel files exported over years
68
+ with evolving structures
69
+ - **ML model serving** - Manage feature schema evolution across model versions
70
+ and A/B tests
71
+ - **API versioning** - Support multiple API versions with automatic
72
+ request/response migration
73
+ - **Database migrations** - Transparently migrate legacy data on read without
74
+ downtime
75
+ - **Data archival** - Process historical data dumps with various schema
76
+ versions
77
+
78
+ See the [examples/](examples/) directory for complete, runnable code
79
+ demonstrating these patterns.
80
+
59
81
  ## Help
60
82
 
61
83
  See [documentation](https://mferrera.github.io/pyrmute/) for complete guides
@@ -232,6 +254,71 @@ results = manager.test_migration(
232
254
  assert results.all_passed, f"Migration failed: {results.failures}"
233
255
  ```
234
256
 
257
+ ### Bidirectional Migrations
258
+
259
+ ```python
260
+ # Support both upgrades and downgrades
261
+ @manager.migration("Config", "2.0.0", "1.0.0")
262
+ def downgrade_config(data: ModelData) -> ModelData:
263
+ """Rollback to v1 format."""
264
+ return {k: v for k, v in data.items() if k in ["setting1", "setting2"]}
265
+
266
+ # Useful for:
267
+ # - Rolling back deployments
268
+ # - Normalizing outputs from multiple model versions
269
+ # - Supporting legacy systems during transitions
270
+ ```
271
+
272
+ ### Nested Model Migrations
273
+
274
+ ```python
275
+ # Automatically migrates nested Pydantic models
276
+ @manager.model("Address", "1.0.0")
277
+ class AddressV1(BaseModel):
278
+ street: str
279
+ city: str
280
+
281
+ @manager.model("Address", "2.0.0")
282
+ class AddressV2(BaseModel):
283
+ street: str
284
+ city: str
285
+ postal_code: str
286
+
287
+ @manager.model("User", "2.0.0")
288
+ class UserV2(BaseModel):
289
+ name: str
290
+ address: AddressV2 # Nested model
291
+
292
+ # When migrating User, Address is automatically migrated too
293
+ @manager.migration("Address", "1.0.0", "2.0.0")
294
+ def add_postal_code(data: ModelData) -> ModelData:
295
+ return {**data, "postal_code": "00000"}
296
+ ```
297
+
298
+ ### Discriminated Unions
299
+
300
+ ```python
301
+ from typing import Literal, Union
302
+ from pydantic import Field
303
+
304
+ # Handle complex type hierarchies
305
+ @manager.model("CreditCard", "1.0.0")
306
+ class CreditCardV1(BaseModel):
307
+ type: Literal["credit_card"] = "credit_card"
308
+ card_number: str
309
+
310
+ @manager.model("PayPal", "1.0.0")
311
+ class PayPalV1(BaseModel):
312
+ type: Literal["paypal"] = "paypal"
313
+ email: str
314
+
315
+ @manager.model("Payment", "1.0.0")
316
+ class PaymentV1(BaseModel):
317
+ method: Union[CreditCardV1, PayPalV1] = Field(discriminator="type")
318
+
319
+ # Migrations respect discriminated unions
320
+ ```
321
+
235
322
  ### Export JSON Schemas
236
323
 
237
324
  ```python
@@ -267,79 +354,146 @@ config = manager.migrate({"timeout": 60}, "Config", "1.0.0", "2.0.0")
267
354
  # ConfigV2(timeout=60, retries=3)
268
355
  ```
269
356
 
270
- ## Real-World Example
357
+ ## Real-World Examples
358
+
359
+ ### Configuration File Evolution
271
360
 
272
361
  ```python
273
- from datetime import datetime
274
- from pydantic import BaseModel, EmailStr
275
- from pyrmute import ModelManager, ModelData
362
+ # Your CLI tool evolves over time
363
+ @manager.model("AppConfig", "1.0.0")
364
+ class AppConfigV1(BaseModel):
365
+ api_key: str
366
+ debug: bool = False
367
+
368
+ @manager.model("AppConfig", "2.0.0")
369
+ class AppConfigV2(BaseModel):
370
+ api_key: str
371
+ api_endpoint: str = "https://api.example.com"
372
+ log_level: Literal["DEBUG", "INFO", "ERROR"] = "INFO"
373
+
374
+ @manager.migration("AppConfig", "1.0.0", "2.0.0")
375
+ def upgrade_config(data: dict) -> dict:
376
+ return {
377
+ "api_key": data["api_key"],
378
+ "api_endpoint": "https://api.example.com",
379
+ "log_level": "DEBUG" if data.get("debug") else "INFO",
380
+ }
276
381
 
277
- manager = ModelManager()
382
+ # Load and auto-upgrade user's config file
383
+ def load_config(config_path: Path) -> AppConfigV2:
384
+ with open(config_path) as f:
385
+ data = json.load(f)
278
386
 
387
+ version = data.get("_version", "1.0.0")
279
388
 
280
- # API v1: Basic order
281
- @manager.model("Order", "1.0.0")
282
- class OrderV1(BaseModel):
283
- id: str
284
- items: list[str]
285
- total: float
286
-
287
-
288
- # API v2: Add customer info
289
- @manager.model("Order", "2.0.0")
290
- class OrderV2(BaseModel):
291
- id: str
292
- items: list[str]
293
- total: float
294
- customer_email: EmailStr
295
-
296
-
297
- # API v3: Structured items and timestamps
298
- @manager.model("Order", "3.0.0")
299
- class OrderItemV3(BaseModel):
300
- product_id: str
301
- quantity: int
302
- price: float
303
-
304
-
305
- @manager.model("Order", "3.0.0")
306
- class OrderV3(BaseModel):
307
- id: str
308
- items: list[OrderItemV3]
309
- total: float
310
- customer_email: EmailStr
311
- created_at: datetime
312
-
313
-
314
- # Define migrations
315
- @manager.migration("Order", "1.0.0", "2.0.0")
316
- def add_customer_email(data: ModelData) -> ModelData:
317
- return {**data, "customer_email": "customer@example.com"}
318
-
319
-
320
- @manager.migration("Order", "2.0.0", "3.0.0")
321
- def structure_items(data: ModelData) -> ModelData:
322
- # Convert simple strings to structured items
323
- structured_items = [
324
- {
325
- "product_id": item,
326
- "quantity": 1,
327
- "price": data["total"] / len(data["items"])
328
- }
329
- for item in data["items"]
330
- ]
331
- return {
332
- **data,
333
- "items": structured_items,
334
- "created_at": datetime.now().isoformat()
335
- }
389
+ # Migrate to current version
390
+ config = manager.migrate(
391
+ data,
392
+ "AppConfig",
393
+ from_version=version,
394
+ to_version="2.0.0"
395
+ )
336
396
 
337
- # Migrate old orders from your database
338
- old_order = {"id": "123", "items": ["widget", "gadget"], "total": 29.99}
339
- new_order = manager.migrate(old_order, "Order", "1.0.0", "3.0.0")
340
- database.save(new_order)
397
+ # Save upgraded config with version tag
398
+ with open(config_path, "w") as f:
399
+ json.dump({**config.model_dump(), "_version": "2.0.0"}, f, indent=2)
400
+
401
+ return config
341
402
  ```
342
403
 
404
+ ### Message Queue Consumer
405
+
406
+ ```python
407
+ # Handle messages from multiple service versions
408
+ @manager.model("OrderEvent", "1.0.0")
409
+ class OrderEventV1(BaseModel):
410
+ order_id: str
411
+ customer_email: str
412
+ items: list[dict] # Unstructured
413
+
414
+ @manager.model("OrderEvent", "2.0.0")
415
+ class OrderEventV2(BaseModel):
416
+ order_id: str
417
+ customer_email: str
418
+ items: list[OrderItem] # Structured
419
+ total: Decimal
420
+
421
+ def process_message(message: dict, schema_version: str) -> None:
422
+ # Migrate to current schema regardless of source version
423
+ event = manager.migrate(
424
+ message,
425
+ "OrderEvent",
426
+ from_version=schema_version,
427
+ to_version="2.0.0"
428
+ )
429
+ # Process with current schema only
430
+ fulfill_order(event)
431
+ ```
432
+
433
+ ### ETL Data Import
434
+
435
+ ```python
436
+ # Import historical exports with evolving schemas
437
+ import csv
438
+
439
+ def import_customers(file_path: Path, file_version: str) -> None:
440
+ with open(file_path) as f:
441
+ reader = csv.DictReader(f)
442
+
443
+ # Stream migration for memory efficiency
444
+ for customer in manager.migrate_batch_streaming(
445
+ reader,
446
+ "Customer",
447
+ from_version=file_version,
448
+ to_version="3.0.0",
449
+ chunk_size=1000
450
+ ):
451
+ database.save(customer)
452
+
453
+ # Handle files from different years
454
+ import_customers("exports/2022_customers.csv", "1.0.0")
455
+ import_customers("exports/2023_customers.csv", "2.0.0")
456
+ import_customers("exports/2024_customers.csv", "3.0.0")
457
+ ```
458
+
459
+ ### ML Model Serving
460
+
461
+ ```python
462
+ # Route requests to appropriate model versions
463
+ class InferenceService:
464
+ def predict(self, features: dict, request_version: str) -> BaseModel:
465
+ # Determine target model version (A/B testing, gradual rollout, etc.)
466
+ model_version = self.get_model_version(features["user_id"])
467
+
468
+ # Migrate request to model's expected format
469
+ model_input = manager.migrate(
470
+ features,
471
+ "PredictionRequest",
472
+ from_version=request_version,
473
+ to_version=model_version
474
+ )
475
+
476
+ # Run inference
477
+ prediction = self.models[model_version].predict(model_input)
478
+
479
+ # Normalize output for logging/analytics
480
+ return manager.migrate(
481
+ prediction,
482
+ "PredictionResponse",
483
+ from_version=model_version,
484
+ to_version="3.0.0"
485
+ )
486
+ ```
487
+
488
+ See [examples/](examples/) for complete runnable code:
489
+ - `config_file_migration.py` - CLI/desktop app config file evolution
490
+ - `message_queue_consumer.py` - Kafka/RabbitMQ/SQS consumer handling multiple
491
+ schemas
492
+ - `etl_data_import.py` - CSV/JSON/Excel import pipeline with historical data
493
+ - `ml_inference_pipeline.py` - ML model serving with feature evolution
494
+ - `advanced_features.py` - Complex Pydantic features (unions, nested models,
495
+ validators)
496
+
343
497
  ## Contributing
344
498
 
345
499
  For guidance on setting up a development environment and how to make a
@@ -0,0 +1,17 @@
1
+ pyrmute/__init__.py,sha256=j2pbMYswL0xR8FZwZDg-qAw6HwFpA3KPhY06_2yc_U0,1132
2
+ pyrmute/_migration_manager.py,sha256=TFws66RsEdKLpvjDDQB1pKgeeyUe5WoutacTaeDsZoE,26154
3
+ pyrmute/_registry.py,sha256=iUjMPd6CYgyvWT8PxZqHWBZnsHrX25fOPDi_-k_QDJs,6124
4
+ pyrmute/_schema_manager.py,sha256=eun8PTL9Gv1XAMVKmE3tYmjdrcf701-IapUXjb6WDL0,12122
5
+ pyrmute/_version.py,sha256=2_0GUP7yBCXRus-qiJKxQD62z172WSs1sQ6DVpPsbmM,704
6
+ pyrmute/exceptions.py,sha256=Q57cUuzzMdkIl5Q0_VyLobpdB0WcrE0ggfC-LBoX2Uo,1681
7
+ pyrmute/migration_testing.py,sha256=fpKT2u7pgPRpswb4PUvbd-fQ3W76svNWvEVYVDmb3Dg,5066
8
+ pyrmute/model_diff.py,sha256=vMa2NTYFqt9E7UYDZH4PQmLcoxQw5Sj-nPpUHB_53Ig,9594
9
+ pyrmute/model_manager.py,sha256=a6ecd-lZ3iliP3lqgCAi7xLeFlBh50kBA-m6gLGKRx4,24585
10
+ pyrmute/model_version.py,sha256=ftNDuJlN3S5ZKQK8DKqqwfBDRiz4rGCYn-aJ3n6Zmqk,2025
11
+ pyrmute/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
+ pyrmute/types.py,sha256=56IH8Rl9AmVh_w3V6PbSSEwaPrBSfc4pYrtcxodvlT0,1187
13
+ pyrmute-0.4.0.dist-info/licenses/LICENSE,sha256=otWInySiZeGwhHqQQ7n7nxM5QBSBe2CzeGEmQDZEz8Q,1119
14
+ pyrmute-0.4.0.dist-info/METADATA,sha256=-KbIQN_INu7ZGfyEO65qmSGsWAXiTHQXR_7F5f-enOQ,14480
15
+ pyrmute-0.4.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
16
+ pyrmute-0.4.0.dist-info/top_level.txt,sha256=C8QtzqE6yBHkeewSp1QewvsyeHj_VQLYjSa5HLtMiow,8
17
+ pyrmute-0.4.0.dist-info/RECORD,,
@@ -1,17 +0,0 @@
1
- pyrmute/__init__.py,sha256=vB0WBe3CukMBDnK0XP4qqehbMM4z_TUQMcTVPCUyt6Q,1082
2
- pyrmute/_migration_manager.py,sha256=KregnRUKqF1TC9XIpGAHpQvFlnRTEnp2X6Q2sAay8D4,12489
3
- pyrmute/_registry.py,sha256=iUjMPd6CYgyvWT8PxZqHWBZnsHrX25fOPDi_-k_QDJs,6124
4
- pyrmute/_schema_manager.py,sha256=eun8PTL9Gv1XAMVKmE3tYmjdrcf701-IapUXjb6WDL0,12122
5
- pyrmute/_version.py,sha256=5zTqm8rgXsWYBpB2M3Zw_K1D-aV8wP7NsBLrmMKkrAQ,704
6
- pyrmute/exceptions.py,sha256=Q57cUuzzMdkIl5Q0_VyLobpdB0WcrE0ggfC-LBoX2Uo,1681
7
- pyrmute/migration_testing.py,sha256=dOR8BDzmz4mFAI4hFtDUCEMS8Qc8qqD_iOV0qRai-qM,4967
8
- pyrmute/model_diff.py,sha256=vMa2NTYFqt9E7UYDZH4PQmLcoxQw5Sj-nPpUHB_53Ig,9594
9
- pyrmute/model_manager.py,sha256=e3UKFo79pkseCUFXIzW2_onu3GYjAnY1FR4JR_QF-Gc,24596
10
- pyrmute/model_version.py,sha256=ftNDuJlN3S5ZKQK8DKqqwfBDRiz4rGCYn-aJ3n6Zmqk,2025
11
- pyrmute/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
- pyrmute/types.py,sha256=56IH8Rl9AmVh_w3V6PbSSEwaPrBSfc4pYrtcxodvlT0,1187
13
- pyrmute-0.3.0.dist-info/licenses/LICENSE,sha256=otWInySiZeGwhHqQQ7n7nxM5QBSBe2CzeGEmQDZEz8Q,1119
14
- pyrmute-0.3.0.dist-info/METADATA,sha256=jnRgO76ovFaktYKqq5BMf4tpUc_DYO_drET7c1hPVk0,9580
15
- pyrmute-0.3.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
16
- pyrmute-0.3.0.dist-info/top_level.txt,sha256=C8QtzqE6yBHkeewSp1QewvsyeHj_VQLYjSa5HLtMiow,8
17
- pyrmute-0.3.0.dist-info/RECORD,,