datamodel-code-generator 0.11.12__py3-none-any.whl → 0.45.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. datamodel_code_generator/__init__.py +654 -185
  2. datamodel_code_generator/__main__.py +872 -388
  3. datamodel_code_generator/arguments.py +798 -0
  4. datamodel_code_generator/cli_options.py +295 -0
  5. datamodel_code_generator/format.py +292 -54
  6. datamodel_code_generator/http.py +85 -10
  7. datamodel_code_generator/imports.py +152 -43
  8. datamodel_code_generator/model/__init__.py +138 -1
  9. datamodel_code_generator/model/base.py +531 -120
  10. datamodel_code_generator/model/dataclass.py +211 -0
  11. datamodel_code_generator/model/enum.py +133 -12
  12. datamodel_code_generator/model/imports.py +22 -0
  13. datamodel_code_generator/model/msgspec.py +462 -0
  14. datamodel_code_generator/model/pydantic/__init__.py +30 -25
  15. datamodel_code_generator/model/pydantic/base_model.py +304 -100
  16. datamodel_code_generator/model/pydantic/custom_root_type.py +11 -2
  17. datamodel_code_generator/model/pydantic/dataclass.py +15 -4
  18. datamodel_code_generator/model/pydantic/imports.py +40 -27
  19. datamodel_code_generator/model/pydantic/types.py +188 -96
  20. datamodel_code_generator/model/pydantic_v2/__init__.py +51 -0
  21. datamodel_code_generator/model/pydantic_v2/base_model.py +268 -0
  22. datamodel_code_generator/model/pydantic_v2/imports.py +15 -0
  23. datamodel_code_generator/model/pydantic_v2/root_model.py +35 -0
  24. datamodel_code_generator/model/pydantic_v2/types.py +143 -0
  25. datamodel_code_generator/model/scalar.py +124 -0
  26. datamodel_code_generator/model/template/Enum.jinja2 +15 -2
  27. datamodel_code_generator/model/template/ScalarTypeAliasAnnotation.jinja2 +6 -0
  28. datamodel_code_generator/model/template/ScalarTypeAliasType.jinja2 +6 -0
  29. datamodel_code_generator/model/template/ScalarTypeStatement.jinja2 +6 -0
  30. datamodel_code_generator/model/template/TypeAliasAnnotation.jinja2 +20 -0
  31. datamodel_code_generator/model/template/TypeAliasType.jinja2 +20 -0
  32. datamodel_code_generator/model/template/TypeStatement.jinja2 +20 -0
  33. datamodel_code_generator/model/template/TypedDict.jinja2 +5 -0
  34. datamodel_code_generator/model/template/TypedDictClass.jinja2 +25 -0
  35. datamodel_code_generator/model/template/TypedDictFunction.jinja2 +24 -0
  36. datamodel_code_generator/model/template/UnionTypeAliasAnnotation.jinja2 +10 -0
  37. datamodel_code_generator/model/template/UnionTypeAliasType.jinja2 +10 -0
  38. datamodel_code_generator/model/template/UnionTypeStatement.jinja2 +10 -0
  39. datamodel_code_generator/model/template/dataclass.jinja2 +50 -0
  40. datamodel_code_generator/model/template/msgspec.jinja2 +55 -0
  41. datamodel_code_generator/model/template/pydantic/BaseModel.jinja2 +17 -4
  42. datamodel_code_generator/model/template/pydantic/BaseModel_root.jinja2 +12 -4
  43. datamodel_code_generator/model/template/pydantic/Config.jinja2 +1 -1
  44. datamodel_code_generator/model/template/pydantic/dataclass.jinja2 +15 -2
  45. datamodel_code_generator/model/template/pydantic_v2/BaseModel.jinja2 +57 -0
  46. datamodel_code_generator/model/template/pydantic_v2/ConfigDict.jinja2 +5 -0
  47. datamodel_code_generator/model/template/pydantic_v2/RootModel.jinja2 +48 -0
  48. datamodel_code_generator/model/type_alias.py +70 -0
  49. datamodel_code_generator/model/typed_dict.py +161 -0
  50. datamodel_code_generator/model/types.py +106 -0
  51. datamodel_code_generator/model/union.py +105 -0
  52. datamodel_code_generator/parser/__init__.py +30 -12
  53. datamodel_code_generator/parser/_graph.py +67 -0
  54. datamodel_code_generator/parser/_scc.py +171 -0
  55. datamodel_code_generator/parser/base.py +2426 -380
  56. datamodel_code_generator/parser/graphql.py +652 -0
  57. datamodel_code_generator/parser/jsonschema.py +2518 -647
  58. datamodel_code_generator/parser/openapi.py +631 -222
  59. datamodel_code_generator/py.typed +0 -0
  60. datamodel_code_generator/pydantic_patch.py +28 -0
  61. datamodel_code_generator/reference.py +672 -290
  62. datamodel_code_generator/types.py +521 -145
  63. datamodel_code_generator/util.py +155 -0
  64. datamodel_code_generator/watch.py +65 -0
  65. datamodel_code_generator-0.45.0.dist-info/METADATA +301 -0
  66. datamodel_code_generator-0.45.0.dist-info/RECORD +69 -0
  67. {datamodel_code_generator-0.11.12.dist-info → datamodel_code_generator-0.45.0.dist-info}/WHEEL +1 -1
  68. datamodel_code_generator-0.45.0.dist-info/entry_points.txt +2 -0
  69. datamodel_code_generator/version.py +0 -1
  70. datamodel_code_generator-0.11.12.dist-info/METADATA +0 -440
  71. datamodel_code_generator-0.11.12.dist-info/RECORD +0 -31
  72. datamodel_code_generator-0.11.12.dist-info/entry_points.txt +0 -3
  73. {datamodel_code_generator-0.11.12.dist-info → datamodel_code_generator-0.45.0.dist-info/licenses}/LICENSE +0 -0
@@ -1,443 +1,858 @@
1
- #! /usr/bin/env python
1
+ """Main module for datamodel-code-generator CLI."""
2
2
 
3
- """
4
- Main function.
5
- """
3
+ from __future__ import annotations
6
4
 
5
+ import difflib
7
6
  import json
8
- import locale
7
+ import shlex
9
8
  import signal
10
9
  import sys
11
- from argparse import ArgumentParser, FileType, Namespace
10
+ import tempfile
11
+ import warnings
12
12
  from collections import defaultdict
13
+ from collections.abc import Sequence # noqa: TC003 # pydantic needs it
13
14
  from enum import IntEnum
14
15
  from io import TextIOBase
15
16
  from pathlib import Path
16
- from typing import (
17
- Any,
18
- DefaultDict,
19
- Dict,
20
- List,
21
- Optional,
22
- Sequence,
23
- Set,
24
- Tuple,
25
- Union,
26
- cast,
27
- )
17
+ from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union, cast
28
18
  from urllib.parse import ParseResult, urlparse
29
19
 
30
20
  import argcomplete
31
- import black
32
- import toml
33
- from pydantic import BaseModel, root_validator, validator
21
+ from pydantic import BaseModel
22
+ from typing_extensions import TypeAlias
34
23
 
35
24
  from datamodel_code_generator import (
36
- DEFAULT_BASE_CLASS,
25
+ DEFAULT_SHARED_MODULE_NAME,
26
+ AllExportsCollisionStrategy,
27
+ AllExportsScope,
28
+ AllOfMergeMode,
29
+ DataclassArguments,
30
+ DataModelType,
37
31
  Error,
38
32
  InputFileType,
39
33
  InvalidClassNameError,
34
+ ModuleSplitMode,
40
35
  OpenAPIScope,
36
+ ReadOnlyWriteOnlyModelType,
37
+ ReuseScope,
41
38
  enable_debug_message,
42
39
  generate,
43
40
  )
44
- from datamodel_code_generator.format import PythonVersion, is_supported_in_black
45
- from datamodel_code_generator.parser import LiteralType
41
+ from datamodel_code_generator.arguments import DEFAULT_ENCODING, arg_parser, namespace
42
+ from datamodel_code_generator.format import (
43
+ DEFAULT_FORMATTERS,
44
+ DatetimeClassType,
45
+ Formatter,
46
+ PythonVersion,
47
+ PythonVersionMin,
48
+ _get_black,
49
+ is_supported_in_black,
50
+ )
51
+ from datamodel_code_generator.model.pydantic_v2 import UnionMode # noqa: TC001 # needed for pydantic
52
+ from datamodel_code_generator.parser import LiteralType # noqa: TC001 # needed for pydantic
46
53
  from datamodel_code_generator.reference import is_url
47
- from datamodel_code_generator.types import StrictTypes
54
+ from datamodel_code_generator.types import StrictTypes # noqa: TC001 # needed for pydantic
55
+ from datamodel_code_generator.util import (
56
+ PYDANTIC_V2,
57
+ ConfigDict,
58
+ field_validator,
59
+ load_toml,
60
+ model_validator,
61
+ )
62
+
63
+ if TYPE_CHECKING:
64
+ from argparse import Namespace
65
+
66
+ from typing_extensions import Self
67
+
68
+ # Options that should be excluded from pyproject.toml config generation
69
+ EXCLUDED_CONFIG_OPTIONS: frozenset[str] = frozenset({
70
+ "check",
71
+ "generate_pyproject_config",
72
+ "generate_cli_command",
73
+ "ignore_pyproject",
74
+ "profile",
75
+ "version",
76
+ "help",
77
+ "debug",
78
+ "no_color",
79
+ "disable_warnings",
80
+ "watch",
81
+ "watch_delay",
82
+ })
83
+
84
+ BOOLEAN_OPTIONAL_OPTIONS: frozenset[str] = frozenset({
85
+ "use_specialized_enum",
86
+ })
48
87
 
49
88
 
50
89
  class Exit(IntEnum):
51
90
  """Exit reasons."""
52
91
 
53
92
  OK = 0
54
- ERROR = 1
55
- KeyboardInterrupt = 2
93
+ DIFF = 1
94
+ ERROR = 2
95
+ KeyboardInterrupt = 3
56
96
 
57
97
 
58
98
  def sig_int_handler(_: int, __: Any) -> None: # pragma: no cover
59
- exit(Exit.OK)
99
+ """Handle SIGINT signal gracefully."""
100
+ sys.exit(Exit.OK)
60
101
 
61
102
 
62
103
  signal.signal(signal.SIGINT, sig_int_handler)
63
104
 
64
- DEFAULT_ENCODING = locale.getpreferredencoding()
65
-
66
- arg_parser = ArgumentParser()
67
- arg_parser.add_argument(
68
- '--input',
69
- help='Input file/directory (default: stdin)',
70
- )
71
- arg_parser.add_argument(
72
- '--url',
73
- help='Input file URL. `--input` is ignore when `--url` is used',
74
- )
75
-
76
- arg_parser.add_argument(
77
- '--http-headers',
78
- nargs='+',
79
- metavar='HTTP_HEADER',
80
- help='Set headers in HTTP requests to the remote host. (example: "Authorization: Basic dXNlcjpwYXNz")',
81
- )
82
- arg_parser.add_argument(
83
- '--input-file-type',
84
- help='Input file type (default: auto)',
85
- choices=[i.value for i in InputFileType],
86
- )
87
- arg_parser.add_argument(
88
- '--openapi-scopes',
89
- help='Scopes of OpenAPI model generation (default: schemas)',
90
- choices=[o.value for o in OpenAPIScope],
91
- nargs='+',
92
- default=[OpenAPIScope.Schemas.value],
93
- )
94
- arg_parser.add_argument('--output', help='Output file (default: stdout)')
95
-
96
- arg_parser.add_argument(
97
- '--base-class',
98
- help='Base Class (default: pydantic.BaseModel)',
99
- type=str,
100
- )
101
- arg_parser.add_argument(
102
- '--field-constraints',
103
- help='Use field constraints and not con* annotations',
104
- action='store_true',
105
- default=None,
106
- )
107
- arg_parser.add_argument(
108
- '--use-annotated',
109
- help='Use typing.Annotated for Field(). Also, `--field-constraints` option will be enabled.',
110
- action='store_true',
111
- default=None,
112
- )
113
- arg_parser.add_argument(
114
- '--field-extra-keys',
115
- help='Add extra keys to field parameters',
116
- type=str,
117
- nargs='+',
118
- )
119
- arg_parser.add_argument(
120
- '--field-include-all-keys',
121
- help='Add all keys to field parameters',
122
- action='store_true',
123
- default=None,
124
- )
125
- arg_parser.add_argument(
126
- '--snake-case-field',
127
- help='Change camel-case field name to snake-case',
128
- action='store_true',
129
- default=None,
130
- )
131
- arg_parser.add_argument(
132
- '--strip-default-none',
133
- help='Strip default None on fields',
134
- action='store_true',
135
- default=None,
136
- )
137
- arg_parser.add_argument(
138
- '--disable-appending-item-suffix',
139
- help='Disable appending `Item` suffix to model name in an array',
140
- action='store_true',
141
- default=None,
142
- )
143
- arg_parser.add_argument(
144
- '--allow-population-by-field-name',
145
- help='Allow population by field name',
146
- action='store_true',
147
- default=None,
148
- )
149
-
150
- arg_parser.add_argument(
151
- '--enable-faux-immutability',
152
- help='Enable faux immutability',
153
- action='store_true',
154
- default=None,
155
- )
156
-
157
- arg_parser.add_argument(
158
- '--use-default',
159
- help='Use default value even if a field is required',
160
- action='store_true',
161
- default=None,
162
- )
163
-
164
- arg_parser.add_argument(
165
- '--force-optional',
166
- help='Force optional for required fields',
167
- action='store_true',
168
- default=None,
169
- )
170
-
171
- arg_parser.add_argument(
172
- '--strict-nullable',
173
- help='Treat default field as a non-nullable field (Only OpenAPI)',
174
- action='store_true',
175
- default=None,
176
- )
177
-
178
- arg_parser.add_argument(
179
- '--strict-types',
180
- help='Use strict types',
181
- choices=[t.value for t in StrictTypes],
182
- nargs='+',
183
- )
184
-
185
- arg_parser.add_argument(
186
- '--disable-timestamp',
187
- help='Disable timestamp on file headers',
188
- action='store_true',
189
- default=None,
190
- )
191
-
192
- arg_parser.add_argument(
193
- '--use-standard-collections',
194
- help='Use standard collections for type hinting (list, dict)',
195
- action='store_true',
196
- default=None,
197
- )
198
-
199
- arg_parser.add_argument(
200
- '--use-generic-container-types',
201
- help='Use generic container types for type hinting (typing.Sequence, typing.Mapping). '
202
- 'If `--use-standard-collections` option is set, then import from collections.abc instead of typing',
203
- action='store_true',
204
- default=None,
205
- )
206
-
207
- arg_parser.add_argument(
208
- '--use-schema-description',
209
- help='Use schema description to populate class docstring',
210
- action='store_true',
211
- default=None,
212
- )
213
-
214
- arg_parser.add_argument(
215
- '--reuse-model',
216
- help='Re-use models on the field when a module has the model with the same content',
217
- action='store_true',
218
- default=None,
219
- )
220
-
221
- arg_parser.add_argument(
222
- '--enum-field-as-literal',
223
- help='Parse enum field as literal. '
224
- 'all: all enum field type are Literal. '
225
- 'one: field type is Literal when an enum has only one possible value',
226
- choices=[l.value for l in LiteralType],
227
- default=None,
228
- )
229
-
230
- arg_parser.add_argument(
231
- '--set-default-enum-member',
232
- help='Set enum members as default values for enum field',
233
- action='store_true',
234
- default=None,
235
- )
236
-
237
- arg_parser.add_argument(
238
- '--empty-enum-field-name',
239
- help='Set field name when enum value is empty (default: `_`)',
240
- default=None,
241
- )
242
105
 
106
+ class Config(BaseModel):
107
+ """Configuration model for code generation."""
243
108
 
244
- arg_parser.add_argument(
245
- '--class-name',
246
- help='Set class name of root model',
247
- default=None,
248
- )
109
+ if PYDANTIC_V2:
110
+ model_config = ConfigDict(arbitrary_types_allowed=True) # pyright: ignore[reportAssignmentType]
249
111
 
250
- arg_parser.add_argument(
251
- '--use-title-as-name',
252
- help='use titles as class names of models',
253
- action='store_true',
254
- default=None,
255
- )
112
+ def get(self, item: str) -> Any: # pragma: no cover
113
+ """Get attribute value by name."""
114
+ return getattr(self, item)
256
115
 
