pyrmute 0.3.0__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pyrmute/_registry.py CHANGED
@@ -10,11 +10,9 @@ from .exceptions import ModelNotFoundError
10
10
  from .model_version import ModelVersion
11
11
  from .types import (
12
12
  DecoratedBaseModel,
13
- JsonSchemaGenerator,
14
13
  MigrationMap,
15
14
  ModelMetadata,
16
15
  ModelName,
17
- SchemaGenerators,
18
16
  VersionedModels,
19
17
  )
20
18
 
@@ -28,7 +26,6 @@ class Registry:
28
26
  Attributes:
29
27
  _models: Dictionary mapping model names to version-model mappings.
30
28
  _migrations: Dictionary storing migration functions between versions.
31
- _schema_generators: Dictionary storing custom schema generators.
32
29
  _model_metadata: Dictionary mapping model classes to (name, version).
33
30
  _ref_enabled: Dictionary tracking which models have enable_ref=True.
34
31
  """
@@ -37,7 +34,6 @@ class Registry:
37
34
  """Initialize the model registry."""
38
35
  self._models: dict[ModelName, VersionedModels] = defaultdict(dict)
39
36
  self._migrations: dict[ModelName, MigrationMap] = defaultdict(dict)
40
- self._schema_generators: dict[ModelName, SchemaGenerators] = defaultdict(dict)
41
37
  self._model_metadata: dict[type[BaseModel], ModelMetadata] = {}
42
38
  self._ref_enabled: dict[ModelName, set[ModelVersion]] = defaultdict(set)
43
39
  self._backward_compatible_enabled: dict[ModelName, set[ModelVersion]] = (
@@ -48,7 +44,6 @@ class Registry:
48
44
  self: Self,
49
45
  name: ModelName,
50
46
  version: str | ModelVersion,
51
- schema_generator: JsonSchemaGenerator | None = None,
52
47
  enable_ref: bool = False,
53
48
  backward_compatible: bool = False,
54
49
  ) -> Callable[[type[DecoratedBaseModel]], type[DecoratedBaseModel]]:
@@ -57,7 +52,6 @@ class Registry:
57
52
  Args:
58
53
  name: Name of the model.
59
54
  version: Semantic version string or ModelVersion instance.
60
- schema_generator: Optional custom schema generator function.
61
55
  enable_ref: If True, this model can be referenced via $ref in separate
62
56
  schema files. If False, it will always be inlined.
63
57
  backward_compatible: If True, this model does not need a migration function
@@ -78,8 +72,6 @@ class Registry:
78
72
  def decorator(cls: type[DecoratedBaseModel]) -> type[DecoratedBaseModel]:
79
73
  self._models[name][ver] = cls
80
74
  self._model_metadata[cls] = (name, ver)
81
- if schema_generator:
82
- self._schema_generators[name][ver] = schema_generator
83
75
  if enable_ref:
84
76
  self._ref_enabled[name].add(ver)
85
77
  if backward_compatible:
@@ -1,22 +1,27 @@
1
- """Schema manager."""
1
+ """Schema manager with customizable generation and transformers."""
2
2
 
3
3
  import json
4
+ from collections import defaultdict
4
5
  from pathlib import Path
5
- from typing import Any, Self, get_args, get_origin
6
+ from typing import Any, Self, cast, get_args, get_origin
6
7
 
7
8
  from pydantic import BaseModel
8
9
  from pydantic.fields import FieldInfo
10
+ from pydantic.json_schema import GenerateJsonSchema
9
11
 
10
12
  from ._registry import Registry
11
13
  from .exceptions import ModelNotFoundError
12
14
  from .model_version import ModelVersion
15
+ from .schema_config import SchemaConfig
13
16
  from .types import (
14
17
  JsonSchema,
15
18
  JsonSchemaDefinitions,
19
+ JsonSchemaGenerator,
16
20
  JsonValue,
17
21
  ModelMetadata,
18
22
  ModelName,
19
23
  NestedModelInfo,
24
+ SchemaTransformer,
20
25
  )
21
26
 
22
27
 
@@ -24,53 +29,217 @@ class SchemaManager:
24
29
  """Manager for JSON schema generation and export.
25
30
 
26
31
  Handles schema generation from Pydantic models with support for custom schema
27
- generators and sub-schema references.
32
+ generators, global configuration, per-call overrides, and schema transformers.
28
33
 
29
34
  Attributes:
30
35
  registry: Reference to the Registry.
36
+ default_config: Default schema generation configuration.
31
37
  """
32
38
 
33
- def __init__(self: Self, registry: Registry) -> None:
39
+ def __init__(
40
+ self: Self, registry: Registry, default_config: SchemaConfig | None = None
41
+ ) -> None:
34
42
  """Initialize the schema manager.
