datamodel-code-generator 0.26.3__py3-none-any.whl → 0.27.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datamodel-code-generator might be problematic. Click here for more details.

Files changed (29) hide show
  1. datamodel_code_generator/__init__.py +39 -6
  2. datamodel_code_generator/__main__.py +42 -21
  3. datamodel_code_generator/arguments.py +8 -1
  4. datamodel_code_generator/format.py +1 -0
  5. datamodel_code_generator/http.py +2 -1
  6. datamodel_code_generator/imports.py +2 -2
  7. datamodel_code_generator/model/__init__.py +22 -9
  8. datamodel_code_generator/model/base.py +18 -8
  9. datamodel_code_generator/model/enum.py +15 -3
  10. datamodel_code_generator/model/msgspec.py +3 -2
  11. datamodel_code_generator/model/pydantic/base_model.py +1 -1
  12. datamodel_code_generator/model/pydantic/types.py +1 -1
  13. datamodel_code_generator/model/pydantic_v2/base_model.py +2 -2
  14. datamodel_code_generator/model/pydantic_v2/types.py +4 -1
  15. datamodel_code_generator/parser/base.py +24 -12
  16. datamodel_code_generator/parser/graphql.py +6 -4
  17. datamodel_code_generator/parser/jsonschema.py +12 -5
  18. datamodel_code_generator/parser/openapi.py +16 -6
  19. datamodel_code_generator/pydantic_patch.py +1 -1
  20. datamodel_code_generator/reference.py +19 -10
  21. datamodel_code_generator/types.py +26 -22
  22. datamodel_code_generator/util.py +7 -11
  23. datamodel_code_generator/version.py +1 -1
  24. {datamodel_code_generator-0.26.3.dist-info → datamodel_code_generator-0.27.0.dist-info}/METADATA +37 -32
  25. {datamodel_code_generator-0.26.3.dist-info → datamodel_code_generator-0.27.0.dist-info}/RECORD +35 -35
  26. {datamodel_code_generator-0.26.3.dist-info → datamodel_code_generator-0.27.0.dist-info}/WHEEL +1 -1
  27. datamodel_code_generator-0.27.0.dist-info/entry_points.txt +2 -0
  28. datamodel_code_generator-0.26.3.dist-info/entry_points.txt +0 -3
  29. {datamodel_code_generator-0.26.3.dist-info → datamodel_code_generator-0.27.0.dist-info/licenses}/LICENSE +0 -0
@@ -65,6 +65,7 @@ def get_special_path(keyword: str, path: List[str]) -> List[str]:
65
65
 
66
66
  escape_characters = str.maketrans(
67
67
  {
68
+ '\u0000': r'\x00', # Null byte
68
69
  '\\': r'\\',
69
70
  "'": r'\'',
70
71
  '\b': r'\b',
@@ -410,8 +411,9 @@ class Parser(ABC):
410
411
  treat_dots_as_module: bool = False,
411
412
  use_exact_imports: bool = False,
412
413
  default_field_extras: Optional[Dict[str, Any]] = None,
413
- target_datetime_class: DatetimeClassType = DatetimeClassType.Datetime,
414
+ target_datetime_class: Optional[DatetimeClassType] = DatetimeClassType.Datetime,
414
415
  keyword_only: bool = False,
416
+ no_alias: bool = False,
415
417
  ) -> None:
416
418
  self.keyword_only = keyword_only
417
419
  self.data_type_manager: DataTypeManager = data_type_manager_type(
@@ -512,6 +514,7 @@ class Parser(ABC):
512
514
  special_field_name_prefix=special_field_name_prefix,
513
515
  remove_special_field_name_prefix=remove_special_field_name_prefix,
514
516
  capitalise_enum_members=capitalise_enum_members,
517
+ no_alias=no_alias,
515
518
  )
516
519
  self.class_name: Optional[str] = class_name
517
520
  self.wrap_string_literal: Optional[bool] = wrap_string_literal
@@ -665,7 +668,16 @@ class Parser(ABC):
665
668
  for model, duplicate_models in model_to_duplicate_models.items():
666
669
  for duplicate_model in duplicate_models:
667
670
  for child in duplicate_model.reference.children[:]:
668
- child.replace_reference(model.reference)
671
+ if isinstance(child, DataType):
672
+ child.replace_reference(model.reference)
673
+ # simplify if introduce duplicate base classes
674
+ if isinstance(child, DataModel):
675
+ child.base_classes = list(
676
+ {
677
+ f'{c.module_name}.{c.type_hint}': c
678
+ for c in child.base_classes
679
+ }.values()
680
+ )
669
681
  models.remove(duplicate_model)
670
682
 
671
683
  @classmethod
@@ -846,12 +858,12 @@ class Parser(ABC):
846
858
 
847
859
  # Check the main discriminator model path
848
860
  if mapping:
849
- check_paths(discriminator_model, mapping)
861
+ check_paths(discriminator_model, mapping) # pyright: ignore [reportArgumentType]
850
862
 
851
863
  # Check the base_classes if they exist
852
864
  if len(type_names) == 0:
853
865
  for base_class in discriminator_model.base_classes:
854
- check_paths(base_class.reference, mapping)
866
+ check_paths(base_class.reference, mapping) # pyright: ignore [reportArgumentType]
855
867
  else:
856
868
  type_names = [discriminator_model.path.split('/')[-1]]
857
869
  if not type_names: # pragma: no cover
@@ -866,10 +878,8 @@ class Parser(ABC):
866
878
  ) != property_name:
