pydantic-marshmallow 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1187 @@
1
+ """
2
+ Bridge between Pydantic models and Marshmallow schemas.
3
+
4
+ Pydantic's Rust-based validation with full Marshmallow compatibility.
5
+ Flow: Input → Marshmallow pre_load → PYDANTIC VALIDATES → Marshmallow post_load → Output
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from collections.abc import Callable, Sequence, Set as AbstractSet
11
+ from typing import Any, ClassVar, Generic, TypeVar, cast, get_args, get_origin
12
+
13
+ from marshmallow import EXCLUDE, INCLUDE, RAISE, Schema, fields as ma_fields
14
+ from marshmallow.decorators import VALIDATES, VALIDATES_SCHEMA
15
+ from marshmallow.error_store import ErrorStore
16
+ from marshmallow.exceptions import ValidationError as MarshmallowValidationError
17
+ from marshmallow.schema import SchemaMeta
18
+ from pydantic import BaseModel, ConfigDict, ValidationError as PydanticValidationError
19
+
20
+ from .errors import BridgeValidationError, convert_pydantic_errors, format_pydantic_error
21
+ from .field_conversion import convert_model_fields, convert_pydantic_field
22
+ from .validators import cache_validators
23
+
24
+ M = TypeVar("M", bound=BaseModel)
25
+
26
+ # Module-level cache for HybridModel schemas
27
+ _hybrid_schema_cache: dict[type[Any], type[PydanticSchema[Any]]] = {}
28
+
29
+ # Field validator registry: maps (schema_class, field_name) -> list of validator functions
30
+ _field_validators: dict[tuple[type[Any], str], list[Callable[..., Any]]] = {}
31
+ _schema_validators: dict[type[Any], list[Callable[..., Any]]] = {}
32
+
33
+
34
+ class PydanticSchemaMeta(SchemaMeta):
35
+ """
36
+ Custom metaclass that adds Pydantic model fields BEFORE Marshmallow processes them.
37
+
38
+ This ensures Meta.fields and Meta.exclude work correctly with dynamically
39
+ generated fields from Pydantic models.
40
+
41
+ The metaclass handles:
42
+ - Extracting the Pydantic model from Meta.model or generic parameters
43
+ - Converting Pydantic model fields to Marshmallow fields at class creation
44
+ - Respecting Meta.fields (whitelist) filtering during field generation
45
+ - Converting @computed_field properties to dump-only Marshmallow fields
46
+
47
+ Note:
48
+ Meta.exclude is NOT applied here - it's handled by Marshmallow's standard
49
+ metaclass after all fields are declared. This ensures proper inheritance
50
+ behavior.
51
+ """
52
+
53
+ def __new__(
54
+ mcs,
55
+ name: str,
56
+ bases: tuple[type, ...],
57
+ attrs: dict[str, Any],
58
+ ) -> PydanticSchemaMeta:
59
+ # Get model class from Meta or generic parameter
60
+ model_class = None
61
+
62
+ # Check Meta.model
63
+ meta = attrs.get("Meta")
64
+ if meta and hasattr(meta, "model") and meta.model:
65
+ model_class = meta.model
66
+
67
+ # Check generic parameter from bases
68
+ if not model_class:
69
+ for base in bases:
70
+ if hasattr(base, "__orig_bases__"):
71
+ for orig_base in base.__orig_bases__:
72
+ origin = get_origin(orig_base)
73
+ is_pydantic_schema = (
74
+ origin
75
+ and hasattr(origin, "__name__")
76
+ and "PydanticSchema" in origin.__name__
77
+ )
78
+ if is_pydantic_schema:
79
+ args = get_args(orig_base)
80
+ is_model_subclass = (
81
+ args
82
+ and isinstance(args[0], type)
83
+ and issubclass(args[0], BaseModel)
84
+ )
85
+ if is_model_subclass:
86
+ model_class = args[0]
87
+ break
88
+ # Direct generic base
89
+ origin = get_origin(base)
90
+ is_pydantic_schema = (
91
+ origin
92
+ and hasattr(origin, "__name__")
93
+ and "PydanticSchema" in origin.__name__
94
+ )
95
+ if is_pydantic_schema:
96
+ args = get_args(base)
97
+ if args and isinstance(args[0], type) and issubclass(args[0], BaseModel):
98
+ model_class = args[0]
99
+ break
100
+
101
+ # Add Pydantic fields to attrs BEFORE Marshmallow processes them
102
+ if model_class:
103
+ # Check if attrs already has pre-filtered fields (from from_model())
104
+ # If so, don't add more - the fields were intentionally filtered
105
+ existing_fields = [
106
+ k for k, v in attrs.items() if isinstance(v, ma_fields.Field)
107
+ ]
108
+ has_prefiltered_fields = len(existing_fields) > 0
109
+
110
+ # Get Meta.fields (whitelist) - only this filters out fields
111
+ # Meta.exclude is handled by Marshmallow after fields are declared
112
+ meta_fields = getattr(meta, 'fields', None) if meta else None
113
+ include_set = set(meta_fields) if meta_fields else None
114
+
115
+ for field_name, field_info in model_class.model_fields.items():
116
+ # Skip if already declared in attrs
117
+ if field_name in attrs:
118
+ continue
119
+
120
+ # If attrs has pre-filtered fields (from from_model), don't add more
121
+ if has_prefiltered_fields:
122
+ continue
123
+
124
+ # Apply Meta.fields whitelist only
125
+ # Note: Meta.exclude is handled by Marshmallow after ALL fields are added
126
+ if include_set is not None and field_name not in include_set:
127
+ continue
128
+
129
+ # Use centralized field conversion
130
+ attrs[field_name] = convert_pydantic_field(field_name, field_info)
131
+
132
+ # Add computed fields as dump_only
133
+ if hasattr(model_class, 'model_computed_fields'):
134
+ from .field_conversion import convert_computed_field
135
+
136
+ for field_name, computed_info in model_class.model_computed_fields.items():
137
+ if field_name in attrs:
138
+ continue
139
+ # If attrs has pre-filtered fields, don't add computed fields
140
+ if has_prefiltered_fields:
141
+ continue
142
+ # Apply Meta.fields whitelist
143
+ if include_set is not None and field_name not in include_set:
144
+ continue
145
+
146
+ attrs[field_name] = convert_computed_field(field_name, computed_info)
147
+
148
+ # Cast to satisfy type checker - SchemaMeta.__new__ returns SchemaMeta
149
+ return cast(PydanticSchemaMeta, super().__new__(mcs, name, bases, attrs))
150
+
151
+
152
+ class PydanticSchema(Schema, Generic[M], metaclass=PydanticSchemaMeta):
153
+ """
154
+ A Marshmallow schema backed by a Pydantic model.
155
+
156
+ This gives you:
157
+ - Pydantic's validation, coercion, and error messages
158
+ - Marshmallow's serialization and ecosystem integration
159
+ - No drift - Pydantic does the heavy lifting
160
+
161
+ Example:
162
+ from pydantic import BaseModel, EmailStr, Field
163
+
164
+ class User(BaseModel):
165
+ name: str = Field(min_length=1)
166
+ email: EmailStr
167
+ age: int = Field(ge=0)
168
+
169
+ class UserSchema(PydanticSchema[User]):
170
+ class Meta:
171
+ model = User
172
+
173
+ # Or use the shortcut:
174
+ UserSchema = PydanticSchema.from_model(User)
175
+
176
+ # Now use like any Marshmallow schema
177
+ schema = UserSchema()
178
+ user = schema.load({"name": "Alice", "email": "alice@example.com", "age": 30})
179
+ # user is a User instance!
180
+
181
+ Supports:
182
+ - `partial=True` or `partial=('field1', 'field2')` for partial loading
183
+ - `unknown=EXCLUDE` or `unknown=INCLUDE` for unknown field handling
184
+ - `only=('field1',)` and `exclude=('field2',)` for field filtering
185
+ - `load_only=('field',)` and `dump_only=('field',)` for directional fields
186
+ - `@validates("field")` decorator for field validators
187
+ - `@validates_schema` decorator for schema validators
188
+ - `validate(data)` method that returns errors dict without raising
189
+ """
190
+
191
+ # Validator caches - populated at class creation, not every load()
192
+ _field_validators_cache: ClassVar[dict[str, list[str]]] = {}
193
+ _schema_validators_cache: ClassVar[list[str]] = []
194
+
195
+ class Meta:
196
+ model: type[BaseModel] | None = None
197
+ unknown = RAISE # Match Pydantic's default strict behavior
198
+
199
+ def __init_subclass__(cls, **kwargs: Any) -> None:
200
+ """
201
+ Cache validators when the schema class is defined.
202
+
203
+ Field setup is now handled by PydanticSchemaMeta to ensure fields
204
+ are visible to Marshmallow BEFORE it processes Meta.fields/exclude.
205
+ """
206
+ super().__init_subclass__(**kwargs)
207
+
208
+ # Cache validators at class creation (PERFORMANCE OPTIMIZATION)
209
+ cache_validators(cls)
210
+
211
+ def __init__(
212
+ self,
213
+ *,
214
+ only: Sequence[str] | None = None,
215
+ exclude: Sequence[str] = (),
216
+ context: dict[str, Any] | None = None,
217
+ load_only: Sequence[str] = (),
218
+ dump_only: Sequence[str] = (),
219
+ partial: bool | Sequence[str] | AbstractSet[str] | None = None,
220
+ unknown: str | None = None,
221
+ many: bool | None = None,
222
+ **kwargs: Any,
223
+ ) -> None:
224
+ # Store filtering options BEFORE calling super().__init__
225
+ self._only_fields: set[str] | None = set(only) if only else None
226
+ self._exclude_fields: set[str] = set(exclude) if exclude else set()
227
+ self._load_only_fields: set[str] = set(load_only) if load_only else set()
228
+ self._dump_only_fields: set[str] = set(dump_only) if dump_only else set()
229
+ self._partial: bool | Sequence[str] | AbstractSet[str] | None = partial
230
+ self._unknown_override: str | None = unknown
231
+ self._context = context or {}
232
+
233
+ # Pass all known kwargs to parent including context
234
+ super().__init__(
235
+ only=only,
236
+ exclude=exclude,
237
+ context=context,
238
+ many=many,
239
+ load_only=load_only,
240
+ dump_only=dump_only,
241
+ partial=partial,
242
+ unknown=unknown,
243
+ **kwargs,
244
+ )
245
+ self._model_class = self._get_model_class()
246
+ if self._model_class:
247
+ self._setup_fields_from_model()
248
+
249
+ # Call on_bind_field for each field
250
+ for field_name, field_obj in self.fields.items():
251
+ self.on_bind_field(field_name, field_obj)
252
+
253
+ def on_bind_field(self, field_name: str, field_obj: ma_fields.Field) -> None:
254
+ """
255
+ Hook called when a field is bound to the schema.
256
+
257
+ Override this to customize field binding behavior. This is called
258
+ for each field after schema initialization, compatible with
259
+ Marshmallow's on_bind_field hook.
260
+
261
+ Example:
262
+ class MySchema(PydanticSchema[MyModel]):
263
+ class Meta:
264
+ model = MyModel
265
+
266
+ def on_bind_field(self, field_name, field_obj):
267
+ # Make all fields allow None
268
+ field_obj.allow_none = True
269
+ super().on_bind_field(field_name, field_obj)
270
+ """
271
+ # Default implementation does nothing
272
+
273
+ def handle_error(
274
+ self,
275
+ error: MarshmallowValidationError,
276
+ data: Any,
277
+ *,
278
+ many: bool,
279
+ **kwargs: Any,
280
+ ) -> None:
281
+ """
282
+ Custom error handler hook, compatible with Marshmallow.
283
+
284
+ Override this method to customize error handling behavior.
285
+ Called when validation errors occur during load/dump.
286
+
287
+ By default, re-raises the error. Override to log, transform,
288
+ or suppress errors.
289
+
290
+ Example:
291
+ class MySchema(PydanticSchema[MyModel]):
292
+ class Meta:
293
+ model = MyModel
294
+
295
+ def handle_error(self, error, data, *, many, **kwargs):
296
+ # Log the error
297
+ logger.error(f"Validation failed: {error.messages}")
298
+ # Re-raise (required to propagate the error)
299
+ raise error
300
+ """
301
+ raise error
302
+
303
+ @property
304
+ def context(self) -> dict[str, Any]:
305
+ """Get the validation context."""
306
+ return self._context
307
+
308
+ @context.setter
309
+ def context(self, value: dict[str, Any]) -> None:
310
+ """Set the validation context."""
311
+ self._context = value
312
+
313
+ def _get_model_class(self) -> type[BaseModel] | None:
314
+ """Get the Pydantic model class from Meta or generic parameter."""
315
+ # Try Meta.model first
316
+ if hasattr(self, "Meta") and hasattr(self.Meta, "model") and self.Meta.model:
317
+ return self.Meta.model
318
+
319
+ # Try to get from generic parameter
320
+ self_type = type(self)
321
+ orig_bases = getattr(self_type, "__orig_bases__", ())
322
+ for base in orig_bases:
323
+ origin = get_origin(base)
324
+ if origin is PydanticSchema:
325
+ args = get_args(base)
326
+ if args and isinstance(args[0], type) and issubclass(args[0], BaseModel):
327
+ return args[0]
328
+
329
+ return None
330
+
331
+ def _setup_fields_from_model(self) -> None:
332
+ """
333
+ Set up Marshmallow fields from Pydantic model for serialization.
334
+
335
+ NOTE: For schemas created via from_model/schema_for, fields are already
336
+ set up in the class dict. This method only adds fields that are missing
337
+ and does NOT override load_fields/dump_fields which Marshmallow's __init__
338
+ has already filtered based on only/exclude/load_only/dump_only.
339
+
340
+ Also respects Meta.fields (whitelist) and Meta.exclude (blacklist).
341
+ If declared_fields is not empty, we use it as the source of truth for
342
+ which fields should exist (i.e., fields were filtered at class creation).
343
+ """
344
+ if not self._model_class:
345
+ return
346
+
347
+ # Get Meta.fields and Meta.exclude for filtering
348
+ meta_fields = getattr(self.Meta, 'fields', None) if hasattr(self, 'Meta') else None
349
+ meta_exclude = getattr(self.Meta, 'exclude', None) if hasattr(self, 'Meta') else None
350
+
351
+ # Combine filtering: respect both only= param and Meta.fields
352
+ allowed_fields: set[str] | None = None
353
+ if self._only_fields is not None:
354
+ allowed_fields = self._only_fields
355
+ elif meta_fields is not None:
356
+ allowed_fields = set(meta_fields)
357
+ elif self.declared_fields:
358
+ # If declared_fields is set (e.g., from from_model()), use it as whitelist
359
+ # This means fields were already filtered at class creation time
360
+ allowed_fields = set(self.declared_fields.keys())
361
+
362
+ # Combine exclusion: respect both exclude= param and Meta.exclude
363
+ excluded_fields = set(self._exclude_fields)
364
+ if meta_exclude:
365
+ excluded_fields.update(meta_exclude)
366
+
367
+ for field_name, field_info in self._model_class.model_fields.items():
368
+ # Skip if field is excluded
369
+ if field_name in excluded_fields:
370
+ continue
371
+ # Skip if whitelist is set and field not in it
372
+ if allowed_fields is not None and field_name not in allowed_fields:
373
+ continue
374
+
375
+ # Only add to self.fields if not already present
376
+ # Do NOT touch load_fields/dump_fields - Marshmallow manages those
377
+ if field_name not in self.fields:
378
+ # Use centralized field conversion
379
+ ma_field = convert_pydantic_field(field_name, field_info)
380
+ self.fields[field_name] = ma_field
381
+ # Only add to load_fields/dump_fields if filter allows
382
+ if field_name not in self._dump_only_fields:
383
+ self.load_fields[field_name] = ma_field
384
+ if field_name not in self._load_only_fields:
385
+ self.dump_fields[field_name] = ma_field
386
+
387
+ def _validate_with_pydantic(
388
+ self,
389
+ data: dict[str, Any],
390
+ partial: bool | Sequence[str] | AbstractSet[str] | None = None,
391
+ original_data: Any | None = None,
392
+ ) -> tuple[dict[str, Any], M | None]:
393
+ """
394
+ Use Pydantic to validate and coerce the input data.
395
+
396
+ Filters out marshmallow.missing sentinel values before Pydantic validation,
397
+ allowing Pydantic to use its own defaults for missing fields.
398
+
399
+ Returns:
400
+ Tuple of (validated_data_dict, model_instance)
401
+ The instance is returned to avoid redundant validation later.
402
+ """
403
+ if not self._model_class:
404
+ return data, None
405
+
406
+ # Filter out marshmallow.missing values - Pydantic should use its defaults
407
+ from marshmallow.utils import missing as ma_missing
408
+
409
+ clean_data = {
410
+ k: v for k, v in data.items()
411
+ if v is not ma_missing
412
+ }
413
+
414
+ try:
415
+ # Handle partial loading - temporarily make fields optional
416
+ if partial:
417
+ # Create a partial model dynamically
418
+ validated_data = self._validate_partial(clean_data, partial, original_data)
419
+ return validated_data, None # Partial returns dict, no instance
420
+ else:
421
+ # Let Pydantic do all the validation - KEEP THE INSTANCE
422
+ instance = self._model_class.model_validate(clean_data)
423
+ # Return both the dict (for validators) and instance (for result)
424
+ validated_data = instance.model_dump(by_alias=False)
425
+ # Cast to M since model_validate returns the correct model type
426
+ return validated_data, cast(M, instance)
427
+ except PydanticValidationError as e:
428
+ # Use centralized error conversion
429
+ raise convert_pydantic_errors(e, self._model_class, original_data or data) from e
430
+
431
+ def _validate_partial(
432
+ self,
433
+ data: dict[str, Any],
434
+ partial: bool | Sequence[str] | AbstractSet[str],
435
+ original_data: Any | None = None,
436
+ ) -> dict[str, Any]:
437
+ """Validate data with partial loading - missing required fields allowed."""
438
+ if not self._model_class:
439
+ return data
440
+
441
+ # If partial is True, all fields are optional
442
+ # If partial is a tuple/list, only those fields are optional
443
+ partial_fields: set[str] = set()
444
+ if partial is True:
445
+ partial_fields = set(self._model_class.model_fields.keys())
446
+ elif isinstance(partial, (list, tuple)):
447
+ partial_fields = set(partial)
448
+
449
+ # Check for required but missing fields (not in partial list)
450
+ errors: dict[str, Any] = {}
451
+ for field_name, field_info in self._model_class.model_fields.items():
452
+ if field_name not in data and field_name not in partial_fields:
453
+ # Check if field has a default
454
+ from pydantic_core import PydanticUndefined
455
+ if field_info.default is PydanticUndefined and field_info.default_factory is None:
456
+ errors[field_name] = ["Missing data for required field."]
457
+
458
+ if errors:
459
+ # Include valid_data even on partial validation errors
460
+ valid_data = {
461
+ k: v for k, v in data.items()
462
+ if k not in errors and k in self._model_class.model_fields
463
+ }
464
+ raise BridgeValidationError(
465
+ errors,
466
+ data=original_data or data,
467
+ valid_data=valid_data,
468
+ )
469
+
470
+ # For validation, we need to provide defaults for missing fields
471
+ # Create a data dict with defaults for unprovided fields
472
+ validation_data = {}
473
+ for field_name, field_info in self._model_class.model_fields.items():
474
+ if field_name in data:
475
+ validation_data[field_name] = data[field_name]
476
+ else:
477
+ # Use default if available
478
+ from pydantic_core import PydanticUndefined
479
+ if field_info.default is not PydanticUndefined:
480
+ validation_data[field_name] = field_info.default
481
+ elif field_info.default_factory is not None:
482
+ # Cast to satisfy type checker - Pydantic's factory takes no args
483
+ factory = cast(Callable[[], Any], field_info.default_factory)
484
+ validation_data[field_name] = factory()
485
+ # else: field is in partial_fields, we'll skip validation for it
486
+
487
+ # Validate provided fields by doing full validation with defaults filled in
488
+ # Only validate if we have all required fields covered
489
+ try:
490
+ instance = self._model_class.model_validate(validation_data)
491
+ # Return only the originally provided fields and their validated values
492
+ result = {}
493
+ for field_name in data:
494
+ if field_name in self._model_class.model_fields:
495
+ result[field_name] = getattr(instance, field_name)
496
+ return result
497
+ except PydanticValidationError as e:
498
+ # Convert to Marshmallow errors, only for provided fields
499
+ errors = {}
500
+ failed_fields: set[str] = set()
501
+ for error in e.errors():
502
+ loc = error.get("loc", ())
503
+ if loc:
504
+ field_name = str(loc[0])
505
+ failed_fields.add(field_name)
506
+ # Only report errors for fields that were actually provided
507
+ if field_name in data:
508
+ if field_name not in errors:
509
+ errors[field_name] = []
510
+ errors[field_name].append(format_pydantic_error(error, self._model_class))
511
+
512
+ if errors:
513
+ valid_data = {
514
+ k: v for k, v in data.items()
515
+ if k not in failed_fields and k in self._model_class.model_fields
516
+ }
517
+ raise BridgeValidationError(
518
+ errors,
519
+ data=original_data or data,
520
+ valid_data=valid_data,
521
+ ) from e
522
+
523
+ # No errors for provided fields - return the validated values for provided fields
524
+ return {k: v for k, v in data.items() if k in self._model_class.model_fields}
525
+
526
+ def _do_load(
527
+ self,
528
+ data: Any,
529
+ *,
530
+ many: bool | None = None,
531
+ partial: bool | Sequence[str] | AbstractSet[str] | None = None,
532
+ unknown: str | None = None,
533
+ postprocess: bool = True,
534
+ return_instance: bool = True,
535
+ ) -> Any:
536
+ """
537
+ Override Marshmallow's _do_load to ensure proper hook ordering:
538
+
539
+ 1. User's @pre_load hooks run FIRST (transform input)
540
+ 2. Field filtering (only/exclude) applied
541
+ 3. Pydantic validates the TRANSFORMED data
542
+ 4. @validates("field") decorators run
543
+ 5. @validates_schema decorators run
544
+ 6. User's @post_load hooks run LAST
545
+
546
+ This ensures 100% Marshmallow hook compatibility.
547
+ """
548
+ # Resolve settings
549
+ if many is None:
550
+ many = self.many
551
+
552
+ # Resolve partial - instance attribute takes precedence
553
+ if partial is None:
554
+ partial = self._partial
555
+
556
+ # Resolve unknown setting
557
+ unknown_setting = unknown if unknown is not None else self._unknown_override
558
+ if unknown_setting is None:
559
+ unknown_setting = getattr(self.Meta, "unknown", RAISE)
560
+
561
+ # Handle many=True
562
+ if many:
563
+ if not isinstance(data, list):
564
+ raise MarshmallowValidationError({"_schema": ["Expected a list."]})
565
+ return [
566
+ self._do_load(
567
+ item,
568
+ many=False,
569
+ partial=partial,
570
+ unknown=unknown_setting,
571
+ postprocess=postprocess,
572
+ return_instance=return_instance,
573
+ )
574
+ for item in data
575
+ ]
576
+
577
+ # Step 1: Run pre_load hooks ONLY if they exist (PERFORMANCE OPTIMIZATION)
578
+ # Skipping _invoke_load_processors when empty saves ~5ms per 10k loads
579
+ if self._hooks.get("pre_load"):
580
+ processed_data = self._invoke_load_processors(
581
+ "pre_load",
582
+ data,
583
+ many=False,
584
+ original_data=data,
585
+ partial=partial,
586
+ )
587
+ else:
588
+ processed_data = data
589
+
590
+ # Step 2: Handle unknown fields based on setting
591
+ if self._model_class:
592
+ model_fields = set(self._model_class.model_fields.keys())
593
+ # Also include aliases in known fields
594
+ for _field_name, field_info in self._model_class.model_fields.items():
595
+ if field_info.alias:
596
+ model_fields.add(field_info.alias)
597
+
598
+ unkn_fields = set(processed_data.keys()) - model_fields
599
+
600
+ if unkn_fields:
601
+ if unknown_setting == RAISE:
602
+ errors = {field: ["Unknown field."] for field in unkn_fields}
603
+ raise MarshmallowValidationError(errors)
604
+ if unknown_setting == EXCLUDE:
605
+ # Remove unknown fields
606
+ processed_data = {
607
+ k: v for k, v in processed_data.items() if k in model_fields
608
+ }
609
+ # INCLUDE: keep unknown fields in the result (handled below)
610
+
611
+ # Step 3: Pydantic validates the transformed data
612
+ # Returns (validated_dict, instance) - instance reused to avoid double validation
613
+ pydantic_instance = None
614
+ if self._model_class:
615
+ try:
616
+ validated_data, pydantic_instance = self._validate_with_pydantic(
617
+ processed_data, partial=partial, original_data=data
618
+ )
619
+ except MarshmallowValidationError as pydantic_error:
620
+ # Call handle_error for Pydantic validation errors
621
+ self.handle_error(pydantic_error, data, many=False)
622
+ # handle_error should re-raise; if it doesn't, we do
623
+ raise
624
+
625
+ # If INCLUDE, add unknown fields back to validated data
626
+ if unknown_setting == INCLUDE and self._model_class:
627
+ for field in (set(processed_data.keys()) - model_fields):
628
+ validated_data[field] = processed_data[field]
629
+ else:
630
+ validated_data = processed_data
631
+
632
+ # Step 4: Run field validators (BOTH Marshmallow native AND our custom)
633
+ # This ensures validators work regardless of import source
634
+ # ErrorStore is marshmallow internal without type hints - cast constructor for type safety
635
+ error_store_cls: Callable[[], Any] = cast(Callable[[], Any], ErrorStore)
636
+ error_store: Any = error_store_cls()
637
+
638
+ # 4a: Run Marshmallow's native @validates decorators (from _hooks)
639
+ if self._hooks[VALIDATES]:
640
+ self._invoke_field_validators(
641
+ error_store=error_store,
642
+ data=validated_data,
643
+ many=False,
644
+ )
645
+
646
+ # 4b: Run our custom @validates decorators (backwards compatibility)
647
+ try:
648
+ self._run_field_validators(validated_data)
649
+ except MarshmallowValidationError as field_error:
650
+ # Merge into error_store
651
+ if isinstance(field_error.messages, dict):
652
+ for key, msgs in field_error.messages.items():
653
+ error_store.store_error({key: msgs if isinstance(msgs, list) else [msgs]})
654
+
655
+ has_field_errors = bool(error_store.errors)
656
+
657
+ # Step 5: Run schema validators (BOTH Marshmallow native AND our custom)
658
+ # 5a: Run Marshmallow's native @validates_schema decorators
659
+ if self._hooks[VALIDATES_SCHEMA]:
660
+ self._invoke_schema_validators(
661
+ error_store=error_store,
662
+ pass_many=True,
663
+ data=validated_data,
664
+ original_data=data,
665
+ many=False,
666
+ partial=partial,
667
+ field_errors=has_field_errors,
668
+ )
669
+ self._invoke_schema_validators(
670
+ error_store=error_store,
671
+ pass_many=False,
672
+ data=validated_data,
673
+ original_data=data,
674
+ many=False,
675
+ partial=partial,
676
+ field_errors=has_field_errors,
677
+ )
678
+
679
+ # 5b: Run our custom @validates_schema decorators (backwards compatibility)
680
+ try:
681
+ self._run_schema_validators(validated_data, has_field_errors=has_field_errors)
682
+ except MarshmallowValidationError as schema_error:
683
+ if isinstance(schema_error.messages, dict):
684
+ for key, msgs in schema_error.messages.items():
685
+ error_store.store_error({key: msgs if isinstance(msgs, list) else [msgs]})
686
+
687
+ # Raise combined errors if any
688
+ if error_store.errors:
689
+ error = MarshmallowValidationError(dict(error_store.errors))
690
+ self.handle_error(error, data, many=False)
691
+
692
+ # Step 6: Prepare result based on return_instance flag
693
+ if self._model_class and return_instance:
694
+ if not partial:
695
+ # OPTIMIZATION: Reuse the instance from _validate_with_pydantic
696
+ result = pydantic_instance if pydantic_instance is not None else validated_data
697
+ else:
698
+ # For partial loading, create model with provided fields set
699
+ # Fill in defaults for unprovided fields to avoid AttributeError
700
+ from pydantic_core import PydanticUndefined
701
+ construct_data = {}
702
+ fields_set = set()
703
+
704
+ for field_name, field_info in self._model_class.model_fields.items():
705
+ if field_name in validated_data:
706
+ construct_data[field_name] = validated_data[field_name]
707
+ fields_set.add(field_name)
708
+ else:
709
+ # Use default for unprovided fields
710
+ if field_info.default is not PydanticUndefined:
711
+ construct_data[field_name] = field_info.default
712
+ elif field_info.default_factory is not None:
713
+ # Cast to satisfy type checker
714
+ factory = cast(Callable[[], Any], field_info.default_factory)
715
+ construct_data[field_name] = factory()
716
+ else:
717
+ # No default - leave as None to avoid issues
718
+ construct_data[field_name] = None
719
+
720
+ result = cast(
721
+ M, self._model_class.model_construct(_fields_set=fields_set, **construct_data)
722
+ )
723
+ else:
724
+ # Return dict instead of instance
725
+ result = validated_data
726
+
727
+ # Step 7: Run post_load hooks ONLY if they exist (PERFORMANCE OPTIMIZATION)
728
+ if postprocess and self._hooks.get("post_load"):
729
+ result = self._invoke_load_processors(
730
+ "post_load",
731
+ result,
732
+ many=False,
733
+ original_data=data,
734
+ partial=partial,
735
+ )
736
+
737
+ return result
738
+
739
+ def _run_field_validators(self, data: dict[str, Any]) -> None:
740
+ """Run @validates("field") decorated methods."""
741
+ errors: dict[str, list[str]] = {}
742
+
743
+ # Check cached validators (built at class creation)
744
+ # Cache structure: {field_name: [validator_method_names]}
745
+ for field_name, validator_names in self._field_validators_cache.items():
746
+ if field_name not in data:
747
+ continue
748
+ for attr_name in validator_names:
749
+ attr = getattr(self, attr_name, None)
750
+ if callable(attr) and hasattr(attr, "_validates_field"):
751
+ try:
752
+ attr(data[field_name])
753
+ except MarshmallowValidationError as e:
754
+ if field_name not in errors:
755
+ errors[field_name] = []
756
+ if isinstance(e.messages, dict):
757
+ errors[field_name].extend(e.messages.get(field_name, [str(e)]))
758
+ else:
759
+ # e.messages is list or set-like
760
+ errors[field_name].extend(e.messages)
761
+
762
+ if errors:
763
+ raise MarshmallowValidationError(errors)
764
+
765
+ def _run_schema_validators(self, data: dict[str, Any], has_field_errors: bool = False) -> None:
766
+ """Run @validates_schema decorated methods."""
767
+ errors: dict[str, list[str]] = {}
768
+
769
+ # Check cached validators (built at class creation)
770
+ for attr_name in self._schema_validators_cache:
771
+ attr = getattr(self, attr_name, None)
772
+ if callable(attr) and hasattr(attr, "_validates_schema"):
773
+ # Check skip_on_field_errors
774
+ skip_on_errors = getattr(attr, "_skip_on_field_errors", True)
775
+ if skip_on_errors and has_field_errors:
776
+ continue
777
+
778
+ try:
779
+ attr(data)
780
+ except MarshmallowValidationError as e:
781
+ if "_schema" not in errors:
782
+ errors["_schema"] = []
783
+ if isinstance(e.messages, dict):
784
+ for key, msgs in e.messages.items():
785
+ if key not in errors:
786
+ errors[key] = []
787
+ if isinstance(msgs, (list, set, frozenset)):
788
+ errors[key].extend(msgs)
789
+ else:
790
+ errors[key].append(str(msgs))
791
+ else:
792
+ # e.messages is list or set-like
793
+ errors["_schema"].extend(e.messages)
794
+
795
+ if errors:
796
+ raise MarshmallowValidationError(errors)
797
+
798
+ def validate(
799
+ self,
800
+ data: Any,
801
+ *,
802
+ many: bool | None = None,
803
+ partial: bool | Sequence[str] | AbstractSet[str] | None = None,
804
+ ) -> dict[str, Any]:
805
+ """
806
+ Validate data without raising an exception.
807
+
808
+ Returns a dict of errors (empty dict if valid).
809
+
810
+ Example:
811
+ errors = schema.validate({"name": "", "email": "invalid"})
812
+ if errors:
813
+ print(errors) # {'name': ['...'], 'email': ['...']}
814
+ """
815
+ try:
816
+ self.load(data, many=many, partial=partial)
817
+ return {}
818
+ except MarshmallowValidationError as e:
819
+ return e.messages if isinstance(e.messages, dict) else {"_schema": e.messages}
820
+
821
+ def load(
822
+ self,
823
+ data: Any,
824
+ *,
825
+ many: bool | None = None,
826
+ partial: bool | Sequence[str] | AbstractSet[str] | None = None,
827
+ unknown: str | None = None,
828
+ return_instance: bool = True,
829
+ ) -> Any:
830
+ """
831
+ Deserialize data to a Pydantic model instance or dict.
832
+
833
+ Args:
834
+ data: Input data to deserialize
835
+ many: If True, expect a list of objects
836
+ partial: If True or tuple of field names, allow missing required fields
837
+ unknown: How to handle unknown fields (RAISE, EXCLUDE, INCLUDE)
838
+ return_instance: If True (default), return Pydantic model instance.
839
+ If False, return dict.
840
+
841
+ Returns:
842
+ Pydantic model instance (return_instance=True) or dict (return_instance=False)
843
+
844
+ Example:
845
+ # Get Pydantic instance (default)
846
+ user = schema.load(data) # Returns User model
847
+
848
+ # Get dict instead
849
+ user_dict = schema.load(data, return_instance=False) # Returns dict
850
+ """
851
+ return self._do_load(
852
+ data,
853
+ many=many,
854
+ partial=partial,
855
+ unknown=unknown,
856
+ postprocess=True,
857
+ return_instance=return_instance,
858
+ )
859
+
860
+ def dump(
861
+ self,
862
+ obj: Any,
863
+ *,
864
+ many: bool | None = None,
865
+ include_computed: bool = True,
866
+ exclude_unset: bool = False,
867
+ exclude_defaults: bool = False,
868
+ exclude_none: bool = False,
869
+ ) -> Any:
870
+ """
871
+ Serialize an object.
872
+
873
+ Accepts either a Pydantic model instance or a dict.
874
+ Pydantic computed_field values are included by default.
875
+
876
+ Args:
877
+ obj: Object or list of objects to serialize
878
+ many: If True, expect a list of objects
879
+ include_computed: If True (default), include @computed_field values
880
+ exclude_unset: If True, exclude fields that were not explicitly set
881
+ exclude_defaults: If True, exclude fields that equal their default value
882
+ exclude_none: If True, exclude fields with None values
883
+
884
+ Returns:
885
+ Serialized dict or list of dicts
886
+
887
+ Example:
888
+ class User(BaseModel):
889
+ first: str
890
+ last: str
891
+ nickname: str | None = None
892
+
893
+ @computed_field
894
+ @property
895
+ def full_name(self) -> str:
896
+ return f"{self.first} {self.last}"
897
+
898
+ user = User(first="Alice", last="Smith")
899
+ schema.dump(user)
900
+ # {'first': 'Alice', 'last': 'Smith', 'full_name': 'Alice Smith', 'nickname': None}
901
+ schema.dump(user, exclude_none=True)
902
+ # {'first': 'Alice', 'last': 'Smith', 'full_name': 'Alice Smith'}
903
+ schema.dump(user, exclude_unset=True)
904
+ # {'first': 'Alice', 'last': 'Smith', 'full_name': 'Alice Smith'}
905
+ """
906
+ # Resolve many - check self.many if not explicitly set
907
+ if many is None:
908
+ many = self.many
909
+
910
+ # Handle many=True (list of objects)
911
+ if many:
912
+ return [
913
+ self._dump_single(
914
+ item,
915
+ include_computed=include_computed,
916
+ exclude_unset=exclude_unset,
917
+ exclude_defaults=exclude_defaults,
918
+ exclude_none=exclude_none,
919
+ )
920
+ for item in obj
921
+ ]
922
+
923
+ return self._dump_single(
924
+ obj,
925
+ include_computed=include_computed,
926
+ exclude_unset=exclude_unset,
927
+ exclude_defaults=exclude_defaults,
928
+ exclude_none=exclude_none,
929
+ )
930
+
931
+ def _dump_single(
932
+ self,
933
+ obj: Any,
934
+ *,
935
+ include_computed: bool = True,
936
+ exclude_unset: bool = False,
937
+ exclude_defaults: bool = False,
938
+ exclude_none: bool = False,
939
+ ) -> dict[str, Any]:
940
+ """Dump a single object, handling computed fields and exclusion options."""
941
+ computed_values = {}
942
+ model_class = None
943
+ fields_to_exclude: set[str] = set()
944
+
945
+ if isinstance(obj, BaseModel):
946
+ model_class = type(obj)
947
+
948
+ # Track which fields should be excluded based on Pydantic rules
949
+ if exclude_unset:
950
+ # Fields not in model_fields_set were not explicitly set
951
+ fields_to_exclude.update(
952
+ f for f in model_class.model_fields
953
+ if f not in obj.model_fields_set
954
+ )
955
+
956
+ if exclude_defaults:
957
+ # Fields that equal their default value
958
+ from pydantic_core import PydanticUndefined
959
+ for field_name, field_info in model_class.model_fields.items():
960
+ value = getattr(obj, field_name)
961
+ if field_info.default is not PydanticUndefined:
962
+ if value == field_info.default:
963
+ fields_to_exclude.add(field_name)
964
+ elif field_info.default_factory is not None:
965
+ # Cast to satisfy type checker
966
+ factory = cast(Callable[[], Any], field_info.default_factory)
967
+ default_val = factory()
968
+ if value == default_val:
969
+ fields_to_exclude.add(field_name)
970
+
971
+ if exclude_none:
972
+ # Fields with None values
973
+ for field_name in model_class.model_fields:
974
+ if getattr(obj, field_name) is None:
975
+ fields_to_exclude.add(field_name)
976
+
977
+ # Extract computed field values BEFORE converting to dict
978
+ if include_computed and hasattr(model_class, 'model_computed_fields'):
979
+ for field_name in model_class.model_computed_fields:
980
+ value = getattr(obj, field_name)
981
+ # Apply exclusion rules to computed fields too
982
+ if exclude_none and value is None:
983
+ fields_to_exclude.add(field_name) # Track for removal
984
+ continue
985
+ computed_values[field_name] = value
986
+
987
+ # Convert to dict for Marshmallow
988
+ # Use by_alias=False - Marshmallow handles aliases via data_key
989
+ obj = obj.model_dump(by_alias=False)
990
+
991
+ # Let Marshmallow handle the standard dump
992
+ result: dict[str, Any] = super().dump(obj, many=False)
993
+
994
+ # Apply field exclusions
995
+ if fields_to_exclude:
996
+ result = {k: v for k, v in result.items() if k not in fields_to_exclude}
997
+
998
+ # Merge in computed fields
999
+ if computed_values:
1000
+ result.update(computed_values)
1001
+
1002
+ return result
1003
+
1004
+ @classmethod
1005
+ def from_model(
1006
+ cls,
1007
+ model: type[M],
1008
+ *,
1009
+ schema_name: str | None = None,
1010
+ **meta_options: Any,
1011
+ ) -> type[PydanticSchema[M]]:
1012
+ """
1013
+ Create a PydanticSchema class from a Pydantic model.
1014
+
1015
+ Example:
1016
+ from pydantic import BaseModel
1017
+
1018
+ class User(BaseModel):
1019
+ name: str
1020
+ email: str
1021
+
1022
+ UserSchema = PydanticSchema.from_model(User)
1023
+
1024
+ schema = UserSchema()
1025
+ user = schema.load({"name": "Alice", "email": "alice@example.com"})
1026
+
1027
+ Args:
1028
+ model: The Pydantic model class
1029
+ schema_name: Optional name for the schema class
1030
+ **meta_options: Additional Meta options (fields, exclude, etc.)
1031
+
1032
+ Returns:
1033
+ A PydanticSchema subclass
1034
+ """
1035
+ name = schema_name or f"{model.__name__}Schema"
1036
+
1037
+ # Extract field filtering options
1038
+ only_fields = meta_options.get("fields") # Field whitelist
1039
+ exclude_fields = meta_options.get("exclude", ()) # Field blacklist
1040
+
1041
+ # Build Meta attributes
1042
+ # Don't pass 'fields' or 'exclude' to Marshmallow's Meta class since we handle
1043
+ # field filtering ourselves by not adding those fields to class_dict.
1044
+ meta_attrs = {"model": model}
1045
+ for key, value in meta_options.items():
1046
+ if key not in ('fields', 'exclude'):
1047
+ meta_attrs[key] = value
1048
+
1049
+ Meta = type("Meta", (), meta_attrs) # noqa: N806 - Class name convention
1050
+
1051
+ # Use centralized field conversion
1052
+ include_set = set(only_fields) if only_fields else None
1053
+ exclude_set = set(exclude_fields) if exclude_fields else None
1054
+
1055
+ fields = convert_model_fields(
1056
+ model,
1057
+ include=include_set,
1058
+ exclude=exclude_set,
1059
+ include_computed=True,
1060
+ )
1061
+
1062
+ # Build class dict with Meta and converted fields
1063
+ class_dict: dict[str, Any] = {"Meta": Meta, **fields}
1064
+
1065
+ schema_cls = type(name, (cls,), class_dict)
1066
+ return schema_cls
1067
+
1068
+
1069
+ def schema_for(model: type[M], **meta_options: Any) -> type[PydanticSchema[M]]:
1070
+ """
1071
+ Shortcut to create a Marshmallow schema from a Pydantic model.
1072
+
1073
+ Example:
1074
+ from pydantic import BaseModel, EmailStr
1075
+
1076
+ class User(BaseModel):
1077
+ name: str
1078
+ email: EmailStr
1079
+
1080
+ UserSchema = schema_for(User)
1081
+
1082
+ # Use it
1083
+ schema = UserSchema()
1084
+ user = schema.load({"name": "Alice", "email": "alice@example.com"})
1085
+ print(user.name) # "Alice" - it's a User instance!
1086
+ """
1087
+ return PydanticSchema.from_model(model, **meta_options)
1088
+
1089
+
1090
+ def pydantic_schema(cls: type[M]) -> type[M]:
1091
+ """
1092
+ Decorator that adds a `.Schema` attribute to a Pydantic model.
1093
+
1094
+ This is the simplest way to use marshmallow-pydantic. Just decorate
1095
+ your Pydantic model and use `.Schema` anywhere Marshmallow is expected.
1096
+
1097
+ Example:
1098
+ from pydantic import BaseModel, EmailStr
1099
+ from pydantic_marshmallow import pydantic_schema
1100
+
1101
+ @pydantic_schema
1102
+ class User(BaseModel):
1103
+ name: str
1104
+ email: EmailStr
1105
+
1106
+ # Use .Schema anywhere Marshmallow schemas are expected:
1107
+ schema = User.Schema()
1108
+ user = schema.load({"name": "Alice", "email": "alice@example.com"})
1109
+ # user is a User instance!
1110
+
1111
+ # Works with webargs:
1112
+ @use_args(User.Schema(), location="json")
1113
+ def create_user(user): ...
1114
+
1115
+ # Works with apispec:
1116
+ spec.components.schema("User", schema=User.Schema)
1117
+
1118
+ # All Marshmallow hooks still work:
1119
+ class UserSchema(User.Schema):
1120
+ @pre_load
1121
+ def normalize(self, data, **kwargs):
1122
+ data["email"] = data["email"].lower()
1123
+ return data
1124
+ """
1125
+ # Dynamically add Schema attribute to class
1126
+ # Note: setattr is required here since cls is type[M] without Schema defined
1127
+ setattr(cls, "Schema", PydanticSchema.from_model(cls)) # noqa: B010
1128
+ return cls
1129
+
1130
+
1131
+ class HybridModel(BaseModel):
1132
+ """
1133
+ A Pydantic model that can also work as a Marshmallow schema.
1134
+
1135
+ This provides a single class that gives you both Pydantic model
1136
+ capabilities AND Marshmallow schema functionality.
1137
+
1138
+ Example:
1139
+ class User(HybridModel):
1140
+ name: str
1141
+ email: EmailStr
1142
+ age: int = Field(ge=0)
1143
+
1144
+ # Use as Pydantic model
1145
+ user = User(name="Alice", email="alice@example.com", age=30)
1146
+
1147
+ # Use for Marshmallow-style loading
1148
+ user = User.ma_load({"name": "Alice", "email": "alice@example.com", "age": 30})
1149
+
1150
+ # Get the Marshmallow schema
1151
+ schema = User.marshmallow_schema()
1152
+ """
1153
+
1154
+ model_config = ConfigDict(extra="forbid") # Match Marshmallow's default strict behavior
1155
+
1156
+ @classmethod
1157
+ def marshmallow_schema(cls) -> type[PydanticSchema[Any]]:
1158
+ """Get or create the Marshmallow schema for this model."""
1159
+ if cls not in _hybrid_schema_cache:
1160
+ _hybrid_schema_cache[cls] = PydanticSchema.from_model(cls)
1161
+ return _hybrid_schema_cache[cls]
1162
+
1163
+ @classmethod
1164
+ def ma_load(cls, data: dict[str, Any], **kwargs: Any) -> HybridModel:
1165
+ """Load data using the Marshmallow schema."""
1166
+ schema = cls.marshmallow_schema()()
1167
+ result = schema.load(data, **kwargs)
1168
+ return cast(HybridModel, result)
1169
+
1170
+ @classmethod
1171
+ def ma_loads(cls, json_str: str, **kwargs: Any) -> HybridModel:
1172
+ """Load data from a JSON string using the Marshmallow schema."""
1173
+ schema = cls.marshmallow_schema()()
1174
+ result = schema.loads(json_str, **kwargs)
1175
+ return cast(HybridModel, result)
1176
+
1177
+ def ma_dump(self, **kwargs: Any) -> dict[str, Any]:
1178
+ """Dump this instance using the Marshmallow schema."""
1179
+ schema = self.__class__.marshmallow_schema()()
1180
+ result = schema.dump(self, **kwargs)
1181
+ return cast(dict[str, Any], result)
1182
+
1183
+ def ma_dumps(self, **kwargs: Any) -> str:
1184
+ """Dump this instance to a JSON string using the Marshmallow schema."""
1185
+ schema = self.__class__.marshmallow_schema()()
1186
+ result = schema.dumps(self, **kwargs)
1187
+ return cast(str, result)