pyrmute 0.3.0__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pyrmute/model_manager.py CHANGED
@@ -1,4 +1,4 @@
1
- """Model manager."""
1
+ """ModelManager class."""
2
2
 
3
3
  from collections.abc import Callable, Iterable
4
4
  from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
@@ -6,6 +6,7 @@ from pathlib import Path
6
6
  from typing import Any, Self
7
7
 
8
8
  from pydantic import BaseModel
9
+ from pydantic.json_schema import GenerateJsonSchema
9
10
 
10
11
  from ._migration_manager import MigrationManager
11
12
  from ._registry import Registry
@@ -13,11 +14,13 @@ from ._schema_manager import SchemaManager
13
14
  from .exceptions import MigrationError, ModelNotFoundError
14
15
  from .migration_testing import (
15
16
  MigrationTestCase,
17
+ MigrationTestCases,
16
18
  MigrationTestResult,
17
19
  MigrationTestResults,
18
20
  )
19
21
  from .model_diff import ModelDiff
20
22
  from .model_version import ModelVersion
23
+ from .schema_config import SchemaConfig
21
24
  from .types import (
22
25
  DecoratedBaseModel,
23
26
  JsonSchema,
@@ -25,20 +28,16 @@ from .types import (
25
28
  MigrationFunc,
26
29
  ModelData,
27
30
  NestedModelInfo,
31
+ SchemaTransformer,
28
32
  )
29
33
 
30
34
 
31
35
  class ModelManager:
32
- """High-level interface for versioned model management.
36
+ """High-level interface for versioned model management and schema generation.
33
37
 
34
38
  ModelManager provides a unified API for managing schema evolution across different
35
39
  versions of Pydantic models. It handles model registration, automatic migration
36
- between versions, schema generation, and batch processing operations.
37
-
38
- Attributes:
39
- registry: Registry instance managing all registered model versions.
40
- migration_manager: MigrationManager instance handling migration logic and paths.
41
- schema_manager: SchemaManager instance for JSON schema generation and export.
40
+ between versions, customizable schema generation, and batch processing operations.
42
41
 
43
42
  Basic Usage:
44
43
  >>> manager = ModelManager()
@@ -63,6 +62,56 @@ class ModelManager:
63
62
  >>> user = manager.migrate(old_data, "User", "1.0.0", "2.0.0")
64
63
  >>> # Result: UserV2(name="Alice", email="unknown@example.com")
65
64
 
65
+ Custom Schema Generation:
66
+ >>> from pydantic.json_schema import GenerateJsonSchema
67
+ >>>
68
+ >>> class CustomSchemaGenerator(GenerateJsonSchema):
69
+ ... '''Add custom metadata to all schemas.'''
70
+ ... def generate(
71
+ ... self,
72
+ ... schema: Mapping[str, Any],
73
+ ... mode: JsonSchemaMode = "validation"
74
+ ... ) -> JsonSchema:
75
+ ... json_schema = super().generate(schema, mode=mode)
76
+ ... json_schema["x-company"] = "Acme"
77
+ ... json_schema["$schema"] = self.schema_dialect
78
+ ... return json_schema
79
+ >>>
80
+ >>> # Set at manager level (applies to all schemas)
81
+ >>> manager = ModelManager(
82
+ ... default_schema_config=SchemaConfig(
83
+ ... schema_generator=CustomSchemaGenerator,
84
+ ... mode="validation",
85
+ ... by_alias=True
86
+ ... )
87
+ ... )
88
+ >>>
89
+ >>> @manager.model("User", "1.0.0")
90
+ ... class User(BaseModel):
91
+ ... name: str = Field(title="Full Name")
92
+ ... email: str
93
+ >>>
94
+ >>> # Get schema with default config
95
+ >>> schema = manager.get_schema("User", "1.0.0")
96
+ >>> # Will include x-company: 'Acme'
97
+
98
+ Schema Transformers:
99
+ >>> manager = ModelManager()
100
+ >>>
101
+ >>> @manager.model("Product", "1.0.0")
102
+ ... class Product(BaseModel):
103
+ ... name: str
104
+ ... price: float
105
+ >>>
106
+ >>> # Add transformer for specific model
107
+ >>> @manager.schema_transformer("Product", "1.0.0")
108
+ ... def add_examples(schema: JsonSchema) -> JsonSchema:
109
+ ... schema["examples"] = [{"name": "Widget", "price": 9.99}]
110
+ ... return schema
111
+ >>>
112
+ >>> schema = manager.get_schema("Product", "1.0.0")
113
+ >>> # Will include examples
114
+
66
115
  Advanced Features:
67
116
  >>> # Batch migration with parallel processing
68
117
  >>> users = manager.migrate_batch(
@@ -71,7 +120,9 @@ class ModelManager:
71
120
  ... )
72
121
  >>>
73
122
  >>> # Stream large datasets efficiently
74
- >>> for user in manager.migrate_batch_streaming(large_dataset, "User", "1.0.0", "2.0.0"):
123
+ >>> for user in manager.migrate_batch_streaming(
124
+ ... large_dataset, "User", "1.0.0", "2.0.0"
125
+ ... ):
75
126
  ... save_to_database(user)
76
127
  >>>
77
128
  >>> # Compare versions and export schemas
@@ -83,23 +134,32 @@ class ModelManager:
83
134
  >>> results = manager.test_migration(
84
135
  ... "User", "1.0.0", "2.0.0",
85
136
  ... test_cases=[
86
- ... ({"name": "Alice"}, {"name": "Alice", "email": "unknown@example.com"})
137
+ ... (
138
+ ... {"name": "Alice"},
139
+ ... {"name": "Alice", "email": "unknown@example.com"}
140
+ ... )
87
141
  ... ]
88
142
  ... )
89
143
  >>> results.assert_all_passed()
90
- """ # noqa: E501
144
+ """
145
+
146
+ def __init__(self: Self, default_schema_config: SchemaConfig | None = None) -> None:
147
+ """Initialize the versioned model manager.
91
148
 
92
- def __init__(self: Self) -> None:
93
- """Initialize the versioned model manager."""
149
+ Args:
150
+ default_schema_config: Default configuration for schema generation
151
+ applied to all schema operations unless overridden.
152
+ """
94
153
  self._registry = Registry()
95
154
  self._migration_manager = MigrationManager(self._registry)
96
- self._schema_manager = SchemaManager(self._registry)
155
+ self._schema_manager = SchemaManager(
156
+ self._registry, default_config=default_schema_config
157
+ )
97
158
 
98
159
  def model(
99
160
  self: Self,
100
161
  name: str,
101
162
  version: str | ModelVersion,
102
- schema_generator: JsonSchemaGenerator | None = None,
103
163
  enable_ref: bool = False,
104
164
  backward_compatible: bool = False,
105
165
  ) -> Callable[[type[DecoratedBaseModel]], type[DecoratedBaseModel]]:
@@ -108,7 +168,6 @@ class ModelManager:
108
168
  Args:
109
169
  name: Name of the model.
110
170
  version: Semantic version.
111
- schema_generator: Optional custom schema generator.
112
171
  enable_ref: If True, this model can be referenced via $ref in separate
113
172
  schema files. If False, it will always be inlined.
114
173
  backward_compatible: If True, this model does not need a migration function
@@ -129,9 +188,7 @@ class ModelManager:
129
188
  ... class CityV1(BaseModel):
130
189
  ... city: City
131
190
  """
132
- return self._registry.register(
133
- name, version, schema_generator, enable_ref, backward_compatible
134
- )
191
+ return self._registry.register(name, version, enable_ref, backward_compatible)
135
192
 
136
193
  def migration(
137
194
  self: Self,
@@ -308,27 +365,12 @@ class ModelManager:
308
365
  from_version: Source version.
309
366
  to_version: Target version.
310
367
  parallel: If True, use parallel processing.
311
- max_workers: Maximum number of workers for parallel processing. Defaults to
312
- None (uses executor default).
368
+ max_workers: Maximum number of workers for parallel processing.
313
369
  use_processes: If True, use ProcessPoolExecutor instead of
314
- ThreadPoolExecutor. Useful for CPU-intensive migrations.
370
+ ThreadPoolExecutor.
315
371
 
316
372
  Returns:
317
373
  List of migrated BaseModel instances.
318
-
319
- Example:
320
- >>> legacy_users = [
321
- ... {"name": "Alice"},
322
- ... {"name": "Bob"},
323
- ... {"name": "Charlie"}
324
- ... ]
325
- >>> users = manager.migrate_batch(
326
- ... legacy_users,
327
- ... "User",
328
- ... from_version="1.0.0",
329
- ... to_version="3.0.0",
330
- ... parallel=True
331
- ... )
332
374
  """
333
375
  data_list = list(data_list)
334
376
 
@@ -371,15 +413,6 @@ class ModelManager:
371
413
 
372
414
  Returns:
373
415
  List of raw migrated dictionaries.
374
-
375
- Example:
376
- >>> legacy_data = [{"name": "Alice"}, {"name": "Bob"}]
377
- >>> migrated_data = manager.migrate_batch_data(
378
- ... legacy_data,
379
- ... "User",
380
- ... from_version="1.0.0",
381
- ... to_version="2.0.0"
382
- ... )
383
416
  """
384
417
  data_list = list(data_list)
385
418
 
@@ -410,9 +443,6 @@ class ModelManager:
410
443
  ) -> Iterable[BaseModel]:
411
444
  """Migrate data in chunks, yielding results as they complete.
412
445
 
413
- Useful for large datasets where you want to start processing results before all
414
- migrations complete.
415
-
416
446
  Args:
417
447
  data_list: Iterable of data dictionaries to migrate.
418
448
  name: Name of the model.
@@ -422,17 +452,6 @@ class ModelManager:
422
452
 
423
453
  Yields:
424
454
  Migrated BaseModel instances.
425
-
426
- Example:
427
- >>> legacy_users = load_large_dataset()
428
- >>> for user in manager.migrate_batch_streaming(
429
- ... legacy_users,
430
- ... "User",
431
- ... from_version="1.0.0",
432
- ... to_version="3.0.0"
433
- ... ):
434
- ... # Process each user as it's migrated
435
- ... save_to_database(user)
436
455
  """
437
456
  chunk = []
438
457
 
@@ -456,9 +475,6 @@ class ModelManager:
456
475
  ) -> Iterable[ModelData]:
457
476
  """Migrate data in chunks, yielding raw dictionaries as they complete.
458
477
 
459
- Useful for large datasets where you want to start processing results before all
460
- migrations complete, without the validation overhead.
461
-
462
478
  Args:
463
479
  data_list: Iterable of data dictionaries to migrate.
464
480
  name: Name of the model.
@@ -468,17 +484,6 @@ class ModelManager:
468
484
 
469
485
  Yields:
470
486
  Raw migrated dictionaries.
471
-
472
- Example:
473
- >>> legacy_data = load_large_dataset()
474
- >>> for data in manager.migrate_batch_data_streaming(
475
- ... legacy_data,
476
- ... "User",
477
- ... from_version="1.0.0",
478
- ... to_version="3.0.0"
479
- ... ):
480
- ... # Process raw data as it's migrated
481
- ... bulk_insert_to_database(data)
482
487
  """
483
488
  chunk = []
484
489
 
@@ -502,9 +507,6 @@ class ModelManager:
502
507
  ) -> ModelDiff:
503
508
  """Get a detailed diff between two model versions.
504
509
 
505
- Compares field names, types, requirements, and default values to provide a
506
- comprehensive view of what changed between versions.
507
-
508
510
  Args:
509
511
  name: Name of the model.
510
512
  from_version: Source version.
@@ -512,12 +514,6 @@ class ModelManager:
512
514
 
513
515
  Returns:
514
516
  ModelDiff with detailed change information.
515
-
516
- Example:
517
- >>> diff = manager.diff("User", "1.0.0", "2.0.0")
518
- >>> print(diff.to_markdown())
519
- >>> print(f"Added: {diff.added_fields}")
520
- >>> print(f"Removed: {diff.removed_fields}")
521
517
  """
522
518
  from_ver_str = str(
523
519
  ModelVersion.parse(from_version)
@@ -541,10 +537,136 @@ class ModelManager:
541
537
  to_version=to_ver_str,
542
538
  )
543
539
 
540
+ def set_default_schema_generator(
541
+ self: Self, generator: JsonSchemaGenerator | type[GenerateJsonSchema]
542
+ ) -> None:
543
+ """Set the default schema generator for all schemas.
544
+
545
+ This is a convenience method that updates the default schema configuration.
546
+
547
+ Args:
548
+ generator: Custom schema generator - either a callable or GenerateJsonSchema
549
+ class.
550
+
551
+ Example (Callable):
552
+ >>> def my_generator(model: type[BaseModel]) -> JsonSchema:
553
+ ... schema = model.model_json_schema()
554
+ ... schema["x-custom"] = True
555
+ ... return schema
556
+ >>>
557
+ >>> manager = ModelManager()
558
+ >>> manager.set_default_schema_generator(my_generator)
559
+
560
+ Example (Class - Recommended):
561
+ >>> from pydantic.json_schema import GenerateJsonSchema
562
+ >>>
563
+ >>> class MyGenerator(GenerateJsonSchema):
564
+ ... def generate(
565
+ ... self,
566
+ ... schema: Mapping[str, Any],
567
+ ... mode: JsonSchemaMode = "validation"
568
+ ... ) -> JsonSchema:
569
+ ... json_schema = super().generate(schema, mode=mode)
570
+ ... json_schema["x-custom"] = True
571
+ ... json_schema["$schema"] = self.schema_dialect
572
+ ... return json_schema
573
+ >>>
574
+ >>> manager = ModelManager()
575
+ >>> manager.set_default_schema_generator(MyGenerator)
576
+ >>>
577
+ >>> # All subsequent schema calls will use MyGenerator
578
+ >>> schema = manager.get_schema("User", "1.0.0")
579
+ """
580
+ self._schema_manager.set_default_schema_generator(generator)
581
+
582
+ def schema_transformer(
583
+ self: Self,
584
+ name: str,
585
+ version: str | ModelVersion,
586
+ ) -> Callable[[SchemaTransformer], SchemaTransformer]:
587
+ """Decorator to register a schema transformer for a model version.
588
+
589
+ Transformers are simple functions that modify a schema after generation.
590
+ They're useful for model-specific customizations that don't require deep
591
+ integration with Pydantic's generation process.
592
+
593
+ Args:
594
+ name: Name of the model.
595
+ version: Model version.
596
+
597
+ Returns:
598
+ Decorator function.
599
+
600
+ Example:
601
+ >>> @manager.schema_transformer("User", "1.0.0")
602
+ ... def add_auth_metadata(schema: JsonSchema) -> JsonSchema:
603
+ ... schema["x-requires-auth"] = True
604
+ ... schema["x-auth-level"] = 'admin'
605
+ ... return schema
606
+ >>>
607
+ >>> @manager.schema_transformer("Product", "2.0.0")
608
+ ... def add_product_examples(schema: JsonSchema) -> JsonSchema:
609
+ ... schema["examples"] = [
610
+ ... {"name": "Widget", "price": 9.99},
611
+ ... {"name": "Gadget", "price": 19.99}
612
+ ... ]
613
+ ... return schema
614
+ """
615
+
616
+ def decorator(func: SchemaTransformer) -> SchemaTransformer:
617
+ self._schema_manager.register_transformer(name, version, func)
618
+ return func
619
+
620
+ return decorator
621
+
622
+ def get_schema_transformers(
623
+ self: Self,
624
+ name: str,
625
+ version: str | ModelVersion,
626
+ ) -> list[SchemaTransformer]:
627
+ """Get all transformers for a model version.
628
+
629
+ Args:
630
+ name: Name of the model.
631
+ version: Model version.
632
+
633
+ Returns:
634
+ List of transformer functions.
635
+
636
+ Example:
637
+ >>> transformers = manager.get_schema_transformers("User", "1.0.0")
638
+ >>> print(f"Found {len(transformers)} transformers")
639
+ """
640
+ return self._schema_manager.get_transformers(name, version)
641
+
642
+ def clear_schema_transformers(
643
+ self: Self,
644
+ name: str | None = None,
645
+ version: str | ModelVersion | None = None,
646
+ ) -> None:
647
+ """Clear schema transformers.
648
+
649
+ Args:
650
+ name: Optional model name. If None, clears all.
651
+ version: Optional version. If None, clears all versions of model.
652
+
653
+ Example:
654
+ >>> # Clear all transformers
655
+ >>> manager.clear_schema_transformers()
656
+ >>>
657
+ >>> # Clear User transformers
658
+ >>> manager.clear_schema_transformers("User")
659
+ >>>
660
+ >>> # Clear specific version
661
+ >>> manager.clear_schema_transformers("User", "1.0.0")
662
+ """
663
+ self._schema_manager.clear_transformers(name, version)
664
+
544
665
  def get_schema(
545
666
  self: Self,
546
667
  name: str,
547
668
  version: str | ModelVersion,
669
+ config: SchemaConfig | None = None,
548
670
  **kwargs: Any,
549
671
  ) -> JsonSchema:
550
672
  """Get JSON schema for a specific version.
@@ -552,12 +674,25 @@ class ModelManager:
552
674
  Args:
553
675
  name: Name of the model.
554
676
  version: Semantic version.
555
- **kwargs: Additional schema generation arguments.
677
+ config: Optional schema configuration (overrides default).
678
+ **kwargs: Additional schema generation arguments (e.g.,
679
+ mode="serialization").
556
680
 
557
681
  Returns:
558
682
  JSON schema dictionary.
683
+
684
+ Example:
685
+ >>> # Use default config
686
+ >>> schema = manager.get_schema("User", "1.0.0")
687
+ >>>
688
+ >>> # Override with custom config
689
+ >>> config = SchemaConfig(mode="serialization")
690
+ >>> schema = manager.get_schema("User", "1.0.0", config=config)
691
+ >>>
692
+ >>> # Quick override with kwargs
693
+ >>> schema = manager.get_schema("User", "1.0.0", mode="serialization")
559
694
  """
560
- return self._schema_manager.get_schema(name, version, **kwargs)
695
+ return self._schema_manager.get_schema(name, version, config=config, **kwargs)
561
696
 
562
697
  def list_models(self: Self) -> list[str]:
563
698
  """Get list of all registered models.
@@ -584,6 +719,7 @@ class ModelManager:
584
719
  indent: int = 2,
585
720
  separate_definitions: bool = False,
586
721
  ref_template: str | None = None,
722
+ config: SchemaConfig | None = None,
587
723
  ) -> None:
588
724
  """Export all schemas to JSON files.
589
725
 
@@ -591,27 +727,30 @@ class ModelManager:
591
727
  output_dir: Directory path for output.
592
728
  indent: JSON indentation level.
593
729
  separate_definitions: If True, create separate schema files for nested
594
- models and use $ref to reference them. Only applies to models with
595
- 'enable_ref=True'.
730
+ models and use $ref to reference them.
596
731
  ref_template: Template for $ref URLs when separate_definitions=True.
597
- Defaults to relative file references if not provided.
732
+ config: Optional schema configuration for all exported schemas.
598
733
 
599
734
  Example:
600
- >>> # Inline definitions (default)
601
- >>> manager.dump_schemas("schemas/")
602
- >>>
603
- >>> # Separate sub-schemas with relative refs
604
- >>> manager.dump_schemas("schemas/", separate_definitions=True)
735
+ >>> # Export with custom generator
736
+ >>> config = SchemaConfig(
737
+ ... schema_generator=CustomGenerator,
738
+ ... mode="validation"
739
+ ... )
740
+ >>> manager.dump_schemas("schemas/", config=config)
605
741
  >>>
606
- >>> # Separate sub-schemas with absolute URLs
742
+ >>> # Export validation and serialization schemas separately
607
743
  >>> manager.dump_schemas(
608
- ... "schemas/",
609
- ... separate_definitions=True,
610
- ... ref_template="https://example.com/schemas/{model}_v{version}.json"
744
+ ... "schemas/validation/",
745
+ ... config=SchemaConfig(mode="validation")
746
+ ... )
747
+ >>> manager.dump_schemas(
748
+ ... "schemas/serialization/",
749
+ ... config=SchemaConfig(mode="serialization")
611
750
  ... )
612
751
  """