35
43
 
36
44
  Args:
37
45
  registry: Registry instance to use.
46
+ default_config: Default configuration for schema generation.
38
47
  """
39
48
  self.registry = registry
49
+ self.default_config = default_config or SchemaConfig()
50
+ self._transformers: dict[
51
+ tuple[ModelName, ModelVersion], list[SchemaTransformer]
52
+ ] = defaultdict(list)
53
+
54
+ def set_default_schema_generator(
55
+ self: Self, generator: JsonSchemaGenerator | type[GenerateJsonSchema]
56
+ ) -> None:
57
+ """Set the default schema generator for all schemas.
58
+
59
+ Args:
60
+ generator: Custom schema generator - either a callable or GenerateJsonSchema
61
+ class.
62
+
63
+ Example (Callable):
64
+ >>> def custom_gen(model: type[BaseModel]) -> JsonSchema:
65
+ ... schema = model.model_json_schema()
66
+ ... schema["x-custom"] = True
67
+ ... return schema
68
+ >>>
69
+ >>> manager.set_default_schema_generator(custom_gen)
70
+
71
+ Example (Class):
72
+ >>> from pydantic.json_schema import GenerateJsonSchema
73
+ >>>
74
+ >>> class CustomGenerator(GenerateJsonSchema):
75
+ ... def generate(
76
+ ... self,
77
+ ... schema: Mapping[str, Any],
78
+ ... mode: JsonSchemaMode = "validation"
79
+ ... ) -> JsonSchema:
80
+ ... json_schema = super().generate(schema, mode=mode)
81
+ ... json_schema["x-custom"] = True
82
+ ... return json_schema
83
+ >>>
84
+ >>> manager.set_default_schema_generator(CustomGenerator)
85
+ """
86
+ self.default_config.schema_generator = generator
87
+
88
+ def register_transformer(
89
+ self: Self,
90
+ name: ModelName,
91
+ version: str | ModelVersion,
92
+ transformer: SchemaTransformer,
93
+ ) -> None:
94
+ """Register a schema transformer for a specific model version.
95
+
96
+ Transformers are applied after schema generation, allowing simple
97
+ post-processing of schemas without needing to customize the generation process
98
+ itself.
99
+
100
+ Args:
101
+ name: Name of the model.
102
+ version: Model version.
103
+ transformer: Function that takes and returns a JsonSchema.
104
+
105
+ Example:
106
+ >>> def add_examples(schema: JsonSchema) -> JsonSchema:
107
+ ... schema["examples"] = [{"name": "John", "age": 30}]
108
+ ... return schema
109
+ >>>
110
+ >>> manager.register_transformer("User", "1.0.0", add_examples)
111
+ """
112
+ ver = ModelVersion.parse(version) if isinstance(version, str) else version
113
+ key = (name, ver)
114
+ self._transformers[key].append(transformer)
40
115
 
41
116
  def get_schema(
42
117
  self: Self,
43
118
  name: ModelName,
44
119
  version: str | ModelVersion,
120
+ config: SchemaConfig | None = None,
121
+ apply_transformers: bool = True,
45
122
  **schema_kwargs: Any,
46
123
  ) -> JsonSchema:
47
124
  """Get JSON schema for a specific model version.
48
125
 
126
+ Execution order:
127
+ 1. Generate base schema using Pydantic
128
+ 2. Apply custom generator (if configured)
129
+ 3. Apply registered transformers (if any)
130
+
49
131
  Args:
50
132
  name: Name of the model.
51
133
  version: Semantic version.
52
- **schema_kwargs: Additional arguments for schema generation.
134
+ config: Optional schema configuration (overrides defaults).
135
+ apply_transformers: If False, skip transformer application.
136
+ **schema_kwargs: Additional arguments for schema generation (overrides
137
+ config).
53
138
 
54
139
  Returns:
55
140
  JSON schema dictionary.
