pyrmute 0.3.0__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,8 +1,9 @@
1
1
  """Migrations manager."""
2
2
 
3
3
  import contextlib
4
+ import types
4
5
  from collections.abc import Callable
5
- from typing import Any, Self, get_args, get_origin
6
+ from typing import Annotated, Any, Literal, Self, Union, get_args, get_origin
6
7
 
7
8
  from pydantic import BaseModel
8
9
  from pydantic.fields import FieldInfo
@@ -54,16 +55,8 @@ class MigrationManager:
54
55
  ... def migrate_v1_to_v2(data: dict[str, Any]) -> dict[str, Any]:
55
56
  ... return {**data, "email": "unknown@example.com"}
56
57
  """
57
- from_ver = (
58
- ModelVersion.parse(from_version)
59
- if isinstance(from_version, str)
60
- else from_version
61
- )
62
- to_ver = (
63
- ModelVersion.parse(to_version)
64
- if isinstance(to_version, str)
65
- else to_version
66
- )
58
+ from_ver = self._parse_version(from_version)
59
+ to_ver = self._parse_version(to_version)
67
60
 
68
61
  def decorator(func: MigrationFunc) -> MigrationFunc:
69
62
  self.registry._migrations[name][(from_ver, to_ver)] = func
@@ -93,61 +86,14 @@ class MigrationManager:
93
86
  ModelNotFoundError: If model or versions don't exist.
94
87
  MigrationError: If migration path cannot be found.
95
88
  """
96
- from_ver = (
97
- ModelVersion.parse(from_version)
98
- if isinstance(from_version, str)
99
- else from_version
100
- )
101
- to_ver = (
102
- ModelVersion.parse(to_version)
103
- if isinstance(to_version, str)
104
- else to_version
105
- )
89
+ from_ver = self._parse_version(from_version)
90
+ to_ver = self._parse_version(to_version)
106
91
 
107
92
  if from_ver == to_ver:
108
93
  return data
109
94
 
110
95
  path = self.find_migration_path(name, from_ver, to_ver)
111
-
112
- current_data = data
113
- for i in range(len(path) - 1):
114
- migration_key = (path[i], path[i + 1])
115
-
116
- if migration_key in self.registry._migrations[name]:
117
- migration_func = self.registry._migrations[name][migration_key]
118
- try:
119
- current_data = migration_func(current_data)
120
- except Exception as e:
121
- raise MigrationError(
122
- name,
123
- str(path[i]),
124
- str(path[i + 1]),
125
- f"Migration function raised: {type(e).__name__}: {e}",
126
- ) from e
127
- elif path[i + 1] in self.registry._backward_compatible_enabled[name]:
128
- try:
129
- current_data = self._auto_migrate(
130
- current_data, name, path[i], path[i + 1]
131
- )
132
- except Exception as e:
133
- raise MigrationError(
134
- name,
135
- str(path[i]),
136
- str(path[i + 1]),
137
- f"Auto-migration failed: {type(e).__name__}: {e}",
138
- ) from e
139
- else:
140
- raise MigrationError(
141
- name,
142
- str(path[i]),
143
- str(path[i + 1]),
144
- (
145
- "No migration path found. Define a migration function or mark "
146
- "the target version as backward_compatible."
147
- ),
148
- )
149
-
150
- return current_data
96
+ return self._execute_migration_path(data, name, path)
151
97
 