867
879
  continue
868
880
  literals = discriminator_field.data_type.literals
869
- if (
870
- len(literals) == 1 and literals[0] == type_names[0]
871
- if type_names
872
- else None
881
+ if len(literals) == 1 and literals[0] == (
882
+ type_names[0] if type_names else None
873
883
  ):
874
884
  has_one_literal = True
875
885
  if isinstance(
@@ -882,7 +892,8 @@ class Parser(ABC):
882
892
  'tag', discriminator_field.represented_default
883
893
  )
884
894
  discriminator_field.extras['is_classvar'] = True
885
- continue
895
+ # Found the discriminator field, no need to keep looking
896
+ break
886
897
  for (
887
898
  field_data_type
888
899
  ) in discriminator_field.data_type.all_data_types:
@@ -1036,7 +1047,7 @@ class Parser(ABC):
1036
1047
  and any(
1037
1048
  d
1038
1049
  for d in model_field.data_type.all_data_types
1039
- if d.is_dict or d.is_union
1050
+ if d.is_dict or d.is_union or d.is_list
1040
1051
  )
1041
1052
  ):
1042
1053
  continue # pragma: no cover
@@ -1059,7 +1070,7 @@ class Parser(ABC):
1059
1070
 
1060
1071
  data_type.parent.data_type = copied_data_type
1061
1072
 
1062
- elif data_type.parent.is_list:
1073
+ elif data_type.parent is not None and data_type.parent.is_list:
1063
1074
  if self.field_constraints:
1064
1075
  model_field.constraints = ConstraintsBase.merge_constraints(
1065
1076
  root_type_field.constraints, model_field.constraints
@@ -1071,6 +1082,7 @@ class Parser(ABC):
1071
1082
  discriminator = root_type_field.extras.get('discriminator')
1072
1083
  if discriminator:
1073
1084
  model_field.extras['discriminator'] = discriminator
1085
+ assert isinstance(data_type.parent, DataType)
1074
1086
  data_type.parent.data_types.remove(
1075
1087
  data_type
1076
1088
  ) # pragma: no cover
@@ -1356,7 +1368,7 @@ class Parser(ABC):
1356
1368
  module_to_import: Dict[Tuple[str, ...], Imports] = {}
1357
1369
 
1358
1370
  previous_module = () # type: Tuple[str, ...]
1359
- for module, models in ((k, [*v]) for k, v in grouped_models): # type: Tuple[str, ...], List[DataModel]
1371
+ for module, models in ((k, [*v]) for k, v in grouped_models):
1360
1372
  for model in models:
1361
1373
  model_to_module_models[model] = module, models
1362
1374
  self.__delete_duplicate_models(models)
@@ -160,6 +160,7 @@ class GraphQLParser(Parser):
160
160
  default_field_extras: Optional[Dict[str, Any]] = None,
161
161
  target_datetime_class: DatetimeClassType = DatetimeClassType.Datetime,
162
162
  keyword_only: bool = False,
163
+ no_alias: bool = False,
163
164
  ) -> None:
164
165
  super().__init__(
165
166
  source=source,
@@ -232,6 +233,7 @@ class GraphQLParser(Parser):
232
233
  default_field_extras=default_field_extras,
233
234
  target_datetime_class=target_datetime_class,
234
235
  keyword_only=keyword_only,
236
+ no_alias=no_alias,
235
237
  )
236
238
 
237
239
  self.data_model_scalar_type = data_model_scalar_type
@@ -371,7 +373,7 @@ class GraphQLParser(Parser):
371
373
  def parse_field(
372
374
  self,
373
375
  field_name: str,
374
- alias: str,
376
+ alias: Optional[str],
375
377
  field: Union[graphql.GraphQLField, graphql.GraphQLInputField],
376
378
  ) -> DataModelFieldBase:
377
379
  final_data_type = DataType(
@@ -397,9 +399,9 @@ class GraphQLParser(Parser):
397
399
  elif graphql.is_non_null_type(obj): # pragma: no cover
398
400
  data_type.is_optional = False
399
401
 
400
- obj = obj.of_type
402
+ obj = obj.of_type # pyright: ignore [reportAttributeAccessIssue]
401
403
 
402
- data_type.type = obj.name
404
+ data_type.type = obj.name # pyright: ignore [reportAttributeAccessIssue]
403
405
 
404
406
  required = (not self.force_optional_for_required_fields) and (
405
407
  not final_data_type.is_optional
@@ -454,7 +456,7 @@ class GraphQLParser(Parser):
454
456
 
455
457
  base_classes = []
456
458
  if hasattr(obj, 'interfaces'): # pragma: no cover
457
- base_classes = [self.references[i.name] for i in obj.interfaces]
459
+ base_classes = [self.references[i.name] for i in obj.interfaces] # pyright: ignore [reportAttributeAccessIssue]
458
460
 
459
461
  data_model_type = self.data_model_type(
460
462
  reference=self.references[obj.name],
@@ -258,7 +258,7 @@ class JsonSchemaObject(BaseModel):
258
258
  extras: Dict[str, Any] = Field(alias=__extra_key__, default_factory=dict)
259
259
  discriminator: Union[Discriminator, str, None] = None
260
260
  if PYDANTIC_V2:
261
- model_config = ConfigDict(
261
+ model_config = ConfigDict( # pyright: ignore [reportPossiblyUnboundVariable]
262
262
  arbitrary_types_allowed=True,
263
263
  ignored_types=(cached_property,),
264
264
  )
@@ -320,7 +320,7 @@ class JsonSchemaObject(BaseModel):
320
320
  return isinstance(self.type, list) and 'null' in self.type
321
321
 
322
322
 
323
- @lru_cache()
323
+ @lru_cache
324
324
  def get_ref_type(ref: str) -> JSONReference:
325
325
  if ref[0] == '#':
326
326
  return JSONReference.LOCAL
@@ -338,7 +338,7 @@ def _get_type(type_: str, format__: Optional[str] = None) -> Types:
338
338
  if data_formats is not None:
339
339
  return data_formats
340
340
 
341
- warn(f'format of {format__!r} not understood for {type_!r} - using default' '')
341
+ warn(f'format of {format__!r} not understood for {type_!r} - using default')
342
342
  return json_schema_data_formats[type_]['default']
343
343
 
344
344
 
@@ -360,7 +360,7 @@ EXCLUDE_FIELD_KEYS_IN_JSON_SCHEMA: Set[str] = {
360
360
  }
361
361
 
362
362
  EXCLUDE_FIELD_KEYS = (
363
- set(JsonSchemaObject.get_fields())
363
+ set(JsonSchemaObject.get_fields()) # pyright: ignore [reportAttributeAccessIssue]
364
364
  - DEFAULT_FIELD_KEYS
365
365
  - EXCLUDE_FIELD_KEYS_IN_JSON_SCHEMA
366
366
  ) | {
@@ -448,6 +448,7 @@ class JsonSchemaParser(Parser):
448
448
  default_field_extras: Optional[Dict[str, Any]] = None,
449
449
  target_datetime_class: DatetimeClassType = DatetimeClassType.Datetime,
450
450
  keyword_only: bool = False,
451
+ no_alias: bool = False,
451
452
  ) -> None:
452
453
  super().__init__(
453
454
  source=source,
@@ -520,13 +521,14 @@ class JsonSchemaParser(Parser):
520
521
  default_field_extras=default_field_extras,
521
522
  target_datetime_class=target_datetime_class,
522
523
  keyword_only=keyword_only,
524
+ no_alias=no_alias,
523
525
  )
524
526
 
525
527
  self.remote_object_cache: DefaultPutDict[str, Dict[str, Any]] = DefaultPutDict()
526
528
  self.raw_obj: Dict[Any, Any] = {}
527
529
  self._root_id: Optional[str] = None
528
530
  self._root_id_base_path: Optional[str] = None
529
- self.reserved_refs: DefaultDict[Tuple[str], Set[str]] = defaultdict(set)
531
+ self.reserved_refs: DefaultDict[Tuple[str, ...], Set[str]] = defaultdict(set)
530
532
  self.field_keys: Set[str] = {
531
533
  *DEFAULT_FIELD_KEYS,
532
534
  *self.field_extra_keys,
@@ -618,6 +620,7 @@ class JsonSchemaParser(Parser):
618
620
  use_default_kwarg=self.use_default_kwarg,
619
621
  original_name=original_field_name,
620
622
  has_default=field.has_default,
623
+ type_has_null=field.type_has_null,
621
624
  )
622
625
 
623
626
  def get_data_type(self, obj: JsonSchemaObject) -> DataType:
@@ -1715,6 +1718,9 @@ class JsonSchemaParser(Parser):
1715
1718
  def parse_raw(self) -> None:
1716
1719
  for source, path_parts in self._get_context_source_path_parts():
1717
1720
  self.raw_obj = load_yaml(source.text)
1721
+ if self.raw_obj is None: # pragma: no cover
1722
+ warn(f'{source.path} is empty. Skipping this file')
1723
+ continue
1718
1724
  if self.custom_class_name_generator:
1719
1725
  obj_name = self.raw_obj.get('title', 'Model')
1720
1726
  else:
@@ -1792,6 +1798,7 @@ class JsonSchemaParser(Parser):
1792
1798
  root_obj = self.SCHEMA_OBJECT_TYPE.parse_obj(raw)
1793
1799
  self.parse_id(root_obj, path_parts)
1794
1800
  definitions: Optional[Dict[Any, Any]] = None
1801
+ schema_path = ''
1795
1802
  for schema_path, split_schema_path in self.schema_paths:
1796
1803
  try:
1797
1804
  definitions = get_model_by_path(raw, split_schema_path)
@@ -228,6 +228,7 @@ class OpenAPIParser(JsonSchemaParser):
228
228
  default_field_extras: Optional[Dict[str, Any]] = None,
229
229
  target_datetime_class: DatetimeClassType = DatetimeClassType.Datetime,
230
230
  keyword_only: bool = False,
231
+ no_alias: bool = False,
231
232
  ):
232
233
  super().__init__(
233
234
  source=source,
@@ -300,6 +301,7 @@ class OpenAPIParser(JsonSchemaParser):
300
301
  default_field_extras=default_field_extras,
301
302
  target_datetime_class=target_datetime_class,
302
303
  keyword_only=keyword_only,
304
+ no_alias=no_alias,
303
305
  )
304
306
  self.open_api_scopes: List[OpenAPIScope] = openapi_scopes or [
305
307
  OpenAPIScope.Schemas
@@ -314,8 +316,10 @@ class OpenAPIParser(JsonSchemaParser):
314
316
  return get_model_by_path(ref_body, ref_path.split('/')[1:])
315
317
 
316
318
  def get_data_type(self, obj: JsonSchemaObject) -> DataType:
317
- # OpenAPI doesn't allow `null` in `type` field and list of types
319
+ # OpenAPI 3.0 doesn't allow `null` in the `type` field and list of types
318
320
  # https://swagger.io/docs/specification/data-models/data-types/#null
321
+ # OpenAPI 3.1 does allow `null` in the `type` field and is equivalent to
322
+ # a `nullable` flag on the property itself
319
323
  if obj.nullable and self.strict_nullable and isinstance(obj.type, str):
320
324
  obj.type = [obj.type, 'null']
321
325
 
@@ -363,7 +367,7 @@ class OpenAPIParser(JsonSchemaParser):
363
367
  for (
364
368
  media_type,
365
369
  media_obj,
366
- ) in request_body.content.items(): # type: str, MediaObject
370
+ ) in request_body.content.items():
367
371
  if isinstance(media_obj.schema_, JsonSchemaObject):
368
372
  self.parse_schema(name, media_obj.schema_, [*path, media_type])
369
373
 
@@ -396,11 +400,13 @@ class OpenAPIParser(JsonSchemaParser):
396
400
  if not object_schema: # pragma: no cover
397
401
  continue
398
402
  if isinstance(object_schema, JsonSchemaObject):
399
- data_types[status_code][content_type] = self.parse_schema(
400
- name, object_schema, [*path, str(status_code), content_type]
403
+ data_types[status_code][content_type] = self.parse_schema( # pyright: ignore [reportArgumentType]
404
+ name,
405
+ object_schema,
406
+ [*path, str(status_code), content_type], # pyright: ignore [reportArgumentType]
401
407
  )
402
408
  else:
403
- data_types[status_code][content_type] = self.get_ref_data_type(
409
+ data_types[status_code][content_type] = self.get_ref_data_type( # pyright: ignore [reportArgumentType]
404
410
  object_schema.ref
405
411
  )
406
412
 
@@ -504,6 +510,9 @@ class OpenAPIParser(JsonSchemaParser):
504
510
  has_default=object_schema.has_default
505
511
  if object_schema
506
512
  else False,
513
+ type_has_null=object_schema.type_has_null
514
+ if object_schema
515
+ else None,
507
516
  )
508
517
  )
509
518
 
@@ -513,6 +522,7 @@ class OpenAPIParser(JsonSchemaParser):
513
522
  fields=fields,
514
523
  reference=reference,
515
524
  custom_base_class=self.base_class,
525
+ custom_template_dir=self.custom_template_dir,
516
526
  keyword_only=self.keyword_only,
517
527
  )
518
528
  )
@@ -596,7 +606,7 @@ class OpenAPIParser(JsonSchemaParser):
596
606
  for (
597
607
  obj_name,
598
608
  raw_obj,
599
- ) in schemas.items(): # type: str, Dict[Any, Any]
609
+ ) in schemas.items():
600
610
  self.parse_raw_obj(
601
611
  obj_name,
602
612
  raw_obj,
@@ -19,4 +19,4 @@ def patched_evaluate_forwardref(
19
19
 
20
20
  # Patch only Python3.12
21
21
  if sys.version_info >= (3, 12):
22
- pydantic.typing.evaluate_forwardref = patched_evaluate_forwardref
22
+ pydantic.typing.evaluate_forwardref = patched_evaluate_forwardref # pyright: ignore [reportAttributeAccessIssue]
@@ -138,7 +138,7 @@ class Reference(_BaseModel):
138
138
  if PYDANTIC_V2:
139
139
  # TODO[pydantic]: The following keys were removed: `copy_on_model_validation`.
140
140
  # Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-config for more information.
141
- model_config = ConfigDict(
141
+ model_config = ConfigDict( # pyright: ignore [reportAssignmentType]
142
142
  arbitrary_types_allowed=True,
143
143
  ignored_types=(cached_property,),
144
144
  revalidate_instances='never',
@@ -182,7 +182,7 @@ _UNDER_SCORE_1: Pattern[str] = re.compile(r'([^_])([A-Z][a-z]+)')
182
182
  _UNDER_SCORE_2: Pattern[str] = re.compile('([a-z0-9])([A-Z])')
183
183
 
184
184
 
185
- @lru_cache()
185
+ @lru_cache
186
186
  def camel_to_snake(string: str) -> str:
187
187
  subbed = _UNDER_SCORE_1.sub(r'\1_\2', string)
188
188
  return _UNDER_SCORE_2.sub(r'\1_\2', subbed).lower()
@@ -198,6 +198,7 @@ class FieldNameResolver:
198
198
  special_field_name_prefix: Optional[str] = None,
199
199
  remove_special_field_name_prefix: bool = False,
200
200
  capitalise_enum_members: bool = False,
201
+ no_alias: bool = False,
201
202
  ):
202
203
  self.aliases: Mapping[str, str] = {} if aliases is None else {**aliases}
203
204
  self.empty_field_name: str = empty_field_name or '_'
@@ -208,6 +209,7 @@ class FieldNameResolver:
208
209
  )
209
210
  self.remove_special_field_name_prefix: bool = remove_special_field_name_prefix
210
211
  self.capitalise_enum_members: bool = capitalise_enum_members
212
+ self.no_alias = no_alias
211
213
 
212
214
  @classmethod
213
215
  def _validate_field_name(cls, field_name: str) -> bool:
@@ -274,7 +276,10 @@ class FieldNameResolver:
274
276
  if field_name in self.aliases:
275
277
  return self.aliases[field_name], field_name
276
278
  valid_name = self.get_valid_name(field_name, excludes=excludes)
277
- return valid_name, None if field_name == valid_name else field_name
279
+ return (
280
+ valid_name,
281
+ None if self.no_alias or field_name == valid_name else field_name,
282
+ )
278
283
 
279
284
 
280
285
  class PydanticFieldNameResolver(FieldNameResolver):
@@ -354,6 +359,7 @@ class ModelResolver:
354
359
  special_field_name_prefix: Optional[str] = None,
355
360
  remove_special_field_name_prefix: bool = False,
356
361
  capitalise_enum_members: bool = False,
362
+ no_alias: bool = False,
357
363
  ) -> None:
358
364
  self.references: Dict[str, Reference] = {}
359
365
  self._current_root: Sequence[str] = []
@@ -383,6 +389,7 @@ class ModelResolver:
383
389
  capitalise_enum_members=capitalise_enum_members
384
390
  if k == ModelType.ENUM
385
391
  else False,
392
+ no_alias=no_alias,
386
393
  )