141
+
142
+ Example:
143
+ >>> # Use default config
144
+ >>> schema = manager.get_schema("User", "1.0.0")
145
+ >>>
146
+ >>> # Override with custom config
147
+ >>> config = SchemaConfig(mode="serialization", by_alias=False)
148
+ >>> schema = manager.get_schema("User", "1.0.0", config=config)
149
+ >>>
150
+ >>> # Quick override with kwargs
151
+ >>> schema = manager.get_schema("User", "1.0.0", mode="serialization")
152
+ >>>
153
+ >>> # Get base schema without transformers
154
+ >>> base_schema = manager.get_schema(
155
+ ... "User", "1.0.0",
156
+ ... apply_transformers=False
157
+ ... )
56
158
  """
57
159
  ver = ModelVersion.parse(version) if isinstance(version, str) else version
58
160
  model = self.registry.get_model(name, ver)
59
161
 
60
- if (
61
- name in self.registry._schema_generators
62
- and ver in self.registry._schema_generators[name]
63
- ):
64
- generator = self.registry._schema_generators[name][ver]
65
- return generator(model)
162
+ # Always use the config-based approach
163
+ final_config = self.default_config
164
+ if config is not None:
165
+ final_config = final_config.merge_with(config)
166
+
167
+ if schema_kwargs:
168
+ kwargs_config = SchemaConfig(extra_kwargs=schema_kwargs)
169
+ final_config = final_config.merge_with(kwargs_config)
170
+
171
+ schema: JsonSchema
172
+ if final_config.is_callable_generator():
173
+ schema = final_config.schema_generator(model) # type: ignore
174
+ else:
175
+ schema = model.model_json_schema(**final_config.to_kwargs())
176
+
177
+ if apply_transformers:
178
+ key = (name, ver)
179
+ if key in self._transformers:
180
+ for transformer in self._transformers[key]:
181
+ schema = transformer(schema)
182
+
183
+ return schema
184
+
185
+ def get_transformers(
186
+ self: Self,
187
+ name: ModelName,
188
+ version: str | ModelVersion,
189
+ ) -> list[SchemaTransformer]:
190
+ """Get all transformers registered for a model version.
191
+
192
+ Args:
193
+ name: Name of the model.
194
+ version: Model version.
195
+
196
+ Returns:
197
+ List of transformer functions.
198
+ """
199
+ ver = ModelVersion.parse(version) if isinstance(version, str) else version
200
+ key = (name, ver)
201
+ return self._transformers.get(key, [])
202
+
203
+ def clear_transformers(
204
+ self: Self,
205
+ name: ModelName | None = None,
206
+ version: str | ModelVersion | None = None,
207
+ ) -> None:
208
+ """Clear registered transformers.
66
209
 
67
- return model.model_json_schema(**schema_kwargs)
210
+ Args:
211
+ name: Optional model name. If None, clears all transformers.
212
+ version: Optional version. If None (but name provided), clears all versions
213
+ of that model.
214
+
215
+ Example:
216
+ >>> # Clear all transformers
217
+ >>> manager.clear_transformers()
218
+ >>>
219
+ >>> # Clear all User transformers
220
+ >>> manager.clear_transformers("User")
221
+ >>>
222
+ >>> # Clear specific version
223
+ >>> manager.clear_transformers("User", "1.0.0")
224
+ """
225
+ if name is None:
226
+ self._transformers.clear()
227
+ elif version is None:
228
+ keys_to_remove = [key for key in self._transformers if key[0] == name]
229
+ for key in keys_to_remove:
230
+ del self._transformers[key]
231
+ else:
232
+ ver = ModelVersion.parse(version) if isinstance(version, str) else version
233
+ key = (name, ver)
234
+ if key in self._transformers:
235
+ del self._transformers[key]
68
236
 
69
237
  def get_schema_with_separate_defs(
70
238
  self: Self,
71
239
  name: ModelName,
72
240
  version: str | ModelVersion,
73
241
  ref_template: str = "{model}_v{version}.json",
242
+ config: SchemaConfig | None = None,
74
243
  **schema_kwargs: Any,
75
244
  ) -> JsonSchema:
76
245
  """Get JSON schema with separate definition files for nested models.
@@ -83,6 +252,7 @@ class SchemaManager:
83
252
  version: Semantic version.
84
253
  ref_template: Template for generating $ref URLs. Supports {model} and
85
254
  {version} placeholders.
255
+ config: Optional schema configuration.
86
256
  **schema_kwargs: Additional arguments for schema generation.
87
257
 
88
258
  Returns:
@@ -91,16 +261,16 @@ class SchemaManager:
91
261
  Example:
92
262
  >>> schema = manager.get_schema_with_separate_defs(
93
263
  ... "User", "2.0.0",
94
- ... ref_template="https://example.com/schemas/{model}_v{version}.json"
264
+ ... ref_template="https://example.com/schemas/{model}_v{version}.json",
265
+ ... mode="serialization"
95
266
  ... )