152
98
  def find_migration_path(
153
99
  self: Self,
@@ -168,6 +114,9 @@ class MigrationManager:
168
114
  Raises:
169
115
  ModelNotFoundError: If the model or versions don't exist.
170
116
  """
117
+ if (from_ver, to_ver) in self.registry._migrations.get(name, {}):
118
+ return [from_ver, to_ver]
119
+
171
120
  versions = sorted(self.registry.get_versions(name))
172
121
 
173
122
  if from_ver not in versions:
@@ -204,14 +153,8 @@ class MigrationManager:
204
153
  for i in range(len(path) - 1):
205
154
  current_ver = path[i]
206
155
  next_ver = path[i + 1]
207
- migration_key = (current_ver, next_ver)
208
-
209
- has_explicit = migration_key in self.registry._migrations.get(name, {})
210
- has_auto = next_ver in self.registry._backward_compatible_enabled.get(
211
- name, set()
212
- )
213
156
 
214
- if not has_explicit and not has_auto:
157
+ if not self._has_migration_step(name, current_ver, next_ver):
215
158
  raise MigrationError(
216
159
  name,
217
160
  str(current_ver),
@@ -222,6 +165,85 @@ class MigrationManager:
222
165
  ),
223
166
  )
224
167
 
168
+ def _parse_version(self: Self, version: str | ModelVersion) -> ModelVersion:
169
+ """Parse version string or return ModelVersion as-is."""
170
+ return ModelVersion.parse(version) if isinstance(version, str) else version
171
+
172
+ def _has_migration_step(
173
+ self: Self, name: ModelName, from_ver: ModelVersion, to_ver: ModelVersion
174
+ ) -> bool:
175
+ """Check if a migration step exists (explicit or auto)."""
176
+ migration_key = (from_ver, to_ver)
177
+ has_explicit = migration_key in self.registry._migrations.get(name, {})
178
+ has_auto = to_ver in self.registry._backward_compatible_enabled.get(name, set())
179
+ return has_explicit or has_auto
180
+
181
+ def _execute_migration_path(
182
+ self: Self, data: ModelData, name: ModelName, path: list[ModelVersion]
183
+ ) -> ModelData:
184
+ """Execute migration through a path of versions."""
185
+ current_data = data
186
+
187
+ for i in range(len(path) - 1):
188
+ try:
189
+ current_data = self._execute_single_step(
190
+ current_data, name, path[i], path[i + 1]
191
+ )
192
+ except Exception as e:
193
+ if isinstance(e, MigrationError):
194
+ raise
195
+ raise MigrationError(
196
+ name,
197
+ str(path[i]),
198
+ str(path[i + 1]),
199
+ f"Migration failed: {type(e).__name__}: {e}",
200
+ ) from e
201
+
202
+ return current_data
203
+
204
+ def _execute_single_step(
205
+ self: Self,
206
+ data: ModelData,
207
+ name: ModelName,
208
+ from_ver: ModelVersion,
209
+ to_ver: ModelVersion,
210
+ ) -> ModelData:
211
+ """Execute a single migration step."""
212
+ migration_key = (from_ver, to_ver)
213
+
214
+ if migration_key in self.registry._migrations[name]:
215
+ migration_func = self.registry._migrations[name][migration_key]
216
+ try:
217
+ return migration_func(data)
218
+ except Exception as e:
219
+ raise MigrationError(
220
+ name,
221
+ str(from_ver),
222
+ str(to_ver),
223
+ f"Migration function raised: {type(e).__name__}: {e}",
224
+ ) from e
225
+
226
+ if to_ver in self.registry._backward_compatible_enabled[name]:
227
+ try:
228
+ return self._auto_migrate(data, name, from_ver, to_ver)
229
+ except Exception as e:
230
+ raise MigrationError(
231
+ name,
232
+ str(from_ver),
233
+ str(to_ver),
234
+ f"Auto-migration failed: {type(e).__name__}: {e}",
235
+ ) from e
236
+
237
+ raise MigrationError(
238
+ name,
239
+ str(from_ver),
240
+ str(to_ver),
241
+ (
242
+ "No migration path found. Define a migration function or mark "
243
+ "the target version as backward_compatible."
244
+ ),
245
+ )
246
+
225
247
  def _auto_migrate(
226
248
  self: Self,
227
249
  data: ModelData,
@@ -232,7 +254,7 @@ class MigrationManager:
232
254
  """Automatically migrate data when no explicit migration exists.
233
255
 
234
256
  This method handles nested Pydantic models recursively, migrating them to their
235
- corresponding versions.
257
+ corresponding versions. Handles field aliases by building a lookup map.
236
258
 
237
259
  Args:
238
260
  data: Data dictionary to migrate.
@@ -249,36 +271,142 @@ class MigrationManager:
249
271
  from_fields = from_model.model_fields
250
272
  to_fields = to_model.model_fields
251
273
 
274
+ key_to_field_name = self._build_alias_map(from_fields)
275
+
252
276
  result: ModelData = {}
277
+ processed_keys: set[str] = set()
253
278
 
254
279
  for field_name, to_field_info in to_fields.items():
255
- # Field exists in data, migrate it
256
- if field_name in data:
257
- value = data[field_name]
258
- from_field_info = from_fields.get(field_name)
259
- result[field_name] = self._migrate_field_value(
260
- value, from_field_info, to_field_info
261
- )
280
+ field_result = self._migrate_single_field(
281
+ data, field_name, to_field_info, from_fields, key_to_field_name
282
+ )
262
283
 
263
- # Field missing from data, use default if available
264
- elif to_field_info.default is not PydanticUndefined:
265
- result[field_name] = to_field_info.default
266
- elif to_field_info.default_factory is not None:
267
- with contextlib.suppress(Exception):
268
- result[field_name] = to_field_info.default_factory() # type: ignore
284
+ if field_result is not None:
285
+ output_key, migrated_value, source_keys = field_result
286
+ result[output_key] = migrated_value
287
+ processed_keys.update(source_keys)
269
288
 
270
- # Migrate all extra data not in the field, too
271
- for field_name, value in data.items():
272
- if field_name not in to_fields:
273
- result[field_name] = value
289
+ # Preserve unprocessed extra fields
290
+ for data_key, value in data.items():
291
+ if data_key not in processed_keys:
292
+ result[data_key] = value
274
293
 
275
294
  return result
276
295
 
277
- def _migrate_field_value(
296
+ def _build_alias_map(self: Self, fields: dict[str, FieldInfo]) -> dict[str, str]:
297
+ """Build a mapping from all possible input keys to canonical field names.
298
+
299
+ Args:
300
+ fields: Model fields to extract aliases from.
301
+
302
+ Returns:
303
+ Dictionary mapping data keys (field names and aliases) to field names.
304
+ """
305
+ key_to_field_name: dict[str, str] = {}
306
+
307
+ for field_name, field_info in fields.items():
308
+ key_to_field_name[field_name] = field_name
309
+
310
+ if field_info.alias:
311
+ key_to_field_name[field_info.alias] = field_name
312
+ if field_info.serialization_alias:
313
+ key_to_field_name[field_info.serialization_alias] = field_name
314
+ if isinstance(field_info.validation_alias, str):
315
+ key_to_field_name[field_info.validation_alias] = field_name
316
+
317
+ return key_to_field_name
318
+
319
+ def _migrate_single_field(
278
320
  self: Self,
279
- value: Any,
280
- from_field: FieldInfo | None,
281
- to_field: FieldInfo,
321
+ data: ModelData,
322
+ field_name: str,
323
+ to_field_info: FieldInfo,
324
+ from_fields: dict[str, FieldInfo],
325
+ key_to_field_name: dict[str, str],
326
+ ) -> tuple[str, Any, set[str]] | None:
327
+ """Migrate a single field, handling aliases and defaults.
328
+
329
+ Args:
330
+ data: Source data dictionary.
331
+ field_name: Target field name.
332
+ to_field_info: Target field info.
333
+ from_fields: Source model fields.
334
+ key_to_field_name: Mapping from data keys to field names.
335
+
336
+ Returns:
337
+ Tuple of (output_key, migrated_value, source_keys_to_mark_processed) if
338
+ field should be included, None if field should be skipped.
339
+ """
340
+ value, data_key_used = self._find_field_value(
341
+ data, field_name, key_to_field_name
342
+ )
343
+
344
+ if data_key_used is not None:
345
+ from_field_info = from_fields.get(field_name)
346
+ migrated_value = self._migrate_field_value(
347
+ value, from_field_info, to_field_info
348
+ )
349
+ keys_to_process = {
350
+ k for k, v in key_to_field_name.items() if v == field_name
351
+ }
352
+ return (data_key_used, migrated_value, keys_to_process)
353
+
354
+ _NO_DEFAULT = object()
355
+ default_value = self._get_field_default(to_field_info, _NO_DEFAULT)
356
+
357
+ if default_value is not _NO_DEFAULT:
358
+ output_key = (
359
+ to_field_info.serialization_alias or to_field_info.alias or field_name
360
+ )
361
+ return (output_key, default_value, set())
362
+
363
+ return None
364
+
365
+ def _find_field_value(
366
+ self: Self, data: ModelData, field_name: str, key_to_field_name: dict[str, str]
367
+ ) -> tuple[Any, str | None]:
368
+ """Find a field's value in data, checking both field name and aliases.
369
+
370
+ Args:
371
+ data: Source data dictionary.
372
+ field_name: Target field name to find.
373
+ key_to_field_name: Mapping from data keys to field names.
374
+
375
+ Returns:
376
+ Tuple of (value, data_key) if found, (None, None) otherwise.
377
+ """
378
+ if field_name in data:
379
+ return (data[field_name], field_name)
380
+
381
+ for data_key in data:
382
+ if key_to_field_name.get(data_key) == field_name:
383
+ return (data[data_key], data_key)
384
+
385
+ return (None, None)
386
+
387
+ def _get_field_default(
388
+ self: Self, field_info: FieldInfo, sentinel: Any = None
389
+ ) -> Any:
390
+ """Get the default value for a field.
391
+
392
+ Args:
393
+ field_info: Field info to extract default from.
394
+ sentinel: Sentinel value to return if no default is available.
395
+
396
+ Returns:
397
+ Default value if available, sentinel otherwise.
398
+ """
399
+ if field_info.default is not PydanticUndefined:
400
+ return field_info.default
401
+
402
+ if field_info.default_factory is not None:
403
+ with contextlib.suppress(Exception):
404
+ return field_info.default_factory() # type: ignore
405
+
406
+ return sentinel
407
+
408
+ def _migrate_field_value(
409
+ self: Self, value: Any, from_field: FieldInfo | None, to_field: FieldInfo
282
410
  ) -> Any:
283
411
  """Migrate a single field value, handling nested models.
284
412
 
@@ -294,31 +422,104 @@ class MigrationManager:
294
422
  return None
295
423
 
296
424
  if isinstance(value, dict):
297
- nested_info = self._extract_nested_model_info(value, from_field, to_field)
298
- if nested_info:
299
- nested_name, nested_from_ver, nested_to_ver = nested_info
300
- return self.migrate(value, nested_name, nested_from_ver, nested_to_ver)
301
-
302
- return {
303
- k: self._migrate_field_value(v, from_field, to_field)
304
- for k, v in value.items()
305
- }
425
+ return self._migrate_dict_value(value, from_field, to_field)
306
426
 
307
427
  if isinstance(value, list):
308
- return [
309
- self._migrate_field_value(item, from_field, to_field) for item in value
310
- ]
428
+ return self._migrate_list_value(value, from_field, to_field)
311
429
 
312
430
  return value
313
431
 
432
+ def _migrate_dict_value(
433
+ self, value: dict[str, Any], from_field: FieldInfo | None, to_field: FieldInfo
434
+ ) -> dict[str, Any]:
435
+ """Migrate a dictionary value (might be a nested model)."""
436
+ nested_info = self._extract_nested_model_info(value, from_field, to_field)
437
+ if nested_info:
438
+ nested_name, nested_from_ver, nested_to_ver = nested_info
439
+ return self.migrate(value, nested_name, nested_from_ver, nested_to_ver)
440
+
441
+ return {
442
+ k: self._migrate_field_value(v, from_field, to_field)
443
+ for k, v in value.items()
444
+ }
445
+
446
+ def _migrate_list_value(
447
+ self, value: list[Any], from_field: FieldInfo | None, to_field: FieldInfo
448
+ ) -> list[Any]:
449
+ """Migrate a list value (might contain nested models)."""
450
+ from_item_field, to_item_field = self._get_list_item_fields(
451
+ from_field, to_field
452
+ )
453
+
454
+ return [
455
+ self._migrate_list_item(item, from_item_field, to_item_field)
456
+ for item in value
457
+ ]
458
+
459
+ def _migrate_list_item(
460
+ self, item: Any, from_field: FieldInfo | None, to_field: FieldInfo | None
461
+ ) -> Any:
462
+ """Migrate a single item from a list."""
463
+ if not isinstance(item, dict) or to_field is None:
464
+ return item
465
+
466
+ nested_info = self._extract_nested_model_info(item, from_field, to_field)
467
+ if nested_info:
468
+ nested_name, nested_from_ver, nested_to_ver = nested_info
469
+ return self.migrate(item, nested_name, nested_from_ver, nested_to_ver)
470
+
471
+ return item
472
+
473
+ def _get_list_item_fields(
474
+ self, from_field: FieldInfo | None, to_field: FieldInfo
475
+ ) -> tuple[FieldInfo | None, FieldInfo | None]:
476
+ """Extract field info for items in a list field."""
477
+ to_item_field = self._extract_list_item_field(to_field)
478
+ from_item_field = (
479
+ self._extract_list_item_field(from_field) if from_field else None
480
+ )
481
+ return from_item_field, to_item_field
482
+
483
+ def _extract_list_item_field(self, field: FieldInfo | None) -> FieldInfo | None:
484
+ """Extract field info for the items of a list field."""
485
+ if field is None or field.annotation is None:
486
+ return None
487
+
488
+ origin = get_origin(field.annotation)
489
+ if origin is not list:
490
+ return None
491
+
492
+ args = get_args(field.annotation)
493
+ if not args:
494
+ return None
495
+
496
+ item_annotation = args[0]
497
+
498
+ discriminator = None
499
+ if get_origin(item_annotation) is Annotated:
500
+ annotated_args = get_args(item_annotation)
501
+ item_annotation = annotated_args[0]
502
+ for metadata in annotated_args[1:]:
503
+ if hasattr(metadata, "discriminator"):
504
+ discriminator = metadata.discriminator
505
+
506
+ synthetic_field = FieldInfo(
507
+ annotation=item_annotation, default=PydanticUndefined
508
+ )
509
+ if discriminator:
510
+ synthetic_field.discriminator = discriminator
511
+
512
+ return synthetic_field
513
+
314
514
  def _extract_nested_model_info(
315
- self: Self,
316
- value: ModelData,
317
- from_field: FieldInfo | None,
318
- to_field: FieldInfo,
515
+ self, value: ModelData, from_field: FieldInfo | None, to_field: FieldInfo
319
516
  ) -> tuple[ModelName, ModelVersion, ModelVersion] | None:
320
517
  """Extract nested model migration information.
321
518
 
519
+ Handles discriminated unions by using the discriminator field to determine which
520
+ model type to migrate.
521
+
522
+
322
523
  Args:
323
524
  value: The nested model data.
324
525
  from_field: Source field info.
@@ -328,8 +529,51 @@ class MigrationManager:
328
529
  Tuple of (model_name, from_version, to_version) if this is a
329
530
  versioned nested model, None otherwise.
330
531
  """
532
+ discriminated_info = self._try_extract_discriminated_model(
533
+ value, from_field, to_field
534
+ )
535
+ if discriminated_info:
536
+ return discriminated_info
537
+
538
+ return self._try_extract_simple_nested_model(from_field, to_field)
539
+
540
+ def _try_extract_discriminated_model(
541
+ self, value: ModelData, from_field: FieldInfo | None, to_field: FieldInfo
542
+ ) -> tuple[ModelName, ModelVersion, ModelVersion] | None:
543
+ """Try to extract model info from a discriminated union field."""
544
+ discriminator_key = self._get_discriminator_key(to_field)
545
+ if not discriminator_key:
546
+ return None
547
+
548
+ discriminator_value = self._find_discriminator_value(
549
+ value, discriminator_key, to_field
550
+ )
551
+ if discriminator_value is None:
552
+ return None
553
+
554
+ to_model_type = self._find_discriminated_type(
555
+ to_field, discriminator_key, discriminator_value
556
+ )
557
+ if not to_model_type:
558
+ return None
559
+
560
+ to_info = self.registry.get_model_info(to_model_type)
561
+ if not to_info:
562
+ return None
563
+
564
+ model_name, to_version = to_info
565
+ from_version = self._get_discriminated_source_version(
566
+ from_field, discriminator_key, discriminator_value, model_name, to_version
567
+ )
568
+
569
+ return (model_name, from_version, to_version)
570
+
571
+ def _try_extract_simple_nested_model(
572
+ self, from_field: FieldInfo | None, to_field: FieldInfo
573
+ ) -> tuple[ModelName, ModelVersion, ModelVersion] | None:
574
+ """Try to extract model info from a simple (non-discriminated) nested field."""
331
575
  to_model_type = self._get_model_type_from_field(to_field)
332
- if not to_model_type or not issubclass(to_model_type, BaseModel):
576
+ if not to_model_type:
333
577
  return None
334
578
 
335
579
  to_info = self.registry.get_model_info(to_model_type)
@@ -337,34 +581,175 @@ class MigrationManager:
337
581
  return None
338
582
 
339
583
  model_name, to_version = to_info
584
+ from_version = self._get_simple_source_version(from_field, model_name)
585
+
586
+ return (model_name, from_version or to_version, to_version)
587
+
588
+ def _get_simple_source_version(
589
+ self, from_field: FieldInfo | None, model_name: str
590
+ ) -> ModelVersion | None:
591
+ """Get the source version for a simple nested field."""
592
+ if not from_field:
593
+ return None
594
+
595
+ from_model_type = self._get_model_type_from_field(from_field)
596
+ if not from_model_type:
597
+ return None
598
+
599
+ from_info = self.registry.get_model_info(from_model_type)
600
+ if not from_info or from_info[0] != model_name:
601
+ return None
602
+
603
+ return from_info[1]
604
+
605
+ def _get_discriminator_key(self, field: FieldInfo) -> str | None:
606
+ """Extract the discriminator key from a field."""
607
+ discriminator = field.discriminator
608
+ if discriminator is None:
609
+ return None
610
+
611
+ if isinstance(discriminator, str):
612
+ return discriminator
613
+
614
+ if hasattr(discriminator, "discriminator") and isinstance(
615
+ discriminator.discriminator, str
616
+ ):
617
+ return discriminator.discriminator
618
+
619
+ return None
620
+
621
+ def _find_discriminator_value(
622
+ self, value: ModelData, discriminator_key: str, field: FieldInfo
623
+ ) -> Any:
624
+ """Find the discriminator value in data, checking field name and aliases."""
625
+ if discriminator_key in value:
626
+ return value[discriminator_key]
627
+
628
+ for model_type in self._get_union_members(field):
629
+ if discriminator_key not in model_type.model_fields:
630
+ continue
631
+
632
+ disc_field = model_type.model_fields[discriminator_key]
633
+
634
+ for alias_attr in ["alias", "serialization_alias"]:
635
+ alias = getattr(disc_field, alias_attr, None)
636
+ if alias and alias in value:
637
+ return value[alias]
340
638
 
341
- # Get the source version
342
- if from_field:
343
- from_model_type = self._get_model_type_from_field(from_field)
344
- if from_model_type and issubclass(from_model_type, BaseModel):
345
- from_info = self.registry.get_model_info(from_model_type)
346
- if from_info and from_info[0] == model_name:
347
- from_version = from_info[1]
348
- return (model_name, from_version, to_version)
639
+ val_alias = disc_field.validation_alias
640
+ if isinstance(val_alias, str) and val_alias in value:
641
+ return value[val_alias]
349
642
 
350
- # If we can't determine the source version, assume it's the same as target
351
- return (model_name, to_version, to_version)
643
+ return None
644
+
645
+ def _get_discriminated_source_version(
646
+ self,
647
+ from_field: FieldInfo | None,
648
+ discriminator_key: str,
649
+ discriminator_value: Any,
650
+ model_name: str,
651
+ default_version: ModelVersion,
652
+ ) -> ModelVersion:
653
+ """Get the source version for a discriminated union field."""
654
+ if not from_field:
655
+ return default_version
656
+
657
+ from_model_type = self._find_discriminated_type(
658
+ from_field, discriminator_key, discriminator_value
659
+ )
660
+ if not from_model_type:
661
+ return default_version
662
+
663
+ from_info = self.registry.get_model_info(from_model_type)
664
+ if not from_info or from_info[0] != model_name:
665
+ return default_version
666
+
667
+ return from_info[1]
352
668
 
353
- def _get_model_type_from_field(
354
- self: Self, field: FieldInfo
669
+ def _find_discriminated_type(
670
+ self, field: FieldInfo, discriminator_key: str, discriminator_value: Any
355
671
  ) -> type[BaseModel] | None:
356
- """Extract the Pydantic model type from a field.
672
+ """Find the right type in a discriminated union based on discriminator value."""
673
+ for model_type in self._get_union_members(field):
674
+ if self._model_matches_discriminator(
675
+ model_type, discriminator_key, discriminator_value
676
+ ):
677
+ return model_type
678
+ return None
357
679
 
358
- Handles Optional, List, and other generic types.
680
+ def _model_matches_discriminator(
681
+ self,
682
+ model_type: type[BaseModel],
683
+ discriminator_key: str,
684
+ discriminator_value: Any,
685
+ ) -> bool:
686
+ """Check if a model type matches a discriminator value."""
687
+ if discriminator_key not in model_type.model_fields:
688
+ return False
359
689
 
360
- Args:
361
- field: The field info to extract from.
690
+ field_info = model_type.model_fields[discriminator_key]
362
691
 
363
- Returns:
364
- The model type if found, None otherwise.
365
- """
692
+ if self._literal_matches_value(field_info.annotation, discriminator_value):
693
+ return True
694
+
695
+ return (
696
+ field_info.default is not PydanticUndefined
697
+ and field_info.default == discriminator_value
698
+ )
699
+
700
+ def _literal_matches_value(self, annotation: Any, value: Any) -> bool:
701
+ """Check if a Literal type annotation contains a specific value."""
702
+ if annotation is None:
703
+ return False
704
+
705
+ origin = get_origin(annotation)
706
+ if origin is not Literal:
707
+ return False
708
+
709
+ literal_values = get_args(annotation)
710
+ return value in literal_values
711
+
712
+ def _get_union_members(self, field: FieldInfo) -> list[type[BaseModel]]:
713
+ """Extract all BaseModel types from a union field."""
366
714
  annotation = field.annotation
715
+ if annotation is None:
716
+ return []
367
717
 
718
+ origin = get_origin(annotation)
719
+
720
+ if origin is None:
721
+ if isinstance(annotation, type) and issubclass(annotation, BaseModel):
722
+ return [annotation]
723
+ return []
724
+
725
+ if not self._is_union_type(origin):
726
+ return []
727
+
728
+ args = get_args(annotation)
729
+ return [
730
+ arg
731
+ for arg in args
732
+ if arg is not type(None)
733
+ and isinstance(arg, type)
734
+ and issubclass(arg, BaseModel)
735
+ ]
736
+
737
+ def _is_union_type(self, origin: Any) -> bool:
738
+ """Check if an origin type represents a Union."""
739
+ if origin is Union:
740
+ return True
741
+
742
+ if hasattr(types, "UnionType"):
743
+ try:
744
+ return origin is types.UnionType
745
+ except (ImportError, AttributeError):
746
+ pass
747
+
748
+ return False
749
+
750
+ def _get_model_type_from_field(self, field: FieldInfo) -> type[BaseModel] | None:
751
+ """Extract the Pydantic model type from a field."""
752
+ annotation = field.annotation
368
753
  if annotation is None:
369
754
  return None
370
755