257
- arg_parser.add_argument(
258
- '--custom-template-dir', help='Custom template directory', type=str
259
- )
260
- arg_parser.add_argument(
261
- '--extra-template-data', help='Extra template data', type=FileType('rt')
262
- )
263
- arg_parser.add_argument('--aliases', help='Alias mapping file', type=FileType('rt'))
264
- arg_parser.add_argument(
265
- '--target-python-version',
266
- help='target python version (default: 3.7)',
267
- choices=[v.value for v in PythonVersion],
268
- )
116
+ def __getitem__(self, item: str) -> Any: # pragma: no cover
117
+ """Get item by key."""
118
+ return self.get(item)
269
119
 
270
- arg_parser.add_argument(
271
- '--wrap-string-literal',
272
- help='Wrap string literal by using black `experimental-string-processing` option (require black 20.8b0 or later)',
273
- action='store_true',
274
- default=None,
275
- )
120
+ @classmethod
121
+ def parse_obj(cls, obj: Any) -> Self:
122
+ """Parse object into Config model."""
123
+ return cls.model_validate(obj)
276
124
 
277
- arg_parser.add_argument(
278
- '--validation',
279
- help='Enable validation (Only OpenAPI)',
280
- action='store_true',
281
- default=None,
282
- )
125
+ @classmethod
126
+ def get_fields(cls) -> dict[str, Any]:
127
+ """Get model fields."""
128
+ return cls.model_fields
283
129
 
284
- arg_parser.add_argument(
285
- '--encoding',
286
- help=f'The encoding of input and output (default: {DEFAULT_ENCODING})',
287
- default=DEFAULT_ENCODING,
288
- )
130
+ else:
289
131
 
290
- arg_parser.add_argument(
291
- '--debug', help='show debug message', action='store_true', default=None
292
- )
293
- arg_parser.add_argument('--version', help='show version', action='store_true')
132
+ class Config:
133
+ """Pydantic v1 configuration."""
294
134
 
135
+ # Pydantic 1.5.1 doesn't support validate_assignment correctly
136
+ arbitrary_types_allowed = (TextIOBase,)
295
137
 
296
- class Config(BaseModel):
297
- class Config:
298
- validate_assignment = True
299
- arbitrary_types_allowed = (TextIOBase,)
138
+ @classmethod
139
+ def get_fields(cls) -> dict[str, Any]:
140
+ """Get model fields."""
141
+ return cls.__fields__
300
142
 
301
- @validator("aliases", "extra_template_data", pre=True)
302
- def validate_file(cls, value: Any) -> Optional[TextIOBase]:
303
- if value is None or isinstance(value, TextIOBase):
143
+ @field_validator("aliases", "extra_template_data", "custom_formatters_kwargs", mode="before")
144
+ def validate_file(cls, value: Any) -> TextIOBase | None: # noqa: N805
145
+ """Validate and open file path."""
146
+ if value is None: # pragma: no cover
304
147
  return value
305
- return cast(TextIOBase, Path(value).expanduser().resolve().open("rt"))
306
148
 
307
- @validator("input", "output", "custom_template_dir", pre=True)
308
- def validate_path(cls, value: Any) -> Optional[Path]:
149
+ path = Path(value)
150
+ if path.is_file():
151
+ return cast("TextIOBase", path.expanduser().resolve().open("rt"))
152
+
153
+ msg = f"A file was expected but {value} is not a file." # pragma: no cover
154
+ raise Error(msg) # pragma: no cover
155
+
156
+ @field_validator(
157
+ "input",
158
+ "output",
159
+ "custom_template_dir",
160
+ "custom_file_header_path",
161
+ mode="before",
162
+ )
163
+ def validate_path(cls, value: Any) -> Path | None: # noqa: N805
164
+ """Validate and resolve path."""
309
165
  if value is None or isinstance(value, Path):
310
166
  return value # pragma: no cover
311
167
  return Path(value).expanduser().resolve()
312
168
 
313
- @validator('url', pre=True)
314
- def validate_url(cls, value: Any) -> Optional[ParseResult]:
169
+ @field_validator("url", mode="before")
170
+ def validate_url(cls, value: Any) -> ParseResult | None: # noqa: N805
171
+ """Validate and parse URL."""
315
172
  if isinstance(value, str) and is_url(value): # pragma: no cover
316
173
  return urlparse(value)
317
- elif value is None: # pragma: no cover
174
+ if value is None: # pragma: no cover
318
175
  return None
319
- raise Error(
320
- f'This protocol doesn\'t support only http/https. --input={value}'
321
- ) # pragma: no cover
322
-
323
- @root_validator
324
- def validate_use_generic_container_types(
325
- cls, values: Dict[str, Any]
326
- ) -> Dict[str, Any]:
327
- if values.get('use_generic_container_types'):
328
- target_python_version: PythonVersion = values['target_python_version']
329
- if target_python_version == target_python_version.PY_36:
330
- raise Error(
331
- f"`--use-generic-container-types` can not be used with `--target-python_version` {target_python_version.PY_36.value}.\n" # type: ignore
332
- " The version will be not supported in a future version"
333
- )
176
+ msg = f"Unsupported URL scheme. Supported: http, https, file. --input={value}" # pragma: no cover
177
+ raise Error(msg) # pragma: no cover
178
+
179
+ # Pydantic 1.5.1 doesn't support each_item=True correctly
180
+ @field_validator("http_headers", mode="before")
181
+ def validate_http_headers(cls, value: Any) -> list[tuple[str, str]] | None: # noqa: N805
182
+ """Validate HTTP headers."""
183
+ if value is None: # pragma: no cover
184
+ return None
185
+
186
+ def validate_each_item(each_item: str | tuple[str, str]) -> tuple[str, str]:
187
+ if isinstance(each_item, str): # pragma: no cover
188
+ try:
189
+ field_name, field_value = each_item.split(":", maxsplit=1)
190
+ return field_name, field_value.lstrip()
191
+ except ValueError as exc:
192
+ msg = f"Invalid http header: {each_item!r}"
193
+ raise Error(msg) from exc
194
+ return each_item # pragma: no cover
195
+
196
+ if isinstance(value, list):
197
+ return [validate_each_item(each_item) for each_item in value]
198
+ msg = f"Invalid http_headers value: {value!r}" # pragma: no cover
199
+ raise Error(msg) # pragma: no cover
200
+
201
+ @field_validator("http_query_parameters", mode="before")
202
+ def validate_http_query_parameters(cls, value: Any) -> list[tuple[str, str]] | None: # noqa: N805
203
+ """Validate HTTP query parameters."""
204
+ if value is None: # pragma: no cover
205
+ return None
206
+
207
+ def validate_each_item(each_item: str | tuple[str, str]) -> tuple[str, str]:
208
+ if isinstance(each_item, str): # pragma: no cover
209
+ try:
210
+ field_name, field_value = each_item.split("=", maxsplit=1)
211
+ return field_name, field_value.lstrip()
212
+ except ValueError as exc:
213
+ msg = f"Invalid http query parameter: {each_item!r}"
214
+ raise Error(msg) from exc
215
+ return each_item # pragma: no cover
216
+
217
+ if isinstance(value, list):
218
+ return [validate_each_item(each_item) for each_item in value]
219
+ msg = f"Invalid http_query_parameters value: {value!r}" # pragma: no cover
220
+ raise Error(msg) # pragma: no cover
221
+
222
+ @model_validator(mode="before")
223
+ def validate_additional_imports(cls, values: dict[str, Any]) -> dict[str, Any]: # noqa: N805
224
+ """Validate and split additional imports."""
225
+ additional_imports = values.get("additional_imports")
226
+ if additional_imports is not None:
227
+ values["additional_imports"] = additional_imports.split(",")
334
228
  return values
335
229
 
336
- @validator('http_headers', pre=True, each_item=True)
337
- def validate_http_headers(cls, value: Any) -> Optional[Tuple[str, str]]:
338
- if isinstance(value, str): # pragma: no cover
339
- try:
340
- field_name, field_value = value.split(':', maxsplit=1) # type: str, str
341
- return field_name, field_value.lstrip()
342
- except ValueError:
343
- raise Error(f'Invalid http header: {value!r}')
344
- return value # pragma: no cover
345
-
346
- @root_validator()
347
- def validate_root(cls, values: Dict[str, Any]) -> Dict[str, Any]:
348
- if values.get('use_annotated'):
349
- values['field_constraints'] = True
230
+ @model_validator(mode="before")
231
+ def validate_custom_formatters(cls, values: dict[str, Any]) -> dict[str, Any]: # noqa: N805
232
+ """Validate and split custom formatters."""
233
+ custom_formatters = values.get("custom_formatters")
234
+ if custom_formatters is not None:
235
+ values["custom_formatters"] = custom_formatters.split(",")
350
236
  return values
351
237
 
352
- input: Optional[Union[Path, str]]
238
+ __validate_output_datetime_class_err: ClassVar[str] = (
239
+ '`--output-datetime-class` only allows "datetime" for '
240
+ f"`--output-model-type` {DataModelType.DataclassesDataclass.value}"
241
+ )
242
+
243
+ __validate_original_field_name_delimiter_err: ClassVar[str] = (
244
+ "`--original-field-name-delimiter` can not be used without `--snake-case-field`."
245
+ )
246
+
247
+ __validate_custom_file_header_err: ClassVar[str] = (
248
+ "`--custom_file_header_path` can not be used with `--custom_file_header`."
249
+ )
250
+ __validate_keyword_only_err: ClassVar[str] = (
251
+ f"`--keyword-only` requires `--target-python-version` {PythonVersion.PY_310.value} or higher."
252
+ )
253
+
254
+ __validate_all_exports_collision_strategy_err: ClassVar[str] = (
255
+ "`--all-exports-collision-strategy` can only be used with `--all-exports-scope=recursive`."
256
+ )
257
+
258
+ if PYDANTIC_V2:
259
+
260
+ @model_validator() # pyright: ignore[reportArgumentType]
261
+ def validate_output_datetime_class(self: Self) -> Self: # pyright: ignore[reportRedeclaration]
262
+ """Validate output datetime class compatibility."""
263
+ datetime_class_type: DatetimeClassType | None = self.output_datetime_class
264
+ if (
265
+ datetime_class_type
266
+ and datetime_class_type is not DatetimeClassType.Datetime
267
+ and self.output_model_type == DataModelType.DataclassesDataclass
268
+ ):
269
+ raise Error(self.__validate_output_datetime_class_err)
270
+ return self
271
+
272
+ @model_validator() # pyright: ignore[reportArgumentType]
273
+ def validate_original_field_name_delimiter(self: Self) -> Self: # pyright: ignore[reportRedeclaration]
274
+ """Validate original field name delimiter requires snake case."""
275
+ if self.original_field_name_delimiter is not None and not self.snake_case_field:
276
+ raise Error(self.__validate_original_field_name_delimiter_err)
277
+ return self
278
+
279
+ @model_validator() # pyright: ignore[reportArgumentType]
280
+ def validate_custom_file_header(self: Self) -> Self: # pyright: ignore[reportRedeclaration]
281
+ """Validate custom file header options are mutually exclusive."""
282
+ if self.custom_file_header and self.custom_file_header_path:
283
+ raise Error(self.__validate_custom_file_header_err)
284
+ return self
285
+
286
+ @model_validator() # pyright: ignore[reportArgumentType]
287
+ def validate_keyword_only(self: Self) -> Self: # pyright: ignore[reportRedeclaration]
288
+ """Validate keyword-only compatibility with target Python version."""
289
+ output_model_type: DataModelType = self.output_model_type
290
+ python_target: PythonVersion = self.target_python_version
291
+ if (
292
+ self.keyword_only
293
+ and output_model_type == DataModelType.DataclassesDataclass
294
+ and not python_target.has_kw_only_dataclass
295
+ ):
296
+ raise Error(self.__validate_keyword_only_err)
297
+ return self
298
+
299
+ @model_validator() # pyright: ignore[reportArgumentType]
300
+ def validate_root(self: Self) -> Self: # pyright: ignore[reportRedeclaration]
301
+ """Validate root model configuration."""
302
+ if self.use_annotated:
303
+ self.field_constraints = True
304
+ return self
305
+
306
+ @model_validator() # pyright: ignore[reportArgumentType]
307
+ def validate_all_exports_collision_strategy(self: Self) -> Self: # pyright: ignore[reportRedeclaration]
308
+ """Validate all_exports_collision_strategy requires recursive scope."""
309
+ if self.all_exports_collision_strategy is not None and self.all_exports_scope != AllExportsScope.Recursive:
310
+ raise Error(self.__validate_all_exports_collision_strategy_err)
311
+ return self
312
+
313
+ else:
314
+
315
+ @model_validator() # pyright: ignore[reportArgumentType]
316
+ def validate_output_datetime_class(cls, values: dict[str, Any]) -> dict[str, Any]: # noqa: N805
317
+ """Validate output datetime class compatibility."""
318
+ datetime_class_type: DatetimeClassType | None = values.get("output_datetime_class")
319
+ if (
320
+ datetime_class_type
321
+ and datetime_class_type is not DatetimeClassType.Datetime
322
+ and values.get("output_model_type") == DataModelType.DataclassesDataclass
323
+ ):
324
+ raise Error(cls.__validate_output_datetime_class_err)
325
+ return values
326
+
327
+ @model_validator() # pyright: ignore[reportArgumentType]
328
+ def validate_original_field_name_delimiter(cls, values: dict[str, Any]) -> dict[str, Any]: # noqa: N805
329
+ """Validate original field name delimiter requires snake case."""
330
+ if values.get("original_field_name_delimiter") is not None and not values.get("snake_case_field"):
331
+ raise Error(cls.__validate_original_field_name_delimiter_err)
332
+ return values
333
+
334
+ @model_validator() # pyright: ignore[reportArgumentType]
335
+ def validate_custom_file_header(cls, values: dict[str, Any]) -> dict[str, Any]: # noqa: N805
336
+ """Validate custom file header options are mutually exclusive."""
337
+ if values.get("custom_file_header") and values.get("custom_file_header_path"):
338
+ raise Error(cls.__validate_custom_file_header_err)
339
+ return values
340
+
341
+ @model_validator() # pyright: ignore[reportArgumentType]
342
+ def validate_keyword_only(cls, values: dict[str, Any]) -> dict[str, Any]: # noqa: N805
343
+ """Validate keyword-only compatibility with target Python version."""
344
+ output_model_type: DataModelType = cast("DataModelType", values.get("output_model_type"))
345
+ python_target: PythonVersion = cast("PythonVersion", values.get("target_python_version"))
346
+ if (
347
+ values.get("keyword_only")
348
+ and output_model_type == DataModelType.DataclassesDataclass
349
+ and not python_target.has_kw_only_dataclass
350
+ ):
351
+ raise Error(cls.__validate_keyword_only_err)
352
+ return values
353
+
354
+ @model_validator() # pyright: ignore[reportArgumentType]
355
+ def validate_root(cls, values: dict[str, Any]) -> dict[str, Any]: # noqa: N805
356
+ """Validate root model configuration."""
357
+ if values.get("use_annotated"):
358
+ values["field_constraints"] = True
359
+ return values
360
+
361
+ @model_validator() # pyright: ignore[reportArgumentType]
362
+ def validate_all_exports_collision_strategy(cls, values: dict[str, Any]) -> dict[str, Any]: # noqa: N805
363
+ """Validate all_exports_collision_strategy requires recursive scope."""
364
+ if (
365
+ values.get("all_exports_collision_strategy") is not None
366
+ and values.get("all_exports_scope") != AllExportsScope.Recursive
367
+ ):
368
+ raise Error(cls.__validate_all_exports_collision_strategy_err)
369
+ return values
370
+
371
+ input: Optional[Union[Path, str]] = None # noqa: UP007, UP045
353
372
  input_file_type: InputFileType = InputFileType.Auto
354
- output: Optional[Path]
373
+ output_model_type: DataModelType = DataModelType.PydanticBaseModel
374
+ output: Optional[Path] = None # noqa: UP045
375
+ check: bool = False
355
376
  debug: bool = False
356
- target_python_version: PythonVersion = PythonVersion.PY_37
357
- base_class: str = DEFAULT_BASE_CLASS
358
- custom_template_dir: Optional[Path]
359
- extra_template_data: Optional[TextIOBase]
377
+ disable_warnings: bool = False
378
+ target_python_version: PythonVersion = PythonVersionMin
379
+ base_class: str = ""
380
+ additional_imports: Optional[list[str]] = None # noqa: UP045
381
+ custom_template_dir: Optional[Path] = None # noqa: UP045
382
+ extra_template_data: Optional[TextIOBase] = None # noqa: UP045
360
383
  validation: bool = False
361
384
  field_constraints: bool = False
362
385
  snake_case_field: bool = False
363
386
  strip_default_none: bool = False
364
- aliases: Optional[TextIOBase]
387
+ aliases: Optional[TextIOBase] = None # noqa: UP045
365
388
  disable_timestamp: bool = False
389
+ enable_version_header: bool = False
390
+ enable_command_header: bool = False
366
391
  allow_population_by_field_name: bool = False
392
+ allow_extra_fields: bool = False
393
+ extra_fields: Optional[str] = None # noqa: UP045
367
394
  use_default: bool = False
368
395
  force_optional: bool = False
369
- class_name: Optional[str] = None
396
+ class_name: Optional[str] = None # noqa: UP045
370
397
  use_standard_collections: bool = False
371
398
  use_schema_description: bool = False
399
+ use_field_description: bool = False
400
+ use_attribute_docstrings: bool = False
401
+ use_inline_field_description: bool = False
402
+ use_default_kwarg: bool = False
372
403
  reuse_model: bool = False
373
- encoding: str = 'utf-8'
374
- enum_field_as_literal: Optional[LiteralType] = None
404
+ reuse_scope: ReuseScope = ReuseScope.Module
405
+ shared_module_name: str = DEFAULT_SHARED_MODULE_NAME
406
+ encoding: str = DEFAULT_ENCODING
407
+ enum_field_as_literal: Optional[LiteralType] = None # noqa: UP045
408
+ use_one_literal_as_default: bool = False
409
+ use_enum_values_in_discriminator: bool = False
375
410
  set_default_enum_member: bool = False
411
+ use_subclass_enum: bool = False
412
+ use_specialized_enum: bool = True
376
413
  strict_nullable: bool = False
377
414
  use_generic_container_types: bool = False
415
+ use_union_operator: bool = False
378
416
  enable_faux_immutability: bool = False
379
- url: Optional[ParseResult] = None
417
+ url: Optional[ParseResult] = None # noqa: UP045
380
418
  disable_appending_item_suffix: bool = False
381
- strict_types: List[StrictTypes] = []
382
- empty_enum_field_name: Optional[str] = None
383
- field_extra_keys: Optional[Set[str]] = None
419
+ strict_types: list[StrictTypes] = []
420
+ empty_enum_field_name: Optional[str] = None # noqa: UP045
421
+ field_extra_keys: Optional[set[str]] = None # noqa: UP045
384
422
  field_include_all_keys: bool = False
385
- openapi_scopes: Optional[List[OpenAPIScope]] = None
386
- wrap_string_literal: Optional[bool] = None
423
+ field_extra_keys_without_x_prefix: Optional[set[str]] = None # noqa: UP045
424
+ openapi_scopes: Optional[list[OpenAPIScope]] = [OpenAPIScope.Schemas] # noqa: UP045
425
+ include_path_parameters: bool = False
426
+ wrap_string_literal: Optional[bool] = None # noqa: UP045
387
427
  use_title_as_name: bool = False
388
- http_headers: Optional[Sequence[Tuple[str, str]]] = None
428
+ use_operation_id_as_name: bool = False
429
+ use_unique_items_as_set: bool = False
430
+ allof_merge_mode: AllOfMergeMode = AllOfMergeMode.Constraints
431
+ http_headers: Optional[Sequence[tuple[str, str]]] = None # noqa: UP045
432
+ http_ignore_tls: bool = False
389
433
  use_annotated: bool = False
434
+ use_serialize_as_any: bool = False
435
+ use_non_positive_negative_number_constrained_types: bool = False
436
+ use_decimal_for_multiple_of: bool = False
437
+ original_field_name_delimiter: Optional[str] = None # noqa: UP045
438
+ use_double_quotes: bool = False
439
+ collapse_root_models: bool = False
440
+ skip_root_model: bool = False
441
+ use_type_alias: bool = False
442
+ special_field_name_prefix: Optional[str] = None # noqa: UP045
443
+ remove_special_field_name_prefix: bool = False
444
+ capitalise_enum_members: bool = False
445
+ keep_model_order: bool = False
446
+ custom_file_header: Optional[str] = None # noqa: UP045
447
+ custom_file_header_path: Optional[Path] = None # noqa: UP045
448
+ custom_formatters: Optional[list[str]] = None # noqa: UP045
449
+ custom_formatters_kwargs: Optional[TextIOBase] = None # noqa: UP045
450
+ use_pendulum: bool = False
451
+ http_query_parameters: Optional[Sequence[tuple[str, str]]] = None # noqa: UP045
452
+ treat_dot_as_module: bool = False
453
+ use_exact_imports: bool = False
454
+ union_mode: Optional[UnionMode] = None # noqa: UP045
455
+ output_datetime_class: Optional[DatetimeClassType] = None # noqa: UP045
456
+ keyword_only: bool = False
457
+ frozen_dataclasses: bool = False
458
+ dataclass_arguments: Optional[DataclassArguments] = None # noqa: UP045
459
+ no_alias: bool = False
460
+ use_frozen_field: bool = False
461
+ formatters: list[Formatter] = DEFAULT_FORMATTERS
462
+ parent_scoped_naming: bool = False
463
+ disable_future_imports: bool = False
464
+ type_mappings: Optional[list[str]] = None # noqa: UP045
465
+ read_only_write_only_model_type: Optional[ReadOnlyWriteOnlyModelType] = None # noqa: UP045
466
+ all_exports_scope: Optional[AllExportsScope] = None # noqa: UP045
467
+ all_exports_collision_strategy: Optional[AllExportsCollisionStrategy] = None # noqa: UP045
468
+ module_split_mode: Optional[ModuleSplitMode] = None # noqa: UP045
469
+ watch: bool = False
470
+ watch_delay: float = 0.5
390
471
 
391
472
  def merge_args(self, args: Namespace) -> None:
392
- for field_name in self.__fields__:
393
- arg = getattr(args, field_name)
394
- if arg is None:
395
- continue
396
- setattr(self, field_name, arg)
473
+ """Merge command-line arguments into config."""
474
+ set_args = {f: getattr(args, f) for f in self.get_fields() if getattr(args, f) is not None}
475
+
476
+ if set_args.get("output_model_type") == DataModelType.MsgspecStruct.value:
477
+ set_args["use_annotated"] = True
478
+
479
+ if set_args.get("use_annotated"):
480
+ set_args["field_constraints"] = True
481
+
482
+ parsed_args = Config.parse_obj(set_args)
483
+ for field_name in set_args:
484
+ setattr(self, field_name, getattr(parsed_args, field_name))
485
+
486
+
487
+ def _get_pyproject_toml_config(source: Path, profile: str | None = None) -> dict[str, Any]:
488
+ """Find and return the [tool.datamodel-codegen] section of the closest pyproject.toml if it exists."""
489
+ current_path = source
490
+ while current_path != current_path.parent:
491
+ if (current_path / "pyproject.toml").is_file():
492
+ pyproject_toml = load_toml(current_path / "pyproject.toml")
493
+ if "datamodel-codegen" in pyproject_toml.get("tool", {}):
494
+ tool_config = pyproject_toml["tool"]["datamodel-codegen"]
495
+
496
+ base_config: dict[str, Any] = {k: v for k, v in tool_config.items() if k != "profiles"}
497
+
498
+ if profile:
499
+ profiles = tool_config.get("profiles", {})
500
+ if profile not in profiles:
501
+ available = list(profiles.keys()) if profiles else "none"
502
+ msg = f"Profile '{profile}' not found in pyproject.toml. Available profiles: {available}"
503
+ raise Error(msg)
504
+ profile_config = profiles[profile]
505
+ base_config.update(profile_config)
506
+
507
+ pyproject_config = {k.replace("-", "_"): v for k, v in base_config.items()}
508
+ # Replace US-american spelling if present (ignore if british spelling is present)
509
+ if (
510
+ "capitalize_enum_members" in pyproject_config and "capitalise_enum_members" not in pyproject_config
511
+ ): # pragma: no cover
512
+ pyproject_config["capitalise_enum_members"] = pyproject_config.pop("capitalize_enum_members")
513
+ return pyproject_config
397
514
 
515
+ if (current_path / ".git").exists():
516
+ # Stop early if we see a git repository root.
517
+ break
398
518
 
399
- def main(args: Optional[Sequence[str]] = None) -> Exit:
400
- """Main function."""
519
+ current_path = current_path.parent
401
520
 
402
- # add cli completion support
521
+ # If profile was requested but no pyproject.toml config was found, raise an error
522
+ if profile:
523
+ msg = f"Profile '{profile}' requested but no [tool.datamodel-codegen] section found in pyproject.toml"
524
+ raise Error(msg)
525
+
526
+ return {}
527
+
528
+
529
+ TomlValue: TypeAlias = Union[str, bool, list["TomlValue"], tuple["TomlValue", ...]]
530
+
531
+
532
+ def _format_toml_value(value: TomlValue) -> str:
533
+ """Format a Python value as a TOML value string."""
534
+ if isinstance(value, bool):
535
+ return "true" if value else "false"
536
+ if isinstance(value, str):
537
+ return f'"{value}"'
538
+ formatted_items = [_format_toml_value(item) for item in value]
539
+ return f"[{', '.join(formatted_items)}]"
540
+
541
+
542
+ def generate_pyproject_config(args: Namespace) -> str:
543
+ """Generate pyproject.toml [tool.datamodel-codegen] section from CLI arguments."""
544
+ lines: list[str] = ["[tool.datamodel-codegen]"]
545
+
546
+ args_dict: dict[str, object] = vars(args)
547
+ for key, value in sorted(args_dict.items()):
548
+ if value is None:
549
+ continue
550
+ if key in EXCLUDED_CONFIG_OPTIONS:
551
+ continue
552
+
553
+ toml_key = key.replace("_", "-")
554
+ toml_value = _format_toml_value(cast("TomlValue", value))
555
+ lines.append(f"{toml_key} = {toml_value}")
556
+
557
+ return "\n".join(lines) + "\n"
558
+
559
+
560
+ def _normalize_line_endings(text: str) -> str:
561
+ """Normalize line endings to LF for cross-platform comparison."""
562
+ return text.replace("\r\n", "\n")
563
+
564
+
565
+ def _compare_single_file(
566
+ generated_path: Path,
567
+ actual_path: Path,
568
+ encoding: str,
569
+ ) -> tuple[bool, list[str]]:
570
+ """Compare generated file content with existing file.
571
+
572
+ Returns:
573
+ Tuple of (has_differences, diff_lines)
574
+ - has_differences: True if files differ or actual file doesn't exist
575
+ - diff_lines: List of diff lines for output
576
+ """
577
+ generated_content = _normalize_line_endings(generated_path.read_text(encoding=encoding))
578
+
579
+ if not actual_path.exists():
580
+ return True, [f"MISSING: {actual_path} (file does not exist but should be generated)"]
581
+
582
+ actual_content = _normalize_line_endings(actual_path.read_text(encoding=encoding))
583
+
584
+ if generated_content == actual_content:
585
+ return False, []
586
+
587
+ diff_lines = list(
588
+ difflib.unified_diff(
589
+ actual_content.splitlines(keepends=True),
590
+ generated_content.splitlines(keepends=True),
591
+ fromfile=str(actual_path),
592
+ tofile=f"{actual_path} (expected)",
593
+ )
594
+ )
595
+ return True, diff_lines
596
+
597
+
598
+ def _compare_directories(
599
+ generated_dir: Path,
600
+ actual_dir: Path,
601
+ encoding: str,
602
+ ) -> tuple[list[str], list[str], list[str]]:
603
+ """Compare generated directory with existing directory."""
604
+ diffs: list[str] = []
605
+
606
+ generated_files = {path.relative_to(generated_dir) for path in generated_dir.rglob("*.py")}
607
+
608
+ actual_files: set[Path] = set()
609
+ if actual_dir.exists():
610
+ for path in actual_dir.rglob("*.py"):
611
+ if "__pycache__" not in path.parts:
612
+ actual_files.add(path.relative_to(actual_dir))
613
+
614
+ missing_files = [str(rel_path) for rel_path in sorted(generated_files - actual_files)]
615
+ extra_files = [str(rel_path) for rel_path in sorted(actual_files - generated_files)]
616
+
617
+ for rel_path in sorted(generated_files & actual_files):
618
+ generated_content = _normalize_line_endings((generated_dir / rel_path).read_text(encoding=encoding))
619
+ actual_content = _normalize_line_endings((actual_dir / rel_path).read_text(encoding=encoding))
620
+ if generated_content != actual_content:
621
+ diffs.extend(
622
+ difflib.unified_diff(
623
+ actual_content.splitlines(keepends=True),
624
+ generated_content.splitlines(keepends=True),
625
+ fromfile=str(rel_path),
626
+ tofile=f"{rel_path} (expected)",
627
+ )
628
+ )
629
+
630
+ return diffs, missing_files, extra_files
631
+
632
+
633
+ def _format_cli_value(value: str | list[str]) -> str:
634
+ """Format a value for CLI argument."""
635
+ if isinstance(value, list):
636
+ return " ".join(f'"{v}"' if " " in v else v for v in value)
637
+ return f'"{value}"' if " " in value else value
638
+
639
+
640
+ def generate_cli_command(config: dict[str, TomlValue]) -> str:
641
+ """Generate CLI command from pyproject.toml configuration."""
642
+ parts: list[str] = ["datamodel-codegen"]
643
+
644
+ for key, value in sorted(config.items()):
645
+ if key in EXCLUDED_CONFIG_OPTIONS:
646
+ continue
647
+
648
+ cli_key = key.replace("_", "-")
649
+
650
+ if isinstance(value, bool):
651
+ if value:
652
+ parts.append(f"--{cli_key}")
653
+ elif key in BOOLEAN_OPTIONAL_OPTIONS:
654
+ parts.append(f"--no-{cli_key}")
655
+ elif isinstance(value, list):
656
+ parts.extend((f"--{cli_key}", _format_cli_value(cast("list[str]", value))))
657
+ else:
658
+ parts.extend((f"--{cli_key}", _format_cli_value(str(value))))
659
+
660
+ return " ".join(parts) + "\n"
661
+
662
+
663
+ def run_generate_from_config( # noqa: PLR0913, PLR0917
664
+ config: Config,
665
+ input_: Path | str | ParseResult,
666
+ output: Path | None,
667
+ extra_template_data: dict[str, Any] | None,
668
+ aliases: dict[str, str] | None,
669
+ command_line: str | None,
670
+ custom_formatters_kwargs: dict[str, str] | None,
671
+ settings_path: Path | None = None,
672
+ ) -> None:
673
+ """Run code generation with the given config and parameters."""
674
+ generate(
675
+ input_=input_,
676
+ input_file_type=config.input_file_type,
677
+ output=output,
678
+ output_model_type=config.output_model_type,
679
+ target_python_version=config.target_python_version,
680
+ base_class=config.base_class,
681
+ additional_imports=config.additional_imports,
682
+ custom_template_dir=config.custom_template_dir,
683
+ validation=config.validation,
684
+ field_constraints=config.field_constraints,
685
+ snake_case_field=config.snake_case_field,
686
+ strip_default_none=config.strip_default_none,
687
+ extra_template_data=extra_template_data, # pyright: ignore[reportArgumentType]
688
+ aliases=aliases,
689
+ disable_timestamp=config.disable_timestamp,
690
+ enable_version_header=config.enable_version_header,
691
+ enable_command_header=config.enable_command_header,
692
+ command_line=command_line,
693
+ allow_population_by_field_name=config.allow_population_by_field_name,
694
+ allow_extra_fields=config.allow_extra_fields,
695
+ extra_fields=config.extra_fields,
696
+ apply_default_values_for_required_fields=config.use_default,
697
+ force_optional_for_required_fields=config.force_optional,
698
+ class_name=config.class_name,
699
+ use_standard_collections=config.use_standard_collections,
700
+ use_schema_description=config.use_schema_description,
701
+ use_field_description=config.use_field_description,
702
+ use_attribute_docstrings=config.use_attribute_docstrings,
703
+ use_inline_field_description=config.use_inline_field_description,
704
+ use_default_kwarg=config.use_default_kwarg,
705
+ reuse_model=config.reuse_model,
706
+ reuse_scope=config.reuse_scope,
707
+ shared_module_name=config.shared_module_name,
708
+ encoding=config.encoding,
709
+ enum_field_as_literal=config.enum_field_as_literal,
710
+ use_one_literal_as_default=config.use_one_literal_as_default,
711
+ use_enum_values_in_discriminator=config.use_enum_values_in_discriminator,
712
+ set_default_enum_member=config.set_default_enum_member,
713
+ use_subclass_enum=config.use_subclass_enum,
714
+ use_specialized_enum=config.use_specialized_enum,
715
+ strict_nullable=config.strict_nullable,
716
+ use_generic_container_types=config.use_generic_container_types,
717
+ enable_faux_immutability=config.enable_faux_immutability,
718
+ disable_appending_item_suffix=config.disable_appending_item_suffix,
719
+ strict_types=config.strict_types,
720
+ empty_enum_field_name=config.empty_enum_field_name,
721
+ field_extra_keys=config.field_extra_keys,
722
+ field_include_all_keys=config.field_include_all_keys,
723
+ field_extra_keys_without_x_prefix=config.field_extra_keys_without_x_prefix,
724
+ openapi_scopes=config.openapi_scopes,
725
+ include_path_parameters=config.include_path_parameters,
726
+ wrap_string_literal=config.wrap_string_literal,
727
+ use_title_as_name=config.use_title_as_name,
728
+ use_operation_id_as_name=config.use_operation_id_as_name,
729
+ use_unique_items_as_set=config.use_unique_items_as_set,
730
+ allof_merge_mode=config.allof_merge_mode,
731
+ http_headers=config.http_headers,
732
+ http_ignore_tls=config.http_ignore_tls,
733
+ use_annotated=config.use_annotated,
734
+ use_serialize_as_any=config.use_serialize_as_any,
735
+ use_non_positive_negative_number_constrained_types=config.use_non_positive_negative_number_constrained_types,
736
+ use_decimal_for_multiple_of=config.use_decimal_for_multiple_of,
737
+ original_field_name_delimiter=config.original_field_name_delimiter,
738
+ use_double_quotes=config.use_double_quotes,
739
+ collapse_root_models=config.collapse_root_models,
740
+ skip_root_model=config.skip_root_model,
741
+ use_type_alias=config.use_type_alias,
742
+ use_union_operator=config.use_union_operator,
743
+ special_field_name_prefix=config.special_field_name_prefix,
744
+ remove_special_field_name_prefix=config.remove_special_field_name_prefix,
745
+ capitalise_enum_members=config.capitalise_enum_members,
746
+ keep_model_order=config.keep_model_order,
747
+ custom_file_header=config.custom_file_header,
748
+ custom_file_header_path=config.custom_file_header_path,
749
+ custom_formatters=config.custom_formatters,
750
+ custom_formatters_kwargs=custom_formatters_kwargs,
751
+ use_pendulum=config.use_pendulum,
752
+ http_query_parameters=config.http_query_parameters,
753
+ treat_dot_as_module=config.treat_dot_as_module,
754
+ use_exact_imports=config.use_exact_imports,
755
+ union_mode=config.union_mode,
756
+ output_datetime_class=config.output_datetime_class,
757
+ keyword_only=config.keyword_only,
758
+ frozen_dataclasses=config.frozen_dataclasses,
759
+ no_alias=config.no_alias,
760
+ use_frozen_field=config.use_frozen_field,
761
+ formatters=config.formatters,
762
+ settings_path=settings_path,
763
+ parent_scoped_naming=config.parent_scoped_naming,
764
+ dataclass_arguments=config.dataclass_arguments,
765
+ disable_future_imports=config.disable_future_imports,
766
+ type_mappings=config.type_mappings,
767
+ read_only_write_only_model_type=config.read_only_write_only_model_type,
768
+ all_exports_scope=config.all_exports_scope,
769
+ all_exports_collision_strategy=config.all_exports_collision_strategy,
770
+ module_split_mode=config.module_split_mode,
771
+ )
772
+
773
+
774
+ def main(args: Sequence[str] | None = None) -> Exit: # noqa: PLR0911, PLR0912, PLR0914, PLR0915
775
+ """Execute datamodel code generation from command-line arguments."""
403
776
  argcomplete.autocomplete(arg_parser)