96
267
  """
97
268
  ver = ModelVersion.parse(version) if isinstance(version, str) else version
98
- schema = self.get_schema(name, ver, **schema_kwargs)
269
+ schema = self.get_schema(name, ver, config=config, **schema_kwargs)
99
270
 
100
- # Extract and replace definitions with external references
101
271
  if "$defs" in schema or "definitions" in schema:
102
272
  defs_key = "$defs" if "$defs" in schema else "definitions"
103
- definitions: JsonSchemaDefinitions = schema.pop(defs_key, {}) # type: ignore[assignment]
273
+ definitions = cast("JsonSchemaDefinitions", schema.pop(defs_key, {}))
104
274
 
105
275
  # Update all $ref in the schema to point to external files
106
276
  schema = self._replace_refs_with_external(schema, definitions, ref_template)
@@ -209,11 +379,14 @@ class SchemaManager:
209
379
  return (name, version)
210
380
  return None
211
381
 
212
- def get_all_schemas(self: Self, name: ModelName) -> dict[ModelVersion, JsonSchema]:
382
+ def get_all_schemas(
383
+ self: Self, name: ModelName, config: SchemaConfig | None = None
384
+ ) -> dict[ModelVersion, JsonSchema]:
213
385
  """Get all schemas for a model across all versions.
214
386
 
215
387
  Args:
216
388
  name: Name of the model.
389
+ config: Optional schema configuration.
217
390
 
218
391
  Returns:
219
392
  Dictionary mapping versions to their schemas.
@@ -225,7 +398,7 @@ class SchemaManager:
225
398
  raise ModelNotFoundError(name)
226
399
 
227
400
  return {
228
- version: self.get_schema(name, version)
401
+ version: self.get_schema(name, version, config=config)
229
402
  for version in self.registry._models[name]
230
403
  }
231
404
 
@@ -235,6 +408,7 @@ class SchemaManager:
235
408
  indent: int = 2,
236
409
  separate_definitions: bool = False,
237
410
  ref_template: str | None = None,
411
+ config: SchemaConfig | None = None,
238
412
  ) -> None:
239
413
  """Dump all schemas to JSON files.
240
414
 
@@ -245,27 +419,24 @@ class SchemaManager:
245
419
  models that have enable_ref=True.
246
420
  ref_template: Template for $ref URLs when separate_definitions=True.
247
421
  Defaults to relative file references if not provided.
422
+ config: Optional schema configuration for all exported schemas.
248
423
 
249
424
  Example:
250
- >>> # Inline definitions (default)
251
- >>> manager.dump_schemas("schemas/")
252
- >>>
253
- >>> # Separate sub-schemas with relative refs (when enable_ref=True models)
254
- >>> manager.dump_schemas("schemas/", separate_definitions=True)
255
- >>>
256
- >>> # Separate sub-schemas with absolute URLs
257
- >>> manager.dump_schemas(
258
- ... "schemas/",
259
- ... separate_definitions=True,
260
- ... ref_template="https://example.com/schemas/{model}_v{version}.json"
425
+ >>> # Export with custom schema generator
426
+ >>> config = SchemaConfig(
427
+ ... schema_generator=CustomGenerator,
428
+ ... mode="serialization"
261
429
  ... )
430
+ >>> manager.dump_schemas("schemas/", config=config)
262
431
  """
263
432
  output_path = Path(output_dir)
264
433
  output_path.mkdir(parents=True, exist_ok=True)
265
434
 
266
435
  if not separate_definitions:
267
436
  for name in self.registry._models:
268
- for version, schema in self.get_all_schemas(name).items():
437
+ for version, schema in self.get_all_schemas(
438
+ name, config=config
439
+ ).items():
269
440
  file_path = output_path / f"{name}_v{version}.json"
270
441
  with open(file_path, "w", encoding="utf-8") as f:
271
442
  json.dump(schema, f, indent=indent)
@@ -276,7 +447,7 @@ class SchemaManager:
276
447
  for name in self.registry._models:
277
448
  for version in self.registry._models[name]:
278
449
  schema = self.get_schema_with_separate_defs(
279
- name, version, ref_template
450
+ name, version, ref_template, config=config
280
451
  )
281
452
  file_path = output_path / f"{name}_v{version}.json"
282
453
  with open(file_path, "w", encoding="utf-8") as f:
pyrmute/_version.py CHANGED
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '0.3.0'
32
- __version_tuple__ = version_tuple = (0, 3, 0)
31
+ __version__ = version = '0.5.0'
32
+ __version_tuple__ = version_tuple = (0, 5, 0)
33
33
 
34
34
  __commit_id__ = commit_id = None
@@ -2,7 +2,7 @@
2
2
 
3
3
  from collections.abc import Iterator
4
4
  from dataclasses import dataclass
5
- from typing import Self
5
+ from typing import Self, TypeAlias
6
6
 
7
7
  from .types import ModelData
8
8
 
@@ -159,3 +159,6 @@ class MigrationTestResults:
159
159
  f"✗ {len(self.failures)} of {total_count} test(s) failed "
160
160
  f"({passed_count} passed)"
161
161
  )
162
+
163
+
164
+ MigrationTestCases: TypeAlias = list[tuple[ModelData, ModelData] | MigrationTestCase]