datamodel-code-generator 0.26.3__py3-none-any.whl → 0.27.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of datamodel-code-generator might be problematic. Click here for more details.
- datamodel_code_generator/__init__.py +39 -6
- datamodel_code_generator/__main__.py +42 -21
- datamodel_code_generator/arguments.py +8 -1
- datamodel_code_generator/format.py +1 -0
- datamodel_code_generator/http.py +2 -1
- datamodel_code_generator/imports.py +2 -2
- datamodel_code_generator/model/__init__.py +22 -9
- datamodel_code_generator/model/base.py +18 -8
- datamodel_code_generator/model/enum.py +15 -3
- datamodel_code_generator/model/msgspec.py +3 -2
- datamodel_code_generator/model/pydantic/base_model.py +1 -1
- datamodel_code_generator/model/pydantic/types.py +1 -1
- datamodel_code_generator/model/pydantic_v2/base_model.py +2 -2
- datamodel_code_generator/model/pydantic_v2/types.py +4 -1
- datamodel_code_generator/parser/base.py +24 -12
- datamodel_code_generator/parser/graphql.py +6 -4
- datamodel_code_generator/parser/jsonschema.py +12 -5
- datamodel_code_generator/parser/openapi.py +16 -6
- datamodel_code_generator/pydantic_patch.py +1 -1
- datamodel_code_generator/reference.py +19 -10
- datamodel_code_generator/types.py +26 -22
- datamodel_code_generator/util.py +7 -11
- datamodel_code_generator/version.py +1 -1
- {datamodel_code_generator-0.26.3.dist-info → datamodel_code_generator-0.27.0.dist-info}/METADATA +37 -32
- {datamodel_code_generator-0.26.3.dist-info → datamodel_code_generator-0.27.0.dist-info}/RECORD +35 -35
- {datamodel_code_generator-0.26.3.dist-info → datamodel_code_generator-0.27.0.dist-info}/WHEEL +1 -1
- datamodel_code_generator-0.27.0.dist-info/entry_points.txt +2 -0
- datamodel_code_generator-0.26.3.dist-info/entry_points.txt +0 -3
- {datamodel_code_generator-0.26.3.dist-info → datamodel_code_generator-0.27.0.dist-info/licenses}/LICENSE +0 -0
|
@@ -81,6 +81,9 @@ def enable_debug_message() -> None: # pragma: no cover
|
|
|
81
81
|
pysnooper.tracer.DISABLED = False
|
|
82
82
|
|
|
83
83
|
|
|
84
|
+
DEFAULT_MAX_VARIABLE_LENGTH: int = 100
|
|
85
|
+
|
|
86
|
+
|
|
84
87
|
def snooper_to_methods( # type: ignore
|
|
85
88
|
output=None,
|
|
86
89
|
watch=(),
|
|
@@ -90,7 +93,7 @@ def snooper_to_methods( # type: ignore
|
|
|
90
93
|
overwrite=False,
|
|
91
94
|
thread_info=False,
|
|
92
95
|
custom_repr=(),
|
|
93
|
-
max_variable_length=
|
|
96
|
+
max_variable_length: Optional[int] = DEFAULT_MAX_VARIABLE_LENGTH,
|
|
94
97
|
) -> Callable[..., Any]:
|
|
95
98
|
def inner(cls: Type[T]) -> Type[T]:
|
|
96
99
|
if not pysnooper:
|
|
@@ -108,7 +111,9 @@ def snooper_to_methods( # type: ignore
|
|
|
108
111
|
overwrite,
|
|
109
112
|
thread_info,
|
|
110
113
|
custom_repr,
|
|
111
|
-
max_variable_length
|
|
114
|
+
max_variable_length
|
|
115
|
+
if max_variable_length is not None
|
|
116
|
+
else DEFAULT_MAX_VARIABLE_LENGTH,
|
|
112
117
|
)(method)
|
|
113
118
|
setattr(cls, name, snooper_method)
|
|
114
119
|
return cls
|
|
@@ -231,7 +236,7 @@ def get_first_file(path: Path) -> Path: # pragma: no cover
|
|
|
231
236
|
|
|
232
237
|
|
|
233
238
|
def generate(
|
|
234
|
-
input_: Union[Path, str, ParseResult],
|
|
239
|
+
input_: Union[Path, str, ParseResult, Mapping[str, Any]],
|
|
235
240
|
*,
|
|
236
241
|
input_filename: Optional[str] = None,
|
|
237
242
|
input_file_type: InputFileType = InputFileType.Auto,
|
|
@@ -303,6 +308,7 @@ def generate(
|
|
|
303
308
|
union_mode: Optional[UnionMode] = None,
|
|
304
309
|
output_datetime_class: Optional[DatetimeClassType] = None,
|
|
305
310
|
keyword_only: bool = False,
|
|
311
|
+
no_alias: bool = False,
|
|
306
312
|
) -> None:
|
|
307
313
|
remote_text_cache: DefaultPutDict[str, str] = DefaultPutDict()
|
|
308
314
|
if isinstance(input_, str):
|
|
@@ -353,6 +359,8 @@ def generate(
|
|
|
353
359
|
parser_class = JsonSchemaParser
|
|
354
360
|
|
|
355
361
|
if input_file_type in RAW_DATA_TYPES:
|
|
362
|
+
import json
|
|
363
|
+
|
|
356
364
|
try:
|
|
357
365
|
if isinstance(input_, Path) and input_.is_dir(): # pragma: no cover
|
|
358
366
|
raise Error(f'Input must be a file for {input_file_type}')
|
|
@@ -371,15 +379,33 @@ def generate(
|
|
|
371
379
|
import io
|
|
372
380
|
|
|
373
381
|
obj = get_header_and_first_line(io.StringIO(input_text))
|
|
374
|
-
|
|
382
|
+
elif input_file_type == InputFileType.Yaml:
|
|
375
383
|
obj = load_yaml(
|
|
376
384
|
input_.read_text(encoding=encoding) # type: ignore
|
|
377
385
|
if isinstance(input_, Path)
|
|
378
386
|
else input_text
|
|
379
387
|
)
|
|
388
|
+
elif input_file_type == InputFileType.Json:
|
|
389
|
+
obj = json.loads(
|
|
390
|
+
input_.read_text(encoding=encoding) # type: ignore
|
|
391
|
+
if isinstance(input_, Path)
|
|
392
|
+
else input_text
|
|
393
|
+
)
|
|
394
|
+
elif input_file_type == InputFileType.Dict:
|
|
395
|
+
import ast
|
|
396
|
+
|
|
397
|
+
# Input can be a dict object stored in a python file
|
|
398
|
+
obj = (
|
|
399
|
+
ast.literal_eval(
|
|
400
|
+
input_.read_text(encoding=encoding) # type: ignore
|
|
401
|
+
)
|
|
402
|
+
if isinstance(input_, Path)
|
|
403
|
+
else input_
|
|
404
|
+
)
|
|
405
|
+
else: # pragma: no cover
|
|
406
|
+
raise Error(f'Unsupported input file type: {input_file_type}')
|
|
380
407
|
except: # noqa
|
|
381
408
|
raise Error('Invalid file format')
|
|
382
|
-
import json
|
|
383
409
|
|
|
384
410
|
from genson import SchemaBuilder
|
|
385
411
|
|
|
@@ -403,8 +429,10 @@ def generate(
|
|
|
403
429
|
data_model_types = get_data_model_types(
|
|
404
430
|
output_model_type, target_python_version, output_datetime_class
|
|
405
431
|
)
|
|
432
|
+
source = input_text or input_
|
|
433
|
+
assert not isinstance(source, Mapping)
|
|
406
434
|
parser = parser_class(
|
|
407
|
-
source=
|
|
435
|
+
source=source,
|
|
408
436
|
data_model_type=data_model_types.data_model,
|
|
409
437
|
data_model_root_type=data_model_types.root_model,
|
|
410
438
|
data_model_field_type=data_model_types.field_model,
|
|
@@ -478,6 +506,7 @@ def generate(
|
|
|
478
506
|
default_field_extras=default_field_extras,
|
|
479
507
|
target_datetime_class=output_datetime_class,
|
|
480
508
|
keyword_only=keyword_only,
|
|
509
|
+
no_alias=no_alias,
|
|
481
510
|
**kwargs,
|
|
482
511
|
)
|
|
483
512
|
|
|
@@ -488,7 +517,11 @@ def generate(
|
|
|
488
517
|
input_filename = '<stdin>'
|
|
489
518
|
elif isinstance(input_, ParseResult):
|
|
490
519
|
input_filename = input_.geturl()
|
|
520
|
+
elif input_file_type == InputFileType.Dict:
|
|
521
|
+
# input_ might be a dict object provided directly, and missing a name field
|
|
522
|
+
input_filename = getattr(input_, 'name', '<dict>')
|
|
491
523
|
else:
|
|
524
|
+
assert isinstance(input_, Path)
|
|
492
525
|
input_filename = input_.name
|
|
493
526
|
if not results:
|
|
494
527
|
raise Error('Models not found in the input data')
|
|
@@ -53,7 +53,6 @@ from datamodel_code_generator.arguments import DEFAULT_ENCODING, arg_parser, nam
|
|
|
53
53
|
from datamodel_code_generator.format import (
|
|
54
54
|
DatetimeClassType,
|
|
55
55
|
PythonVersion,
|
|
56
|
-
black_find_project_root,
|
|
57
56
|
is_supported_in_black,
|
|
58
57
|
)
|
|
59
58
|
from datamodel_code_generator.parser import LiteralType
|
|
@@ -86,7 +85,7 @@ signal.signal(signal.SIGINT, sig_int_handler)
|
|
|
86
85
|
|
|
87
86
|
class Config(BaseModel):
|
|
88
87
|
if PYDANTIC_V2:
|
|
89
|
-
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
88
|
+
model_config = ConfigDict(arbitrary_types_allowed=True) # pyright: ignore [reportAssignmentType]
|
|
90
89
|
|
|
91
90
|
def get(self, item: str) -> Any:
|
|
92
91
|
return getattr(self, item)
|
|
@@ -186,8 +185,13 @@ class Config(BaseModel):
|
|
|
186
185
|
|
|
187
186
|
@model_validator(mode='after')
|
|
188
187
|
def validate_keyword_only(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
|
189
|
-
|
|
190
|
-
|
|
188
|
+
output_model_type: DataModelType = values.get('output_model_type') # pyright: ignore [reportAssignmentType]
|
|
189
|
+
python_target: PythonVersion = values.get('target_python_version') # pyright: ignore [reportAssignmentType]
|
|
190
|
+
if (
|
|
191
|
+
values.get('keyword_only')
|
|
192
|
+
and output_model_type == DataModelType.DataclassesDataclass
|
|
193
|
+
and not python_target.has_kw_only_dataclass
|
|
194
|
+
):
|
|
191
195
|
raise Error(
|
|
192
196
|
f'`--keyword-only` requires `--target-python-version` {PythonVersion.PY_310.value} or higher.'
|
|
193
197
|
)
|
|
@@ -215,7 +219,7 @@ class Config(BaseModel):
|
|
|
215
219
|
def validate_each_item(each_item: Any) -> Tuple[str, str]:
|
|
216
220
|
if isinstance(each_item, str): # pragma: no cover
|
|
217
221
|
try:
|
|
218
|
-
field_name, field_value = each_item.split(':', maxsplit=1)
|
|
222
|
+
field_name, field_value = each_item.split(':', maxsplit=1)
|
|
219
223
|
return field_name, field_value.lstrip()
|
|
220
224
|
except ValueError:
|
|
221
225
|
raise Error(f'Invalid http header: {each_item!r}')
|
|
@@ -232,7 +236,7 @@ class Config(BaseModel):
|
|
|
232
236
|
def validate_each_item(each_item: Any) -> Tuple[str, str]:
|
|
233
237
|
if isinstance(each_item, str): # pragma: no cover
|
|
234
238
|
try:
|
|
235
|
-
field_name, field_value = each_item.split('=', maxsplit=1)
|
|
239
|
+
field_name, field_value = each_item.split('=', maxsplit=1)
|
|
236
240
|
return field_name, field_value.lstrip()
|
|
237
241
|
except ValueError:
|
|
238
242
|
raise Error(f'Invalid http query parameter: {each_item!r}')
|
|
@@ -244,14 +248,16 @@ class Config(BaseModel):
|
|
|
244
248
|
|
|
245
249
|
@model_validator(mode='before')
|
|
246
250
|
def validate_additional_imports(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
|
247
|
-
|
|
248
|
-
|
|
251
|
+
additional_imports = values.get('additional_imports')
|
|
252
|
+
if additional_imports is not None:
|
|
253
|
+
values['additional_imports'] = additional_imports.split(',')
|
|
249
254
|
return values
|
|
250
255
|
|
|
251
256
|
@model_validator(mode='before')
|
|
252
257
|
def validate_custom_formatters(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
|
253
|
-
|
|
254
|
-
|
|
258
|
+
custom_formatters = values.get('custom_formatters')
|
|
259
|
+
if custom_formatters is not None:
|
|
260
|
+
values['custom_formatters'] = custom_formatters.split(',')
|
|
255
261
|
return values
|
|
256
262
|
|
|
257
263
|
if PYDANTIC_V2:
|
|
@@ -278,7 +284,7 @@ class Config(BaseModel):
|
|
|
278
284
|
disable_warnings: bool = False
|
|
279
285
|
target_python_version: PythonVersion = PythonVersion.PY_38
|
|
280
286
|
base_class: str = ''
|
|
281
|
-
additional_imports: Optional[List[str]] =
|
|
287
|
+
additional_imports: Optional[List[str]] = None
|
|
282
288
|
custom_template_dir: Optional[Path] = None
|
|
283
289
|
extra_template_data: Optional[TextIOBase] = None
|
|
284
290
|
validation: bool = False
|
|
@@ -341,6 +347,7 @@ class Config(BaseModel):
|
|
|
341
347
|
union_mode: Optional[UnionMode] = None
|
|
342
348
|
output_datetime_class: Optional[DatetimeClassType] = None
|
|
343
349
|
keyword_only: bool = False
|
|
350
|
+
no_alias: bool = False
|
|
344
351
|
|
|
345
352
|
def merge_args(self, args: Namespace) -> None:
|
|
346
353
|
set_args = {
|
|
@@ -360,6 +367,26 @@ class Config(BaseModel):
|
|
|
360
367
|
setattr(self, field_name, getattr(parsed_args, field_name))
|
|
361
368
|
|
|
362
369
|
|
|
370
|
+
def _get_pyproject_toml_config(source: Path) -> Optional[Dict[str, Any]]:
|
|
371
|
+
"""Find and return the [tool.datamodel-codgen] section of the closest
|
|
372
|
+
pyproject.toml if it exists.
|
|
373
|
+
"""
|
|
374
|
+
|
|
375
|
+
current_path = source
|
|
376
|
+
while current_path != current_path.parent:
|
|
377
|
+
if (current_path / 'pyproject.toml').is_file():
|
|
378
|
+
pyproject_toml = load_toml(current_path / 'pyproject.toml')
|
|
379
|
+
if 'datamodel-codegen' in pyproject_toml.get('tool', {}):
|
|
380
|
+
return pyproject_toml['tool']['datamodel-codegen']
|
|
381
|
+
|
|
382
|
+
if (current_path / '.git').exists():
|
|
383
|
+
# Stop early if we see a git repository root.
|
|
384
|
+
return None
|
|
385
|
+
|
|
386
|
+
current_path = current_path.parent
|
|
387
|
+
return None
|
|
388
|
+
|
|
389
|
+
|
|
363
390
|
def main(args: Optional[Sequence[str]] = None) -> Exit:
|
|
364
391
|
"""Main function."""
|
|
365
392
|
|
|
@@ -377,16 +404,9 @@ def main(args: Optional[Sequence[str]] = None) -> Exit:
|
|
|
377
404
|
print(version)
|
|
378
405
|
exit(0)
|
|
379
406
|
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
pyproject_toml: Dict[str, Any] = {
|
|
384
|
-
k.replace('-', '_'): v
|
|
385
|
-
for k, v in load_toml(pyproject_toml_path)
|
|
386
|
-
.get('tool', {})
|
|
387
|
-
.get('datamodel-codegen', {})
|
|
388
|
-
.items()
|
|
389
|
-
}
|
|
407
|
+
pyproject_config = _get_pyproject_toml_config(Path().resolve())
|
|
408
|
+
if pyproject_config is not None:
|
|
409
|
+
pyproject_toml = {k.replace('-', '_'): v for k, v in pyproject_config.items()}
|
|
390
410
|
else:
|
|
391
411
|
pyproject_toml = {}
|
|
392
412
|
|
|
@@ -542,6 +562,7 @@ def main(args: Optional[Sequence[str]] = None) -> Exit:
|
|
|
542
562
|
union_mode=config.union_mode,
|
|
543
563
|
output_datetime_class=config.output_datetime_class,
|
|
544
564
|
keyword_only=config.keyword_only,
|
|
565
|
+
no_alias=config.no_alias,
|
|
545
566
|
)
|
|
546
567
|
return Exit.OK
|
|
547
568
|
except InvalidClassNameError as e:
|
|
@@ -98,7 +98,7 @@ base_options.add_argument(
|
|
|
98
98
|
# ======================================================================================
|
|
99
99
|
model_options.add_argument(
|
|
100
100
|
'--allow-extra-fields',
|
|
101
|
-
help='Allow
|
|
101
|
+
help='Allow passing extra fields, if this flag is not passed, extra fields are forbidden.',
|
|
102
102
|
action='store_true',
|
|
103
103
|
default=None,
|
|
104
104
|
)
|
|
@@ -381,6 +381,13 @@ field_options.add_argument(
|
|
|
381
381
|
choices=[u.value for u in UnionMode],
|
|
382
382
|
default=None,
|
|
383
383
|
)
|
|
384
|
+
field_options.add_argument(
|
|
385
|
+
'--no-alias',
|
|
386
|
+
help="""Do not add a field alias. E.g., if --snake-case-field is used along with a base class, which has an
|
|
387
|
+
alias_generator""",
|
|
388
|
+
action='store_true',
|
|
389
|
+
default=None,
|
|
390
|
+
)
|
|
384
391
|
|
|
385
392
|
# ======================================================================================
|
|
386
393
|
# Options for templating output
|
datamodel_code_generator/http.py
CHANGED
|
@@ -14,7 +14,7 @@ class Import(BaseModel):
|
|
|
14
14
|
reference_path: Optional[str] = None
|
|
15
15
|
|
|
16
16
|
@classmethod
|
|
17
|
-
@lru_cache
|
|
17
|
+
@lru_cache
|
|
18
18
|
def from_full_path(cls, class_path: str) -> Import:
|
|
19
19
|
split_class_path: List[str] = class_path.split('.')
|
|
20
20
|
return Import(
|
|
@@ -43,7 +43,7 @@ class Imports(DefaultDict[Optional[str], Set[str]]):
|
|
|
43
43
|
|
|
44
44
|
def create_line(self, from_: Optional[str], imports: Set[str]) -> str:
|
|
45
45
|
if from_:
|
|
46
|
-
return f
|
|
46
|
+
return f'from {from_} import {", ".join(self._set_alias(from_, imports))}'
|
|
47
47
|
return '\n'.join(f'import {i}' for i in self._set_alias(from_, imports))
|
|
48
48
|
|
|
49
49
|
def dump(self) -> str:
|
|
@@ -1,12 +1,19 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import sys
|
|
3
4
|
from typing import TYPE_CHECKING, Callable, Iterable, List, NamedTuple, Optional, Type
|
|
4
5
|
|
|
6
|
+
from .. import DatetimeClassType, PythonVersion
|
|
5
7
|
from ..types import DataTypeManager as DataTypeManagerABC
|
|
6
8
|
from .base import ConstraintsBase, DataModel, DataModelFieldBase
|
|
7
9
|
|
|
8
10
|
if TYPE_CHECKING:
|
|
9
|
-
from .. import DataModelType
|
|
11
|
+
from .. import DataModelType
|
|
12
|
+
|
|
13
|
+
DEFAULT_TARGET_DATETIME_CLASS = DatetimeClassType.Datetime
|
|
14
|
+
DEFAULT_TARGET_PYTHON_VERSION = PythonVersion(
|
|
15
|
+
f'{sys.version_info.major}.{sys.version_info.minor}'
|
|
16
|
+
)
|
|
10
17
|
|
|
11
18
|
|
|
12
19
|
class DataModelSet(NamedTuple):
|
|
@@ -20,13 +27,15 @@ class DataModelSet(NamedTuple):
|
|
|
20
27
|
|
|
21
28
|
def get_data_model_types(
|
|
22
29
|
data_model_type: DataModelType,
|
|
23
|
-
target_python_version: PythonVersion,
|
|
24
|
-
target_datetime_class: DatetimeClassType,
|
|
30
|
+
target_python_version: PythonVersion = DEFAULT_TARGET_PYTHON_VERSION,
|
|
31
|
+
target_datetime_class: Optional[DatetimeClassType] = None,
|
|
25
32
|
) -> DataModelSet:
|
|
26
33
|
from .. import DataModelType
|
|
27
34
|
from . import dataclass, msgspec, pydantic, pydantic_v2, rootmodel, typed_dict
|
|
28
35
|
from .types import DataTypeManager
|
|
29
36
|
|
|
37
|
+
if target_datetime_class is None:
|
|
38
|
+
target_datetime_class = DEFAULT_TARGET_DATETIME_CLASS
|
|
30
39
|
if data_model_type == DataModelType.PydanticBaseModel:
|
|
31
40
|
return DataModelSet(
|
|
32
41
|
data_model=pydantic.BaseModel,
|
|
@@ -53,13 +62,17 @@ def get_data_model_types(
|
|
|
53
62
|
)
|
|
54
63
|
elif data_model_type == DataModelType.TypingTypedDict:
|
|
55
64
|
return DataModelSet(
|
|
56
|
-
data_model=
|
|
57
|
-
|
|
58
|
-
|
|
65
|
+
data_model=(
|
|
66
|
+
typed_dict.TypedDict
|
|
67
|
+
if target_python_version.has_typed_dict
|
|
68
|
+
else typed_dict.TypedDictBackport
|
|
69
|
+
),
|
|
59
70
|
root_model=rootmodel.RootModel,
|
|
60
|
-
field_model=
|
|
61
|
-
|
|
62
|
-
|
|
71
|
+
field_model=(
|
|
72
|
+
typed_dict.DataModelField
|
|
73
|
+
if target_python_version.has_typed_dict_non_required
|
|
74
|
+
else typed_dict.DataModelFieldBackport
|
|
75
|
+
),
|
|
63
76
|
data_type_manager=DataTypeManager,
|
|
64
77
|
dump_resolve_reference_action=None,
|
|
65
78
|
)
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from abc import ABC, abstractmethod
|
|
2
2
|
from collections import defaultdict
|
|
3
|
+
from copy import deepcopy
|
|
3
4
|
from functools import lru_cache
|
|
4
5
|
from pathlib import Path
|
|
5
6
|
from typing import (
|
|
@@ -52,7 +53,7 @@ class ConstraintsBase(_BaseModel):
|
|
|
52
53
|
unique_items: Optional[bool] = Field(None, alias='uniqueItems')
|
|
53
54
|
_exclude_fields: ClassVar[Set[str]] = {'has_constraints'}
|
|
54
55
|
if PYDANTIC_V2:
|
|
55
|
-
model_config = ConfigDict(
|
|
56
|
+
model_config = ConfigDict( # pyright: ignore [reportAssignmentType]
|
|
56
57
|
arbitrary_types_allowed=True, ignored_types=(cached_property,)
|
|
57
58
|
)
|
|
58
59
|
else:
|
|
@@ -86,7 +87,9 @@ class ConstraintsBase(_BaseModel):
|
|
|
86
87
|
else:
|
|
87
88
|
model_field_constraints = {}
|
|
88
89
|
|
|
89
|
-
if
|
|
90
|
+
if constraints_class is None or not issubclass(
|
|
91
|
+
constraints_class, ConstraintsBase
|
|
92
|
+
): # pragma: no cover
|
|
90
93
|
return None
|
|
91
94
|
|
|
92
95
|
return constraints_class.parse_obj(
|
|
@@ -118,6 +121,7 @@ class DataModelFieldBase(_BaseModel):
|
|
|
118
121
|
_exclude_fields: ClassVar[Set[str]] = {'parent'}
|
|
119
122
|
_pass_fields: ClassVar[Set[str]] = {'parent', 'data_type'}
|
|
120
123
|
can_have_extra_keys: ClassVar[bool] = True
|
|
124
|
+
type_has_null: Optional[bool] = None
|
|
121
125
|
|
|
122
126
|
if not TYPE_CHECKING:
|
|
123
127
|
|
|
@@ -150,6 +154,8 @@ class DataModelFieldBase(_BaseModel):
|
|
|
150
154
|
return get_optional_type(type_hint, self.data_type.use_union_operator)
|
|
151
155
|
return type_hint
|
|
152
156
|
elif self.required:
|
|
157
|
+
if self.type_has_null:
|
|
158
|
+
return get_optional_type(type_hint, self.data_type.use_union_operator)
|
|
153
159
|
return type_hint
|
|
154
160
|
elif self.fall_back_to_nullable:
|
|
155
161
|
return get_optional_type(type_hint, self.data_type.use_union_operator)
|
|
@@ -161,7 +167,7 @@ class DataModelFieldBase(_BaseModel):
|
|
|
161
167
|
type_hint = self.type_hint
|
|
162
168
|
has_union = not self.data_type.use_union_operator and UNION_PREFIX in type_hint
|
|
163
169
|
imports: List[Union[Tuple[Import], Iterator[Import]]] = [
|
|
164
|
-
(
|
|
170
|
+
iter(
|
|
165
171
|
i
|
|
166
172
|
for i in self.data_type.all_imports
|
|
167
173
|
if not (not has_union and i == IMPORT_UNION)
|
|
@@ -225,7 +231,7 @@ class DataModelFieldBase(_BaseModel):
|
|
|
225
231
|
return True
|
|
226
232
|
|
|
227
233
|
|
|
228
|
-
@lru_cache
|
|
234
|
+
@lru_cache
|
|
229
235
|
def get_template(template_file_path: Path) -> Template:
|
|
230
236
|
loader = FileSystemLoader(str(TEMPLATE_DIR / template_file_path.parent))
|
|
231
237
|
environment: Environment = Environment(loader=loader)
|
|
@@ -247,7 +253,7 @@ def get_module_name(name: str, file_path: Optional[Path]) -> str:
|
|
|
247
253
|
|
|
248
254
|
|
|
249
255
|
class TemplateBase(ABC):
|
|
250
|
-
@
|
|
256
|
+
@cached_property
|
|
251
257
|
@abstractmethod
|
|
252
258
|
def template_file_path(self) -> Path:
|
|
253
259
|
raise NotImplementedError
|
|
@@ -316,6 +322,8 @@ class DataModel(TemplateBase, Nullable, ABC):
|
|
|
316
322
|
self.reference.source = self
|
|
317
323
|
|
|
318
324
|
self.extra_template_data = (
|
|
325
|
+
# The supplied defaultdict will either create a new entry,
|
|
326
|
+
# or already contain a predefined entry for this type
|
|
319
327
|
extra_template_data[self.name]
|
|
320
328
|
if extra_template_data is not None
|
|
321
329
|
else defaultdict(dict)
|
|
@@ -327,10 +335,12 @@ class DataModel(TemplateBase, Nullable, ABC):
|
|
|
327
335
|
if base_class.reference:
|
|
328
336
|
base_class.reference.children.append(self)
|
|
329
337
|
|
|
330
|
-
if extra_template_data:
|
|
338
|
+
if extra_template_data is not None:
|
|
331
339
|
all_model_extra_template_data = extra_template_data.get(ALL_MODEL)
|
|
332
340
|
if all_model_extra_template_data:
|
|
333
|
-
|
|
341
|
+
# The deepcopy is needed here to ensure that different models don't
|
|
342
|
+
# end up inadvertently sharing state (such as "base_class_kwargs")
|
|
343
|
+
self.extra_template_data.update(deepcopy(all_model_extra_template_data))
|
|
334
344
|
|
|
335
345
|
self.methods: List[str] = methods or []
|
|
336
346
|
|
|
@@ -415,7 +425,7 @@ class DataModel(TemplateBase, Nullable, ABC):
|
|
|
415
425
|
def class_name(self, class_name: str) -> None:
|
|
416
426
|
if '.' in self.reference.name:
|
|
417
427
|
self.reference.name = (
|
|
418
|
-
f
|
|
428
|
+
f'{self.reference.name.rsplit(".", 1)[0]}.{class_name}'
|
|
419
429
|
)
|
|
420
430
|
else:
|
|
421
431
|
self.reference.name = class_name
|
|
@@ -82,10 +82,22 @@ class Enum(DataModel):
|
|
|
82
82
|
|
|
83
83
|
def find_member(self, value: Any) -> Optional[Member]:
|
|
84
84
|
repr_value = repr(value)
|
|
85
|
-
|
|
86
|
-
|
|
85
|
+
# Remove surrounding quotes from the string representation
|
|
86
|
+
str_value = str(value).strip('\'"')
|
|
87
|
+
|
|
88
|
+
for field in self.fields:
|
|
89
|
+
# Remove surrounding quotes from field default value
|
|
90
|
+
field_default = (field.default or '').strip('\'"')
|
|
91
|
+
|
|
92
|
+
# Compare values after removing quotes
|
|
93
|
+
if field_default == str_value:
|
|
87
94
|
return self.get_member(field)
|
|
88
|
-
|
|
95
|
+
|
|
96
|
+
# Keep original comparison for backwards compatibility
|
|
97
|
+
if field.default == repr_value: # pragma: no cover
|
|
98
|
+
return self.get_member(field)
|
|
99
|
+
|
|
100
|
+
return None
|
|
89
101
|
|
|
90
102
|
@property
|
|
91
103
|
def imports(self) -> Tuple[Import, ...]:
|
|
@@ -31,7 +31,6 @@ from datamodel_code_generator.model.imports import (
|
|
|
31
31
|
IMPORT_MSGSPEC_CONVERT,
|
|
32
32
|
IMPORT_MSGSPEC_FIELD,
|
|
33
33
|
IMPORT_MSGSPEC_META,
|
|
34
|
-
IMPORT_MSGSPEC_STRUCT,
|
|
35
34
|
)
|
|
36
35
|
from datamodel_code_generator.model.pydantic.base_model import (
|
|
37
36
|
Constraints as _Constraints,
|
|
@@ -88,7 +87,7 @@ class RootModel(_RootModel):
|
|
|
88
87
|
class Struct(DataModel):
|
|
89
88
|
TEMPLATE_FILE_PATH: ClassVar[str] = 'msgspec.jinja2'
|
|
90
89
|
BASE_CLASS: ClassVar[str] = 'msgspec.Struct'
|
|
91
|
-
DEFAULT_IMPORTS: ClassVar[Tuple[Import, ...]] = (
|
|
90
|
+
DEFAULT_IMPORTS: ClassVar[Tuple[Import, ...]] = ()
|
|
92
91
|
|
|
93
92
|
def __init__(
|
|
94
93
|
self,
|
|
@@ -123,6 +122,8 @@ class Struct(DataModel):
|
|
|
123
122
|
keyword_only=keyword_only,
|
|
124
123
|
)
|
|
125
124
|
self.extra_template_data.setdefault('base_class_kwargs', {})
|
|
125
|
+
if self.keyword_only:
|
|
126
|
+
self.add_base_class_kwarg('kw_only', 'True')
|
|
126
127
|
|
|
127
128
|
def add_base_class_kwarg(self, name: str, value):
|
|
128
129
|
self.extra_template_data['base_class_kwargs'][name] = value
|
|
@@ -322,4 +322,4 @@ class BaseModel(BaseModelBase):
|
|
|
322
322
|
if config_parameters:
|
|
323
323
|
from datamodel_code_generator.model.pydantic import Config
|
|
324
324
|
|
|
325
|
-
self.extra_template_data['config'] = Config.parse_obj(config_parameters)
|
|
325
|
+
self.extra_template_data['config'] = Config.parse_obj(config_parameters) # pyright: ignore [reportArgumentType]
|
|
@@ -180,7 +180,7 @@ class DataTypeManager(_DataTypeManager):
|
|
|
180
180
|
self.data_type,
|
|
181
181
|
strict_types=self.strict_types,
|
|
182
182
|
pattern_key=self.PATTERN_KEY,
|
|
183
|
-
target_datetime_class=target_datetime_class,
|
|
183
|
+
target_datetime_class=self.target_datetime_class,
|
|
184
184
|
)
|
|
185
185
|
self.strict_type_map: Dict[StrictTypes, DataType] = strict_type_map_factory(
|
|
186
186
|
self.data_type,
|
|
@@ -98,7 +98,7 @@ class DataModelField(DataModelFieldV1):
|
|
|
98
98
|
'max_length',
|
|
99
99
|
'union_mode',
|
|
100
100
|
}
|
|
101
|
-
constraints: Optional[Constraints] = None
|
|
101
|
+
constraints: Optional[Constraints] = None # pyright: ignore [reportIncompatibleVariableOverride]
|
|
102
102
|
_PARSE_METHOD: ClassVar[str] = 'model_validate'
|
|
103
103
|
can_have_extra_keys: ClassVar[bool] = False
|
|
104
104
|
|
|
@@ -234,7 +234,7 @@ class BaseModel(BaseModelBase):
|
|
|
234
234
|
if config_parameters:
|
|
235
235
|
from datamodel_code_generator.model.pydantic_v2 import ConfigDict
|
|
236
236
|
|
|
237
|
-
self.extra_template_data['config'] = ConfigDict.parse_obj(config_parameters)
|
|
237
|
+
self.extra_template_data['config'] = ConfigDict.parse_obj(config_parameters) # pyright: ignore [reportArgumentType]
|
|
238
238
|
self._additional_imports.append(IMPORT_CONFIG_DICT)
|
|
239
239
|
|
|
240
240
|
def _get_config_extra(self) -> Optional[Literal["'allow'", "'forbid'"]]:
|
|
@@ -24,7 +24,10 @@ class DataTypeManager(_DataTypeManager):
|
|
|
24
24
|
) -> Dict[Types, DataType]:
|
|
25
25
|
result = {
|
|
26
26
|
**super().type_map_factory(
|
|
27
|
-
data_type,
|
|
27
|
+
data_type,
|
|
28
|
+
strict_types,
|
|
29
|
+
pattern_key,
|
|
30
|
+
target_datetime_class or DatetimeClassType.Datetime,
|
|
28
31
|
),
|
|
29
32
|
Types.hostname: self.data_type.from_import(
|
|
30
33
|
IMPORT_CONSTR,
|