404
777
 
405
- if args is None:
778
+ if args is None: # pragma: no cover
406
779
  args = sys.argv[1:]
407
780
 
408
- namespace: Namespace = arg_parser.parse_args(args)
781
+ arg_parser.parse_args(args, namespace=namespace)
409
782
 
410
783
  if namespace.version:
411
- from datamodel_code_generator.version import version
412
-
413
- print(version)
414
- exit(0)
415
-
416
- root = black.find_project_root((Path().resolve(),))
417
- pyproject_toml_path = root / "pyproject.toml"
418
- if pyproject_toml_path.is_file():
419
- pyproject_toml: Dict[str, Any] = {
420
- k.replace('-', '_'): v
421
- for k, v in toml.load(str(pyproject_toml_path))
422
- .get('tool', {})
423
- .get('datamodel-codegen', {})
424
- .items()
425
- }
784
+ from datamodel_code_generator import get_version # noqa: PLC0415
785
+
786
+ print(get_version()) # noqa: T201
787
+ sys.exit(0)
788
+
789
+ if namespace.generate_pyproject_config:
790
+ config_output = generate_pyproject_config(namespace)
791
+ print(config_output) # noqa: T201
792
+ return Exit.OK
793
+
794
+ # Handle --ignore-pyproject and --profile options
795
+ if namespace.ignore_pyproject:
796
+ pyproject_config: dict[str, Any] = {}
426
797
  else:
427
- pyproject_toml = {}
798
+ try:
799
+ pyproject_config = _get_pyproject_toml_config(Path.cwd(), profile=namespace.profile)
800
+ except Error as e:
801
+ print(e.message, file=sys.stderr) # noqa: T201
802
+ return Exit.ERROR
803
+
804
+ if namespace.generate_cli_command:
805
+ if not pyproject_config:
806
+ print( # noqa: T201
807
+ "No [tool.datamodel-codegen] section found in pyproject.toml",
808
+ file=sys.stderr,
809
+ )
810
+ return Exit.ERROR
811
+ command_output = generate_cli_command(pyproject_config)
812
+ print(command_output) # noqa: T201
813
+ return Exit.OK
428
814
 
429
815
  try:
430
- config = Config.parse_obj(pyproject_toml)
816
+ config = Config.parse_obj(pyproject_config)
431
817
  config.merge_args(namespace)
432
818
  except Error as e:
433
- print(e.message, file=sys.stderr)
819
+ print(e.message, file=sys.stderr) # noqa: T201
820
+ return Exit.ERROR
821
+
822
+ if not config.input and not config.url and sys.stdin.isatty():
823
+ print( # noqa: T201
824
+ "Not Found Input: require `stdin` or arguments `--input` or `--url`",
825
+ file=sys.stderr,
826
+ )
827
+ arg_parser.print_help()
828
+ return Exit.ERROR
829
+
830
+ if config.check and config.output is None:
831
+ print( # noqa: T201
832
+ "Error: --check cannot be used with stdout output (no --output specified)",
833
+ file=sys.stderr,
834
+ )
835
+ return Exit.ERROR
836
+
837
+ if config.watch and config.check:
838
+ print( # noqa: T201
839
+ "Error: --watch and --check cannot be used together",
840
+ file=sys.stderr,
841
+ )
842
+ return Exit.ERROR
843
+
844
+ if config.watch and (config.input is None or is_url(str(config.input))):
845
+ print( # noqa: T201
846
+ "Error: --watch requires --input file path (not URL or stdin)",
847
+ file=sys.stderr,
848
+ )
434
849
  return Exit.ERROR
435
850
 
436
851
  if not is_supported_in_black(config.target_python_version): # pragma: no cover