613
752
  self._schema_manager.dump_schemas(
614
- output_dir, indent, separate_definitions, ref_template
753
+ output_dir, indent, separate_definitions, ref_template, config=config
615
754
  )
616
755
 
617
756
  def get_nested_models(
@@ -635,58 +774,19 @@ class ModelManager:
635
774
  name: str,
636
775
  from_version: str | ModelVersion,
637
776
  to_version: str | ModelVersion,
638
- test_cases: list[tuple[ModelData, ModelData] | MigrationTestCase],
777
+ test_cases: MigrationTestCases,
639
778
  ) -> MigrationTestResults:
640
779
  """Test a migration with multiple test cases.
641
780
 
642
- Executes a migration on multiple test inputs and validates the outputs match
643
- expected values. Useful for regression testing and validating migration logic.
644
-
645
781
  Args:
646
782
  name: Name of the model.
647
783
  from_version: Source version to migrate from.
648
784
  to_version: Target version to migrate to.
649
- test_cases: List of test cases, either as (source, target) tuples or
650
- MigrationTestCase objects. If target is None, only verifies the
651
- migration completes without errors.
785
+ test_cases: List of test cases.
652
786
 
653
787
  Returns:
654
788
  MigrationTestResults containing individual results for each test case.
655
-
656
- Example:
657
- >>> # Using tuples (source, target)
658
- >>> results = manager.test_migration(
659
- ... "User", "1.0.0", "2.0.0",
660
- ... test_cases=[
661
- ... ({"name": "Alice"}, {"name": "Alice", "email": "alice@example.com"}),
662
- ... ({"name": "Bob"}, {"name": "Bob", "email": "bob@example.com"})
663
- ... ]
664
- ... )
665
- >>> assert results.all_passed
666
- >>>
667
- >>> # Using MigrationTestCase objects
668
- >>> results = manager.test_migration(
669
- ... "User", "1.0.0", "2.0.0",
670
- ... test_cases=[
671
- ... MigrationTestCase(
672
- ... source={"name": "Alice"},
673
- ... target={"name": "Alice", "email": "alice@example.com"},
674
- ... description="Standard user migration"
675
- ... )
676
- ... ]
677
- ... )
678
- >>>
679
- >>> # Use in pytest
680
- >>> def test_user_migration():
681
- ... results = manager.test_migration("User", "1.0.0", "2.0.0", test_cases)
682
- ... results.assert_all_passed() # Raises AssertionError with details if failed
683
- >>>
684
- >>> # Inspect failures
685
- >>> if not results.all_passed:
686
- ... for failure in results.failures:
687
- ... print(f"Failed: {failure.test_case.description}")
688
- ... print(f" Error: {failure.error}")
689
- """ # noqa: E501
789
+ """
690
790
  results = []
691
791
 
692
792
  for test_case_input in test_cases: