datamodel-code-generator 0.11.12__py3-none-any.whl → 0.45.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. datamodel_code_generator/__init__.py +654 -185
  2. datamodel_code_generator/__main__.py +872 -388
  3. datamodel_code_generator/arguments.py +798 -0
  4. datamodel_code_generator/cli_options.py +295 -0
  5. datamodel_code_generator/format.py +292 -54
  6. datamodel_code_generator/http.py +85 -10
  7. datamodel_code_generator/imports.py +152 -43
  8. datamodel_code_generator/model/__init__.py +138 -1
  9. datamodel_code_generator/model/base.py +531 -120
  10. datamodel_code_generator/model/dataclass.py +211 -0
  11. datamodel_code_generator/model/enum.py +133 -12
  12. datamodel_code_generator/model/imports.py +22 -0
  13. datamodel_code_generator/model/msgspec.py +462 -0
  14. datamodel_code_generator/model/pydantic/__init__.py +30 -25
  15. datamodel_code_generator/model/pydantic/base_model.py +304 -100
  16. datamodel_code_generator/model/pydantic/custom_root_type.py +11 -2
  17. datamodel_code_generator/model/pydantic/dataclass.py +15 -4
  18. datamodel_code_generator/model/pydantic/imports.py +40 -27
  19. datamodel_code_generator/model/pydantic/types.py +188 -96
  20. datamodel_code_generator/model/pydantic_v2/__init__.py +51 -0
  21. datamodel_code_generator/model/pydantic_v2/base_model.py +268 -0
  22. datamodel_code_generator/model/pydantic_v2/imports.py +15 -0
  23. datamodel_code_generator/model/pydantic_v2/root_model.py +35 -0
  24. datamodel_code_generator/model/pydantic_v2/types.py +143 -0
  25. datamodel_code_generator/model/scalar.py +124 -0
  26. datamodel_code_generator/model/template/Enum.jinja2 +15 -2
  27. datamodel_code_generator/model/template/ScalarTypeAliasAnnotation.jinja2 +6 -0
  28. datamodel_code_generator/model/template/ScalarTypeAliasType.jinja2 +6 -0
  29. datamodel_code_generator/model/template/ScalarTypeStatement.jinja2 +6 -0
  30. datamodel_code_generator/model/template/TypeAliasAnnotation.jinja2 +20 -0
  31. datamodel_code_generator/model/template/TypeAliasType.jinja2 +20 -0
  32. datamodel_code_generator/model/template/TypeStatement.jinja2 +20 -0
  33. datamodel_code_generator/model/template/TypedDict.jinja2 +5 -0
  34. datamodel_code_generator/model/template/TypedDictClass.jinja2 +25 -0
  35. datamodel_code_generator/model/template/TypedDictFunction.jinja2 +24 -0
  36. datamodel_code_generator/model/template/UnionTypeAliasAnnotation.jinja2 +10 -0
  37. datamodel_code_generator/model/template/UnionTypeAliasType.jinja2 +10 -0
  38. datamodel_code_generator/model/template/UnionTypeStatement.jinja2 +10 -0
  39. datamodel_code_generator/model/template/dataclass.jinja2 +50 -0
  40. datamodel_code_generator/model/template/msgspec.jinja2 +55 -0
  41. datamodel_code_generator/model/template/pydantic/BaseModel.jinja2 +17 -4
  42. datamodel_code_generator/model/template/pydantic/BaseModel_root.jinja2 +12 -4
  43. datamodel_code_generator/model/template/pydantic/Config.jinja2 +1 -1
  44. datamodel_code_generator/model/template/pydantic/dataclass.jinja2 +15 -2
  45. datamodel_code_generator/model/template/pydantic_v2/BaseModel.jinja2 +57 -0
  46. datamodel_code_generator/model/template/pydantic_v2/ConfigDict.jinja2 +5 -0
  47. datamodel_code_generator/model/template/pydantic_v2/RootModel.jinja2 +48 -0
  48. datamodel_code_generator/model/type_alias.py +70 -0
  49. datamodel_code_generator/model/typed_dict.py +161 -0
  50. datamodel_code_generator/model/types.py +106 -0
  51. datamodel_code_generator/model/union.py +105 -0
  52. datamodel_code_generator/parser/__init__.py +30 -12
  53. datamodel_code_generator/parser/_graph.py +67 -0
  54. datamodel_code_generator/parser/_scc.py +171 -0
  55. datamodel_code_generator/parser/base.py +2426 -380
  56. datamodel_code_generator/parser/graphql.py +652 -0
  57. datamodel_code_generator/parser/jsonschema.py +2518 -647
  58. datamodel_code_generator/parser/openapi.py +631 -222
  59. datamodel_code_generator/py.typed +0 -0
  60. datamodel_code_generator/pydantic_patch.py +28 -0
  61. datamodel_code_generator/reference.py +672 -290
  62. datamodel_code_generator/types.py +521 -145
  63. datamodel_code_generator/util.py +155 -0
  64. datamodel_code_generator/watch.py +65 -0
  65. datamodel_code_generator-0.45.0.dist-info/METADATA +301 -0
  66. datamodel_code_generator-0.45.0.dist-info/RECORD +69 -0
  67. {datamodel_code_generator-0.11.12.dist-info → datamodel_code_generator-0.45.0.dist-info}/WHEEL +1 -1
  68. datamodel_code_generator-0.45.0.dist-info/entry_points.txt +2 -0
  69. datamodel_code_generator/version.py +0 -1
  70. datamodel_code_generator-0.11.12.dist-info/METADATA +0 -440
  71. datamodel_code_generator-0.11.12.dist-info/RECORD +0 -31
  72. datamodel_code_generator-0.11.12.dist-info/entry_points.txt +0 -3
  73. {datamodel_code_generator-0.11.12.dist-info → datamodel_code_generator-0.45.0.dist-info/licenses}/LICENSE +0 -0
@@ -1,60 +1,289 @@
1
+ """Abstract base parser and utilities for schema parsing.
2
+
3
+ Provides the Parser abstract base class that defines the parsing algorithm,
4
+ along with helper functions for model sorting, import resolution, and
5
+ code generation.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import operator
11
+ import os.path
1
12
  import re
13
+ import sys
2
14
  from abc import ABC, abstractmethod
3
- from collections import OrderedDict, defaultdict
15
+ from collections import Counter, OrderedDict, defaultdict
16
+ from collections.abc import Hashable, Sequence
4
17
  from itertools import groupby
5
18
  from pathlib import Path
6
- from typing import (
7
- Any,
8
- Callable,
9
- DefaultDict,
10
- Dict,
11
- Iterable,
12
- Iterator,
13
- List,
14
- Mapping,
15
- Optional,
16
- Sequence,
17
- Set,
18
- Tuple,
19
- Type,
20
- Union,
21
- )
19
+ from typing import TYPE_CHECKING, Any, Callable, NamedTuple, Optional, Protocol, TypeVar, cast, runtime_checkable
22
20
  from urllib.parse import ParseResult
21
+ from warnings import warn
23
22
 
24
23
  from pydantic import BaseModel
25
-
26
- from datamodel_code_generator import Protocol, runtime_checkable
27
- from datamodel_code_generator.format import CodeFormatter, PythonVersion
28
- from datamodel_code_generator.imports import IMPORT_ANNOTATIONS, Import, Imports
24
+ from typing_extensions import TypeAlias
25
+
26
+ from datamodel_code_generator import (
27
+ DEFAULT_SHARED_MODULE_NAME,
28
+ AllExportsCollisionStrategy,
29
+ AllExportsScope,
30
+ AllOfMergeMode,
31
+ Error,
32
+ ModuleSplitMode,
33
+ ReadOnlyWriteOnlyModelType,
34
+ ReuseScope,
35
+ )
36
+ from datamodel_code_generator.format import (
37
+ DEFAULT_FORMATTERS,
38
+ CodeFormatter,
39
+ DatetimeClassType,
40
+ Formatter,
41
+ PythonVersion,
42
+ PythonVersionMin,
43
+ )
44
+ from datamodel_code_generator.imports import (
45
+ IMPORT_ANNOTATIONS,
46
+ IMPORT_LITERAL,
47
+ IMPORT_OPTIONAL,
48
+ IMPORT_UNION,
49
+ Import,
50
+ Imports,
51
+ )
52
+ from datamodel_code_generator.model import dataclass as dataclass_model
53
+ from datamodel_code_generator.model import msgspec as msgspec_model
29
54
  from datamodel_code_generator.model import pydantic as pydantic_model
55
+ from datamodel_code_generator.model import pydantic_v2 as pydantic_model_v2
30
56
  from datamodel_code_generator.model.base import (
31
57
  ALL_MODEL,
58
+ UNDEFINED,
32
59
  BaseClassDataType,
60
+ ConstraintsBase,
33
61
  DataModel,
34
62
  DataModelFieldBase,
63
+ WrappedDefault,
35
64
  )
36
- from datamodel_code_generator.model.enum import Enum
65
+ from datamodel_code_generator.model.enum import Enum, Member
66
+ from datamodel_code_generator.model.type_alias import TypeAliasBase, TypeStatement
37
67
  from datamodel_code_generator.parser import DefaultPutDict, LiteralType
38
- from datamodel_code_generator.reference import ModelResolver, Reference
68
+ from datamodel_code_generator.parser._graph import stable_toposort
69
+ from datamodel_code_generator.parser._scc import find_circular_sccs, strongly_connected_components
70
+ from datamodel_code_generator.reference import ModelResolver, ModelType, Reference
39
71
  from datamodel_code_generator.types import DataType, DataTypeManager, StrictTypes
72
+ from datamodel_code_generator.util import camel_to_snake
40
73
 
41
- escape_characters = str.maketrans(
42
- {
43
- "\\": r"\\",
44
- "'": r"\'",
45
- '\b': r'\b',
46
- '\f': r'\f',
47
- '\n': r'\n',
48
- '\r': r'\r',
49
- '\t': r'\t',
50
- }
51
- )
74
+ if TYPE_CHECKING:
75
+ from collections.abc import Iterable, Iterator, Mapping, Sequence
52
76
 
77
+ from datamodel_code_generator import DataclassArguments
53
78
 
54
- def to_hashable(item: Any) -> Any:
55
- if isinstance(item, list):
56
- return tuple(to_hashable(i) for i in item)
57
- elif isinstance(item, dict):
79
+
80
+ @runtime_checkable
81
+ class HashableComparable(Hashable, Protocol):
82
+ """Protocol for types that are both hashable and support comparison."""
83
+
84
+ def __lt__(self, value: Any, /) -> bool: ... # noqa: D105
85
+ def __le__(self, value: Any, /) -> bool: ... # noqa: D105
86
+ def __gt__(self, value: Any, /) -> bool: ... # noqa: D105
87
+ def __ge__(self, value: Any, /) -> bool: ... # noqa: D105
88
+
89
+
90
+ ModelName: TypeAlias = str
91
+ ModelNames: TypeAlias = set[ModelName]
92
+ ModelDeps: TypeAlias = dict[ModelName, set[ModelName]]
93
+ OrderIndex: TypeAlias = dict[ModelName, int]
94
+
95
+ ComponentId: TypeAlias = int
96
+ Components: TypeAlias = list[list[ModelName]]
97
+ ComponentOf: TypeAlias = dict[ModelName, ComponentId]
98
+ ComponentEdges: TypeAlias = dict[ComponentId, set[ComponentId]]
99
+
100
+ ClassNode: TypeAlias = tuple[ModelName, ...]
101
+ ClassGraph: TypeAlias = dict[ClassNode, set[ClassNode]]
102
+
103
+
104
+ class _KeepModelOrderDeps(NamedTuple):
105
+ strong: ModelDeps
106
+ all: ModelDeps
107
+
108
+
109
+ class _KeepModelOrderComponents(NamedTuple):
110
+ components: Components
111
+ comp_of: ComponentOf
112
+
113
+
114
+ def _collect_keep_model_order_deps(
115
+ model: DataModel,
116
+ *,
117
+ model_names: ModelNames,
118
+ imported: ModelNames,
119
+ use_deferred_annotations: bool,
120
+ ) -> tuple[set[ModelName], set[ModelName]]:
121
+ """Collect (strong_deps, all_deps) used by keep_model_order sorting.
122
+
123
+ - strong_deps: base class references (within-module, non-imported)
124
+ - all_deps: base class refs + (optionally) field refs (within-module, non-imported)
125
+ """
126
+ class_name = model.class_name
127
+ base_class_refs = {b.reference.short_name for b in model.base_classes if b.reference}
128
+ field_refs = {t.reference.short_name for f in model.fields for t in f.data_type.all_data_types if t.reference}
129
+
130
+ if use_deferred_annotations and not isinstance(model, (TypeAliasBase, pydantic_model_v2.RootModel)):
131
+ field_refs = set()
132
+
133
+ strong = {r for r in base_class_refs if r in model_names and r not in imported and r != class_name}
134
+ deps = {r for r in (base_class_refs | field_refs) if r in model_names and r not in imported and r != class_name}
135
+ return strong, deps
136
+
137
+
138
+ def _build_keep_model_order_dependency_maps(
139
+ models: list[DataModel],
140
+ *,
141
+ model_names: ModelNames,
142
+ imported: ModelNames,
143
+ use_deferred_annotations: bool,
144
+ ) -> _KeepModelOrderDeps:
145
+ strong_deps: ModelDeps = {}
146
+ all_deps: ModelDeps = {}
147
+ for model in models:
148
+ strong, deps = _collect_keep_model_order_deps(
149
+ model,
150
+ model_names=model_names,
151
+ imported=imported,
152
+ use_deferred_annotations=use_deferred_annotations,
153
+ )
154
+ strong_deps[model.class_name] = strong
155
+ all_deps[model.class_name] = deps
156
+ return _KeepModelOrderDeps(strong=strong_deps, all=all_deps)
157
+
158
+
159
+ def _build_keep_model_order_components(
160
+ all_deps: ModelDeps,
161
+ order_index: OrderIndex,
162
+ ) -> _KeepModelOrderComponents:
163
+ graph: ClassGraph = {(name,): {(dep,) for dep in deps} for name, deps in all_deps.items()}
164
+ sccs = strongly_connected_components(graph)
165
+ components: Components = [sorted((node[0] for node in scc), key=order_index.__getitem__) for scc in sccs]
166
+ components.sort(key=lambda members: min(order_index[n] for n in members))
167
+ comp_of: ComponentOf = {name: i for i, members in enumerate(components) for name in members}
168
+ return _KeepModelOrderComponents(components=components, comp_of=comp_of)
169
+
170
+
171
+ def _build_keep_model_order_component_edges(
172
+ all_deps: ModelDeps,
173
+ comp_of: ComponentOf,
174
+ num_components: int,
175
+ ) -> ComponentEdges:
176
+ comp_edges: ComponentEdges = {i: set() for i in range(num_components)}
177
+ for name, deps in all_deps.items():
178
+ name_comp = comp_of[name]
179
+ for dep in deps:
180
+ if (dep_comp := comp_of[dep]) != name_comp:
181
+ comp_edges[dep_comp].add(name_comp)
182
+ return comp_edges
183
+
184
+
185
+ def _build_keep_model_order_component_order(
186
+ components: Components,
187
+ comp_edges: ComponentEdges,
188
+ order_index: OrderIndex,
189
+ ) -> list[ComponentId]:
190
+ comp_key = [min(order_index[n] for n in members) for members in components]
191
+ return stable_toposort(
192
+ list(range(len(components))),
193
+ comp_edges,
194
+ key=lambda component_id: comp_key[component_id],
195
+ )
196
+
197
+
198
+ def _build_keep_model_ordered_names(
199
+ ordered_comp_ids: list[ComponentId],
200
+ components: Components,
201
+ strong_deps: ModelDeps,
202
+ order_index: OrderIndex,
203
+ ) -> list[ModelName]:
204
+ ordered_names: list[ModelName] = []
205
+ for component_id in ordered_comp_ids:
206
+ members = components[component_id]
207
+ if len(members) > 1:
208
+ strong_edges: dict[ModelName, set[ModelName]] = {n: set() for n in members}
209
+ member_set = set(members)
210
+ for base in members:
211
+ derived_members = {member for member in members if base in strong_deps.get(member, set()) & member_set}
212
+ strong_edges[base].update(derived_members)
213
+ members = stable_toposort(members, strong_edges, key=order_index.__getitem__)
214
+ ordered_names.extend(members)
215
+ return ordered_names
216
+
217
+
218
+ def _reorder_models_keep_model_order(
219
+ models: list[DataModel],
220
+ imports: Imports,
221
+ *,
222
+ use_deferred_annotations: bool,
223
+ ) -> None:
224
+ """Reorder models deterministically based on their dependencies.
225
+
226
+ Starts from class_name order and only moves models when required to satisfy dependencies.
227
+ Cycles are kept as SCC groups; within each SCC, base-class dependencies are prioritized.
228
+ """
229
+ models.sort(key=lambda x: x.class_name)
230
+ imported: ModelNames = {i for v in imports.values() for i in v}
231
+ model_by_name = {m.class_name: m for m in models}
232
+ model_names: ModelNames = set(model_by_name)
233
+ order_index: OrderIndex = {m.class_name: i for i, m in enumerate(models)}
234
+
235
+ deps = _build_keep_model_order_dependency_maps(
236
+ models,
237
+ model_names=model_names,
238
+ imported=imported,
239
+ use_deferred_annotations=use_deferred_annotations,
240
+ )
241
+ comps = _build_keep_model_order_components(deps.all, order_index)
242
+ comp_edges = _build_keep_model_order_component_edges(deps.all, comps.comp_of, len(comps.components))
243
+ ordered_comp_ids = _build_keep_model_order_component_order(comps.components, comp_edges, order_index)
244
+ ordered_names = _build_keep_model_ordered_names(ordered_comp_ids, comps.components, deps.strong, order_index)
245
+ models[:] = [model_by_name[name] for name in ordered_names]
246
+
247
+
248
+ SPECIAL_PATH_FORMAT: str = "#-datamodel-code-generator-#-{}-#-special-#"
249
+
250
+
251
+ def get_special_path(keyword: str, path: list[str]) -> list[str]:
252
+ """Create a special path marker for internal reference tracking."""
253
+ return [*path, SPECIAL_PATH_FORMAT.format(keyword)]
254
+
255
+
256
+ escape_characters = str.maketrans({
257
+ "\u0000": r"\x00", # Null byte
258
+ "\\": r"\\",
259
+ "'": r"\'",
260
+ "\b": r"\b",
261
+ "\f": r"\f",
262
+ "\n": r"\n",
263
+ "\r": r"\r",
264
+ "\t": r"\t",
265
+ })
266
+
267
+
268
+ def to_hashable(item: Any) -> HashableComparable: # noqa: PLR0911
269
+ """Convert an item to a hashable and comparable representation.
270
+
271
+ Returns a value that is both hashable and supports comparison operators.
272
+ Used for caching and deduplication of models.
273
+ """
274
+ if isinstance(
275
+ item,
276
+ (
277
+ list,
278
+ tuple,
279
+ ),
280
+ ):
281
+ try:
282
+ return tuple(sorted((to_hashable(i) for i in item), key=lambda v: (str(type(v)), v)))
283
+ except TypeError:
284
+ # Fallback when mixed, non-comparable types are present; preserve original order
285
+ return tuple(to_hashable(i) for i in item)
286
+ if isinstance(item, dict):
58
287
  return tuple(
59
288
  sorted(
60
289
  (
@@ -64,53 +293,88 @@ def to_hashable(item: Any) -> Any:
64
293
  for k, v in item.items()
65
294
  )
66
295
  )
67
- elif isinstance(item, set): # pragma: no cover
68
- return frozenset(to_hashable(i) for i in item)
69
- elif isinstance(item, BaseModel):
296
+ if isinstance(item, set): # pragma: no cover
297
+ return frozenset(to_hashable(i) for i in item) # type: ignore[return-value]
298
+ if isinstance(item, BaseModel):
70
299
  return to_hashable(item.dict())
71
- return item
300
+ if item is None:
301
+ return ""
302
+ return item # type: ignore[return-value]
303
+
304
+
305
+ def dump_templates(templates: list[DataModel]) -> str:
306
+ """Join model templates into a single code string."""
307
+ return "\n\n\n".join(str(m) for m in templates)
308
+
309
+
310
+ def iter_models_field_data_types(
311
+ models: Iterable[DataModel],
312
+ ) -> Iterator[tuple[DataModel, DataModelFieldBase, DataType]]:
313
+ """Yield (model, field, data_type) for all models, fields, and nested data types."""
314
+ for model in models:
315
+ for field in model.fields:
316
+ for data_type in field.data_type.all_data_types:
317
+ yield model, field, data_type
72
318
 
73
319
 
74
- def dump_templates(templates: List[DataModel]) -> str:
75
- return '\n\n\n'.join(str(m) for m in templates)
320
+ ReferenceMapSet = dict[str, set[str]]
321
+ SortedDataModels = dict[str, DataModel]
76
322
 
323
+ MAX_RECURSION_COUNT: int = sys.getrecursionlimit()
77
324
 
78
- ReferenceMapSet = Dict[str, Set[str]]
79
- SortedDataModels = Dict[str, DataModel]
80
325
 
81
- MAX_RECURSION_COUNT: int = 100
326
+ def add_model_path_to_list(
327
+ paths: list[str] | None,
328
+ model: DataModel,
329
+ /,
330
+ ) -> list[str]:
331
+ """
332
+ Auxiliary method which adds model path to list, provided the following hold.
82
333
 
334
+ - model is not a type alias
335
+ - path is not already in the list.
83
336
 
84
- def sort_data_models(
85
- unsorted_data_models: List[DataModel],
86
- sorted_data_models: Optional[SortedDataModels] = None,
87
- require_update_action_models: Optional[List[str]] = None,
337
+ """
338
+ if paths is None:
339
+ paths = []
340
+ if model.is_alias:
341
+ return paths
342
+ if (path := model.path) in paths:
343
+ return paths
344
+ paths.append(path)
345
+ return paths
346
+
347
+
348
+ def sort_data_models( # noqa: PLR0912, PLR0915
349
+ unsorted_data_models: list[DataModel],
350
+ sorted_data_models: SortedDataModels | None = None,
351
+ require_update_action_models: list[str] | None = None,
88
352
  recursion_count: int = MAX_RECURSION_COUNT,
89
- ) -> Tuple[List[DataModel], SortedDataModels, List[str]]:
353
+ ) -> tuple[list[DataModel], SortedDataModels, list[str]]:
354
+ """Sort data models by dependency order for correct forward references."""
90
355
  if sorted_data_models is None:
91
356
  sorted_data_models = OrderedDict()
92
357
  if require_update_action_models is None:
93
358
  require_update_action_models = []
359
+ sorted_model_count: int = len(sorted_data_models)
94
360
 
95
- unresolved_references: List[DataModel] = []
361
+ unresolved_references: list[DataModel] = []
96
362
  for model in unsorted_data_models:
97
363
  if not model.reference_classes:
98
364
  sorted_data_models[model.path] = model
99
- elif (
100
- model.path in model.reference_classes and len(model.reference_classes) == 1
101
- ): # only self-referencing
365
+ elif model.path in model.reference_classes and len(model.reference_classes) == 1: # only self-referencing
102
366
  sorted_data_models[model.path] = model
103
- require_update_action_models.append(model.path)
367
+ add_model_path_to_list(require_update_action_models, model)
104
368
  elif (
105
369
  not model.reference_classes - {model.path} - set(sorted_data_models)
106
370
  ): # reference classes have been resolved
107
371
  sorted_data_models[model.path] = model
108
372
  if model.path in model.reference_classes:
109
- require_update_action_models.append(model.path)
373
+ add_model_path_to_list(require_update_action_models, model)
110
374
  else:
111
375
  unresolved_references.append(model)
112
376
  if unresolved_references:
113
- if recursion_count:
377
+ if sorted_model_count != len(sorted_data_models) and recursion_count:
114
378
  try:
115
379
  return sort_data_models(
116
380
  unresolved_references,
@@ -118,37 +382,38 @@ def sort_data_models(
118
382
  require_update_action_models,
119
383
  recursion_count - 1,
120
384
  )
121
- except RecursionError:
385
+ except RecursionError: # pragma: no cover
122
386
  pass
123
387
 
124
388
  # sort on base_class dependency
125
389
  while True:
126
- ordered_models: List[Tuple[int, DataModel]] = []
390
+ ordered_models: list[tuple[int, DataModel]] = []
127
391
  unresolved_reference_model_names = [m.path for m in unresolved_references]
128
392
  for model in unresolved_references:
129
- indexes = [
130
- unresolved_reference_model_names.index(b.reference.path)
131
- for b in model.base_classes
132
- if b.reference
133
- and b.reference.path in unresolved_reference_model_names
134
- ]
393
+ if isinstance(model, pydantic_model_v2.RootModel):
394
+ indexes = [
395
+ unresolved_reference_model_names.index(ref_path)
396
+ for f in model.fields
397
+ for t in f.data_type.all_data_types
398
+ if t.reference and (ref_path := t.reference.path) in unresolved_reference_model_names
399
+ ]
400
+ else:
401
+ indexes = [
402
+ unresolved_reference_model_names.index(b.reference.path)
403
+ for b in model.base_classes
404
+ if b.reference and b.reference.path in unresolved_reference_model_names
405
+ ]
135
406
  if indexes:
136
- ordered_models.append(
137
- (
138
- min(indexes),
139
- model,
140
- )
141
- )
407
+ ordered_models.append((
408
+ max(indexes),
409
+ model,
410
+ ))
142
411
  else:
143
- ordered_models.append(
144
- (
145
- -1,
146
- model,
147
- )
148
- )
149
- sorted_unresolved_models = [
150
- m[1] for m in sorted(ordered_models, key=lambda m: m[0])
151
- ]
412
+ ordered_models.append((
413
+ -1,
414
+ model,
415
+ ))
416
+ sorted_unresolved_models = [m[1] for m in sorted(ordered_models, key=operator.itemgetter(0))]
152
417
  if sorted_unresolved_models == unresolved_references:
153
418
  break
154
419
  unresolved_references = sorted_unresolved_models
@@ -156,33 +421,58 @@ def sort_data_models(
156
421
  # circular reference
157
422
  unsorted_data_model_names = set(unresolved_reference_model_names)
158
423
  for model in unresolved_references:
159
- unresolved_model = (
160
- model.reference_classes - {model.path} - set(sorted_data_models)
161
- )
424
+ unresolved_model = model.reference_classes - {model.path} - set(sorted_data_models)
425
+ base_models = [getattr(s.reference, "path", None) for s in model.base_classes]
426
+ update_action_parent = set(require_update_action_models).intersection(base_models)
162
427
  if not unresolved_model:
163
428
  sorted_data_models[model.path] = model
429
+ if update_action_parent:
430
+ add_model_path_to_list(require_update_action_models, model)
164
431
  continue
165
432
  if not unresolved_model - unsorted_data_model_names:
166
433
  sorted_data_models[model.path] = model
167
- require_update_action_models.append(model.path)
434
+ add_model_path_to_list(require_update_action_models, model)
168
435
  continue
169
436
  # unresolved
170
- unresolved_classes = ', '.join(
171
- f"[class: {item.path} references: {item.reference_classes}]"
172
- for item in unresolved_references
437
+ unresolved_classes = ", ".join(
438
+ f"[class: {item.path} references: {item.reference_classes}]" for item in unresolved_references
173
439
  )
174
- raise Exception(f'A Parser can not resolve classes: {unresolved_classes}.')
440
+ msg = f"A Parser can not resolve classes: {unresolved_classes}."
441
+ raise Exception(msg) # noqa: TRY002
175
442
  return unresolved_references, sorted_data_models, require_update_action_models
176
443
 
177
444
 
178
- def relative(current_module: str, reference: str) -> Tuple[str, str]:
179
- """Find relative module path."""
180
-
181
- current_module_path = current_module.split('.') if current_module else []
182
- *reference_path, name = reference.split('.')
445
+ def relative(
446
+ current_module: str,
447
+ reference: str,
448
+ *,
449
+ reference_is_module: bool = False,
450
+ current_is_init: bool = False,
451
+ ) -> tuple[str, str]:
452
+ """Find relative module path.
453
+
454
+ Args:
455
+ current_module: Current module path (e.g., "foo.bar")
456
+ reference: Reference path (e.g., "foo.baz.ClassName" or "foo.baz" if reference_is_module)
457
+ reference_is_module: If True, treat reference as a module path (not module.class)
458
+ current_is_init: If True, treat current_module as a package __init__.py (adds depth)
459
+
460
+ Returns:
461
+ Tuple of (from_path, import_name) for constructing import statements
462
+ """
463
+ if current_is_init:
464
+ current_module_path = [*current_module.split("."), "__init__"] if current_module else ["__init__"]
465
+ else:
466
+ current_module_path = current_module.split(".") if current_module else []
467
+
468
+ if reference_is_module:
469
+ reference_path = reference.split(".") if reference else []
470
+ name = reference_path[-1] if reference_path else ""
471
+ else:
472
+ *reference_path, name = reference.split(".")
183
473
 
184
474
  if current_module_path == reference_path:
185
- return '', ''
475
+ return "", ""
186
476
 
187
477
  i = 0
188
478
  for x, y in zip(current_module_path, reference_path):
@@ -190,50 +480,159 @@ def relative(current_module: str, reference: str) -> Tuple[str, str]:
190
480
  break
191
481
  i += 1
192
482
 
193
- left = '.' * (len(current_module_path) - i)
194
- right = '.'.join(reference_path[i:])
483
+ left = "." * (len(current_module_path) - i)
484
+ right = ".".join(reference_path[i:])
195
485
 
196
486
  if not left:
197
- left = '.'
487
+ left = "."
198
488
  if not right:
199
489
  right = name
200
- elif '.' in right:
201
- extra, right = right.rsplit('.', 1)
490
+ elif "." in right:
491
+ extra, right = right.rsplit(".", 1)
202
492
  left += extra
203
493
 
204
494
  return left, right
205
495
 
206
496
 
497
+ def is_ancestor_package_reference(current_module: str, reference: str) -> bool:
498
+ """Check if reference is in an ancestor package (__init__.py).
499
+
500
+ When the reference's module path is an ancestor (prefix) of the current module,
501
+ the reference is in an ancestor package's __init__.py file.
502
+
503
+ Args:
504
+ current_module: The current module path (e.g., "v0.mammal.canine")
505
+ reference: The full reference path (e.g., "v0.Animal")
506
+
507
+ Returns:
508
+ True if the reference is in an ancestor package, False otherwise.
509
+
510
+ Examples:
511
+ - current="v0.animal", ref="v0.Animal" -> True (immediate parent)
512
+ - current="v0.mammal.canine", ref="v0.Animal" -> True (grandparent)
513
+ - current="v0.animal", ref="v0.animal.Dog" -> False (same or child)
514
+ - current="pets", ref="Animal" -> True (root package is immediate parent)
515
+ """
516
+ current_path = current_module.split(".") if current_module else []
517
+ *reference_path, _ = reference.split(".")
518
+
519
+ if not current_path:
520
+ return False
521
+
522
+ # Case 1: Direct parent package (includes root package when reference_path is empty)
523
+ # e.g., current="pets", ref="Animal" -> current_path[:-1]=[] == reference_path=[]
524
+ if current_path[:-1] == reference_path:
525
+ return True
526
+
527
+ # Case 2: Deeper ancestor package (reference_path must be non-empty proper prefix)
528
+ # e.g., current="v0.mammal.canine", ref="v0.Animal" -> ["v0"] is prefix of ["v0","mammal","canine"]
529
+ return (
530
+ len(reference_path) > 0
531
+ and len(reference_path) < len(current_path)
532
+ and current_path[: len(reference_path)] == reference_path
533
+ )
534
+
535
+
536
+ def exact_import(from_: str, import_: str, short_name: str) -> tuple[str, str]:
537
+ """Create exact import path to avoid relative import issues."""
538
+ if from_ == len(from_) * ".":
539
+ # Prevents "from . import foo" becoming "from ..foo import Foo"
540
+ # or "from .. import foo" becoming "from ...foo import Foo"
541
+ # when our imported module has the same parent
542
+ return f"{from_}{import_}", short_name
543
+ return f"{from_}.{import_}", short_name
544
+
545
+
546
+ def get_module_directory(module: tuple[str, ...]) -> tuple[str, ...]:
547
+ """Get the directory portion of a module tuple.
548
+
549
+ Note: Module tuples in module_models do NOT include .py extension.
550
+ The last element is either the module name (e.g., "issuing") or empty for root.
551
+
552
+ Examples:
553
+ ("pkg",) -> ("pkg",) - root module
554
+ ("pkg", "issuing") -> ("pkg",) - submodule
555
+ ("foo", "bar", "baz") -> ("foo", "bar") - deeply nested module
556
+ """
557
+ if not module:
558
+ return ()
559
+ if len(module) == 1:
560
+ return module
561
+ return module[:-1]
562
+
563
+
207
564
  @runtime_checkable
208
565
  class Child(Protocol):
566
+ """Protocol for objects with a parent reference."""
567
+
209
568
  @property
210
- def parent(self) -> Optional[Any]:
569
+ def parent(self) -> Any | None:
570
+ """Get the parent object reference."""
211
571
  raise NotImplementedError
212
572
 
213
573
 
214
- def get_most_of_parent(value: Any) -> Optional[Any]:
215
- if isinstance(value, Child):
216
- return get_most_of_parent(value.parent)
574
+ T = TypeVar("T")
575
+
576
+
577
+ def get_most_of_parent(value: Any, type_: type[T] | None = None) -> T | None:
578
+ """Traverse parent chain to find the outermost matching parent."""
579
+ if isinstance(value, Child) and (type_ is None or not isinstance(value, type_)):
580
+ return get_most_of_parent(value.parent, type_)
217
581
  return value
218
582
 
219
583
 
220
584
  def title_to_class_name(title: str) -> str:
221
- classname = re.sub('[^A-Za-z0-9]+', ' ', title)
222
- classname = ''.join(x for x in classname.title() if not x.isspace())
223
- return classname
585
+ """Convert a schema title to a valid Python class name."""
586
+ classname = re.sub(r"[^A-Za-z0-9]+", " ", title)
587
+ return "".join(x for x in classname.title() if not x.isspace())
588
+
589
+
590
+ def _find_base_classes(model: DataModel) -> list[DataModel]:
591
+ """Get direct base class DataModels."""
592
+ return [b.reference.source for b in model.base_classes if b.reference and isinstance(b.reference.source, DataModel)]
593
+
594
+
595
+ def _find_field(original_name: str, models: list[DataModel]) -> DataModelFieldBase | None:
596
+ """Find a field by original_name in the models and their base classes."""
597
+ for model in models:
598
+ for field in model.iter_all_fields(): # pragma: no cover
599
+ if field.original_name == original_name:
600
+ return field
601
+ return None # pragma: no cover
602
+
603
+
604
+ def _copy_data_types(data_types: list[DataType]) -> list[DataType]:
605
+ """Deep copy a list of DataType objects, preserving references."""
606
+ copied_data_types: list[DataType] = []
607
+ for data_type_ in data_types:
608
+ if data_type_.reference:
609
+ copied_data_types.append(data_type_.__class__(reference=data_type_.reference))
610
+ elif data_type_.data_types: # pragma: no cover
611
+ copied_data_type = data_type_.copy()
612
+ copied_data_type.data_types = _copy_data_types(data_type_.data_types)
613
+ copied_data_types.append(copied_data_type)
614
+ else:
615
+ copied_data_types.append(data_type_.copy())
616
+ return copied_data_types
224
617
 
225
618
 
226
619
  class Result(BaseModel):
620
+ """Generated code result with optional source file reference."""
621
+
227
622
  body: str
228
- source: Optional[Path]
623
+ future_imports: str = ""
624
+ source: Optional[Path] = None # noqa: UP045
229
625
 
230
626
 
231
627
  class Source(BaseModel):
628
+ """Schema source file with path and content."""
629
+
232
630
  path: Path
233
631
  text: str
234
632
 
235
633
  @classmethod
236
- def from_path(cls, path: Path, base_path: Path, encoding: str) -> 'Source':
634
+ def from_path(cls, path: Path, base_path: Path, encoding: str) -> Source:
635
+ """Create a Source from a file path relative to base_path."""
237
636
  return cls(
238
637
  path=path.relative_to(base_path),
239
638
  text=path.read_text(encoding=encoding),
@@ -241,144 +640,295 @@ class Source(BaseModel):
241
640
 
242
641
 
243
642
  class Parser(ABC):
244
- def __init__(
643
+ """Abstract base class for schema parsers.
644
+
645
+ Provides the parsing algorithm and code generation. Subclasses implement
646
+ parse_raw() to handle specific schema formats.
647
+ """
648
+
649
+ def __init__( # noqa: PLR0913, PLR0915
245
650
  self,
246
- source: Union[str, Path, List[Path], ParseResult],
651
+ source: str | Path | list[Path] | ParseResult,
247
652
  *,
248
- data_model_type: Type[DataModel] = pydantic_model.BaseModel,
249
- data_model_root_type: Type[DataModel] = pydantic_model.CustomRootType,
250
- data_type_manager_type: Type[DataTypeManager] = pydantic_model.DataTypeManager,
251
- data_model_field_type: Type[DataModelFieldBase] = pydantic_model.DataModelField,
252
- base_class: Optional[str] = None,
253
- custom_template_dir: Optional[Path] = None,
254
- extra_template_data: Optional[DefaultDict[str, Dict[str, Any]]] = None,
255
- target_python_version: PythonVersion = PythonVersion.PY_37,
256
- dump_resolve_reference_action: Optional[Callable[[Iterable[str]], str]] = None,
653
+ data_model_type: type[DataModel] = pydantic_model.BaseModel,
654
+ data_model_root_type: type[DataModel] = pydantic_model.CustomRootType,
655
+ data_type_manager_type: type[DataTypeManager] = pydantic_model.DataTypeManager,
656
+ data_model_field_type: type[DataModelFieldBase] = pydantic_model.DataModelField,
657
+ base_class: str | None = None,
658
+ additional_imports: list[str] | None = None,
659
+ custom_template_dir: Path | None = None,
660
+ extra_template_data: defaultdict[str, dict[str, Any]] | None = None,
661
+ target_python_version: PythonVersion = PythonVersionMin,
662
+ dump_resolve_reference_action: Callable[[Iterable[str]], str] | None = None,
257
663
  validation: bool = False,
258
664
  field_constraints: bool = False,
259
665
  snake_case_field: bool = False,
260
666
  strip_default_none: bool = False,
261
- aliases: Optional[Mapping[str, str]] = None,
667
+ aliases: Mapping[str, str] | None = None,
262
668
  allow_population_by_field_name: bool = False,
263
669
  apply_default_values_for_required_fields: bool = False,
670
+ allow_extra_fields: bool = False,
671
+ extra_fields: str | None = None,
264
672
  force_optional_for_required_fields: bool = False,
265
- class_name: Optional[str] = None,
673
+ class_name: str | None = None,
266
674
  use_standard_collections: bool = False,
267
- base_path: Optional[Path] = None,
675
+ base_path: Path | None = None,
268
676
  use_schema_description: bool = False,
677
+ use_field_description: bool = False,
678
+ use_attribute_docstrings: bool = False,
679
+ use_inline_field_description: bool = False,
680
+ use_default_kwarg: bool = False,
269
681
  reuse_model: bool = False,
270
- encoding: str = 'utf-8',
271
- enum_field_as_literal: Optional[LiteralType] = None,
682
+ reuse_scope: ReuseScope | None = None,
683
+ shared_module_name: str = DEFAULT_SHARED_MODULE_NAME,
684
+ encoding: str = "utf-8",
685
+ enum_field_as_literal: LiteralType | None = None,
272
686
  set_default_enum_member: bool = False,
687
+ use_subclass_enum: bool = False,
688
+ use_specialized_enum: bool = True,
273
689
  strict_nullable: bool = False,
274
690
  use_generic_container_types: bool = False,
275
691
  enable_faux_immutability: bool = False,
276
- remote_text_cache: Optional[DefaultPutDict[str, str]] = None,
692
+ remote_text_cache: DefaultPutDict[str, str] | None = None,
277
693
  disable_appending_item_suffix: bool = False,
278
- strict_types: Optional[Sequence[StrictTypes]] = None,
279
- empty_enum_field_name: Optional[str] = None,
280
- custom_class_name_generator: Optional[
281
- Callable[[str], str]
282
- ] = title_to_class_name,
283
- field_extra_keys: Optional[Set[str]] = None,
694
+ strict_types: Sequence[StrictTypes] | None = None,
695
+ empty_enum_field_name: str | None = None,
696
+ custom_class_name_generator: Callable[[str], str] | None = title_to_class_name,
697
+ field_extra_keys: set[str] | None = None,
284
698
  field_include_all_keys: bool = False,
285
- wrap_string_literal: Optional[bool] = None,
699
+ field_extra_keys_without_x_prefix: set[str] | None = None,
700
+ wrap_string_literal: bool | None = None,
286
701
  use_title_as_name: bool = False,
287
- http_headers: Optional[Sequence[Tuple[str, str]]] = None,
702
+ use_operation_id_as_name: bool = False,
703
+ use_unique_items_as_set: bool = False,
704
+ allof_merge_mode: AllOfMergeMode = AllOfMergeMode.Constraints,
705
+ http_headers: Sequence[tuple[str, str]] | None = None,
706
+ http_ignore_tls: bool = False,
288
707
  use_annotated: bool = False,
289
- ):
708
+ use_serialize_as_any: bool = False,
709
+ use_non_positive_negative_number_constrained_types: bool = False,
710
+ use_decimal_for_multiple_of: bool = False,
711
+ original_field_name_delimiter: str | None = None,
712
+ use_double_quotes: bool = False,
713
+ use_union_operator: bool = False,
714
+ allow_responses_without_content: bool = False,
715
+ collapse_root_models: bool = False,
716
+ skip_root_model: bool = False,
717
+ use_type_alias: bool = False,
718
+ special_field_name_prefix: str | None = None,
719
+ remove_special_field_name_prefix: bool = False,
720
+ capitalise_enum_members: bool = False,
721
+ keep_model_order: bool = False,
722
+ use_one_literal_as_default: bool = False,
723
+ use_enum_values_in_discriminator: bool = False,
724
+ known_third_party: list[str] | None = None,
725
+ custom_formatters: list[str] | None = None,
726
+ custom_formatters_kwargs: dict[str, Any] | None = None,
727
+ use_pendulum: bool = False,
728
+ http_query_parameters: Sequence[tuple[str, str]] | None = None,
729
+ treat_dot_as_module: bool = False,
730
+ use_exact_imports: bool = False,
731
+ default_field_extras: dict[str, Any] | None = None,
732
+ target_datetime_class: DatetimeClassType | None = None,
733
+ keyword_only: bool = False,
734
+ frozen_dataclasses: bool = False,
735
+ no_alias: bool = False,
736
+ use_frozen_field: bool = False,
737
+ formatters: list[Formatter] = DEFAULT_FORMATTERS,
738
+ parent_scoped_naming: bool = False,
739
+ dataclass_arguments: DataclassArguments | None = None,
740
+ type_mappings: list[str] | None = None,
741
+ read_only_write_only_model_type: ReadOnlyWriteOnlyModelType | None = None,
742
+ ) -> None:
743
+ """Initialize the Parser with configuration options."""
744
+ self.keyword_only = keyword_only
745
+ self.frozen_dataclasses = frozen_dataclasses
290
746
  self.data_type_manager: DataTypeManager = data_type_manager_type(
291
- target_python_version,
292
- use_standard_collections,
293
- use_generic_container_types,
294
- strict_types,
747
+ python_version=target_python_version,
748
+ use_standard_collections=use_standard_collections,
749
+ use_generic_container_types=use_generic_container_types,
750
+ use_non_positive_negative_number_constrained_types=use_non_positive_negative_number_constrained_types,
751
+ use_decimal_for_multiple_of=use_decimal_for_multiple_of,
752
+ strict_types=strict_types,
753
+ use_union_operator=use_union_operator,
754
+ use_pendulum=use_pendulum,
755
+ target_datetime_class=target_datetime_class,
756
+ treat_dot_as_module=treat_dot_as_module,
757
+ use_serialize_as_any=use_serialize_as_any,
295
758
  )
296
- self.data_model_type: Type[DataModel] = data_model_type
297
- self.data_model_root_type: Type[DataModel] = data_model_root_type
298
- self.data_model_field_type: Type[DataModelFieldBase] = data_model_field_type
299
- self.imports: Imports = Imports()
300
- self.base_class: Optional[str] = base_class
759
+ self.data_model_type: type[DataModel] = data_model_type
760
+ self.data_model_root_type: type[DataModel] = data_model_root_type
761
+ self.data_model_field_type: type[DataModelFieldBase] = data_model_field_type
762
+
763
+ self.imports: Imports = Imports(use_exact_imports)
764
+ self.use_exact_imports: bool = use_exact_imports
765
+ self._append_additional_imports(additional_imports=additional_imports)
766
+
767
+ self.base_class: str | None = base_class
301
768
  self.target_python_version: PythonVersion = target_python_version
302
- self.results: List[DataModel] = []
303
- self.dump_resolve_reference_action: Optional[
304
- Callable[[Iterable[str]], str]
305
- ] = dump_resolve_reference_action
769
+ self.results: list[DataModel] = []
770
+ self.dump_resolve_reference_action: Callable[[Iterable[str]], str] | None = dump_resolve_reference_action
306
771
  self.validation: bool = validation
307
772
  self.field_constraints: bool = field_constraints
308
773
  self.snake_case_field: bool = snake_case_field
309
774
  self.strip_default_none: bool = strip_default_none
310
- self.apply_default_values_for_required_fields: bool = (
311
- apply_default_values_for_required_fields
312
- )
313
- self.force_optional_for_required_fields: bool = (
314
- force_optional_for_required_fields
315
- )
775
+ self.apply_default_values_for_required_fields: bool = apply_default_values_for_required_fields
776
+ self.force_optional_for_required_fields: bool = force_optional_for_required_fields
316
777
  self.use_schema_description: bool = use_schema_description
778
+ self.use_field_description: bool = use_field_description
779
+ self.use_inline_field_description: bool = use_inline_field_description
780
+ self.use_default_kwarg: bool = use_default_kwarg
317
781
  self.reuse_model: bool = reuse_model
782
+ self.reuse_scope: ReuseScope | None = reuse_scope
783
+ self.shared_module_name: str = shared_module_name
318
784
  self.encoding: str = encoding
319
- self.enum_field_as_literal: Optional[LiteralType] = enum_field_as_literal
785
+ self.enum_field_as_literal: LiteralType | None = enum_field_as_literal
320
786
  self.set_default_enum_member: bool = set_default_enum_member
787
+ self.use_subclass_enum: bool = use_subclass_enum
788
+ self.use_specialized_enum: bool = use_specialized_enum
321
789
  self.strict_nullable: bool = strict_nullable
322
790
  self.use_generic_container_types: bool = use_generic_container_types
791
+ self.use_union_operator: bool = use_union_operator
323
792
  self.enable_faux_immutability: bool = enable_faux_immutability
324
- self.custom_class_name_generator: Optional[
325
- Callable[[str], str]
326
- ] = custom_class_name_generator
327
- self.field_extra_keys: Set[str] = field_extra_keys or set()
793
+ self.custom_class_name_generator: Callable[[str], str] | None = custom_class_name_generator
794
+ self.field_extra_keys: set[str] = field_extra_keys or set()
795
+ self.field_extra_keys_without_x_prefix: set[str] = field_extra_keys_without_x_prefix or set()
328
796
  self.field_include_all_keys: bool = field_include_all_keys
329
797
 
330
- self.remote_text_cache: DefaultPutDict[str, str] = (
331
- remote_text_cache or DefaultPutDict()
332
- )
333
- self.current_source_path: Optional[Path] = None
798
+ self.remote_text_cache: DefaultPutDict[str, str] = remote_text_cache or DefaultPutDict()
799
+ self.current_source_path: Path | None = None
334
800
  self.use_title_as_name: bool = use_title_as_name
801
+ self.use_operation_id_as_name: bool = use_operation_id_as_name
802
+ self.use_unique_items_as_set: bool = use_unique_items_as_set
803
+ self.allof_merge_mode: AllOfMergeMode = allof_merge_mode
804
+ self.dataclass_arguments = dataclass_arguments
335
805
 
336
806
  if base_path:
337
807
  self.base_path = base_path
338
808
  elif isinstance(source, Path):
339
- self.base_path = (
340
- source.absolute() if source.is_dir() else source.absolute().parent
341
- )
809
+ self.base_path = source.absolute() if source.is_dir() else source.absolute().parent
342
810
  else:
343
811
  self.base_path = Path.cwd()
344
812
 
345
- self.source: Union[str, Path, List[Path], ParseResult] = source
813
+ self.source: str | Path | list[Path] | ParseResult = source
346
814
  self.custom_template_dir = custom_template_dir
347
- self.extra_template_data: DefaultDict[
348
- str, Any
349
- ] = extra_template_data or defaultdict(dict)
815
+ self.extra_template_data: defaultdict[str, Any] = extra_template_data or defaultdict(dict)
350
816
 
351
817
  if allow_population_by_field_name:
352
- self.extra_template_data[ALL_MODEL]['allow_population_by_field_name'] = True
818
+ self.extra_template_data[ALL_MODEL]["allow_population_by_field_name"] = True
819
+
820
+ if allow_extra_fields:
821
+ self.extra_template_data[ALL_MODEL]["allow_extra_fields"] = True
822
+
823
+ if extra_fields:
824
+ self.extra_template_data[ALL_MODEL]["extra_fields"] = extra_fields
353
825
 
354
826
  if enable_faux_immutability:
355
- self.extra_template_data[ALL_MODEL]['allow_mutation'] = False
827
+ self.extra_template_data[ALL_MODEL]["allow_mutation"] = False
828
+
829
+ if use_attribute_docstrings:
830
+ self.extra_template_data[ALL_MODEL]["use_attribute_docstrings"] = True
356
831
 
357
832
  self.model_resolver = ModelResolver(
358
833
  base_url=source.geturl() if isinstance(source, ParseResult) else None,
359
- singular_name_suffix='' if disable_appending_item_suffix else None,
834
+ singular_name_suffix="" if disable_appending_item_suffix else None,
360
835
  aliases=aliases,
361
836
  empty_field_name=empty_enum_field_name,
362
837
  snake_case_field=snake_case_field,
363
838
  custom_class_name_generator=custom_class_name_generator,
364
839
  base_path=self.base_path,
840
+ original_field_name_delimiter=original_field_name_delimiter,
841
+ special_field_name_prefix=special_field_name_prefix,
842
+ remove_special_field_name_prefix=remove_special_field_name_prefix,
843
+ capitalise_enum_members=capitalise_enum_members,
844
+ no_alias=no_alias,
845
+ parent_scoped_naming=parent_scoped_naming,
846
+ treat_dot_as_module=treat_dot_as_module,
365
847
  )
366
- self.class_name: Optional[str] = class_name
367
- self.wrap_string_literal: Optional[bool] = wrap_string_literal
368
- self.http_headers: Optional[Sequence[Tuple[str, str]]] = http_headers
848
+ self.class_name: str | None = class_name
849
+ self.wrap_string_literal: bool | None = wrap_string_literal
850
+ self.http_headers: Sequence[tuple[str, str]] | None = http_headers
851
+ self.http_query_parameters: Sequence[tuple[str, str]] | None = http_query_parameters
852
+ self.http_ignore_tls: bool = http_ignore_tls
369
853
  self.use_annotated: bool = use_annotated
370
854
  if self.use_annotated and not self.field_constraints: # pragma: no cover
371
- raise Exception(
372
- '`use_annotated=True` has to be used with `field_constraints=True`'
373
- )
855
+ msg = "`use_annotated=True` has to be used with `field_constraints=True`"
856
+ raise Exception(msg) # noqa: TRY002
857
+ self.use_serialize_as_any: bool = use_serialize_as_any
858
+ self.use_non_positive_negative_number_constrained_types = use_non_positive_negative_number_constrained_types
859
+ self.use_double_quotes = use_double_quotes
860
+ self.allow_responses_without_content = allow_responses_without_content
861
+ self.collapse_root_models = collapse_root_models
862
+ self.skip_root_model = skip_root_model
863
+ self.use_type_alias = use_type_alias
864
+ self.capitalise_enum_members = capitalise_enum_members
865
+ self.keep_model_order = keep_model_order
866
+ self.use_one_literal_as_default = use_one_literal_as_default
867
+ self.use_enum_values_in_discriminator = use_enum_values_in_discriminator
868
+ self.known_third_party = known_third_party
869
+ self.custom_formatter = custom_formatters
870
+ self.custom_formatters_kwargs = custom_formatters_kwargs
871
+ self.treat_dot_as_module = treat_dot_as_module
872
+ self.default_field_extras: dict[str, Any] | None = default_field_extras
873
+ self.formatters: list[Formatter] = formatters
874
+ self.type_mappings: dict[tuple[str, str], str] = Parser._parse_type_mappings(type_mappings)
875
+ self.read_only_write_only_model_type: ReadOnlyWriteOnlyModelType | None = read_only_write_only_model_type
876
+ self.use_frozen_field: bool = use_frozen_field
877
+
878
+ @property
879
+ def field_name_model_type(self) -> ModelType:
880
+ """Get the ModelType for field name validation based on data_model_type.
881
+
882
+ Returns ModelType.PYDANTIC for Pydantic models (which have reserved attributes
883
+ like 'schema', 'model_fields', etc.), and ModelType.CLASS for other model types
884
+ (TypedDict, dataclass, msgspec) which don't have such constraints.
885
+ """
886
+ if issubclass(
887
+ self.data_model_type,
888
+ (pydantic_model.BaseModel, pydantic_model_v2.BaseModel),
889
+ ):
890
+ return ModelType.PYDANTIC
891
+ return ModelType.CLASS
892
+
893
+ @staticmethod
894
+ def _parse_type_mappings(type_mappings: list[str] | None) -> dict[tuple[str, str], str]:
895
+ """Parse type mappings from CLI format to internal format.
896
+
897
+ Supports two formats:
898
+ - "type+format=target" (e.g., "string+binary=string")
899
+ - "format=target" (e.g., "binary=string", assumes type="string")
900
+
901
+ Returns a dict mapping (type, format) tuples to target type names.
902
+ """
903
+ if not type_mappings:
904
+ return {}
905
+
906
+ result: dict[tuple[str, str], str] = {}
907
+ for mapping in type_mappings:
908
+ if "=" not in mapping:
909
+ msg = f"Invalid type mapping format: {mapping!r}. Expected 'type+format=target' or 'format=target'."
910
+ raise ValueError(msg)
911
+
912
+ source, target = mapping.split("=", 1)
913
+ if "+" in source:
914
+ type_, format_ = source.split("+", 1)
915
+ else:
916
+ # Default to "string" type if only format is specified
917
+ type_ = "string"
918
+ format_ = source
919
+
920
+ result[type_, format_] = target
921
+
922
+ return result
374
923
 
375
924
  @property
376
925
  def iter_source(self) -> Iterator[Source]:
926
+ """Iterate over all source files to be parsed."""
377
927
  if isinstance(self.source, str):
378
928
  yield Source(path=Path(), text=self.source)
379
929
  elif isinstance(self.source, Path): # pragma: no cover
380
930
  if self.source.is_dir():
381
- for path in self.source.rglob('*'):
931
+ for path in sorted(self.source.rglob("*"), key=lambda p: p.name):
382
932
  if path.is_file():
383
933
  yield Source.from_path(path, self.base_path, self.encoding)
384
934
  else:
@@ -389,266 +939,1762 @@ class Parser(ABC):
389
939
  else:
390
940
  yield Source(
391
941
  path=Path(self.source.path),
392
- text=self.remote_text_cache.get_or_put(
393
- self.source.geturl(), default_factory=self._get_text_from_url
394
- ),
942
+ text=self.remote_text_cache.get_or_put(self.source.geturl(), default_factory=self._get_text_from_url),
395
943
  )
396
944
 
945
+ def _append_additional_imports(self, additional_imports: list[str] | None) -> None:
946
+ if additional_imports is None:
947
+ additional_imports = []
948
+
949
+ for additional_import_string in additional_imports:
950
+ if additional_import_string is None: # pragma: no cover
951
+ continue
952
+ new_import = Import.from_full_path(additional_import_string)
953
+ self.imports.append(new_import)
954
+
397
955
  def _get_text_from_url(self, url: str) -> str:
398
- from datamodel_code_generator.http import get_body
956
+ from datamodel_code_generator.http import get_body # noqa: PLC0415
399
957
 
400
958
  return self.remote_text_cache.get_or_put(
401
- url, default_factory=lambda url_: get_body(url, self.http_headers)
959
+ url,
960
+ default_factory=lambda _url: get_body(
961
+ url, self.http_headers, self.http_ignore_tls, self.http_query_parameters
962
+ ),
402
963
  )
403
964
 
404
965
  @classmethod
405
- def get_url_path_parts(cls, url: ParseResult) -> List[str]:
966
+ def get_url_path_parts(cls, url: ParseResult) -> list[str]:
967
+ """Split URL into scheme/host and path components."""
406
968
  return [
407
- f'{url.scheme}://{url.hostname}',
408
- *url.path.split('/')[1:],
969
+ f"{url.scheme}://{url.hostname}",
970
+ *url.path.split("/")[1:],
409
971
  ]
410
972
 
411
973
  @property
412
- def data_type(self) -> Type[DataType]:
974
+ def data_type(self) -> type[DataType]:
975
+ """Get the DataType class from the type manager."""
413
976
  return self.data_type_manager.data_type
414
977
 
415
978
  @abstractmethod
416
979
  def parse_raw(self) -> None:
980
+ """Parse the raw schema source. Must be implemented by subclasses."""
417
981
  raise NotImplementedError
418
982
 
419
- def parse(
983
+ @classmethod
984
+ def _replace_model_in_list(
985
+ cls,
986
+ models: list[DataModel],
987
+ original: DataModel,
988
+ replacement: DataModel,
989
+ ) -> None:
990
+ """Replace model at its position in list."""
991
+ models.insert(models.index(original), replacement)
992
+ models.remove(original)
993
+
994
+ def __delete_duplicate_models(self, models: list[DataModel]) -> None:
995
+ model_class_names: dict[str, DataModel] = {}
996
+ model_to_duplicate_models: defaultdict[DataModel, list[DataModel]] = defaultdict(list)
997
+ for model in models.copy():
998
+ if isinstance(model, self.data_model_root_type):
999
+ root_data_type = model.fields[0].data_type
1000
+
1001
+ # backward compatible
1002
+ # Remove duplicated root model
1003
+ if (
1004
+ root_data_type.reference
1005
+ and not root_data_type.is_dict
1006
+ and not root_data_type.is_list
1007
+ and root_data_type.reference.source in models
1008
+ and root_data_type.reference.name
1009
+ == self.model_resolver.get_class_name(model.reference.original_name, unique=False).name
1010
+ ):
1011
+ model.reference.replace_children_references(root_data_type.reference)
1012
+ models.remove(model)
1013
+ for data_type in model.all_data_types:
1014
+ if data_type.reference:
1015
+ data_type.remove_reference()
1016
+ continue
1017
+
1018
+ # Remove self from all DataModel children's base_classes
1019
+ for child in model.reference.iter_data_model_children():
1020
+ child.base_classes = [bc for bc in child.base_classes if bc.reference != model.reference]
1021
+ if not child.base_classes: # pragma: no cover
1022
+ child.set_base_class()
1023
+
1024
+ class_name = model.duplicate_class_name or model.class_name
1025
+ if class_name in model_class_names:
1026
+ original_model = model_class_names[class_name]
1027
+ if model.get_dedup_key(model.duplicate_class_name, use_default=False) == original_model.get_dedup_key(
1028
+ original_model.duplicate_class_name, use_default=False
1029
+ ):
1030
+ model_to_duplicate_models[original_model].append(model)
1031
+ continue
1032
+ model_class_names[class_name] = model
1033
+ for model, duplicate_models in model_to_duplicate_models.items():
1034
+ for duplicate_model in duplicate_models:
1035
+ duplicate_model.reference.replace_children_references(model.reference)
1036
+ # Deduplicate base_classes in all DataModel children
1037
+ for child in duplicate_model.reference.iter_data_model_children():
1038
+ child.base_classes = list(
1039
+ {f"{c.module_name}.{c.type_hint}": c for c in child.base_classes}.values()
1040
+ )
1041
+ models.remove(duplicate_model)
1042
+
1043
+ @classmethod
1044
+ def __replace_duplicate_name_in_module(cls, models: list[DataModel]) -> None:
1045
+ scoped_model_resolver = ModelResolver(
1046
+ exclude_names={i.alias or i.import_ for m in models for i in m.imports},
1047
+ duplicate_name_suffix="Model",
1048
+ )
1049
+
1050
+ model_names: dict[str, DataModel] = {}
1051
+ for model in models:
1052
+ class_name: str = model.class_name
1053
+ generated_name: str = scoped_model_resolver.add([model.path], class_name, unique=True, class_name=True).name
1054
+ if class_name != generated_name:
1055
+ model.class_name = generated_name
1056
+ model_names[model.class_name] = model
1057
+
1058
+ for model in models:
1059
+ duplicate_name = model.duplicate_class_name
1060
+ # check only first desired name
1061
+ if duplicate_name and duplicate_name not in model_names:
1062
+ del model_names[model.class_name]
1063
+ model.class_name = duplicate_name
1064
+ model_names[duplicate_name] = model
1065
+
1066
+ def __change_from_import( # noqa: PLR0913, PLR0914
420
1067
  self,
421
- with_import: Optional[bool] = True,
422
- format_: Optional[bool] = True,
423
- settings_path: Optional[Path] = None,
424
- ) -> Union[str, Dict[Tuple[str, ...], Result]]:
1068
+ models: list[DataModel],
1069
+ imports: Imports,
1070
+ scoped_model_resolver: ModelResolver,
1071
+ *,
1072
+ init: bool,
1073
+ internal_modules: set[tuple[str, ...]] | None = None,
1074
+ model_path_to_module_name: dict[str, str] | None = None,
1075
+ ) -> None:
1076
+ model_paths = {model.path for model in models}
1077
+ internal_modules = internal_modules or set()
1078
+ model_path_to_module_name = model_path_to_module_name or {}
1079
+
1080
+ for model in models:
1081
+ scoped_model_resolver.add([model.path], model.class_name)
1082
+ for model in models:
1083
+ before_import = model.imports
1084
+ imports.append(before_import)
1085
+ current_module_name = model_path_to_module_name.get(model.path, model.module_name)
1086
+ for data_type in model.all_data_types:
1087
+ if not data_type.reference or data_type.reference.path in model_paths:
1088
+ continue
1089
+
1090
+ ref_module_name = model_path_to_module_name.get(
1091
+ data_type.reference.path,
1092
+ data_type.full_name.rsplit(".", 1)[0] if "." in data_type.full_name else "",
1093
+ )
1094
+ target_full_name = (
1095
+ f"{ref_module_name}.{data_type.reference.short_name}"
1096
+ if ref_module_name
1097
+ else data_type.reference.short_name
1098
+ )
425
1099
 
426
- self.parse_raw()
1100
+ if isinstance(data_type, BaseClassDataType):
1101
+ left, right = relative(current_module_name, target_full_name)
1102
+ is_ancestor = is_ancestor_package_reference(current_module_name, target_full_name)
1103
+ from_ = left if is_ancestor else (f"{left}{right}" if left.endswith(".") else f"{left}.{right}")
1104
+ import_ = data_type.reference.short_name
1105
+ full_path = from_, import_
1106
+ else:
1107
+ from_, import_ = full_path = relative(current_module_name, target_full_name)
1108
+ if imports.use_exact:
1109
+ from_, import_ = exact_import(from_, import_, data_type.reference.short_name)
1110
+ import_ = import_.replace("-", "_")
1111
+ current_module_path = tuple(current_module_name.split(".")) if current_module_name else ()
1112
+ if ( # pragma: no cover
1113
+ len(current_module_path) > 1
1114
+ and current_module_path[-1].count(".") > 0
1115
+ and not self.treat_dot_as_module
1116
+ ):
1117
+ rel_path_depth = current_module_path[-1].count(".")
1118
+ from_ = from_[rel_path_depth:]
427
1119
 
428
- if with_import:
429
- if self.target_python_version != PythonVersion.PY_36:
430
- self.imports.append(IMPORT_ANNOTATIONS)
1120
+ ref_module = tuple(target_full_name.split(".")[:-1])
431
1121
 
432
- if format_:
433
- code_formatter: Optional[CodeFormatter] = CodeFormatter(
434
- self.target_python_version, settings_path, self.wrap_string_literal
435
- )
1122
+ is_module_class_collision = (
1123
+ ref_module and import_ == data_type.reference.short_name and ref_module[-1] == import_
1124
+ )
1125
+
1126
+ if from_ and (ref_module in internal_modules or is_module_class_collision):
1127
+ from_ = f"{from_}{import_}" if from_.endswith(".") else f"{from_}.{import_}"
1128
+ import_ = data_type.reference.short_name
1129
+ full_path = from_, import_
1130
+
1131
+ alias = scoped_model_resolver.add(full_path, import_).name
1132
+
1133
+ name = data_type.reference.short_name
1134
+ if from_ and import_ and alias != name:
1135
+ data_type.alias = alias if data_type.reference.short_name == import_ else f"{alias}.{name}"
1136
+
1137
+ if init and not target_full_name.startswith(current_module_name + "."):
1138
+ from_ = "." + from_
1139
+ imports.append(
1140
+ Import(
1141
+ from_=from_,
1142
+ import_=import_,
1143
+ alias=alias,
1144
+ reference_path=data_type.reference.path,
1145
+ ),
1146
+ )
1147
+ after_import = model.imports
1148
+ if before_import != after_import:
1149
+ imports.append(after_import)
1150
+
1151
+ @classmethod
1152
+ def __extract_inherited_enum(cls, models: list[DataModel]) -> None:
1153
+ for model in models.copy():
1154
+ if model.fields:
1155
+ continue
1156
+ enums: list[Enum] = []
1157
+ for base_model in model.base_classes:
1158
+ if not base_model.reference:
1159
+ continue
1160
+ source_model = base_model.reference.source
1161
+ if isinstance(source_model, Enum):
1162
+ enums.append(source_model)
1163
+ if enums:
1164
+ merged_enum = enums[0].__class__(
1165
+ fields=[f for e in enums for f in e.fields],
1166
+ description=model.description,
1167
+ reference=model.reference,
1168
+ )
1169
+ cls._replace_model_in_list(models, model, merged_enum)
1170
+
1171
+ def _create_discriminator_data_type(
1172
+ self,
1173
+ enum_source: Enum | None,
1174
+ type_names: list[str],
1175
+ discriminator_model: DataModel,
1176
+ imports: Imports,
1177
+ ) -> DataType:
1178
+ """Create a data type for discriminator field, using enum literals if available."""
1179
+ if enum_source:
1180
+ enum_class_name = enum_source.reference.short_name
1181
+ enum_member_literals: list[tuple[str, str]] = []
1182
+ for value in type_names:
1183
+ member = enum_source.find_member(value)
1184
+ if member and member.field.name:
1185
+ enum_member_literals.append((enum_class_name, member.field.name))
1186
+ else: # pragma: no cover
1187
+ enum_member_literals.append((enum_class_name, value))
1188
+ data_type = self.data_type(enum_member_literals=enum_member_literals)
1189
+ if enum_source.module_path != discriminator_model.module_path: # pragma: no cover
1190
+ imports.append(Import.from_full_path(enum_source.name))
436
1191
  else:
437
- code_formatter = None
1192
+ data_type = self.data_type(literals=type_names)
1193
+ return data_type
438
1194
 
439
- _, sorted_data_models, require_update_action_models = sort_data_models(
440
- self.results
441
- )
1195
+ def __apply_discriminator_type( # noqa: PLR0912, PLR0914, PLR0915
1196
+ self,
1197
+ models: list[DataModel],
1198
+ imports: Imports,
1199
+ ) -> None:
1200
+ for model in models: # noqa: PLR1702
1201
+ for field in model.fields:
1202
+ discriminator = field.extras.get("discriminator")
1203
+ if not discriminator or not isinstance(discriminator, dict):
1204
+ continue
1205
+ property_name = discriminator.get("propertyName")
1206
+ if not property_name: # pragma: no cover
1207
+ continue
1208
+ field_name, alias = self.model_resolver.get_valid_field_name_and_alias(
1209
+ field_name=property_name, model_type=self.field_name_model_type
1210
+ )
1211
+ discriminator["propertyName"] = field_name
1212
+ mapping = discriminator.get("mapping", {})
1213
+ for data_type in field.data_type.data_types:
1214
+ if not data_type.reference: # pragma: no cover
1215
+ continue
1216
+ discriminator_model = data_type.reference.source
1217
+
1218
+ if not isinstance( # pragma: no cover
1219
+ discriminator_model,
1220
+ (
1221
+ pydantic_model.BaseModel,
1222
+ pydantic_model_v2.BaseModel,
1223
+ dataclass_model.DataClass,
1224
+ msgspec_model.Struct,
1225
+ ),
1226
+ ):
1227
+ continue # pragma: no cover
1228
+
1229
+ type_names: list[str] = []
1230
+
1231
+ def check_paths(
1232
+ model: pydantic_model.BaseModel | pydantic_model_v2.BaseModel | Reference,
1233
+ mapping: dict[str, str],
1234
+ type_names: list[str] = type_names,
1235
+ ) -> None:
1236
+ """Validate discriminator mapping paths for a model."""
1237
+ for name, path in mapping.items():
1238
+ if (model.path.split("#/")[-1] != path.split("#/")[-1]) and (
1239
+ path.startswith("#/") or model.path[:-1] != path.split("/")[-1]
1240
+ ):
1241
+ t_path = path[str(path).find("/") + 1 :]
1242
+ t_disc = model.path[: str(model.path).find("#")].lstrip("../") # noqa: B005
1243
+ t_disc_2 = "/".join(t_disc.split("/")[1:])
1244
+ if t_path not in {t_disc, t_disc_2}:
1245
+ continue
1246
+ type_names.append(name)
1247
+
1248
+ # First try to get the discriminator value from the const field
1249
+ for discriminator_field in discriminator_model.fields:
1250
+ if field_name not in {discriminator_field.original_name, discriminator_field.name}:
1251
+ continue
1252
+ if discriminator_field.extras.get("const"):
1253
+ type_names = [discriminator_field.extras["const"]]
1254
+ break
1255
+
1256
+ # If no const value found, try to get it from the mapping
1257
+ if not type_names:
1258
+ # Check the main discriminator model path
1259
+ if mapping:
1260
+ check_paths(discriminator_model, mapping) # pyright: ignore[reportArgumentType]
1261
+
1262
+ # Check the base_classes if they exist
1263
+ if len(type_names) == 0:
1264
+ for base_class in discriminator_model.base_classes:
1265
+ check_paths(base_class.reference, mapping) # pyright: ignore[reportArgumentType]
1266
+ else:
1267
+ for discriminator_field in discriminator_model.fields:
1268
+ if field_name not in {discriminator_field.original_name, discriminator_field.name}:
1269
+ continue
1270
+
1271
+ literals = discriminator_field.data_type.literals
1272
+ if literals and len(literals) == 1: # pragma: no cover
1273
+ type_names = [str(v) for v in literals]
1274
+ break
1275
+
1276
+ enum_source = discriminator_field.data_type.find_source(Enum)
1277
+ if enum_source and len(enum_source.fields) == 1:
1278
+ first_field = enum_source.fields[0]
1279
+ raw_default = first_field.default
1280
+ if isinstance(raw_default, str):
1281
+ type_names = [raw_default.strip("'\"")]
1282
+ else: # pragma: no cover
1283
+ type_names = [str(raw_default)]
1284
+ break
1285
+
1286
+ if not type_names:
1287
+ type_names = [discriminator_model.path.split("/")[-1]]
1288
+
1289
+ if not type_names: # pragma: no cover
1290
+ msg = f"Discriminator type is not found. {data_type.reference.path}"
1291
+ raise RuntimeError(msg)
1292
+
1293
+ enum_from_base: Enum | None = None
1294
+ if self.use_enum_values_in_discriminator:
1295
+ for base_class in discriminator_model.base_classes:
1296
+ if not base_class.reference or not base_class.reference.source: # pragma: no cover
1297
+ continue
1298
+ base_model = base_class.reference.source
1299
+ if not isinstance( # pragma: no cover
1300
+ base_model,
1301
+ (
1302
+ pydantic_model.BaseModel,
1303
+ pydantic_model_v2.BaseModel,
1304
+ dataclass_model.DataClass,
1305
+ msgspec_model.Struct,
1306
+ ),
1307
+ ):
1308
+ continue
1309
+ for base_field in base_model.fields: # pragma: no branch
1310
+ if field_name not in {base_field.original_name, base_field.name}: # pragma: no cover
1311
+ continue
1312
+ enum_from_base = base_field.data_type.find_source(Enum)
1313
+ if enum_from_base: # pragma: no branch
1314
+ break
1315
+ if enum_from_base: # pragma: no branch
1316
+ break
1317
+
1318
+ has_one_literal = False
1319
+ for discriminator_field in discriminator_model.fields:
1320
+ if field_name not in {discriminator_field.original_name, discriminator_field.name}:
1321
+ continue
1322
+ literals = discriminator_field.data_type.literals
1323
+ const_value = discriminator_field.extras.get("const")
1324
+ expected_value = type_names[0] if type_names else None
1325
+
1326
+ # Check if literals match (existing behavior)
1327
+ literals_match = len(literals) == 1 and literals[0] == expected_value
1328
+ # Check if const value matches (for msgspec with type: string + const)
1329
+ const_match = const_value is not None and const_value == expected_value
1330
+
1331
+ if literals_match:
1332
+ has_one_literal = True
1333
+ if isinstance(discriminator_model, msgspec_model.Struct): # pragma: no cover
1334
+ discriminator_model.add_base_class_kwarg("tag_field", f"'{field_name}'")
1335
+ discriminator_model.add_base_class_kwarg("tag", discriminator_field.represented_default)
1336
+ discriminator_field.extras["is_classvar"] = True
1337
+ # Found the discriminator field, no need to keep looking
1338
+ break
1339
+
1340
+ # For msgspec with const value but no literal (type: string + const case)
1341
+ if const_match and isinstance(discriminator_model, msgspec_model.Struct): # pragma: no cover
1342
+ has_one_literal = True
1343
+ discriminator_model.add_base_class_kwarg("tag_field", f"'{field_name}'")
1344
+ discriminator_model.add_base_class_kwarg("tag", repr(const_value))
1345
+ discriminator_field.extras["is_classvar"] = True
1346
+ break
1347
+
1348
+ enum_source: Enum | None = None
1349
+ if self.use_enum_values_in_discriminator:
1350
+ enum_source = ( # pragma: no cover
1351
+ discriminator_field.data_type.find_source(Enum) or enum_from_base
1352
+ )
1353
+
1354
+ for field_data_type in discriminator_field.data_type.all_data_types:
1355
+ if field_data_type.reference: # pragma: no cover
1356
+ field_data_type.remove_reference()
1357
+
1358
+ discriminator_field.data_type = self._create_discriminator_data_type(
1359
+ enum_source, type_names, discriminator_model, imports
1360
+ )
1361
+ discriminator_field.data_type.parent = discriminator_field
1362
+ discriminator_field.required = True
1363
+ imports.append(discriminator_field.imports)
1364
+ has_one_literal = True
1365
+ if not has_one_literal:
1366
+ new_data_type = self._create_discriminator_data_type(
1367
+ enum_from_base, type_names, discriminator_model, imports
1368
+ )
1369
+ discriminator_model.fields.append(
1370
+ self.data_model_field_type(
1371
+ name=field_name,
1372
+ data_type=new_data_type,
1373
+ required=True,
1374
+ alias=alias,
1375
+ )
1376
+ )
1377
+ has_imported_literal = any(import_ == IMPORT_LITERAL for import_ in imports)
1378
+ if has_imported_literal: # pragma: no cover
1379
+ imports.append(IMPORT_LITERAL)
442
1380
 
443
- results: Dict[Tuple[str, ...], Result] = {}
1381
+ @classmethod
1382
+ def _create_set_from_list(cls, data_type: DataType) -> DataType | None:
1383
+ if data_type.is_list:
1384
+ new_data_type = data_type.copy()
1385
+ new_data_type.is_list = False
1386
+ new_data_type.is_set = True
1387
+ for data_type_ in new_data_type.data_types:
1388
+ data_type_.parent = new_data_type
1389
+ return new_data_type
1390
+ if data_type.data_types: # pragma: no cover
1391
+ for nested_data_type in data_type.data_types[:]:
1392
+ set_data_type = cls._create_set_from_list(nested_data_type)
1393
+ if set_data_type: # pragma: no cover
1394
+ nested_data_type.swap_with(set_data_type)
1395
+ return data_type
1396
+ return None # pragma: no cover
1397
+
1398
+ def __replace_unique_list_to_set(self, models: list[DataModel]) -> None:
1399
+ for model in models:
1400
+ for model_field in model.fields:
1401
+ if not self.use_unique_items_as_set:
1402
+ continue
1403
+
1404
+ if not (model_field.constraints and model_field.constraints.unique_items):
1405
+ continue
1406
+ set_data_type = self._create_set_from_list(model_field.data_type)
1407
+ if set_data_type: # pragma: no cover
1408
+ # Check if default list elements are hashable before converting type
1409
+ if isinstance(model_field.default, list):
1410
+ try:
1411
+ converted_default = set(model_field.default)
1412
+ except TypeError:
1413
+ # Elements are not hashable (e.g., contains dicts)
1414
+ # Skip both type and default conversion to keep consistency
1415
+ continue
1416
+ model_field.default = converted_default
1417
+ model_field.replace_data_type(set_data_type)
444
1418
 
445
- module_key = lambda x: x.module_path
1419
+ @classmethod
1420
+ def __set_reference_default_value_to_field(cls, models: list[DataModel]) -> None:
1421
+ for model in models:
1422
+ for model_field in model.fields:
1423
+ if not model_field.data_type.reference or model_field.has_default:
1424
+ continue
1425
+ if (
1426
+ isinstance(model_field.data_type.reference.source, DataModel)
1427
+ and model_field.data_type.reference.source.default != UNDEFINED
1428
+ ):
1429
+ # pragma: no cover
1430
+ model_field.default = model_field.data_type.reference.source.default
1431
+
1432
+ def __reuse_model(self, models: list[DataModel], require_update_action_models: list[str]) -> None:
1433
+ if not self.reuse_model or self.reuse_scope == ReuseScope.Tree:
1434
+ return
1435
+ model_cache: dict[tuple[HashableComparable, ...], Reference] = {}
1436
+ duplicates = []
1437
+ for model in models.copy():
1438
+ model_key = model.get_dedup_key()
1439
+ cached_model_reference = model_cache.get(model_key)
1440
+ if cached_model_reference:
1441
+ if isinstance(model, Enum):
1442
+ model.replace_children_in_models(models, cached_model_reference)
1443
+ duplicates.append(model)
1444
+ else:
1445
+ inherited_model = model.create_reuse_model(cached_model_reference)
1446
+ if cached_model_reference.path in require_update_action_models:
1447
+ add_model_path_to_list(require_update_action_models, inherited_model)
1448
+ self._replace_model_in_list(models, model, inherited_model)
1449
+ else:
1450
+ model_cache[model_key] = model.reference
446
1451
 
447
- # process in reverse order to correctly establish module levels
448
- grouped_models = groupby(
449
- sorted(sorted_data_models.values(), key=module_key, reverse=True),
450
- key=module_key,
1452
+ for duplicate in duplicates:
1453
+ models.remove(duplicate)
1454
+
1455
+ def __find_duplicate_models_across_modules( # noqa: PLR6301
1456
+ self,
1457
+ module_models: list[tuple[tuple[str, ...], list[DataModel]]],
1458
+ ) -> list[tuple[tuple[str, ...], DataModel, tuple[str, ...], DataModel]]:
1459
+ """Find duplicate models across all modules by comparing render output and imports."""
1460
+ all_models: list[tuple[tuple[str, ...], DataModel]] = []
1461
+ for module, models in module_models:
1462
+ all_models.extend((module, model) for model in models)
1463
+
1464
+ model_cache: dict[tuple[HashableComparable, ...], tuple[tuple[str, ...], DataModel]] = {}
1465
+ duplicates: list[tuple[tuple[str, ...], DataModel, tuple[str, ...], DataModel]] = []
1466
+
1467
+ for module, model in all_models:
1468
+ model_key = model.get_dedup_key()
1469
+ cached = model_cache.get(model_key)
1470
+ if cached:
1471
+ canonical_module, canonical_model = cached
1472
+ duplicates.append((module, model, canonical_module, canonical_model))
1473
+ else:
1474
+ model_cache[model_key] = (module, model)
1475
+
1476
+ return duplicates
1477
+
1478
+ def __validate_shared_module_name(
1479
+ self,
1480
+ module_models: list[tuple[tuple[str, ...], list[DataModel]]],
1481
+ ) -> None:
1482
+ """Validate that the shared module name doesn't conflict with existing modules."""
1483
+ shared_module = self.shared_module_name
1484
+ existing_module_names = {module[0] for module, _ in module_models}
1485
+ if shared_module in existing_module_names:
1486
+ msg = (
1487
+ f"Schema file or directory '{shared_module}' conflicts with the shared module name. "
1488
+ f"Use --shared-module-name to specify a different name."
1489
+ )
1490
+ raise Error(msg)
1491
+
1492
+ def __create_shared_module_from_duplicates( # noqa: PLR0912
1493
+ self,
1494
+ module_models: list[tuple[tuple[str, ...], list[DataModel]]],
1495
+ duplicates: list[tuple[tuple[str, ...], DataModel, tuple[str, ...], DataModel]],
1496
+ require_update_action_models: list[str],
1497
+ ) -> tuple[tuple[str, ...], list[DataModel]]:
1498
+ """Create shared module with canonical models and replace duplicates with inherited models."""
1499
+ shared_module = self.shared_module_name
1500
+
1501
+ shared_models: list[DataModel] = []
1502
+ canonical_to_shared_ref: dict[DataModel, Reference] = {}
1503
+ canonical_models_seen: set[DataModel] = set()
1504
+
1505
+ # Process in order of first appearance in duplicates to ensure stable ordering
1506
+ for _, _, _, canonical in duplicates:
1507
+ if canonical in canonical_models_seen:
1508
+ continue
1509
+ canonical_models_seen.add(canonical)
1510
+ canonical.file_path = Path(f"{shared_module}.py")
1511
+ canonical_to_shared_ref[canonical] = canonical.reference
1512
+ shared_models.append(canonical)
1513
+
1514
+ supports_inheritance = issubclass(
1515
+ self.data_model_type,
1516
+ (
1517
+ pydantic_model.BaseModel,
1518
+ pydantic_model_v2.BaseModel,
1519
+ dataclass_model.DataClass,
1520
+ ),
451
1521
  )
452
1522
 
453
- module_models: List[Tuple[Tuple[str, ...], List[DataModel]]] = []
1523
+ for duplicate_module, duplicate_model, _, canonical_model in duplicates:
1524
+ shared_ref = canonical_to_shared_ref[canonical_model]
1525
+ for module, models in module_models:
1526
+ if module != duplicate_module or duplicate_model not in models:
1527
+ continue
1528
+ if isinstance(duplicate_model, Enum) or not supports_inheritance:
1529
+ duplicate_model.replace_children_in_models(models, shared_ref)
1530
+ models.remove(duplicate_model)
1531
+ else:
1532
+ inherited_model = duplicate_model.create_reuse_model(shared_ref)
1533
+ if shared_ref.path in require_update_action_models:
1534
+ add_model_path_to_list(require_update_action_models, inherited_model)
1535
+ self._replace_model_in_list(models, duplicate_model, inherited_model)
1536
+ break
1537
+ else: # pragma: no cover
1538
+ msg = f"Duplicate model {duplicate_model.name} not found in module {duplicate_module}"
1539
+ raise RuntimeError(msg)
1540
+
1541
+ for canonical in canonical_models_seen:
1542
+ for _module, models in module_models:
1543
+ if canonical in models:
1544
+ models.remove(canonical)
1545
+ break
1546
+ else: # pragma: no cover
1547
+ msg = f"Canonical model {canonical.name} not found in any module"
1548
+ raise RuntimeError(msg)
1549
+
1550
+ return (shared_module,), shared_models
1551
+
1552
+ def __reuse_model_tree_scope(
1553
+ self,
1554
+ module_models: list[tuple[tuple[str, ...], list[DataModel]]],
1555
+ require_update_action_models: list[str],
1556
+ ) -> tuple[tuple[str, ...], list[DataModel]] | None:
1557
+ """Deduplicate models across all modules, placing shared models in shared.py."""
1558
+ if not self.reuse_model or self.reuse_scope != ReuseScope.Tree:
1559
+ return None
454
1560
 
455
- for module, models in (
456
- (k, [*v]) for k, v in grouped_models
457
- ): # type: Tuple[str, ...], List[DataModel]
1561
+ duplicates = self.__find_duplicate_models_across_modules(module_models)
1562
+ if not duplicates:
1563
+ return None
458
1564
 
459
- for model in models:
460
- if isinstance(model, self.data_model_root_type):
461
- root_data_type = model.fields[0].data_type
1565
+ self.__validate_shared_module_name(module_models)
1566
+ return self.__create_shared_module_from_duplicates(module_models, duplicates, require_update_action_models)
1567
+
1568
+ def __collapse_root_models( # noqa: PLR0912
1569
+ self,
1570
+ models: list[DataModel],
1571
+ unused_models: list[DataModel],
1572
+ imports: Imports,
1573
+ scoped_model_resolver: ModelResolver,
1574
+ ) -> None:
1575
+ if not self.collapse_root_models:
1576
+ return
1577
+
1578
+ for model in models: # noqa: PLR1702
1579
+ for model_field in model.fields:
1580
+ for data_type in model_field.data_type.all_data_types:
1581
+ reference = data_type.reference
1582
+ if not reference or not isinstance(reference.source, self.data_model_root_type):
1583
+ # If the data type is not a reference, we can't collapse it.
1584
+ # If it's a reference to a root model type, we don't do anything.
1585
+ continue
1586
+
1587
+ # Use root-type as model_field type
1588
+ root_type_model = reference.source
1589
+ root_type_field = root_type_model.fields[0]
462
1590
 
463
- # backward compatible
464
- # Remove duplicated root model
465
1591
  if (
466
- root_data_type.reference
467
- and not root_data_type.is_dict
468
- and not root_data_type.is_list
469
- and root_data_type.reference.source in models
470
- and root_data_type.reference.name
471
- == self.model_resolver.get_class_name(
472
- model.reference.original_name, unique=False
473
- )
1592
+ self.field_constraints
1593
+ and isinstance(root_type_field.constraints, ConstraintsBase)
1594
+ and root_type_field.constraints.has_constraints
1595
+ and any(d for d in model_field.data_type.all_data_types if d.is_dict or d.is_union or d.is_list)
474
1596
  ):
475
- # Replace referenced duplicate model to original model
476
- for child in model.reference.children[:]:
477
- child.replace_reference(root_data_type.reference)
478
- models.remove(model)
1597
+ continue # pragma: no cover
1598
+
1599
+ if root_type_field.data_type.reference:
1600
+ # If the root type field is a reference, we aren't able to collapse it yet.
479
1601
  continue
480
1602
 
481
- # Custom root model can't be inherited on restriction of Pydantic
482
- for child in model.reference.children:
483
- # inheritance model
484
- if isinstance(child, DataModel):
485
- for base_class in child.base_classes:
486
- if base_class.reference == model.reference:
487
- child.base_classes.remove(base_class)
488
- if not child.base_classes:
489
- child.set_base_class()
490
-
491
- module_models.append(
492
- (
493
- module,
494
- models,
495
- )
496
- )
1603
+ # set copied data_type
1604
+ copied_data_type = root_type_field.data_type.copy()
1605
+ if isinstance(data_type.parent, self.data_model_field_type):
1606
+ # for field
1607
+ # override empty field by root-type field
1608
+ model_field.extras = {
1609
+ **root_type_field.extras,
1610
+ **model_field.extras,
1611
+ }
1612
+ model_field.process_const()
1613
+
1614
+ if self.field_constraints:
1615
+ model_field.constraints = ConstraintsBase.merge_constraints(
1616
+ root_type_field.constraints, model_field.constraints
1617
+ )
497
1618
 
498
- scoped_model_resolver = ModelResolver(
499
- exclude_names={i.alias or i.import_ for m in models for i in m.imports},
500
- duplicate_name_suffix='Model',
501
- )
1619
+ data_type.parent.data_type = copied_data_type
502
1620
 
503
- for model in models:
504
- class_name: str = model.class_name
505
- generated_name: str = scoped_model_resolver.add(
506
- model.path, class_name, unique=True, class_name=True
507
- ).name
508
- if class_name != generated_name:
509
- if '.' in model.reference.name:
510
- model.reference.name = (
511
- f"{model.reference.name.rsplit('.', 1)[0]}.{generated_name}"
512
- )
1621
+ elif isinstance(data_type.parent, DataType) and data_type.parent.is_list:
1622
+ if self.field_constraints:
1623
+ model_field.constraints = ConstraintsBase.merge_constraints(
1624
+ root_type_field.constraints, model_field.constraints
1625
+ )
1626
+ if ( # pragma: no cover
1627
+ isinstance(
1628
+ root_type_field,
1629
+ pydantic_model.DataModelField,
1630
+ )
1631
+ and not model_field.extras.get("discriminator")
1632
+ and not any(t.is_list for t in model_field.data_type.data_types)
1633
+ ):
1634
+ discriminator = root_type_field.extras.get("discriminator")
1635
+ if discriminator:
1636
+ model_field.extras["discriminator"] = discriminator
1637
+ assert isinstance(data_type.parent, DataType)
1638
+ data_type.parent.data_types.remove(data_type) # pragma: no cover
1639
+ data_type.parent.data_types.append(copied_data_type)
1640
+
1641
+ elif isinstance(data_type.parent, DataType):
1642
+ # for data_type
1643
+ data_type_id = id(data_type)
1644
+ data_type.parent.data_types = [
1645
+ d for d in (*data_type.parent.data_types, copied_data_type) if id(d) != data_type_id
1646
+ ]
1647
+ else: # pragma: no cover
1648
+ continue
1649
+
1650
+ for d in copied_data_type.all_data_types:
1651
+ if d.reference is None:
1652
+ continue
1653
+ from_, import_ = full_path = relative(model.module_name, d.full_name)
1654
+ if from_ and import_:
1655
+ alias = scoped_model_resolver.add(full_path, import_)
1656
+ d.alias = (
1657
+ alias.name
1658
+ if d.reference.short_name == import_
1659
+ else f"{alias.name}.{d.reference.short_name}"
1660
+ )
1661
+ imports.append([
1662
+ Import(
1663
+ from_=from_,
1664
+ import_=import_,
1665
+ alias=alias.name,
1666
+ reference_path=d.reference.path,
1667
+ )
1668
+ ])
1669
+
1670
+ original_field = get_most_of_parent(data_type, DataModelFieldBase)
1671
+ if original_field: # pragma: no cover
1672
+ # TODO: Improve detection of reference type
1673
+ # Use list instead of set because Import is not hashable
1674
+ excluded_imports = [IMPORT_OPTIONAL, IMPORT_UNION]
1675
+ field_imports = [i for i in original_field.imports if i not in excluded_imports]
1676
+ imports.append(field_imports)
1677
+
1678
+ data_type.remove_reference()
1679
+
1680
+ assert isinstance(root_type_model, DataModel)
1681
+ root_type_model.reference.children = [
1682
+ c for c in root_type_model.reference.children if getattr(c, "parent", None)
1683
+ ]
1684
+
1685
+ imports.remove_referenced_imports(root_type_model.path)
1686
+ if not root_type_model.reference.children:
1687
+ unused_models.append(root_type_model)
1688
+
1689
+ def __set_default_enum_member(
1690
+ self,
1691
+ models: list[DataModel],
1692
+ ) -> None:
1693
+ if not self.set_default_enum_member:
1694
+ return
1695
+ for _, model_field, data_type in iter_models_field_data_types(models):
1696
+ if not model_field.default:
1697
+ continue
1698
+ if data_type.reference and isinstance(data_type.reference.source, Enum): # pragma: no cover
1699
+ if isinstance(model_field.default, list):
1700
+ enum_member: list[Member] | (Member | None) = [
1701
+ e for e in (data_type.reference.source.find_member(d) for d in model_field.default) if e
1702
+ ]
1703
+ else:
1704
+ enum_member = data_type.reference.source.find_member(model_field.default)
1705
+ if not enum_member:
1706
+ continue
1707
+ model_field.default = enum_member
1708
+ if data_type.alias:
1709
+ if isinstance(enum_member, list):
1710
+ for enum_member_ in enum_member:
1711
+ enum_member_.alias = data_type.alias
513
1712
  else:
514
- model.reference.name = generated_name
1713
+ enum_member.alias = data_type.alias
515
1714
 
516
- for module, models in module_models:
517
- init = False
518
- if module:
519
- parent = (*module[:-1], '__init__.py')
520
- if parent not in results:
521
- results[parent] = Result(body='')
522
- if (*module, '__init__.py') in results:
523
- module = (*module, '__init__.py')
524
- init = True
1715
+ def __wrap_root_model_default_values(
1716
+ self,
1717
+ models: list[DataModel],
1718
+ ) -> None:
1719
+ """Wrap RootModel reference default values with their type constructors."""
1720
+ if not self.use_annotated:
1721
+ return
1722
+ for model, model_field, data_type in iter_models_field_data_types(models):
1723
+ if isinstance(model, (Enum, self.data_model_root_type)):
1724
+ continue
1725
+ if model_field.default is None:
1726
+ continue
1727
+ if isinstance(model_field.default, (WrappedDefault, Member)):
1728
+ continue
1729
+ if isinstance(model_field.default, list):
1730
+ continue
1731
+ if data_type.reference and isinstance(data_type.reference.source, pydantic_model_v2.RootModel):
1732
+ # Use alias if available (handles import collisions)
1733
+ type_name = data_type.alias or data_type.reference.short_name
1734
+ model_field.default = WrappedDefault(
1735
+ value=model_field.default,
1736
+ type_name=type_name,
1737
+ )
1738
+
1739
+ def __override_required_field(
1740
+ self,
1741
+ models: list[DataModel],
1742
+ ) -> None:
1743
+ for model in models:
1744
+ if isinstance(model, (Enum, self.data_model_root_type)):
1745
+ continue
1746
+ for index, model_field in enumerate(model.fields[:]):
1747
+ data_type = model_field.data_type
1748
+ if (
1749
+ not model_field.original_name # noqa: PLR0916
1750
+ or data_type.data_types
1751
+ or data_type.reference
1752
+ or data_type.type
1753
+ or data_type.literals
1754
+ or data_type.dict_key
1755
+ ):
1756
+ continue
1757
+
1758
+ original_field = _find_field(model_field.original_name, _find_base_classes(model))
1759
+ if not original_field: # pragma: no cover
1760
+ model.fields.remove(model_field)
1761
+ continue
1762
+ copied_original_field = original_field.copy()
1763
+ if original_field.data_type.reference:
1764
+ data_type = self.data_type_manager.data_type(
1765
+ reference=original_field.data_type.reference,
1766
+ )
1767
+ elif original_field.data_type.data_types:
1768
+ data_type = original_field.data_type.copy()
1769
+ data_type.data_types = _copy_data_types(original_field.data_type.data_types)
1770
+ for data_type_ in data_type.data_types:
1771
+ data_type_.parent = data_type
525
1772
  else:
526
- module = (*module[:-1], f'{module[-1]}.py')
1773
+ data_type = original_field.data_type.copy()
1774
+ data_type.parent = copied_original_field
1775
+ copied_original_field.data_type = data_type
1776
+ copied_original_field.parent = model
1777
+ copied_original_field.required = True
1778
+ model.fields.insert(index, copied_original_field)
1779
+ model.fields.remove(model_field)
1780
+
1781
+ def __sort_models(
1782
+ self,
1783
+ models: list[DataModel],
1784
+ imports: Imports,
1785
+ *,
1786
+ use_deferred_annotations: bool,
1787
+ ) -> None:
1788
+ if not self.keep_model_order:
1789
+ return
1790
+
1791
+ _reorder_models_keep_model_order(models, imports, use_deferred_annotations=use_deferred_annotations)
1792
+
1793
+ def __change_field_name(
1794
+ self,
1795
+ models: list[DataModel],
1796
+ ) -> None:
1797
+ if not issubclass(self.data_model_type, pydantic_model_v2.BaseModel):
1798
+ return
1799
+ for model in models:
1800
+ if "Enum" in model.base_class:
1801
+ continue
1802
+
1803
+ for field in model.fields:
1804
+ filed_name = field.name
1805
+ filed_name_resolver = ModelResolver(snake_case_field=self.snake_case_field, remove_suffix_number=True)
1806
+ for data_type in field.data_type.all_data_types:
1807
+ if data_type.reference:
1808
+ filed_name_resolver.exclude_names.add(data_type.reference.short_name)
1809
+ new_filed_name = filed_name_resolver.add(["field"], cast("str", filed_name)).name
1810
+ if filed_name != new_filed_name:
1811
+ field.alias = filed_name
1812
+ field.name = new_filed_name
1813
+
1814
+ def __set_one_literal_on_default(self, models: list[DataModel]) -> None:
1815
+ if not self.use_one_literal_as_default:
1816
+ return
1817
+ for model in models:
1818
+ for model_field in model.fields:
1819
+ if not model_field.required or len(model_field.data_type.literals) != 1:
1820
+ continue
1821
+ model_field.default = model_field.data_type.literals[0]
1822
+ model_field.required = False
1823
+ if model_field.nullable is not True: # pragma: no cover
1824
+ model_field.nullable = False
1825
+
1826
+ def __fix_dataclass_field_ordering(self, models: list[DataModel]) -> None:
1827
+ """Fix field ordering for dataclasses with inheritance after defaults are set."""
1828
+ for model in models:
1829
+ if (inherited := self.__get_dataclass_inherited_info(model)) is None:
1830
+ continue
1831
+ inherited_names, has_default = inherited
1832
+ if not has_default or not any(self.__is_new_required_field(f, inherited_names) for f in model.fields):
1833
+ continue
1834
+
1835
+ if self.target_python_version.has_kw_only_dataclass:
1836
+ for field in model.fields:
1837
+ if self.__is_new_required_field(field, inherited_names):
1838
+ field.extras["kw_only"] = True
527
1839
  else:
528
- module = ('__init__.py',)
1840
+ warn(
1841
+ f"Dataclass '{model.class_name}' has a field ordering conflict due to inheritance. "
1842
+ f"An inherited field has a default value, but new required fields are added. "
1843
+ f"This will cause a TypeError at runtime. Consider using --target-python-version 3.10 "
1844
+ f"or higher to enable automatic field(kw_only=True) fix.",
1845
+ category=UserWarning,
1846
+ stacklevel=2,
1847
+ )
1848
+ model.fields = sorted(model.fields, key=dataclass_model.has_field_assignment)
529
1849
 
530
- result: List[str] = []
531
- imports = Imports()
532
- scoped_model_resolver = ModelResolver()
1850
+ @classmethod
1851
+ def __get_dataclass_inherited_info(cls, model: DataModel) -> tuple[set[str], bool] | None:
1852
+ """Get inherited field names and whether any has default. Returns None if not applicable."""
1853
+ if not isinstance(model, dataclass_model.DataClass):
1854
+ return None
1855
+ if not model.base_classes or model.dataclass_arguments.get("kw_only"):
1856
+ return None
1857
+
1858
+ inherited_names: set[str] = set()
1859
+ has_default = False
1860
+ for base in model.base_classes:
1861
+ if not base.reference or not isinstance(base.reference.source, DataModel):
1862
+ continue # pragma: no cover
1863
+ for f in base.reference.source.iter_all_fields():
1864
+ if not f.name or f.extras.get("init") is False:
1865
+ continue # pragma: no cover
1866
+ inherited_names.add(f.name)
1867
+ if dataclass_model.has_field_assignment(f):
1868
+ has_default = True
1869
+
1870
+ for f in model.fields:
1871
+ if f.name not in inherited_names or f.extras.get("init") is False:
1872
+ continue
1873
+ if dataclass_model.has_field_assignment(f): # pragma: no branch
1874
+ has_default = True
1875
+ return (inherited_names, has_default) if inherited_names else None
1876
+
1877
+ def __is_new_required_field(self, field: DataModelFieldBase, inherited: set[str]) -> bool: # noqa: PLR6301
1878
+ """Check if field is a new required init field."""
1879
+ return (
1880
+ field.name not in inherited
1881
+ and field.extras.get("init") is not False
1882
+ and not dataclass_model.has_field_assignment(field)
1883
+ )
533
1884
 
534
- for model in models:
535
- imports.append(model.imports)
536
- for data_type in model.all_data_types:
537
- # To change from/import
1885
+ @classmethod
1886
+ def __update_type_aliases(cls, models: list[DataModel]) -> None:
1887
+ """Update type aliases to properly handle forward references per PEP 484."""
1888
+ model_index: dict[str, int] = {m.class_name: i for i, m in enumerate(models)}
1889
+
1890
+ for i, model in enumerate(models):
1891
+ if not isinstance(model, TypeAliasBase):
1892
+ continue
1893
+ if isinstance(model, TypeStatement):
1894
+ continue
538
1895
 
539
- if not data_type.reference or data_type.reference.source in models:
540
- # No need to import non-reference model.
541
- # Or, Referenced model is in the same file. we don't need to import the model
1896
+ for field in model.fields:
1897
+ for data_type in field.data_type.all_data_types:
1898
+ if not data_type.reference:
1899
+ continue
1900
+ source = data_type.reference.source
1901
+ if not isinstance(source, DataModel):
1902
+ continue # pragma: no cover
1903
+ if isinstance(source, TypeStatement):
1904
+ continue # pragma: no cover
1905
+ if source.module_path != model.module_path:
542
1906
  continue
1907
+ name = data_type.reference.short_name
1908
+ source_index = model_index.get(name)
1909
+ if source_index is not None and source_index >= i:
1910
+ data_type.alias = f'"{name}"'
543
1911
 
544
- if isinstance(data_type, BaseClassDataType):
545
- from_ = ''.join(
546
- relative(model.module_name, data_type.full_name)
547
- )
548
- import_ = data_type.reference.short_name
549
- full_path = from_, import_
550
- else:
551
- from_, import_ = full_path = relative(
552
- model.module_name, data_type.full_name
553
- )
1912
+ @classmethod
1913
+ def __postprocess_result_modules(cls, results: dict[tuple[str, ...], Result]) -> dict[tuple[str, ...], Result]:
1914
+ def process(input_tuple: tuple[str, ...]) -> tuple[str, ...]:
1915
+ r = []
1916
+ for item in input_tuple:
1917
+ p = item.split(".")
1918
+ if len(p) > 1:
1919
+ r.extend(p[:-1])
1920
+ r.append(p[-1])
1921
+ else:
1922
+ r.append(item)
554
1923
 
555
- alias = scoped_model_resolver.add(full_path, import_).name
1924
+ if len(r) >= 2: # noqa: PLR2004
1925
+ r = [*r[:-2], f"{r[-2]}.{r[-1]}"]
1926
+ return tuple(r)
556
1927
 
557
- name = data_type.reference.short_name
558
- if from_ and import_ and alias != name:
559
- data_type.alias = f'{alias}.{name}'
1928
+ results = {process(k): v for k, v in results.items()}
560
1929
 
561
- if init:
562
- from_ += "."
563
- imports.append(Import(from_=from_, import_=import_, alias=alias))
1930
+ init_result = next(v for k, v in results.items() if k[-1] == "__init__.py")
1931
+ folders = {t[:-1] if t[-1].endswith(".py") else t for t in results}
1932
+ for folder in folders:
1933
+ for i in range(len(folder)):
1934
+ subfolder = folder[: i + 1]
1935
+ init_file = (*subfolder, "__init__.py")
1936
+ results.update({init_file: init_result})
1937
+ return results
564
1938
 
565
- if self.reuse_model:
566
- model_cache: Dict[Tuple[str, ...], Reference] = {}
567
- duplicates = []
568
- for model in models:
569
- model_key = tuple(
570
- to_hashable(v)
571
- for v in (
572
- model.base_classes,
573
- model.extra_template_data,
574
- model.fields,
575
- )
1939
+ def __change_imported_model_name( # noqa: PLR6301
1940
+ self,
1941
+ models: list[DataModel],
1942
+ imports: Imports,
1943
+ scoped_model_resolver: ModelResolver,
1944
+ ) -> None:
1945
+ imported_names = {
1946
+ imports.alias[from_][i] if i in imports.alias[from_] and i != imports.alias[from_][i] else i
1947
+ for from_, import_ in imports.items()
1948
+ for i in import_
1949
+ }
1950
+ for model in models:
1951
+ if model.class_name not in imported_names: # pragma: no cover
1952
+ continue
1953
+
1954
+ model.reference.name = scoped_model_resolver.add( # pragma: no cover
1955
+ path=get_special_path("imported_name", model.path.split("/")),
1956
+ original_name=model.reference.name,
1957
+ unique=True,
1958
+ class_name=True,
1959
+ ).name
1960
+
1961
+ def __alias_shadowed_imports( # noqa: PLR6301
1962
+ self,
1963
+ models: list[DataModel],
1964
+ all_model_field_names: set[str],
1965
+ ) -> None:
1966
+ for _, model_field, data_type in iter_models_field_data_types(models):
1967
+ if data_type and data_type.type in all_model_field_names and data_type.type == model_field.name:
1968
+ alias = data_type.type + "_aliased"
1969
+ data_type.type = alias
1970
+ if data_type.import_: # pragma: no cover
1971
+ data_type.import_ = Import(
1972
+ from_=data_type.import_.from_,
1973
+ import_=data_type.import_.import_,
1974
+ alias=alias,
1975
+ reference_path=data_type.import_.reference_path,
576
1976
  )
577
- cached_model_reference = model_cache.get(model_key)
578
- if cached_model_reference:
579
- if isinstance(model, Enum):
580
- for child in model.reference.children[:]:
581
- # child is resolved data_type by reference
582
- data_model = get_most_of_parent(child)
583
- # TODO: replace reference in all modules
584
- if data_model in models: # pragma: no cover
585
- child.replace_reference(cached_model_reference)
586
- duplicates.append(model)
587
- else:
588
- index = models.index(model)
589
- inherited_model = model.__class__(
590
- fields=[],
591
- base_classes=[cached_model_reference],
592
- description=model.description,
593
- reference=Reference(
594
- name=model.name,
595
- path=model.reference.path + '/reuse',
596
- ),
597
- )
598
- if (
599
- cached_model_reference.path
600
- in require_update_action_models
601
- ):
602
- require_update_action_models.append(
603
- inherited_model.path
604
- )
605
- models.insert(index, inherited_model)
606
- models.remove(model)
607
1977
 
608
- else:
609
- model_cache[model_key] = model.reference
1978
+ @classmethod
1979
+ def _collect_exports_for_init(
1980
+ cls,
1981
+ module: tuple[str, ...],
1982
+ processed_models: Sequence[
1983
+ tuple[tuple[str, ...], tuple[str, ...], Sequence[DataModel], bool, Imports, ModelResolver]
1984
+ ],
1985
+ scope: AllExportsScope,
1986
+ ) -> list[tuple[str, tuple[str, ...], str]]:
1987
+ """Collect exports for __init__.py based on scope."""
1988
+ exports: list[tuple[str, tuple[str, ...], str]] = []
1989
+ base = module[:-1] if module[-1] == "__init__.py" else module
1990
+ base_len = len(base)
1991
+
1992
+ for proc_module, _, proc_models, _, _, _ in processed_models:
1993
+ if not proc_models or proc_module == module:
1994
+ continue
1995
+ last = proc_module[-1]
1996
+ prefix = proc_module[:-1] if last == "__init__.py" else (*proc_module[:-1], last[:-3])
1997
+ if prefix[:base_len] != base or (depth := len(prefix) - base_len) < 1:
1998
+ continue
1999
+ if scope == AllExportsScope.Children and depth != 1:
2000
+ continue
2001
+ rel = prefix[base_len:]
2002
+ exports.extend(
2003
+ (ref.short_name, rel, ".".join(rel))
2004
+ for m in proc_models
2005
+ if (ref := m.reference) and not ref.short_name.startswith("_")
2006
+ )
2007
+ return exports
2008
+
2009
+ @classmethod
2010
+ def _resolve_export_collisions(
2011
+ cls,
2012
+ exports: list[tuple[str, tuple[str, ...], str]],
2013
+ strategy: AllExportsCollisionStrategy | None,
2014
+ reserved: set[str] | None = None,
2015
+ ) -> dict[str, list[tuple[str, tuple[str, ...], str]]]:
2016
+ """Resolve name collisions in exports based on strategy."""
2017
+ reserved = reserved or set()
2018
+ by_name: dict[str, list[tuple[str, tuple[str, ...], str]]] = {}
2019
+ for item in exports:
2020
+ by_name.setdefault(item[0], []).append(item)
2021
+
2022
+ if not (colliding := {n for n, items in by_name.items() if len(items) > 1 or n in reserved}):
2023
+ return dict(by_name)
2024
+ if (effective := strategy or AllExportsCollisionStrategy.Error) == AllExportsCollisionStrategy.Error:
2025
+ cls._raise_collision_error(by_name, colliding)
2026
+
2027
+ used: set[str] = {n for n in by_name if n not in colliding} | reserved
2028
+ result = {n: items for n, items in by_name.items() if n not in colliding}
2029
+
2030
+ for name in sorted(colliding):
2031
+ for item in sorted(by_name[name], key=lambda x: len(x[1])):
2032
+ new_name = cls._make_prefixed_name(
2033
+ item[0], item[1], used, minimal=effective == AllExportsCollisionStrategy.MinimalPrefix
2034
+ )
2035
+ if new_name in reserved:
2036
+ msg = (
2037
+ f"Cannot resolve collision: '{new_name}' conflicts with __init__.py model. "
2038
+ "Please rename one of the models."
2039
+ )
2040
+ raise Error(msg)
2041
+ result[new_name] = [item]
2042
+ used.add(new_name)
2043
+ return result
2044
+
2045
+ @classmethod
2046
+ def _raise_collision_error(
2047
+ cls,
2048
+ by_name: dict[str, list[tuple[str, tuple[str, ...], str]]],
2049
+ colliding: set[str],
2050
+ ) -> None:
2051
+ """Raise an error with collision details."""
2052
+ details = []
2053
+ for n in colliding:
2054
+ if len(items := by_name[n]) > 1:
2055
+ details.append(f" '{n}' is defined in: {', '.join(f'.{s}' for _, _, s in items)}")
2056
+ else:
2057
+ details.append(f" '{n}' conflicts with a model in __init__.py")
2058
+ raise Error(
2059
+ "Name collision detected with --all-exports-scope:\n"
2060
+ + "\n".join(details)
2061
+ + "\n\nUse --all-exports-collision-strategy to specify how to handle collisions."
2062
+ )
2063
+
2064
+ @staticmethod
2065
+ def _make_prefixed_name(name: str, path: tuple[str, ...], used: set[str], *, minimal: bool) -> str:
2066
+ """Generate a prefixed name, using minimal or full prefix."""
2067
+ if minimal:
2068
+ for depth in range(1, len(path) + 1):
2069
+ if (candidate := "".join(p.title().replace("_", "") for p in path[-depth:]) + name) not in used:
2070
+ return candidate
2071
+ return "".join(p.title().replace("_", "") for p in path) + name
2072
+
2073
+ @classmethod
2074
+ def _build_all_exports_code(
2075
+ cls,
2076
+ resolved: dict[str, list[tuple[str, tuple[str, ...], str]]],
2077
+ ) -> Imports:
2078
+ """Build import statements from resolved exports."""
2079
+ export_imports = Imports()
2080
+ for export_name, items in resolved.items():
2081
+ for orig, _, short in items:
2082
+ export_imports.append(
2083
+ Import(from_=f".{short}", import_=orig, alias=export_name if export_name != orig else None)
2084
+ )
2085
+ return export_imports
2086
+
2087
+ @classmethod
2088
+ def _collect_used_names_from_models(cls, models: list[DataModel]) -> set[str]:
2089
+ """Collect identifiers referenced by models before rendering."""
2090
+ names: set[str] = set()
2091
+
2092
+ def add(name: str | None) -> None:
2093
+ if not name:
2094
+ return
2095
+ # first segment is sufficient to match import target or alias
2096
+ names.add(name.split(".")[0])
2097
+
2098
+ def walk_data_type(data_type: DataType) -> None:
2099
+ add(data_type.alias or data_type.type)
2100
+ if data_type.reference:
2101
+ add(data_type.reference.short_name)
2102
+ for child in data_type.data_types:
2103
+ walk_data_type(child)
2104
+ if data_type.dict_key:
2105
+ walk_data_type(data_type.dict_key)
2106
+
2107
+ for model in models:
2108
+ add(model.class_name)
2109
+ add(model.duplicate_class_name)
2110
+ for base in model.base_classes:
2111
+ add(base.type_hint)
2112
+ for import_ in model.imports:
2113
+ add(import_.alias or import_.import_.split(".")[-1])
2114
+ for field in model.fields:
2115
+ if field.extras.get("is_classvar"):
2116
+ continue
2117
+ add(field.name)
2118
+ add(field.alias)
2119
+ walk_data_type(field.data_type)
2120
+ return names
2121
+
2122
+ def __generate_forwarder_content( # noqa: PLR6301
2123
+ self,
2124
+ original_module: tuple[str, ...],
2125
+ internal_module: tuple[str, ...],
2126
+ class_mappings: list[tuple[str, str]],
2127
+ *,
2128
+ is_init: bool = False,
2129
+ ) -> str:
2130
+ """Generate forwarder module content that re-exports classes from _internal.
2131
+
2132
+ Args:
2133
+ original_module: The original module tuple (e.g., ("issuing",) or ())
2134
+ internal_module: The _internal module tuple (e.g., ("_internal",))
2135
+ class_mappings: List of (original_name, new_name) tuples, sorted by original_name
2136
+ is_init: True if this is a package __init__.py, False for regular .py files
2137
+
2138
+ Returns:
2139
+ The forwarder module content as a string
2140
+ """
2141
+ original_str = ".".join(original_module)
2142
+ internal_str = ".".join(internal_module)
2143
+ from_dots, module_name = relative(original_str, internal_str, reference_is_module=True, current_is_init=is_init)
2144
+ relative_import = f"{from_dots}{module_name}"
2145
+
2146
+ imports = Imports()
2147
+ for original_name, new_name in class_mappings:
2148
+ if original_name == new_name:
2149
+ imports.append(Import(from_=relative_import, import_=new_name))
2150
+ else:
2151
+ imports.append(Import(from_=relative_import, import_=new_name, alias=original_name))
2152
+
2153
+ return f"{imports.dump()}\n\n{imports.dump_all()}\n"
2154
+
2155
+ def __compute_internal_module_path( # noqa: PLR6301
2156
+ self,
2157
+ scc_modules: set[tuple[str, ...]],
2158
+ existing_modules: set[tuple[str, ...]],
2159
+ *,
2160
+ base_name: str = "_internal",
2161
+ ) -> tuple[str, ...]:
2162
+ """Compute the internal module path for an SCC."""
2163
+ directories = [get_module_directory(m) for m in sorted(scc_modules)]
2164
+
2165
+ if not directories or any(not d for d in directories):
2166
+ prefix: tuple[str, ...] = ()
2167
+ else:
2168
+ path_strings = ["/".join(d) for d in directories]
2169
+ common = os.path.commonpath(path_strings)
2170
+ prefix = tuple(common.split("/")) if common else ()
2171
+
2172
+ base_module = (base_name,) if not prefix else (*prefix, base_name)
2173
+
2174
+ if base_module in existing_modules:
2175
+ counter = 1
2176
+ while True:
2177
+ candidate = (*prefix, f"{base_name}_{counter}") if prefix else (f"{base_name}_{counter}",)
2178
+ if candidate not in existing_modules:
2179
+ return candidate
2180
+ counter += 1
610
2181
 
611
- for duplicate in duplicates:
612
- models.remove(duplicate)
2182
+ return base_module
613
2183
 
614
- if self.set_default_enum_member:
2184
+ def __collect_scc_models( # noqa: PLR6301
2185
+ self,
2186
+ scc: set[tuple[str, ...]],
2187
+ result_modules: dict[tuple[str, ...], list[DataModel]],
2188
+ ) -> tuple[list[DataModel], dict[int, tuple[str, ...]]]:
2189
+ """Collect all models from SCC modules.
2190
+
2191
+ Returns:
2192
+ - List of all models in the SCC
2193
+ - Mapping from model id to its original module
2194
+ """
2195
+ all_models: list[DataModel] = []
2196
+ model_to_module: dict[int, tuple[str, ...]] = {}
2197
+ for scc_module in sorted(scc):
2198
+ for model in result_modules[scc_module]:
2199
+ all_models.append(model)
2200
+ model_to_module[id(model)] = scc_module
2201
+ return all_models, model_to_module
2202
+
2203
+ def __rename_and_relocate_scc_models( # noqa: PLR6301
2204
+ self,
2205
+ all_scc_models: list[DataModel],
2206
+ model_to_original_module: dict[int, tuple[str, ...]],
2207
+ internal_module: tuple[str, ...],
2208
+ internal_path: Path,
2209
+ ) -> tuple[defaultdict[tuple[str, ...], list[tuple[str, str]]], dict[str, str]]:
2210
+ """Rename duplicate classes and relocate models to internal module.
2211
+
2212
+ Returns:
2213
+ Tuple of:
2214
+ - Mapping from original module to list of (original_name, new_name) tuples.
2215
+ - Mapping from old reference paths to new reference paths.
2216
+ """
2217
+ class_name_counts = Counter(model.class_name for model in all_scc_models)
2218
+ class_name_seen: dict[str, int] = {}
2219
+ internal_module_str = ".".join(internal_module)
2220
+ module_class_mappings: defaultdict[tuple[str, ...], list[tuple[str, str]]] = defaultdict(list)
2221
+ path_mapping: dict[str, str] = {}
2222
+
2223
+ for model in all_scc_models:
2224
+ original_class_name = model.class_name
2225
+ original_module = model_to_original_module[id(model)]
2226
+ old_path = model.path # Save old path before updating
2227
+
2228
+ if class_name_counts[original_class_name] > 1:
2229
+ seen_count = class_name_seen.get(original_class_name, 0)
2230
+ new_class_name = f"{original_class_name}_{seen_count}" if seen_count > 0 else original_class_name
2231
+ class_name_seen[original_class_name] = seen_count + 1
2232
+ else:
2233
+ new_class_name = original_class_name
2234
+
2235
+ model.reference.name = new_class_name
2236
+ new_path = f"{internal_module_str}.{new_class_name}"
2237
+ model.set_reference_path(new_path)
2238
+ model.file_path = internal_path
2239
+
2240
+ module_class_mappings[original_module].append((original_class_name, new_class_name))
2241
+ path_mapping[old_path] = new_path
2242
+
2243
+ return module_class_mappings, path_mapping
2244
+
2245
+ def __build_module_dependency_graph( # noqa: PLR6301
2246
+ self,
2247
+ module_models_list: list[tuple[tuple[str, ...], list[DataModel]]],
2248
+ ) -> dict[tuple[str, ...], set[tuple[str, ...]]]:
2249
+ """Build a directed graph of module dependencies."""
2250
+ path_to_module: dict[str, tuple[str, ...]] = {}
2251
+ for module, models in module_models_list:
2252
+ for model in models:
2253
+ path_to_module[model.path] = module
2254
+
2255
+ graph: dict[tuple[str, ...], set[tuple[str, ...]]] = {}
2256
+
2257
+ def add_cross_module_edge(ref_path: str, source_module: tuple[str, ...]) -> None:
2258
+ """Add edge if ref_path points to a different module."""
2259
+ if ref_path in path_to_module:
2260
+ target_module = path_to_module[ref_path]
2261
+ if target_module != source_module:
2262
+ graph[source_module].add(target_module)
2263
+
2264
+ for module, models in module_models_list:
2265
+ graph[module] = set()
2266
+
2267
+ for model in models:
2268
+ for data_type in model.all_data_types:
2269
+ if data_type.reference and data_type.reference.source:
2270
+ add_cross_module_edge(data_type.reference.path, module)
2271
+
2272
+ for base_class in model.base_classes:
2273
+ if base_class.reference and base_class.reference.source:
2274
+ add_cross_module_edge(base_class.reference.path, module)
2275
+
2276
+ return graph
2277
+
2278
+ def __resolve_circular_imports( # noqa: PLR0914
2279
+ self,
2280
+ module_models_list: list[tuple[tuple[str, ...], list[DataModel]]],
2281
+ ) -> tuple[
2282
+ list[tuple[tuple[str, ...], list[DataModel]]],
2283
+ set[tuple[str, ...]],
2284
+ dict[tuple[str, ...], tuple[tuple[str, ...], list[tuple[str, str]]]],
2285
+ dict[str, str],
2286
+ ]:
2287
+ """Resolve circular imports by merging all SCCs into _internal.py modules.
2288
+
2289
+ Uses Tarjan's algorithm to find strongly connected components (SCCs) in the
2290
+ module dependency graph. All modules in each SCC are merged into a single
2291
+ _internal.py module to break import cycles. Original modules become thin
2292
+ forwarders that re-export their classes from _internal.
2293
+
2294
+ Returns:
2295
+ - Updated module_models_list with models moved to _internal modules
2296
+ - Set of _internal modules created
2297
+ - Forwarder map: original_module -> (internal_module, [(original_name, new_name)])
2298
+ - Path mapping: old_reference_path -> new_reference_path
2299
+ """
2300
+ graph = self.__build_module_dependency_graph(module_models_list)
2301
+
2302
+ circular_sccs = find_circular_sccs(graph)
2303
+
2304
+ forwarder_map: dict[tuple[str, ...], tuple[tuple[str, ...], list[tuple[str, str]]]] = {}
2305
+ all_path_mappings: dict[str, str] = {}
2306
+
2307
+ if not circular_sccs:
2308
+ return module_models_list, set(), forwarder_map, all_path_mappings
2309
+
2310
+ # All circular SCCs are problematic and should be merged into _internal.py
2311
+ # to break the import cycles.
2312
+ problematic_sccs = circular_sccs
2313
+
2314
+ existing_modules = {module for module, _ in module_models_list}
2315
+ internal_modules_created: set[tuple[str, ...]] = set()
2316
+
2317
+ result_modules: dict[tuple[str, ...], list[DataModel]] = {
2318
+ module: list(models) for module, models in module_models_list
2319
+ }
2320
+
2321
+ for scc in problematic_sccs:
2322
+ internal_module = self.__compute_internal_module_path(scc, existing_modules | internal_modules_created)
2323
+ internal_modules_created.add(internal_module)
2324
+ internal_path = Path("/".join(internal_module))
2325
+
2326
+ all_scc_models, model_to_original_module = self.__collect_scc_models(scc, result_modules)
2327
+ module_class_mappings, path_mapping = self.__rename_and_relocate_scc_models(
2328
+ all_scc_models, model_to_original_module, internal_module, internal_path
2329
+ )
2330
+ all_path_mappings.update(path_mapping)
2331
+
2332
+ for scc_module in scc:
2333
+ if scc_module in result_modules: # pragma: no branch
2334
+ result_modules[scc_module] = []
2335
+ if scc_module in module_class_mappings: # pragma: no branch
2336
+ sorted_mappings = sorted(module_class_mappings[scc_module], key=operator.itemgetter(0))
2337
+ forwarder_map[scc_module] = (internal_module, sorted_mappings)
2338
+ result_modules[internal_module] = all_scc_models
2339
+
2340
+ new_module_models: list[tuple[tuple[str, ...], list[DataModel]]] = [
2341
+ (internal_module, result_modules[internal_module])
2342
+ for internal_module in sorted(internal_modules_created)
2343
+ if internal_module in result_modules # pragma: no branch
2344
+ ]
2345
+
2346
+ for module, _ in module_models_list:
2347
+ if module not in internal_modules_created: # pragma: no branch
2348
+ new_module_models.append((module, result_modules.get(module, [])))
2349
+
2350
+ return new_module_models, internal_modules_created, forwarder_map, all_path_mappings
2351
+
2352
+ def __get_resolve_reference_action_parts(
2353
+ self,
2354
+ models: list[DataModel],
2355
+ require_update_action_models: list[str],
2356
+ *,
2357
+ use_deferred_annotations: bool,
2358
+ ) -> list[str]:
2359
+ """Return the trailing rebuild/update calls for the given module's models."""
2360
+ if self.dump_resolve_reference_action is None:
2361
+ return []
2362
+
2363
+ require_update_action_model_paths = set(require_update_action_models)
2364
+ required_paths_in_module = {m.path for m in models if m.path in require_update_action_model_paths}
2365
+
2366
+ if (
2367
+ use_deferred_annotations
2368
+ and required_paths_in_module
2369
+ and self.dump_resolve_reference_action is pydantic_model_v2.dump_resolve_reference_action
2370
+ ):
2371
+ module_positions = {m.reference.short_name: i for i, m in enumerate(models) if m.reference}
2372
+ module_model_names = set(module_positions)
2373
+
2374
+ forward_needed: set[str] = set()
2375
+ for model in models:
2376
+ if model.path not in required_paths_in_module or not model.reference:
2377
+ continue
2378
+ name = model.reference.short_name
2379
+ pos = module_positions[name]
2380
+ refs = {
2381
+ t.reference.short_name
2382
+ for f in model.fields
2383
+ for t in f.data_type.all_data_types
2384
+ if t.reference and t.reference.short_name in module_model_names
2385
+ }
2386
+ if name in refs or any(module_positions.get(r, -1) > pos for r in refs):
2387
+ forward_needed.add(model.path)
2388
+
2389
+ # Propagate requirement through inheritance.
2390
+ changed = True
2391
+ required_filtered = set(forward_needed)
2392
+ while changed:
2393
+ changed = False
615
2394
  for model in models:
616
- for model_field in model.fields:
617
- if not model_field.default:
618
- continue
619
- for data_type in model_field.data_type.all_data_types:
620
- if data_type.reference and isinstance(
621
- data_type.reference.source, Enum
622
- ): # pragma: no cover
623
- enum_member = data_type.reference.source.find_member(
624
- model_field.default
625
- )
626
- if enum_member:
627
- model_field.default = enum_member
628
- if with_import:
629
- result += [str(self.imports), str(imports), '\n']
630
-
631
- code = dump_templates(models)
632
- result += [code]
633
-
634
- if self.dump_resolve_reference_action is not None:
635
- result += [
636
- '\n',
637
- self.dump_resolve_reference_action(
2395
+ if not model.reference or model.path in required_filtered:
2396
+ continue
2397
+ base_paths = {b.reference.path for b in model.base_classes if b.reference}
2398
+ if base_paths & required_filtered:
2399
+ required_filtered.add(model.path)
2400
+ changed = True
2401
+
2402
+ required_paths_in_module = required_filtered
2403
+
2404
+ return [
2405
+ "\n",
2406
+ self.dump_resolve_reference_action(
2407
+ m.reference.short_name for m in models if m.reference and m.path in required_paths_in_module
2408
+ ),
2409
+ ]
2410
+
2411
+ def parse( # noqa: PLR0912, PLR0913, PLR0914, PLR0915, PLR0917
2412
+ self,
2413
+ with_import: bool | None = True, # noqa: FBT001, FBT002
2414
+ format_: bool | None = True, # noqa: FBT001, FBT002
2415
+ settings_path: Path | None = None,
2416
+ disable_future_imports: bool = False, # noqa: FBT001, FBT002
2417
+ all_exports_scope: AllExportsScope | None = None,
2418
+ all_exports_collision_strategy: AllExportsCollisionStrategy | None = None,
2419
+ module_split_mode: ModuleSplitMode | None = None,
2420
+ ) -> str | dict[tuple[str, ...], Result]:
2421
+ """Parse schema and generate code, returning single file or module dict."""
2422
+ self.parse_raw()
2423
+
2424
+ use_deferred_annotations = bool(
2425
+ self.target_python_version.has_native_deferred_annotations or (with_import and not disable_future_imports)
2426
+ )
2427
+
2428
+ if (
2429
+ with_import
2430
+ and not disable_future_imports
2431
+ and not self.target_python_version.has_native_deferred_annotations
2432
+ ):
2433
+ self.imports.append(IMPORT_ANNOTATIONS)
2434
+
2435
+ if format_:
2436
+ code_formatter: CodeFormatter | None = CodeFormatter(
2437
+ self.target_python_version,
2438
+ settings_path,
2439
+ self.wrap_string_literal,
2440
+ skip_string_normalization=not self.use_double_quotes,
2441
+ known_third_party=self.known_third_party,
2442
+ custom_formatters=self.custom_formatter,
2443
+ custom_formatters_kwargs=self.custom_formatters_kwargs,
2444
+ encoding=self.encoding,
2445
+ formatters=self.formatters,
2446
+ )
2447
+ else:
2448
+ code_formatter = None
2449
+
2450
+ _, sorted_data_models, require_update_action_models = sort_data_models(self.results)
2451
+
2452
+ results: dict[tuple[str, ...], Result] = {}
2453
+
2454
+ def module_key(data_model: DataModel) -> tuple[str, ...]:
2455
+ if module_split_mode == ModuleSplitMode.Single:
2456
+ file_name = camel_to_snake(data_model.class_name)
2457
+ return (*data_model.module_path, file_name)
2458
+ return tuple(data_model.module_path)
2459
+
2460
+ def sort_key(data_model: DataModel) -> tuple[int, tuple[str, ...]]:
2461
+ key = module_key(data_model)
2462
+ return (len(key), key)
2463
+
2464
+ # process in reverse order to correctly establish module levels
2465
+ grouped_models = groupby(
2466
+ sorted(sorted_data_models.values(), key=sort_key, reverse=True),
2467
+ key=module_key,
2468
+ )
2469
+
2470
+ module_models: list[tuple[tuple[str, ...], list[DataModel]]] = []
2471
+ unused_models: list[DataModel] = []
2472
+ model_to_module_models: dict[DataModel, tuple[tuple[str, ...], list[DataModel]]] = {}
2473
+ module_to_import: dict[tuple[str, ...], Imports] = {}
2474
+ model_path_to_module_name: dict[str, str] = {}
2475
+
2476
+ previous_module: tuple[str, ...] = ()
2477
+ for module, models in ((k, [*v]) for k, v in grouped_models):
2478
+ for model in models:
2479
+ model_to_module_models[model] = module, models
2480
+ if module_split_mode == ModuleSplitMode.Single:
2481
+ model_path_to_module_name[model.path] = ".".join(module)
2482
+ self.__delete_duplicate_models(models)
2483
+ self.__replace_duplicate_name_in_module(models)
2484
+ if len(previous_module) - len(module) > 1:
2485
+ module_models.extend(
2486
+ (
2487
+ previous_module[:parts],
2488
+ [],
2489
+ )
2490
+ for parts in range(len(previous_module) - 1, len(module), -1)
2491
+ )
2492
+ module_models.append((
2493
+ module,
2494
+ models,
2495
+ ))
2496
+ previous_module = module
2497
+
2498
+ shared_module_entry = self.__reuse_model_tree_scope(module_models, require_update_action_models)
2499
+ if shared_module_entry:
2500
+ module_models.insert(0, shared_module_entry)
2501
+
2502
+ # Resolve circular imports by moving models to _internal.py modules
2503
+ module_models, internal_modules, forwarder_map, path_mapping = self.__resolve_circular_imports(module_models)
2504
+
2505
+ # Update require_update_action_models with new paths for relocated models
2506
+ if path_mapping:
2507
+ require_update_action_models[:] = [path_mapping.get(path, path) for path in require_update_action_models]
2508
+
2509
+ class Processed(NamedTuple):
2510
+ module: tuple[str, ...]
2511
+ module_key: tuple[str, ...] # Original module tuple (without file extension)
2512
+ models: list[DataModel]
2513
+ init: bool
2514
+ imports: Imports
2515
+ scoped_model_resolver: ModelResolver
2516
+
2517
+ processed_models: list[Processed] = []
2518
+
2519
+ for module_, models in module_models:
2520
+ imports = module_to_import[module_] = Imports(self.use_exact_imports)
2521
+ init = False
2522
+ if module_:
2523
+ if len(module_) == 1:
2524
+ parent = ("__init__.py",)
2525
+ if parent not in results:
2526
+ results[parent] = Result(body="")
2527
+ else:
2528
+ for i in range(1, len(module_)):
2529
+ parent = (*module_[:i], "__init__.py")
2530
+ if parent not in results:
2531
+ results[parent] = Result(body="")
2532
+ if (*module_, "__init__.py") in results:
2533
+ module = (*module_, "__init__.py")
2534
+ init = True
2535
+ else:
2536
+ module = tuple(part.replace("-", "_") for part in (*module_[:-1], f"{module_[-1]}.py"))
2537
+ else:
2538
+ module = ("__init__.py",)
2539
+
2540
+ all_module_fields = {field.name for model in models for field in model.fields if field.name is not None}
2541
+ scoped_model_resolver = ModelResolver(exclude_names=all_module_fields)
2542
+
2543
+ self.__alias_shadowed_imports(models, all_module_fields)
2544
+ self.__override_required_field(models)
2545
+ self.__replace_unique_list_to_set(models)
2546
+ self.__change_from_import(
2547
+ models,
2548
+ imports,
2549
+ scoped_model_resolver,
2550
+ init=init,
2551
+ internal_modules=internal_modules,
2552
+ model_path_to_module_name=model_path_to_module_name,
2553
+ )
2554
+ self.__extract_inherited_enum(models)
2555
+ self.__set_reference_default_value_to_field(models)
2556
+ self.__reuse_model(models, require_update_action_models)
2557
+ self.__collapse_root_models(models, unused_models, imports, scoped_model_resolver)
2558
+ self.__set_default_enum_member(models)
2559
+ self.__wrap_root_model_default_values(models)
2560
+ self.__sort_models(
2561
+ models,
2562
+ imports,
2563
+ use_deferred_annotations=bool(
2564
+ self.target_python_version.has_native_deferred_annotations
2565
+ or (with_import and not disable_future_imports)
2566
+ ),
2567
+ )
2568
+ self.__change_field_name(models)
2569
+ self.__apply_discriminator_type(models, imports)
2570
+ self.__set_one_literal_on_default(models)
2571
+ self.__fix_dataclass_field_ordering(models)
2572
+
2573
+ processed_models.append(Processed(module, module_, models, init, imports, scoped_model_resolver))
2574
+
2575
+ for processed_model in processed_models:
2576
+ for model in processed_model.models:
2577
+ processed_model.imports.append(model.imports)
2578
+
2579
+ for unused_model in unused_models:
2580
+ module, models = model_to_module_models[unused_model]
2581
+ if unused_model in models: # pragma: no cover
2582
+ imports = module_to_import[module]
2583
+ imports.remove(unused_model.imports)
2584
+ models.remove(unused_model)
2585
+
2586
+ for processed_model in processed_models:
2587
+ # postprocess imports to remove unused imports.
2588
+ used_names = self._collect_used_names_from_models(processed_model.models)
2589
+ unused_imports = [
2590
+ (from_, import_)
2591
+ for from_, imports_ in processed_model.imports.items()
2592
+ for import_ in imports_
2593
+ if not {processed_model.imports.alias.get(from_, {}).get(import_, import_), import_}.intersection(
2594
+ used_names
2595
+ )
2596
+ ]
2597
+ for from_, import_ in unused_imports:
2598
+ import_obj = Import(from_=from_, import_=import_)
2599
+ while processed_model.imports.counter.get((from_, import_), 0) > 0:
2600
+ processed_model.imports.remove(import_obj)
2601
+
2602
+ for module, mod_key, models, init, imports, scoped_model_resolver in processed_models: # noqa: B007
2603
+ # process after removing unused models
2604
+ self.__change_imported_model_name(models, imports, scoped_model_resolver)
2605
+
2606
+ future_imports = self.imports.extract_future()
2607
+ future_imports_str = str(future_imports)
2608
+
2609
+ for module, mod_key, models, init, imports, scoped_model_resolver in processed_models: # noqa: B007
2610
+ result: list[str] = []
2611
+ export_imports: Imports | None = None
2612
+
2613
+ if all_exports_scope is not None and module[-1] == "__init__.py":
2614
+ child_exports = self._collect_exports_for_init(module, processed_models, all_exports_scope)
2615
+ if child_exports:
2616
+ local_model_names = {
638
2617
  m.reference.short_name
639
2618
  for m in models
640
- if m.path in require_update_action_models
641
- ),
642
- ]
2619
+ if m.reference and not m.reference.short_name.startswith("_")
2620
+ }
2621
+ resolved_exports = self._resolve_export_collisions(
2622
+ child_exports, all_exports_collision_strategy, local_model_names
2623
+ )
2624
+ export_imports = self._build_all_exports_code(resolved_exports)
2625
+
2626
+ if models:
2627
+ if with_import:
2628
+ import_parts = [s for s in [future_imports_str, str(self.imports), str(imports)] if s]
2629
+ result += [*import_parts, "\n"]
643
2630
 
644
- body = '\n'.join(result)
2631
+ if export_imports:
2632
+ result += [str(export_imports), ""]
2633
+ for m in models:
2634
+ if m.reference and not m.reference.short_name.startswith("_"): # pragma: no branch
2635
+ export_imports.add_export(m.reference.short_name)
2636
+ result += [export_imports.dump_all(multiline=True) + "\n"]
2637
+
2638
+ self.__update_type_aliases(models)
2639
+ code = dump_templates(models)
2640
+ result += [code]
2641
+
2642
+ result += self.__get_resolve_reference_action_parts(
2643
+ models,
2644
+ require_update_action_models,
2645
+ use_deferred_annotations=use_deferred_annotations,
2646
+ )
2647
+
2648
+ # Generate forwarder content for modules that had models moved to _internal
2649
+ if not result and mod_key in forwarder_map:
2650
+ internal_module, class_mappings = forwarder_map[mod_key]
2651
+ forwarder_content = self.__generate_forwarder_content(
2652
+ mod_key, internal_module, class_mappings, is_init=init
2653
+ )
2654
+ result = [forwarder_content]
2655
+
2656
+ if not result and not init:
2657
+ continue
2658
+ body = "\n".join(result)
645
2659
  if code_formatter:
646
2660
  body = code_formatter.format_code(body)
647
2661
 
648
- results[module] = Result(body=body, source=models[0].file_path)
2662
+ results[module] = Result(
2663
+ body=body,
2664
+ future_imports=future_imports_str,
2665
+ source=models[0].file_path if models else None,
2666
+ )
2667
+
2668
+ if all_exports_scope is not None:
2669
+ processed_init_modules = {m for m, _, _, _, _, _ in processed_models if m[-1] == "__init__.py"}
2670
+ for init_module, init_result in list(results.items()):
2671
+ if init_module[-1] != "__init__.py" or init_module in processed_init_modules or init_result.body:
2672
+ continue
2673
+ if child_exports := self._collect_exports_for_init(
2674
+ init_module, processed_models, all_exports_scope
2675
+ ): # pragma: no branch
2676
+ resolved = self._resolve_export_collisions(child_exports, all_exports_collision_strategy, set())
2677
+ export_imports = self._build_all_exports_code(resolved)
2678
+ import_parts = [s for s in [future_imports_str, str(self.imports)] if s] if with_import else []
2679
+ parts = import_parts + (["\n"] if import_parts else [])
2680
+ parts += [str(export_imports), "", export_imports.dump_all(multiline=True)]
2681
+ body = "\n".join(parts)
2682
+ results[init_module] = Result(
2683
+ body=code_formatter.format_code(body) if code_formatter else body,
2684
+ future_imports=future_imports_str,
2685
+ )
649
2686
 
650
2687
  # retain existing behaviour
651
- if [*results] == [('__init__.py',)]:
652
- return results[('__init__.py',)].body
653
-
654
- return results
2688
+ if [*results] == [("__init__.py",)]:
2689
+ single_result = results["__init__.py",]
2690
+ return single_result.body
2691
+
2692
+ results = {tuple(i.replace("-", "_") for i in k): v for k, v in results.items()}
2693
+ return (
2694
+ self.__postprocess_result_modules(results)
2695
+ if self.treat_dot_as_module
2696
+ else {
2697
+ tuple((part[: part.rfind(".")].replace(".", "_") + part[part.rfind(".") :]) for part in k): v
2698
+ for k, v in results.items()
2699
+ }
2700
+ )