387
394
  for k, v in merged_field_name_resolver_classes.items()
388
395
  }
@@ -466,7 +473,7 @@ class ModelResolver:
466
473
  else:
467
474
  joined_path = self.join_path(path)
468
475
  if joined_path == '#':
469
- return f"{'/'.join(self.current_root)}#"
476
+ return f'{"/".join(self.current_root)}#'
470
477
  if (
471
478
  self.current_base_path
472
479
  and not self.base_url
@@ -491,7 +498,7 @@ class ModelResolver:
491
498
 
492
499
  delimiter = joined_path.index('#')
493
500
  file_path = ''.join(joined_path[:delimiter])
494
- ref = f"{''.join(joined_path[:delimiter])}#{''.join(joined_path[delimiter + 1:])}"
501
+ ref = f'{"".join(joined_path[:delimiter])}#{"".join(joined_path[delimiter + 1 :])}'
495
502
  if self.root_id_base_path and not (
496
503
  is_url(joined_path) or Path(self._base_path, file_path).is_file()
497
504
  ):
@@ -566,11 +573,13 @@ class ModelResolver:
566
573
  split_ref = ref.rsplit('/', 1)
567
574
  if len(split_ref) == 1:
568
575
  original_name = Path(
569
- split_ref[0][:-1] if self.is_external_root_ref(path) else split_ref[0]
576
+ split_ref[0].rstrip('#')
577
+ if self.is_external_root_ref(path)
578
+ else split_ref[0]
570
579
  ).stem
571
580
  else:
572
581
  original_name = (
573
- Path(split_ref[1][:-1]).stem
582
+ Path(split_ref[1].rstrip('#')).stem
574
583
  if self.is_external_root_ref(path)
575
584
  else split_ref[1]
576
585
  )
@@ -741,15 +750,15 @@ class ModelResolver:
741
750
  )