437
- print(
438
- f"Installed black doesn't support Python version {config.target_python_version.value}.\n" # type: ignore
852
+ print( # noqa: T201
853
+ f"Installed black doesn't support Python version {config.target_python_version.value}.\n"
439
854
  f"You have to install a newer black.\n"
440
- f"Installed black version: {black.__version__}",
855
+ f"Installed black version: {_get_black().__version__}",
441
856
  file=sys.stderr,
442
857
  )
443
858
  return Exit.ERROR
@@ -445,17 +860,38 @@ def main(args: Optional[Sequence[str]] = None) -> Exit:
445
860
  if config.debug: # pragma: no cover
446
861
  enable_debug_message()
447
862
 
448
- extra_template_data: Optional[DefaultDict[str, Dict[str, Any]]]
863
+ if config.disable_warnings:
864
+ warnings.simplefilter("ignore")
865
+
866
+ if config.reuse_scope == ReuseScope.Tree and not config.reuse_model:
867
+ print( # noqa: T201
868
+ "Warning: --reuse-scope=tree has no effect without --reuse-model",
869
+ file=sys.stderr,
870
+ )
871
+
872
+ if (
873
+ config.use_specialized_enum
874
+ and namespace.use_specialized_enum is not False # CLI didn't disable it
875
+ and (namespace.use_specialized_enum is True or pyproject_config.get("use_specialized_enum") is True)
876
+ and not config.target_python_version.has_strenum
877
+ ):
878
+ print( # noqa: T201
879
+ f"Error: --use-specialized-enum requires --target-python-version 3.11 or later.\n"
880
+ f"Current target version: {config.target_python_version.value}\n"
881
+ f"StrEnum is only available in Python 3.11+.",
882
+ file=sys.stderr,
883
+ )
884
+ return Exit.ERROR
885
+
886
+ extra_template_data: defaultdict[str, dict[str, Any]] | None
449
887
  if config.extra_template_data is None:
450
888
  extra_template_data = None
451
889
  else:
452
890
  with config.extra_template_data as data:
453
891
  try:
454
- extra_template_data = json.load(
455
- data, object_hook=lambda d: defaultdict(dict, **d)
456
- )
892
+ extra_template_data = json.load(data, object_hook=lambda d: defaultdict(dict, **d))
457
893
  except json.JSONDecodeError as e:
458
- print(f"Unable to load extra template data: {e}", file=sys.stderr)
894
+ print(f"Unable to load extra template data: {e}", file=sys.stderr) # noqa: T201
459
895
  return Exit.ERROR
460
896
 
461
897
  if config.aliases is None:
@@ -465,69 +901,117 @@ def main(args: Optional[Sequence[str]] = None) -> Exit:
465
901
  try:
466
902
  aliases = json.load(data)
467
903
  except json.JSONDecodeError as e:
468
- print(f"Unable to load alias mapping: {e}", file=sys.stderr)
904
+ print(f"Unable to load alias mapping: {e}", file=sys.stderr) # noqa: T201
469
905
  return Exit.ERROR
470
906
  if not isinstance(aliases, dict) or not all(
471
907
  isinstance(k, str) and isinstance(v, str) for k, v in aliases.items()
472
908
  ):
473
- print(
909
+ print( # noqa: T201
474
910
  'Alias mapping must be a JSON string mapping (e.g. {"from": "to", ...})',
475
911
  file=sys.stderr,
476
912
  )
477
913
  return Exit.ERROR
478
914
 
915
+ if config.custom_formatters_kwargs is None:
916
+ custom_formatters_kwargs = None
917
+ else:
918
+ with config.custom_formatters_kwargs as data:
919
+ try:
920
+ custom_formatters_kwargs = json.load(data)
921
+ except json.JSONDecodeError as e: # pragma: no cover
922
+ print( # noqa: T201
923
+ f"Unable to load custom_formatters_kwargs mapping: {e}",
924
+ file=sys.stderr,
925
+ )
926
+ return Exit.ERROR
927
+ if not isinstance(custom_formatters_kwargs, dict) or not all(
928
+ isinstance(k, str) and isinstance(v, str) for k, v in custom_formatters_kwargs.items()
929
+ ): # pragma: no cover
930
+ print( # noqa: T201
931
+ 'Custom formatters kwargs mapping must be a JSON string mapping (e.g. {"from": "to", ...})',
932
+ file=sys.stderr,
933
+ )
934
+ return Exit.ERROR
935
+
936
+ if config.check:
937
+ config_output = cast("Path", config.output)
938
+ is_directory_output = not config_output.suffix
939
+ temp_context: tempfile.TemporaryDirectory[str] | None = tempfile.TemporaryDirectory()
940
+ temp_dir = Path(temp_context.name)
941
+ if is_directory_output:
942
+ generate_output: Path | None = temp_dir / config_output.name
943
+ else:
944
+ generate_output = temp_dir / "output.py"
945
+ else:
946
+ temp_context = None
947
+ generate_output = config.output
948
+ is_directory_output = False
949
+
479
950
  try:
480
- generate(
951
+ run_generate_from_config(
952
+ config=config,
481
953
  input_=config.url or config.input or sys.stdin.read(),
482
- input_file_type=config.input_file_type,
483
- output=config.output,
484
- target_python_version=config.target_python_version,
485
- base_class=config.base_class,
486
- custom_template_dir=config.custom_template_dir,
487
- validation=config.validation,
488
- field_constraints=config.field_constraints,
489
- snake_case_field=config.snake_case_field,
490
- strip_default_none=config.strip_default_none,
954
+ output=generate_output,
491
955
  extra_template_data=extra_template_data,
492
956
  aliases=aliases,
493
- disable_timestamp=config.disable_timestamp,
494
- allow_population_by_field_name=config.allow_population_by_field_name,
495
- apply_default_values_for_required_fields=config.use_default,
496
- force_optional_for_required_fields=config.force_optional,
497
- class_name=config.class_name,
498
- use_standard_collections=config.use_standard_collections,
499
- use_schema_description=config.use_schema_description,
500
- reuse_model=config.reuse_model,
501
- encoding=config.encoding,
502
- enum_field_as_literal=config.enum_field_as_literal,
503
- set_default_enum_member=config.set_default_enum_member,
504
- strict_nullable=config.strict_nullable,
505
- use_generic_container_types=config.use_generic_container_types,
506
- enable_faux_immutability=config.enable_faux_immutability,
507
- disable_appending_item_suffix=config.disable_appending_item_suffix,
508
- strict_types=config.strict_types,
509
- empty_enum_field_name=config.empty_enum_field_name,
510
- field_extra_keys=config.field_extra_keys,
511
- field_include_all_keys=config.field_include_all_keys,
512
- openapi_scopes=config.openapi_scopes,
513
- wrap_string_literal=config.wrap_string_literal,
514
- use_title_as_name=config.use_title_as_name,
515
- http_headers=config.http_headers,
516
- use_annotated=config.use_annotated,
957
+ command_line=shlex.join(["datamodel-codegen", *args]) if config.enable_command_header else None,
958
+ custom_formatters_kwargs=custom_formatters_kwargs,
959
+ settings_path=config.output if config.check else None,
517
960
  )
518
- return Exit.OK
519
961
  except InvalidClassNameError as e:
520
- print(f'{e} You have to set `--class-name` option', file=sys.stderr)
962
+ print(f"{e} You have to set `--class-name` option", file=sys.stderr) # noqa: T201
963
+ if temp_context is not None:
964
+ temp_context.cleanup()
521
965
  return Exit.ERROR
522
966
  except Error as e:
523
- print(str(e), file=sys.stderr)
967
+ print(str(e), file=sys.stderr) # noqa: T201
968
+ if temp_context is not None:
969
+ temp_context.cleanup()
524
970
  return Exit.ERROR
525
- except Exception:
526
- import traceback
971
+ except Exception: # noqa: BLE001
972
+ import traceback # noqa: PLC0415
527
973
 
528
- print(traceback.format_exc(), file=sys.stderr)
974
+ print(traceback.format_exc(), file=sys.stderr) # noqa: T201
975
+ if temp_context is not None:
976
+ temp_context.cleanup()
529
977
  return Exit.ERROR
530
978
 
979
+ if config.check and config.output is not None and generate_output is not None:
980
+ has_differences = False
981
+
982
+ if is_directory_output:
983
+ diffs, missing_files, extra_files = _compare_directories(generate_output, config.output, config.encoding)
984
+ if diffs:
985
+ print("".join(diffs), end="") # noqa: T201
986
+ has_differences = True
987
+ for missing in missing_files:
988
+ print(f"MISSING: {missing} (should be generated)") # noqa: T201
989
+ has_differences = True
990
+ for extra in extra_files:
991
+ print(f"EXTRA: {extra} (no longer generated)") # noqa: T201
992
+ has_differences = True
993
+ else:
994
+ diff_found, diff_lines = _compare_single_file(generate_output, config.output, config.encoding)
995
+ if diff_found:
996
+ print("".join(diff_lines), end="") # noqa: T201
997
+ has_differences = True
998
+
999
+ if temp_context is not None: # pragma: no branch
1000
+ temp_context.cleanup()
1001
+
1002
+ return Exit.DIFF if has_differences else Exit.OK
1003
+
1004
+ if config.watch:
1005
+ try:
1006
+ from datamodel_code_generator.watch import watch_and_regenerate # noqa: PLC0415
1007
+
1008
+ return watch_and_regenerate(config, extra_template_data, aliases, custom_formatters_kwargs)
1009
+ except Exception as e: # noqa: BLE001
1010
+ print(str(e), file=sys.stderr) # noqa: T201
1011
+ return Exit.ERROR
1012
+
1013
+ return Exit.OK
1014
+
531
1015
 
532
- if __name__ == '__main__':
1016
+ if __name__ == "__main__":
533
1017
  sys.exit(main())