742
751
 
743
752
 
744
- @lru_cache()
753
+ @lru_cache
745
754
  def get_singular_name(name: str, suffix: str = SINGULAR_NAME_SUFFIX) -> str:
746
755
  singular_name = inflect_engine.singular_noun(name)
747
756
  if singular_name is False:
748
757
  singular_name = f'{name}{suffix}'
749
- return singular_name
758
+ return singular_name # pyright: ignore [reportReturnType]
750
759
 
751
760
 
752
- @lru_cache()
761
+ @lru_cache
753
762
  def snake_to_upper_camel(word: str, delimiter: str = '_') -> str:
754
763
  prefix = ''
755
764
  if word.startswith(delimiter):
@@ -114,25 +114,25 @@ class UnionIntFloat:
114
114
  def __get_pydantic_core_schema__(
115
115
  cls, _source_type: Any, _handler: 'GetCoreSchemaHandler'
116
116
  ) -> 'core_schema.CoreSchema':
117
- from_int_schema = core_schema.chain_schema(
117
+ from_int_schema = core_schema.chain_schema( # pyright: ignore [reportPossiblyUnboundVariable]
118
118
  [
119
- core_schema.union_schema(
120
- [core_schema.int_schema(), core_schema.float_schema()]
119
+ core_schema.union_schema( # pyright: ignore [reportPossiblyUnboundVariable]
120
+ [core_schema.int_schema(), core_schema.float_schema()] # pyright: ignore [reportPossiblyUnboundVariable]
121
121
  ),
122
- core_schema.no_info_plain_validator_function(cls.validate),
122
+ core_schema.no_info_plain_validator_function(cls.validate), # pyright: ignore [reportPossiblyUnboundVariable]
123
123
  ]
124
124
  )
125
125
 
126
- return core_schema.json_or_python_schema(
126
+ return core_schema.json_or_python_schema( # pyright: ignore [reportPossiblyUnboundVariable]
127
127
  json_schema=from_int_schema,
128
- python_schema=core_schema.union_schema(
128
+ python_schema=core_schema.union_schema( # pyright: ignore [reportPossiblyUnboundVariable]
129
129
  [
130
130
  # check if it's an instance first before doing any further work
131
- core_schema.is_instance_schema(UnionIntFloat),
131
+ core_schema.is_instance_schema(UnionIntFloat), # pyright: ignore [reportPossiblyUnboundVariable]
132
132
  from_int_schema,
133
133
  ]
134
134
  ),
135
- serialization=core_schema.plain_serializer_function_ser_schema(
135
+ serialization=core_schema.plain_serializer_function_ser_schema( # pyright: ignore [reportPossiblyUnboundVariable]
136
136
  lambda instance: instance.value
137
137
  ),
138
138
  )
@@ -161,7 +161,7 @@ def chain_as_tuple(*iterables: Iterable[T]) -> Tuple[T, ...]:
161
161
  return tuple(chain(*iterables))
162
162
 
163
163
 
164
- @lru_cache()
164
+ @lru_cache
165
165
  def _remove_none_from_type(
166
166
  type_: str, split_pattern: Pattern[str], delimiter: str
167
167
  ) -> List[str]:
@@ -207,7 +207,7 @@ def _remove_none_from_union(type_: str, use_union_operator: bool) -> str:
207
207
  return f'{UNION_PREFIX}{UNION_DELIMITER.join(inner_types)}]'
208
208
 
209
209
 
210
- @lru_cache()
210
+ @lru_cache
211
211
  def get_optional_type(type_: str, use_union_operator: bool) -> str:
212
212
  type_ = _remove_none_from_union(type_, use_union_operator)
213
213
 
@@ -236,7 +236,7 @@ class DataType(_BaseModel):
236
236
  if PYDANTIC_V2:
237
237
  # TODO[pydantic]: The following keys were removed: `copy_on_model_validation`.
238
238
  # Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-config for more information.
239
- model_config = ConfigDict(
239
+ model_config = ConfigDict( # pyright: ignore [reportAssignmentType]
240
240
  extra='forbid',
241
241
  revalidate_instances='never',
242
242
  )
@@ -362,22 +362,21 @@ class DataType(_BaseModel):
362
362
 
363
363
  @property
364
364
  def imports(self) -> Iterator[Import]:
365
+ # Add base import if exists
365
366
  if self.import_:
366
367
  yield self.import_
368
+
369
+ # Define required imports based on type features and conditions
367
370
  imports: Tuple[Tuple[bool, Import], ...] = (
368
371
  (self.is_optional and not self.use_union_operator, IMPORT_OPTIONAL),
369
372
  (len(self.data_types) > 1 and not self.use_union_operator, IMPORT_UNION),
370
- )
371
- if any(self.literals):
372
- import_literal = (
373
+ (
374
+ bool(self.literals),
373
375
  IMPORT_LITERAL
374
376
  if self.python_version.has_literal_type
375
- else IMPORT_LITERAL_BACKPORT
376
- )
377
- imports = (
378
- *imports,
379
- (any(self.literals), import_literal),
380
- )
377
+ else IMPORT_LITERAL_BACKPORT,
378
+ ),
379
+ )
381
380
 
382
381
  if self.use_generic_container:
383
382
  if self.use_standard_collections:
@@ -401,10 +400,13 @@ class DataType(_BaseModel):
401
400
  (self.is_set, IMPORT_SET),
402
401
  (self.is_dict, IMPORT_DICT),
403
402
  )
403
+
404
+ # Yield imports based on conditions
404
405
  for field, import_ in imports:
405
406
  if field and import_ != self.import_:
406
407
  yield import_
407
408
 
409
+ # Propagate imports from any dict_key type
408
410
  if self.dict_key:
409
411
  yield from self.dict_key.imports
410
412
 
@@ -463,7 +465,7 @@ class DataType(_BaseModel):
463
465
  elif len(self.data_types) == 1:
464
466
  type_ = self.data_types[0].type_hint
465
467
  elif self.literals:
466
- type_ = f"{LITERAL}[{', '.join(repr(literal) for literal in self.literals)}]"
468
+ type_ = f'{LITERAL}[{", ".join(repr(literal) for literal in self.literals)}]'
467
469
  else:
468
470
  if self.reference:
469
471
  type_ = self.reference.short_name
@@ -586,7 +588,9 @@ class DataTypeManager(ABC):
586
588
  )
587
589
  self.use_union_operator: bool = use_union_operator
588
590
  self.use_pendulum: bool = use_pendulum
589
- self.target_datetime_class: DatetimeClassType = target_datetime_class
591
+ self.target_datetime_class: DatetimeClassType = (
592
+ target_datetime_class or DatetimeClassType.Datetime
593
+ )
590
594
 
591
595
  if (
592
596
  use_generic_container_types and python_version == PythonVersion.PY_36
@@ -37,17 +37,13 @@ else:
37
37
  from yaml import SafeLoader
38
38
 
39
39
  try:
40
- import tomllib
41
-
42
- def load_toml(path: Path) -> Dict[str, Any]:
43
- with path.open('rb') as f:
44
- return tomllib.load(f)
45
-
40
+ from tomllib import load as load_tomllib
46
41
  except ImportError:
47
- import toml
42
+ from tomli import load as load_tomllib
48
43
 
49
- def load_toml(path: Path) -> Dict[str, Any]:
50
- return toml.load(path)
44
+ def load_toml(path: Path) -> Dict[str, Any]:
45
+ with path.open('rb') as f:
46
+ return load_tomllib(f)
51
47
 
52
48
 
53
49
  SafeLoaderTemp = copy.deepcopy(SafeLoader)
@@ -81,7 +77,7 @@ def field_validator(
81
77
  field_name: str,
82
78
  *fields: str,
83
79
  mode: Literal['before', 'after'] = 'after',
84
- ) -> Callable[[Any], Callable[[Model, Any], Any]]:
80
+ ) -> Callable[[Any], Callable[[BaseModel, Any], Any]]:
85
81
  def inner(method: Callable[[Model, Any], Any]) -> Callable[[Model, Any], Any]:
86
82
  if PYDANTIC_V2:
87
83
  from pydantic import field_validator as field_validator_v2
@@ -103,4 +99,4 @@ else:
103
99
 
104
100
  class BaseModel(_BaseModel):
105
101
  if PYDANTIC_V2:
106
- model_config = ConfigDict(strict=False)
102
+ model_config = ConfigDict(strict=False) # pyright: ignore [reportAssignmentType]
@@ -1 +1 @@
1
- version: str = '0.26.3'
1
+ version: str = '0.0.0'