datamodel-code-generator 0.11.12__py3-none-any.whl → 0.45.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- datamodel_code_generator/__init__.py +654 -185
- datamodel_code_generator/__main__.py +872 -388
- datamodel_code_generator/arguments.py +798 -0
- datamodel_code_generator/cli_options.py +295 -0
- datamodel_code_generator/format.py +292 -54
- datamodel_code_generator/http.py +85 -10
- datamodel_code_generator/imports.py +152 -43
- datamodel_code_generator/model/__init__.py +138 -1
- datamodel_code_generator/model/base.py +531 -120
- datamodel_code_generator/model/dataclass.py +211 -0
- datamodel_code_generator/model/enum.py +133 -12
- datamodel_code_generator/model/imports.py +22 -0
- datamodel_code_generator/model/msgspec.py +462 -0
- datamodel_code_generator/model/pydantic/__init__.py +30 -25
- datamodel_code_generator/model/pydantic/base_model.py +304 -100
- datamodel_code_generator/model/pydantic/custom_root_type.py +11 -2
- datamodel_code_generator/model/pydantic/dataclass.py +15 -4
- datamodel_code_generator/model/pydantic/imports.py +40 -27
- datamodel_code_generator/model/pydantic/types.py +188 -96
- datamodel_code_generator/model/pydantic_v2/__init__.py +51 -0
- datamodel_code_generator/model/pydantic_v2/base_model.py +268 -0
- datamodel_code_generator/model/pydantic_v2/imports.py +15 -0
- datamodel_code_generator/model/pydantic_v2/root_model.py +35 -0
- datamodel_code_generator/model/pydantic_v2/types.py +143 -0
- datamodel_code_generator/model/scalar.py +124 -0
- datamodel_code_generator/model/template/Enum.jinja2 +15 -2
- datamodel_code_generator/model/template/ScalarTypeAliasAnnotation.jinja2 +6 -0
- datamodel_code_generator/model/template/ScalarTypeAliasType.jinja2 +6 -0
- datamodel_code_generator/model/template/ScalarTypeStatement.jinja2 +6 -0
- datamodel_code_generator/model/template/TypeAliasAnnotation.jinja2 +20 -0
- datamodel_code_generator/model/template/TypeAliasType.jinja2 +20 -0
- datamodel_code_generator/model/template/TypeStatement.jinja2 +20 -0
- datamodel_code_generator/model/template/TypedDict.jinja2 +5 -0
- datamodel_code_generator/model/template/TypedDictClass.jinja2 +25 -0
- datamodel_code_generator/model/template/TypedDictFunction.jinja2 +24 -0
- datamodel_code_generator/model/template/UnionTypeAliasAnnotation.jinja2 +10 -0
- datamodel_code_generator/model/template/UnionTypeAliasType.jinja2 +10 -0
- datamodel_code_generator/model/template/UnionTypeStatement.jinja2 +10 -0
- datamodel_code_generator/model/template/dataclass.jinja2 +50 -0
- datamodel_code_generator/model/template/msgspec.jinja2 +55 -0
- datamodel_code_generator/model/template/pydantic/BaseModel.jinja2 +17 -4
- datamodel_code_generator/model/template/pydantic/BaseModel_root.jinja2 +12 -4
- datamodel_code_generator/model/template/pydantic/Config.jinja2 +1 -1
- datamodel_code_generator/model/template/pydantic/dataclass.jinja2 +15 -2
- datamodel_code_generator/model/template/pydantic_v2/BaseModel.jinja2 +57 -0
- datamodel_code_generator/model/template/pydantic_v2/ConfigDict.jinja2 +5 -0
- datamodel_code_generator/model/template/pydantic_v2/RootModel.jinja2 +48 -0
- datamodel_code_generator/model/type_alias.py +70 -0
- datamodel_code_generator/model/typed_dict.py +161 -0
- datamodel_code_generator/model/types.py +106 -0
- datamodel_code_generator/model/union.py +105 -0
- datamodel_code_generator/parser/__init__.py +30 -12
- datamodel_code_generator/parser/_graph.py +67 -0
- datamodel_code_generator/parser/_scc.py +171 -0
- datamodel_code_generator/parser/base.py +2426 -380
- datamodel_code_generator/parser/graphql.py +652 -0
- datamodel_code_generator/parser/jsonschema.py +2518 -647
- datamodel_code_generator/parser/openapi.py +631 -222
- datamodel_code_generator/py.typed +0 -0
- datamodel_code_generator/pydantic_patch.py +28 -0
- datamodel_code_generator/reference.py +672 -290
- datamodel_code_generator/types.py +521 -145
- datamodel_code_generator/util.py +155 -0
- datamodel_code_generator/watch.py +65 -0
- datamodel_code_generator-0.45.0.dist-info/METADATA +301 -0
- datamodel_code_generator-0.45.0.dist-info/RECORD +69 -0
- {datamodel_code_generator-0.11.12.dist-info → datamodel_code_generator-0.45.0.dist-info}/WHEEL +1 -1
- datamodel_code_generator-0.45.0.dist-info/entry_points.txt +2 -0
- datamodel_code_generator/version.py +0 -1
- datamodel_code_generator-0.11.12.dist-info/METADATA +0 -440
- datamodel_code_generator-0.11.12.dist-info/RECORD +0 -31
- datamodel_code_generator-0.11.12.dist-info/entry_points.txt +0 -3
- {datamodel_code_generator-0.11.12.dist-info → datamodel_code_generator-0.45.0.dist-info/licenses}/LICENSE +0 -0
|
@@ -1,60 +1,289 @@
|
|
|
1
|
+
"""Abstract base parser and utilities for schema parsing.
|
|
2
|
+
|
|
3
|
+
Provides the Parser abstract base class that defines the parsing algorithm,
|
|
4
|
+
along with helper functions for model sorting, import resolution, and
|
|
5
|
+
code generation.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import operator
|
|
11
|
+
import os.path
|
|
1
12
|
import re
|
|
13
|
+
import sys
|
|
2
14
|
from abc import ABC, abstractmethod
|
|
3
|
-
from collections import OrderedDict, defaultdict
|
|
15
|
+
from collections import Counter, OrderedDict, defaultdict
|
|
16
|
+
from collections.abc import Hashable, Sequence
|
|
4
17
|
from itertools import groupby
|
|
5
18
|
from pathlib import Path
|
|
6
|
-
from typing import
|
|
7
|
-
Any,
|
|
8
|
-
Callable,
|
|
9
|
-
DefaultDict,
|
|
10
|
-
Dict,
|
|
11
|
-
Iterable,
|
|
12
|
-
Iterator,
|
|
13
|
-
List,
|
|
14
|
-
Mapping,
|
|
15
|
-
Optional,
|
|
16
|
-
Sequence,
|
|
17
|
-
Set,
|
|
18
|
-
Tuple,
|
|
19
|
-
Type,
|
|
20
|
-
Union,
|
|
21
|
-
)
|
|
19
|
+
from typing import TYPE_CHECKING, Any, Callable, NamedTuple, Optional, Protocol, TypeVar, cast, runtime_checkable
|
|
22
20
|
from urllib.parse import ParseResult
|
|
21
|
+
from warnings import warn
|
|
23
22
|
|
|
24
23
|
from pydantic import BaseModel
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
from datamodel_code_generator
|
|
28
|
-
|
|
24
|
+
from typing_extensions import TypeAlias
|
|
25
|
+
|
|
26
|
+
from datamodel_code_generator import (
|
|
27
|
+
DEFAULT_SHARED_MODULE_NAME,
|
|
28
|
+
AllExportsCollisionStrategy,
|
|
29
|
+
AllExportsScope,
|
|
30
|
+
AllOfMergeMode,
|
|
31
|
+
Error,
|
|
32
|
+
ModuleSplitMode,
|
|
33
|
+
ReadOnlyWriteOnlyModelType,
|
|
34
|
+
ReuseScope,
|
|
35
|
+
)
|
|
36
|
+
from datamodel_code_generator.format import (
|
|
37
|
+
DEFAULT_FORMATTERS,
|
|
38
|
+
CodeFormatter,
|
|
39
|
+
DatetimeClassType,
|
|
40
|
+
Formatter,
|
|
41
|
+
PythonVersion,
|
|
42
|
+
PythonVersionMin,
|
|
43
|
+
)
|
|
44
|
+
from datamodel_code_generator.imports import (
|
|
45
|
+
IMPORT_ANNOTATIONS,
|
|
46
|
+
IMPORT_LITERAL,
|
|
47
|
+
IMPORT_OPTIONAL,
|
|
48
|
+
IMPORT_UNION,
|
|
49
|
+
Import,
|
|
50
|
+
Imports,
|
|
51
|
+
)
|
|
52
|
+
from datamodel_code_generator.model import dataclass as dataclass_model
|
|
53
|
+
from datamodel_code_generator.model import msgspec as msgspec_model
|
|
29
54
|
from datamodel_code_generator.model import pydantic as pydantic_model
|
|
55
|
+
from datamodel_code_generator.model import pydantic_v2 as pydantic_model_v2
|
|
30
56
|
from datamodel_code_generator.model.base import (
|
|
31
57
|
ALL_MODEL,
|
|
58
|
+
UNDEFINED,
|
|
32
59
|
BaseClassDataType,
|
|
60
|
+
ConstraintsBase,
|
|
33
61
|
DataModel,
|
|
34
62
|
DataModelFieldBase,
|
|
63
|
+
WrappedDefault,
|
|
35
64
|
)
|
|
36
|
-
from datamodel_code_generator.model.enum import Enum
|
|
65
|
+
from datamodel_code_generator.model.enum import Enum, Member
|
|
66
|
+
from datamodel_code_generator.model.type_alias import TypeAliasBase, TypeStatement
|
|
37
67
|
from datamodel_code_generator.parser import DefaultPutDict, LiteralType
|
|
38
|
-
from datamodel_code_generator.
|
|
68
|
+
from datamodel_code_generator.parser._graph import stable_toposort
|
|
69
|
+
from datamodel_code_generator.parser._scc import find_circular_sccs, strongly_connected_components
|
|
70
|
+
from datamodel_code_generator.reference import ModelResolver, ModelType, Reference
|
|
39
71
|
from datamodel_code_generator.types import DataType, DataTypeManager, StrictTypes
|
|
72
|
+
from datamodel_code_generator.util import camel_to_snake
|
|
40
73
|
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
"\\": r"\\",
|
|
44
|
-
"'": r"\'",
|
|
45
|
-
'\b': r'\b',
|
|
46
|
-
'\f': r'\f',
|
|
47
|
-
'\n': r'\n',
|
|
48
|
-
'\r': r'\r',
|
|
49
|
-
'\t': r'\t',
|
|
50
|
-
}
|
|
51
|
-
)
|
|
74
|
+
if TYPE_CHECKING:
|
|
75
|
+
from collections.abc import Iterable, Iterator, Mapping, Sequence
|
|
52
76
|
|
|
77
|
+
from datamodel_code_generator import DataclassArguments
|
|
53
78
|
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
79
|
+
|
|
80
|
+
@runtime_checkable
|
|
81
|
+
class HashableComparable(Hashable, Protocol):
|
|
82
|
+
"""Protocol for types that are both hashable and support comparison."""
|
|
83
|
+
|
|
84
|
+
def __lt__(self, value: Any, /) -> bool: ... # noqa: D105
|
|
85
|
+
def __le__(self, value: Any, /) -> bool: ... # noqa: D105
|
|
86
|
+
def __gt__(self, value: Any, /) -> bool: ... # noqa: D105
|
|
87
|
+
def __ge__(self, value: Any, /) -> bool: ... # noqa: D105
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
ModelName: TypeAlias = str
|
|
91
|
+
ModelNames: TypeAlias = set[ModelName]
|
|
92
|
+
ModelDeps: TypeAlias = dict[ModelName, set[ModelName]]
|
|
93
|
+
OrderIndex: TypeAlias = dict[ModelName, int]
|
|
94
|
+
|
|
95
|
+
ComponentId: TypeAlias = int
|
|
96
|
+
Components: TypeAlias = list[list[ModelName]]
|
|
97
|
+
ComponentOf: TypeAlias = dict[ModelName, ComponentId]
|
|
98
|
+
ComponentEdges: TypeAlias = dict[ComponentId, set[ComponentId]]
|
|
99
|
+
|
|
100
|
+
ClassNode: TypeAlias = tuple[ModelName, ...]
|
|
101
|
+
ClassGraph: TypeAlias = dict[ClassNode, set[ClassNode]]
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
class _KeepModelOrderDeps(NamedTuple):
|
|
105
|
+
strong: ModelDeps
|
|
106
|
+
all: ModelDeps
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
class _KeepModelOrderComponents(NamedTuple):
|
|
110
|
+
components: Components
|
|
111
|
+
comp_of: ComponentOf
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def _collect_keep_model_order_deps(
|
|
115
|
+
model: DataModel,
|
|
116
|
+
*,
|
|
117
|
+
model_names: ModelNames,
|
|
118
|
+
imported: ModelNames,
|
|
119
|
+
use_deferred_annotations: bool,
|
|
120
|
+
) -> tuple[set[ModelName], set[ModelName]]:
|
|
121
|
+
"""Collect (strong_deps, all_deps) used by keep_model_order sorting.
|
|
122
|
+
|
|
123
|
+
- strong_deps: base class references (within-module, non-imported)
|
|
124
|
+
- all_deps: base class refs + (optionally) field refs (within-module, non-imported)
|
|
125
|
+
"""
|
|
126
|
+
class_name = model.class_name
|
|
127
|
+
base_class_refs = {b.reference.short_name for b in model.base_classes if b.reference}
|
|
128
|
+
field_refs = {t.reference.short_name for f in model.fields for t in f.data_type.all_data_types if t.reference}
|
|
129
|
+
|
|
130
|
+
if use_deferred_annotations and not isinstance(model, (TypeAliasBase, pydantic_model_v2.RootModel)):
|
|
131
|
+
field_refs = set()
|
|
132
|
+
|
|
133
|
+
strong = {r for r in base_class_refs if r in model_names and r not in imported and r != class_name}
|
|
134
|
+
deps = {r for r in (base_class_refs | field_refs) if r in model_names and r not in imported and r != class_name}
|
|
135
|
+
return strong, deps
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def _build_keep_model_order_dependency_maps(
|
|
139
|
+
models: list[DataModel],
|
|
140
|
+
*,
|
|
141
|
+
model_names: ModelNames,
|
|
142
|
+
imported: ModelNames,
|
|
143
|
+
use_deferred_annotations: bool,
|
|
144
|
+
) -> _KeepModelOrderDeps:
|
|
145
|
+
strong_deps: ModelDeps = {}
|
|
146
|
+
all_deps: ModelDeps = {}
|
|
147
|
+
for model in models:
|
|
148
|
+
strong, deps = _collect_keep_model_order_deps(
|
|
149
|
+
model,
|
|
150
|
+
model_names=model_names,
|
|
151
|
+
imported=imported,
|
|
152
|
+
use_deferred_annotations=use_deferred_annotations,
|
|
153
|
+
)
|
|
154
|
+
strong_deps[model.class_name] = strong
|
|
155
|
+
all_deps[model.class_name] = deps
|
|
156
|
+
return _KeepModelOrderDeps(strong=strong_deps, all=all_deps)
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def _build_keep_model_order_components(
|
|
160
|
+
all_deps: ModelDeps,
|
|
161
|
+
order_index: OrderIndex,
|
|
162
|
+
) -> _KeepModelOrderComponents:
|
|
163
|
+
graph: ClassGraph = {(name,): {(dep,) for dep in deps} for name, deps in all_deps.items()}
|
|
164
|
+
sccs = strongly_connected_components(graph)
|
|
165
|
+
components: Components = [sorted((node[0] for node in scc), key=order_index.__getitem__) for scc in sccs]
|
|
166
|
+
components.sort(key=lambda members: min(order_index[n] for n in members))
|
|
167
|
+
comp_of: ComponentOf = {name: i for i, members in enumerate(components) for name in members}
|
|
168
|
+
return _KeepModelOrderComponents(components=components, comp_of=comp_of)
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
def _build_keep_model_order_component_edges(
|
|
172
|
+
all_deps: ModelDeps,
|
|
173
|
+
comp_of: ComponentOf,
|
|
174
|
+
num_components: int,
|
|
175
|
+
) -> ComponentEdges:
|
|
176
|
+
comp_edges: ComponentEdges = {i: set() for i in range(num_components)}
|
|
177
|
+
for name, deps in all_deps.items():
|
|
178
|
+
name_comp = comp_of[name]
|
|
179
|
+
for dep in deps:
|
|
180
|
+
if (dep_comp := comp_of[dep]) != name_comp:
|
|
181
|
+
comp_edges[dep_comp].add(name_comp)
|
|
182
|
+
return comp_edges
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
def _build_keep_model_order_component_order(
|
|
186
|
+
components: Components,
|
|
187
|
+
comp_edges: ComponentEdges,
|
|
188
|
+
order_index: OrderIndex,
|
|
189
|
+
) -> list[ComponentId]:
|
|
190
|
+
comp_key = [min(order_index[n] for n in members) for members in components]
|
|
191
|
+
return stable_toposort(
|
|
192
|
+
list(range(len(components))),
|
|
193
|
+
comp_edges,
|
|
194
|
+
key=lambda component_id: comp_key[component_id],
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
def _build_keep_model_ordered_names(
|
|
199
|
+
ordered_comp_ids: list[ComponentId],
|
|
200
|
+
components: Components,
|
|
201
|
+
strong_deps: ModelDeps,
|
|
202
|
+
order_index: OrderIndex,
|
|
203
|
+
) -> list[ModelName]:
|
|
204
|
+
ordered_names: list[ModelName] = []
|
|
205
|
+
for component_id in ordered_comp_ids:
|
|
206
|
+
members = components[component_id]
|
|
207
|
+
if len(members) > 1:
|
|
208
|
+
strong_edges: dict[ModelName, set[ModelName]] = {n: set() for n in members}
|
|
209
|
+
member_set = set(members)
|
|
210
|
+
for base in members:
|
|
211
|
+
derived_members = {member for member in members if base in strong_deps.get(member, set()) & member_set}
|
|
212
|
+
strong_edges[base].update(derived_members)
|
|
213
|
+
members = stable_toposort(members, strong_edges, key=order_index.__getitem__)
|
|
214
|
+
ordered_names.extend(members)
|
|
215
|
+
return ordered_names
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
def _reorder_models_keep_model_order(
|
|
219
|
+
models: list[DataModel],
|
|
220
|
+
imports: Imports,
|
|
221
|
+
*,
|
|
222
|
+
use_deferred_annotations: bool,
|
|
223
|
+
) -> None:
|
|
224
|
+
"""Reorder models deterministically based on their dependencies.
|
|
225
|
+
|
|
226
|
+
Starts from class_name order and only moves models when required to satisfy dependencies.
|
|
227
|
+
Cycles are kept as SCC groups; within each SCC, base-class dependencies are prioritized.
|
|
228
|
+
"""
|
|
229
|
+
models.sort(key=lambda x: x.class_name)
|
|
230
|
+
imported: ModelNames = {i for v in imports.values() for i in v}
|
|
231
|
+
model_by_name = {m.class_name: m for m in models}
|
|
232
|
+
model_names: ModelNames = set(model_by_name)
|
|
233
|
+
order_index: OrderIndex = {m.class_name: i for i, m in enumerate(models)}
|
|
234
|
+
|
|
235
|
+
deps = _build_keep_model_order_dependency_maps(
|
|
236
|
+
models,
|
|
237
|
+
model_names=model_names,
|
|
238
|
+
imported=imported,
|
|
239
|
+
use_deferred_annotations=use_deferred_annotations,
|
|
240
|
+
)
|
|
241
|
+
comps = _build_keep_model_order_components(deps.all, order_index)
|
|
242
|
+
comp_edges = _build_keep_model_order_component_edges(deps.all, comps.comp_of, len(comps.components))
|
|
243
|
+
ordered_comp_ids = _build_keep_model_order_component_order(comps.components, comp_edges, order_index)
|
|
244
|
+
ordered_names = _build_keep_model_ordered_names(ordered_comp_ids, comps.components, deps.strong, order_index)
|
|
245
|
+
models[:] = [model_by_name[name] for name in ordered_names]
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
SPECIAL_PATH_FORMAT: str = "#-datamodel-code-generator-#-{}-#-special-#"
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
def get_special_path(keyword: str, path: list[str]) -> list[str]:
|
|
252
|
+
"""Create a special path marker for internal reference tracking."""
|
|
253
|
+
return [*path, SPECIAL_PATH_FORMAT.format(keyword)]
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
escape_characters = str.maketrans({
|
|
257
|
+
"\u0000": r"\x00", # Null byte
|
|
258
|
+
"\\": r"\\",
|
|
259
|
+
"'": r"\'",
|
|
260
|
+
"\b": r"\b",
|
|
261
|
+
"\f": r"\f",
|
|
262
|
+
"\n": r"\n",
|
|
263
|
+
"\r": r"\r",
|
|
264
|
+
"\t": r"\t",
|
|
265
|
+
})
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
def to_hashable(item: Any) -> HashableComparable: # noqa: PLR0911
|
|
269
|
+
"""Convert an item to a hashable and comparable representation.
|
|
270
|
+
|
|
271
|
+
Returns a value that is both hashable and supports comparison operators.
|
|
272
|
+
Used for caching and deduplication of models.
|
|
273
|
+
"""
|
|
274
|
+
if isinstance(
|
|
275
|
+
item,
|
|
276
|
+
(
|
|
277
|
+
list,
|
|
278
|
+
tuple,
|
|
279
|
+
),
|
|
280
|
+
):
|
|
281
|
+
try:
|
|
282
|
+
return tuple(sorted((to_hashable(i) for i in item), key=lambda v: (str(type(v)), v)))
|
|
283
|
+
except TypeError:
|
|
284
|
+
# Fallback when mixed, non-comparable types are present; preserve original order
|
|
285
|
+
return tuple(to_hashable(i) for i in item)
|
|
286
|
+
if isinstance(item, dict):
|
|
58
287
|
return tuple(
|
|
59
288
|
sorted(
|
|
60
289
|
(
|
|
@@ -64,53 +293,88 @@ def to_hashable(item: Any) -> Any:
|
|
|
64
293
|
for k, v in item.items()
|
|
65
294
|
)
|
|
66
295
|
)
|
|
67
|
-
|
|
68
|
-
return frozenset(to_hashable(i) for i in item)
|
|
69
|
-
|
|
296
|
+
if isinstance(item, set): # pragma: no cover
|
|
297
|
+
return frozenset(to_hashable(i) for i in item) # type: ignore[return-value]
|
|
298
|
+
if isinstance(item, BaseModel):
|
|
70
299
|
return to_hashable(item.dict())
|
|
71
|
-
|
|
300
|
+
if item is None:
|
|
301
|
+
return ""
|
|
302
|
+
return item # type: ignore[return-value]
|
|
303
|
+
|
|
304
|
+
|
|
305
|
+
def dump_templates(templates: list[DataModel]) -> str:
|
|
306
|
+
"""Join model templates into a single code string."""
|
|
307
|
+
return "\n\n\n".join(str(m) for m in templates)
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
def iter_models_field_data_types(
|
|
311
|
+
models: Iterable[DataModel],
|
|
312
|
+
) -> Iterator[tuple[DataModel, DataModelFieldBase, DataType]]:
|
|
313
|
+
"""Yield (model, field, data_type) for all models, fields, and nested data types."""
|
|
314
|
+
for model in models:
|
|
315
|
+
for field in model.fields:
|
|
316
|
+
for data_type in field.data_type.all_data_types:
|
|
317
|
+
yield model, field, data_type
|
|
72
318
|
|
|
73
319
|
|
|
74
|
-
|
|
75
|
-
|
|
320
|
+
ReferenceMapSet = dict[str, set[str]]
|
|
321
|
+
SortedDataModels = dict[str, DataModel]
|
|
76
322
|
|
|
323
|
+
MAX_RECURSION_COUNT: int = sys.getrecursionlimit()
|
|
77
324
|
|
|
78
|
-
ReferenceMapSet = Dict[str, Set[str]]
|
|
79
|
-
SortedDataModels = Dict[str, DataModel]
|
|
80
325
|
|
|
81
|
-
|
|
326
|
+
def add_model_path_to_list(
|
|
327
|
+
paths: list[str] | None,
|
|
328
|
+
model: DataModel,
|
|
329
|
+
/,
|
|
330
|
+
) -> list[str]:
|
|
331
|
+
"""
|
|
332
|
+
Auxiliary method which adds model path to list, provided the following hold.
|
|
82
333
|
|
|
334
|
+
- model is not a type alias
|
|
335
|
+
- path is not already in the list.
|
|
83
336
|
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
337
|
+
"""
|
|
338
|
+
if paths is None:
|
|
339
|
+
paths = []
|
|
340
|
+
if model.is_alias:
|
|
341
|
+
return paths
|
|
342
|
+
if (path := model.path) in paths:
|
|
343
|
+
return paths
|
|
344
|
+
paths.append(path)
|
|
345
|
+
return paths
|
|
346
|
+
|
|
347
|
+
|
|
348
|
+
def sort_data_models( # noqa: PLR0912, PLR0915
|
|
349
|
+
unsorted_data_models: list[DataModel],
|
|
350
|
+
sorted_data_models: SortedDataModels | None = None,
|
|
351
|
+
require_update_action_models: list[str] | None = None,
|
|
88
352
|
recursion_count: int = MAX_RECURSION_COUNT,
|
|
89
|
-
) ->
|
|
353
|
+
) -> tuple[list[DataModel], SortedDataModels, list[str]]:
|
|
354
|
+
"""Sort data models by dependency order for correct forward references."""
|
|
90
355
|
if sorted_data_models is None:
|
|
91
356
|
sorted_data_models = OrderedDict()
|
|
92
357
|
if require_update_action_models is None:
|
|
93
358
|
require_update_action_models = []
|
|
359
|
+
sorted_model_count: int = len(sorted_data_models)
|
|
94
360
|
|
|
95
|
-
unresolved_references:
|
|
361
|
+
unresolved_references: list[DataModel] = []
|
|
96
362
|
for model in unsorted_data_models:
|
|
97
363
|
if not model.reference_classes:
|
|
98
364
|
sorted_data_models[model.path] = model
|
|
99
|
-
elif (
|
|
100
|
-
model.path in model.reference_classes and len(model.reference_classes) == 1
|
|
101
|
-
): # only self-referencing
|
|
365
|
+
elif model.path in model.reference_classes and len(model.reference_classes) == 1: # only self-referencing
|
|
102
366
|
sorted_data_models[model.path] = model
|
|
103
|
-
require_update_action_models
|
|
367
|
+
add_model_path_to_list(require_update_action_models, model)
|
|
104
368
|
elif (
|
|
105
369
|
not model.reference_classes - {model.path} - set(sorted_data_models)
|
|
106
370
|
): # reference classes have been resolved
|
|
107
371
|
sorted_data_models[model.path] = model
|
|
108
372
|
if model.path in model.reference_classes:
|
|
109
|
-
require_update_action_models
|
|
373
|
+
add_model_path_to_list(require_update_action_models, model)
|
|
110
374
|
else:
|
|
111
375
|
unresolved_references.append(model)
|
|
112
376
|
if unresolved_references:
|
|
113
|
-
if recursion_count:
|
|
377
|
+
if sorted_model_count != len(sorted_data_models) and recursion_count:
|
|
114
378
|
try:
|
|
115
379
|
return sort_data_models(
|
|
116
380
|
unresolved_references,
|
|
@@ -118,37 +382,38 @@ def sort_data_models(
|
|
|
118
382
|
require_update_action_models,
|
|
119
383
|
recursion_count - 1,
|
|
120
384
|
)
|
|
121
|
-
except RecursionError:
|
|
385
|
+
except RecursionError: # pragma: no cover
|
|
122
386
|
pass
|
|
123
387
|
|
|
124
388
|
# sort on base_class dependency
|
|
125
389
|
while True:
|
|
126
|
-
ordered_models:
|
|
390
|
+
ordered_models: list[tuple[int, DataModel]] = []
|
|
127
391
|
unresolved_reference_model_names = [m.path for m in unresolved_references]
|
|
128
392
|
for model in unresolved_references:
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
393
|
+
if isinstance(model, pydantic_model_v2.RootModel):
|
|
394
|
+
indexes = [
|
|
395
|
+
unresolved_reference_model_names.index(ref_path)
|
|
396
|
+
for f in model.fields
|
|
397
|
+
for t in f.data_type.all_data_types
|
|
398
|
+
if t.reference and (ref_path := t.reference.path) in unresolved_reference_model_names
|
|
399
|
+
]
|
|
400
|
+
else:
|
|
401
|
+
indexes = [
|
|
402
|
+
unresolved_reference_model_names.index(b.reference.path)
|
|
403
|
+
for b in model.base_classes
|
|
404
|
+
if b.reference and b.reference.path in unresolved_reference_model_names
|
|
405
|
+
]
|
|
135
406
|
if indexes:
|
|
136
|
-
ordered_models.append(
|
|
137
|
-
(
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
)
|
|
141
|
-
)
|
|
407
|
+
ordered_models.append((
|
|
408
|
+
max(indexes),
|
|
409
|
+
model,
|
|
410
|
+
))
|
|
142
411
|
else:
|
|
143
|
-
ordered_models.append(
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
)
|
|
149
|
-
sorted_unresolved_models = [
|
|
150
|
-
m[1] for m in sorted(ordered_models, key=lambda m: m[0])
|
|
151
|
-
]
|
|
412
|
+
ordered_models.append((
|
|
413
|
+
-1,
|
|
414
|
+
model,
|
|
415
|
+
))
|
|
416
|
+
sorted_unresolved_models = [m[1] for m in sorted(ordered_models, key=operator.itemgetter(0))]
|
|
152
417
|
if sorted_unresolved_models == unresolved_references:
|
|
153
418
|
break
|
|
154
419
|
unresolved_references = sorted_unresolved_models
|
|
@@ -156,33 +421,58 @@ def sort_data_models(
|
|
|
156
421
|
# circular reference
|
|
157
422
|
unsorted_data_model_names = set(unresolved_reference_model_names)
|
|
158
423
|
for model in unresolved_references:
|
|
159
|
-
unresolved_model = (
|
|
160
|
-
|
|
161
|
-
)
|
|
424
|
+
unresolved_model = model.reference_classes - {model.path} - set(sorted_data_models)
|
|
425
|
+
base_models = [getattr(s.reference, "path", None) for s in model.base_classes]
|
|
426
|
+
update_action_parent = set(require_update_action_models).intersection(base_models)
|
|
162
427
|
if not unresolved_model:
|
|
163
428
|
sorted_data_models[model.path] = model
|
|
429
|
+
if update_action_parent:
|
|
430
|
+
add_model_path_to_list(require_update_action_models, model)
|
|
164
431
|
continue
|
|
165
432
|
if not unresolved_model - unsorted_data_model_names:
|
|
166
433
|
sorted_data_models[model.path] = model
|
|
167
|
-
require_update_action_models
|
|
434
|
+
add_model_path_to_list(require_update_action_models, model)
|
|
168
435
|
continue
|
|
169
436
|
# unresolved
|
|
170
|
-
unresolved_classes =
|
|
171
|
-
f"[class: {item.path} references: {item.reference_classes}]"
|
|
172
|
-
for item in unresolved_references
|
|
437
|
+
unresolved_classes = ", ".join(
|
|
438
|
+
f"[class: {item.path} references: {item.reference_classes}]" for item in unresolved_references
|
|
173
439
|
)
|
|
174
|
-
|
|
440
|
+
msg = f"A Parser can not resolve classes: {unresolved_classes}."
|
|
441
|
+
raise Exception(msg) # noqa: TRY002
|
|
175
442
|
return unresolved_references, sorted_data_models, require_update_action_models
|
|
176
443
|
|
|
177
444
|
|
|
178
|
-
def relative(
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
445
|
+
def relative(
|
|
446
|
+
current_module: str,
|
|
447
|
+
reference: str,
|
|
448
|
+
*,
|
|
449
|
+
reference_is_module: bool = False,
|
|
450
|
+
current_is_init: bool = False,
|
|
451
|
+
) -> tuple[str, str]:
|
|
452
|
+
"""Find relative module path.
|
|
453
|
+
|
|
454
|
+
Args:
|
|
455
|
+
current_module: Current module path (e.g., "foo.bar")
|
|
456
|
+
reference: Reference path (e.g., "foo.baz.ClassName" or "foo.baz" if reference_is_module)
|
|
457
|
+
reference_is_module: If True, treat reference as a module path (not module.class)
|
|
458
|
+
current_is_init: If True, treat current_module as a package __init__.py (adds depth)
|
|
459
|
+
|
|
460
|
+
Returns:
|
|
461
|
+
Tuple of (from_path, import_name) for constructing import statements
|
|
462
|
+
"""
|
|
463
|
+
if current_is_init:
|
|
464
|
+
current_module_path = [*current_module.split("."), "__init__"] if current_module else ["__init__"]
|
|
465
|
+
else:
|
|
466
|
+
current_module_path = current_module.split(".") if current_module else []
|
|
467
|
+
|
|
468
|
+
if reference_is_module:
|
|
469
|
+
reference_path = reference.split(".") if reference else []
|
|
470
|
+
name = reference_path[-1] if reference_path else ""
|
|
471
|
+
else:
|
|
472
|
+
*reference_path, name = reference.split(".")
|
|
183
473
|
|
|
184
474
|
if current_module_path == reference_path:
|
|
185
|
-
return
|
|
475
|
+
return "", ""
|
|
186
476
|
|
|
187
477
|
i = 0
|
|
188
478
|
for x, y in zip(current_module_path, reference_path):
|
|
@@ -190,50 +480,159 @@ def relative(current_module: str, reference: str) -> Tuple[str, str]:
|
|
|
190
480
|
break
|
|
191
481
|
i += 1
|
|
192
482
|
|
|
193
|
-
left =
|
|
194
|
-
right =
|
|
483
|
+
left = "." * (len(current_module_path) - i)
|
|
484
|
+
right = ".".join(reference_path[i:])
|
|
195
485
|
|
|
196
486
|
if not left:
|
|
197
|
-
left =
|
|
487
|
+
left = "."
|
|
198
488
|
if not right:
|
|
199
489
|
right = name
|
|
200
|
-
elif
|
|
201
|
-
extra, right = right.rsplit(
|
|
490
|
+
elif "." in right:
|
|
491
|
+
extra, right = right.rsplit(".", 1)
|
|
202
492
|
left += extra
|
|
203
493
|
|
|
204
494
|
return left, right
|
|
205
495
|
|
|
206
496
|
|
|
497
|
+
def is_ancestor_package_reference(current_module: str, reference: str) -> bool:
|
|
498
|
+
"""Check if reference is in an ancestor package (__init__.py).
|
|
499
|
+
|
|
500
|
+
When the reference's module path is an ancestor (prefix) of the current module,
|
|
501
|
+
the reference is in an ancestor package's __init__.py file.
|
|
502
|
+
|
|
503
|
+
Args:
|
|
504
|
+
current_module: The current module path (e.g., "v0.mammal.canine")
|
|
505
|
+
reference: The full reference path (e.g., "v0.Animal")
|
|
506
|
+
|
|
507
|
+
Returns:
|
|
508
|
+
True if the reference is in an ancestor package, False otherwise.
|
|
509
|
+
|
|
510
|
+
Examples:
|
|
511
|
+
- current="v0.animal", ref="v0.Animal" -> True (immediate parent)
|
|
512
|
+
- current="v0.mammal.canine", ref="v0.Animal" -> True (grandparent)
|
|
513
|
+
- current="v0.animal", ref="v0.animal.Dog" -> False (same or child)
|
|
514
|
+
- current="pets", ref="Animal" -> True (root package is immediate parent)
|
|
515
|
+
"""
|
|
516
|
+
current_path = current_module.split(".") if current_module else []
|
|
517
|
+
*reference_path, _ = reference.split(".")
|
|
518
|
+
|
|
519
|
+
if not current_path:
|
|
520
|
+
return False
|
|
521
|
+
|
|
522
|
+
# Case 1: Direct parent package (includes root package when reference_path is empty)
|
|
523
|
+
# e.g., current="pets", ref="Animal" -> current_path[:-1]=[] == reference_path=[]
|
|
524
|
+
if current_path[:-1] == reference_path:
|
|
525
|
+
return True
|
|
526
|
+
|
|
527
|
+
# Case 2: Deeper ancestor package (reference_path must be non-empty proper prefix)
|
|
528
|
+
# e.g., current="v0.mammal.canine", ref="v0.Animal" -> ["v0"] is prefix of ["v0","mammal","canine"]
|
|
529
|
+
return (
|
|
530
|
+
len(reference_path) > 0
|
|
531
|
+
and len(reference_path) < len(current_path)
|
|
532
|
+
and current_path[: len(reference_path)] == reference_path
|
|
533
|
+
)
|
|
534
|
+
|
|
535
|
+
|
|
536
|
+
def exact_import(from_: str, import_: str, short_name: str) -> tuple[str, str]:
|
|
537
|
+
"""Create exact import path to avoid relative import issues."""
|
|
538
|
+
if from_ == len(from_) * ".":
|
|
539
|
+
# Prevents "from . import foo" becoming "from ..foo import Foo"
|
|
540
|
+
# or "from .. import foo" becoming "from ...foo import Foo"
|
|
541
|
+
# when our imported module has the same parent
|
|
542
|
+
return f"{from_}{import_}", short_name
|
|
543
|
+
return f"{from_}.{import_}", short_name
|
|
544
|
+
|
|
545
|
+
|
|
546
|
+
def get_module_directory(module: tuple[str, ...]) -> tuple[str, ...]:
|
|
547
|
+
"""Get the directory portion of a module tuple.
|
|
548
|
+
|
|
549
|
+
Note: Module tuples in module_models do NOT include .py extension.
|
|
550
|
+
The last element is either the module name (e.g., "issuing") or empty for root.
|
|
551
|
+
|
|
552
|
+
Examples:
|
|
553
|
+
("pkg",) -> ("pkg",) - root module
|
|
554
|
+
("pkg", "issuing") -> ("pkg",) - submodule
|
|
555
|
+
("foo", "bar", "baz") -> ("foo", "bar") - deeply nested module
|
|
556
|
+
"""
|
|
557
|
+
if not module:
|
|
558
|
+
return ()
|
|
559
|
+
if len(module) == 1:
|
|
560
|
+
return module
|
|
561
|
+
return module[:-1]
|
|
562
|
+
|
|
563
|
+
|
|
207
564
|
@runtime_checkable
|
|
208
565
|
class Child(Protocol):
|
|
566
|
+
"""Protocol for objects with a parent reference."""
|
|
567
|
+
|
|
209
568
|
@property
|
|
210
|
-
def parent(self) ->
|
|
569
|
+
def parent(self) -> Any | None:
|
|
570
|
+
"""Get the parent object reference."""
|
|
211
571
|
raise NotImplementedError
|
|
212
572
|
|
|
213
573
|
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
574
|
+
T = TypeVar("T")
|
|
575
|
+
|
|
576
|
+
|
|
577
|
+
def get_most_of_parent(value: Any, type_: type[T] | None = None) -> T | None:
|
|
578
|
+
"""Traverse parent chain to find the outermost matching parent."""
|
|
579
|
+
if isinstance(value, Child) and (type_ is None or not isinstance(value, type_)):
|
|
580
|
+
return get_most_of_parent(value.parent, type_)
|
|
217
581
|
return value
|
|
218
582
|
|
|
219
583
|
|
|
220
584
|
def title_to_class_name(title: str) -> str:
|
|
221
|
-
|
|
222
|
-
classname =
|
|
223
|
-
return classname
|
|
585
|
+
"""Convert a schema title to a valid Python class name."""
|
|
586
|
+
classname = re.sub(r"[^A-Za-z0-9]+", " ", title)
|
|
587
|
+
return "".join(x for x in classname.title() if not x.isspace())
|
|
588
|
+
|
|
589
|
+
|
|
590
|
+
def _find_base_classes(model: DataModel) -> list[DataModel]:
|
|
591
|
+
"""Get direct base class DataModels."""
|
|
592
|
+
return [b.reference.source for b in model.base_classes if b.reference and isinstance(b.reference.source, DataModel)]
|
|
593
|
+
|
|
594
|
+
|
|
595
|
+
def _find_field(original_name: str, models: list[DataModel]) -> DataModelFieldBase | None:
|
|
596
|
+
"""Find a field by original_name in the models and their base classes."""
|
|
597
|
+
for model in models:
|
|
598
|
+
for field in model.iter_all_fields(): # pragma: no cover
|
|
599
|
+
if field.original_name == original_name:
|
|
600
|
+
return field
|
|
601
|
+
return None # pragma: no cover
|
|
602
|
+
|
|
603
|
+
|
|
604
|
+
def _copy_data_types(data_types: list[DataType]) -> list[DataType]:
|
|
605
|
+
"""Deep copy a list of DataType objects, preserving references."""
|
|
606
|
+
copied_data_types: list[DataType] = []
|
|
607
|
+
for data_type_ in data_types:
|
|
608
|
+
if data_type_.reference:
|
|
609
|
+
copied_data_types.append(data_type_.__class__(reference=data_type_.reference))
|
|
610
|
+
elif data_type_.data_types: # pragma: no cover
|
|
611
|
+
copied_data_type = data_type_.copy()
|
|
612
|
+
copied_data_type.data_types = _copy_data_types(data_type_.data_types)
|
|
613
|
+
copied_data_types.append(copied_data_type)
|
|
614
|
+
else:
|
|
615
|
+
copied_data_types.append(data_type_.copy())
|
|
616
|
+
return copied_data_types
|
|
224
617
|
|
|
225
618
|
|
|
226
619
|
class Result(BaseModel):
|
|
620
|
+
"""Generated code result with optional source file reference."""
|
|
621
|
+
|
|
227
622
|
body: str
|
|
228
|
-
|
|
623
|
+
future_imports: str = ""
|
|
624
|
+
source: Optional[Path] = None # noqa: UP045
|
|
229
625
|
|
|
230
626
|
|
|
231
627
|
class Source(BaseModel):
|
|
628
|
+
"""Schema source file with path and content."""
|
|
629
|
+
|
|
232
630
|
path: Path
|
|
233
631
|
text: str
|
|
234
632
|
|
|
235
633
|
@classmethod
|
|
236
|
-
def from_path(cls, path: Path, base_path: Path, encoding: str) ->
|
|
634
|
+
def from_path(cls, path: Path, base_path: Path, encoding: str) -> Source:
|
|
635
|
+
"""Create a Source from a file path relative to base_path."""
|
|
237
636
|
return cls(
|
|
238
637
|
path=path.relative_to(base_path),
|
|
239
638
|
text=path.read_text(encoding=encoding),
|
|
@@ -241,144 +640,295 @@ class Source(BaseModel):
|
|
|
241
640
|
|
|
242
641
|
|
|
243
642
|
class Parser(ABC):
|
|
244
|
-
|
|
643
|
+
"""Abstract base class for schema parsers.
|
|
644
|
+
|
|
645
|
+
Provides the parsing algorithm and code generation. Subclasses implement
|
|
646
|
+
parse_raw() to handle specific schema formats.
|
|
647
|
+
"""
|
|
648
|
+
|
|
649
|
+
def __init__( # noqa: PLR0913, PLR0915
|
|
245
650
|
self,
|
|
246
|
-
source:
|
|
651
|
+
source: str | Path | list[Path] | ParseResult,
|
|
247
652
|
*,
|
|
248
|
-
data_model_type:
|
|
249
|
-
data_model_root_type:
|
|
250
|
-
data_type_manager_type:
|
|
251
|
-
data_model_field_type:
|
|
252
|
-
base_class:
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
653
|
+
data_model_type: type[DataModel] = pydantic_model.BaseModel,
|
|
654
|
+
data_model_root_type: type[DataModel] = pydantic_model.CustomRootType,
|
|
655
|
+
data_type_manager_type: type[DataTypeManager] = pydantic_model.DataTypeManager,
|
|
656
|
+
data_model_field_type: type[DataModelFieldBase] = pydantic_model.DataModelField,
|
|
657
|
+
base_class: str | None = None,
|
|
658
|
+
additional_imports: list[str] | None = None,
|
|
659
|
+
custom_template_dir: Path | None = None,
|
|
660
|
+
extra_template_data: defaultdict[str, dict[str, Any]] | None = None,
|
|
661
|
+
target_python_version: PythonVersion = PythonVersionMin,
|
|
662
|
+
dump_resolve_reference_action: Callable[[Iterable[str]], str] | None = None,
|
|
257
663
|
validation: bool = False,
|
|
258
664
|
field_constraints: bool = False,
|
|
259
665
|
snake_case_field: bool = False,
|
|
260
666
|
strip_default_none: bool = False,
|
|
261
|
-
aliases:
|
|
667
|
+
aliases: Mapping[str, str] | None = None,
|
|
262
668
|
allow_population_by_field_name: bool = False,
|
|
263
669
|
apply_default_values_for_required_fields: bool = False,
|
|
670
|
+
allow_extra_fields: bool = False,
|
|
671
|
+
extra_fields: str | None = None,
|
|
264
672
|
force_optional_for_required_fields: bool = False,
|
|
265
|
-
class_name:
|
|
673
|
+
class_name: str | None = None,
|
|
266
674
|
use_standard_collections: bool = False,
|
|
267
|
-
base_path:
|
|
675
|
+
base_path: Path | None = None,
|
|
268
676
|
use_schema_description: bool = False,
|
|
677
|
+
use_field_description: bool = False,
|
|
678
|
+
use_attribute_docstrings: bool = False,
|
|
679
|
+
use_inline_field_description: bool = False,
|
|
680
|
+
use_default_kwarg: bool = False,
|
|
269
681
|
reuse_model: bool = False,
|
|
270
|
-
|
|
271
|
-
|
|
682
|
+
reuse_scope: ReuseScope | None = None,
|
|
683
|
+
shared_module_name: str = DEFAULT_SHARED_MODULE_NAME,
|
|
684
|
+
encoding: str = "utf-8",
|
|
685
|
+
enum_field_as_literal: LiteralType | None = None,
|
|
272
686
|
set_default_enum_member: bool = False,
|
|
687
|
+
use_subclass_enum: bool = False,
|
|
688
|
+
use_specialized_enum: bool = True,
|
|
273
689
|
strict_nullable: bool = False,
|
|
274
690
|
use_generic_container_types: bool = False,
|
|
275
691
|
enable_faux_immutability: bool = False,
|
|
276
|
-
remote_text_cache:
|
|
692
|
+
remote_text_cache: DefaultPutDict[str, str] | None = None,
|
|
277
693
|
disable_appending_item_suffix: bool = False,
|
|
278
|
-
strict_types:
|
|
279
|
-
empty_enum_field_name:
|
|
280
|
-
custom_class_name_generator:
|
|
281
|
-
|
|
282
|
-
] = title_to_class_name,
|
|
283
|
-
field_extra_keys: Optional[Set[str]] = None,
|
|
694
|
+
strict_types: Sequence[StrictTypes] | None = None,
|
|
695
|
+
empty_enum_field_name: str | None = None,
|
|
696
|
+
custom_class_name_generator: Callable[[str], str] | None = title_to_class_name,
|
|
697
|
+
field_extra_keys: set[str] | None = None,
|
|
284
698
|
field_include_all_keys: bool = False,
|
|
285
|
-
|
|
699
|
+
field_extra_keys_without_x_prefix: set[str] | None = None,
|
|
700
|
+
wrap_string_literal: bool | None = None,
|
|
286
701
|
use_title_as_name: bool = False,
|
|
287
|
-
|
|
702
|
+
use_operation_id_as_name: bool = False,
|
|
703
|
+
use_unique_items_as_set: bool = False,
|
|
704
|
+
allof_merge_mode: AllOfMergeMode = AllOfMergeMode.Constraints,
|
|
705
|
+
http_headers: Sequence[tuple[str, str]] | None = None,
|
|
706
|
+
http_ignore_tls: bool = False,
|
|
288
707
|
use_annotated: bool = False,
|
|
289
|
-
|
|
708
|
+
use_serialize_as_any: bool = False,
|
|
709
|
+
use_non_positive_negative_number_constrained_types: bool = False,
|
|
710
|
+
use_decimal_for_multiple_of: bool = False,
|
|
711
|
+
original_field_name_delimiter: str | None = None,
|
|
712
|
+
use_double_quotes: bool = False,
|
|
713
|
+
use_union_operator: bool = False,
|
|
714
|
+
allow_responses_without_content: bool = False,
|
|
715
|
+
collapse_root_models: bool = False,
|
|
716
|
+
skip_root_model: bool = False,
|
|
717
|
+
use_type_alias: bool = False,
|
|
718
|
+
special_field_name_prefix: str | None = None,
|
|
719
|
+
remove_special_field_name_prefix: bool = False,
|
|
720
|
+
capitalise_enum_members: bool = False,
|
|
721
|
+
keep_model_order: bool = False,
|
|
722
|
+
use_one_literal_as_default: bool = False,
|
|
723
|
+
use_enum_values_in_discriminator: bool = False,
|
|
724
|
+
known_third_party: list[str] | None = None,
|
|
725
|
+
custom_formatters: list[str] | None = None,
|
|
726
|
+
custom_formatters_kwargs: dict[str, Any] | None = None,
|
|
727
|
+
use_pendulum: bool = False,
|
|
728
|
+
http_query_parameters: Sequence[tuple[str, str]] | None = None,
|
|
729
|
+
treat_dot_as_module: bool = False,
|
|
730
|
+
use_exact_imports: bool = False,
|
|
731
|
+
default_field_extras: dict[str, Any] | None = None,
|
|
732
|
+
target_datetime_class: DatetimeClassType | None = None,
|
|
733
|
+
keyword_only: bool = False,
|
|
734
|
+
frozen_dataclasses: bool = False,
|
|
735
|
+
no_alias: bool = False,
|
|
736
|
+
use_frozen_field: bool = False,
|
|
737
|
+
formatters: list[Formatter] = DEFAULT_FORMATTERS,
|
|
738
|
+
parent_scoped_naming: bool = False,
|
|
739
|
+
dataclass_arguments: DataclassArguments | None = None,
|
|
740
|
+
type_mappings: list[str] | None = None,
|
|
741
|
+
read_only_write_only_model_type: ReadOnlyWriteOnlyModelType | None = None,
|
|
742
|
+
) -> None:
|
|
743
|
+
"""Initialize the Parser with configuration options."""
|
|
744
|
+
self.keyword_only = keyword_only
|
|
745
|
+
self.frozen_dataclasses = frozen_dataclasses
|
|
290
746
|
self.data_type_manager: DataTypeManager = data_type_manager_type(
|
|
291
|
-
target_python_version,
|
|
292
|
-
use_standard_collections,
|
|
293
|
-
use_generic_container_types,
|
|
294
|
-
|
|
747
|
+
python_version=target_python_version,
|
|
748
|
+
use_standard_collections=use_standard_collections,
|
|
749
|
+
use_generic_container_types=use_generic_container_types,
|
|
750
|
+
use_non_positive_negative_number_constrained_types=use_non_positive_negative_number_constrained_types,
|
|
751
|
+
use_decimal_for_multiple_of=use_decimal_for_multiple_of,
|
|
752
|
+
strict_types=strict_types,
|
|
753
|
+
use_union_operator=use_union_operator,
|
|
754
|
+
use_pendulum=use_pendulum,
|
|
755
|
+
target_datetime_class=target_datetime_class,
|
|
756
|
+
treat_dot_as_module=treat_dot_as_module,
|
|
757
|
+
use_serialize_as_any=use_serialize_as_any,
|
|
295
758
|
)
|
|
296
|
-
self.data_model_type:
|
|
297
|
-
self.data_model_root_type:
|
|
298
|
-
self.data_model_field_type:
|
|
299
|
-
|
|
300
|
-
self.
|
|
759
|
+
self.data_model_type: type[DataModel] = data_model_type
|
|
760
|
+
self.data_model_root_type: type[DataModel] = data_model_root_type
|
|
761
|
+
self.data_model_field_type: type[DataModelFieldBase] = data_model_field_type
|
|
762
|
+
|
|
763
|
+
self.imports: Imports = Imports(use_exact_imports)
|
|
764
|
+
self.use_exact_imports: bool = use_exact_imports
|
|
765
|
+
self._append_additional_imports(additional_imports=additional_imports)
|
|
766
|
+
|
|
767
|
+
self.base_class: str | None = base_class
|
|
301
768
|
self.target_python_version: PythonVersion = target_python_version
|
|
302
|
-
self.results:
|
|
303
|
-
self.dump_resolve_reference_action:
|
|
304
|
-
Callable[[Iterable[str]], str]
|
|
305
|
-
] = dump_resolve_reference_action
|
|
769
|
+
self.results: list[DataModel] = []
|
|
770
|
+
self.dump_resolve_reference_action: Callable[[Iterable[str]], str] | None = dump_resolve_reference_action
|
|
306
771
|
self.validation: bool = validation
|
|
307
772
|
self.field_constraints: bool = field_constraints
|
|
308
773
|
self.snake_case_field: bool = snake_case_field
|
|
309
774
|
self.strip_default_none: bool = strip_default_none
|
|
310
|
-
self.apply_default_values_for_required_fields: bool =
|
|
311
|
-
|
|
312
|
-
)
|
|
313
|
-
self.force_optional_for_required_fields: bool = (
|
|
314
|
-
force_optional_for_required_fields
|
|
315
|
-
)
|
|
775
|
+
self.apply_default_values_for_required_fields: bool = apply_default_values_for_required_fields
|
|
776
|
+
self.force_optional_for_required_fields: bool = force_optional_for_required_fields
|
|
316
777
|
self.use_schema_description: bool = use_schema_description
|
|
778
|
+
self.use_field_description: bool = use_field_description
|
|
779
|
+
self.use_inline_field_description: bool = use_inline_field_description
|
|
780
|
+
self.use_default_kwarg: bool = use_default_kwarg
|
|
317
781
|
self.reuse_model: bool = reuse_model
|
|
782
|
+
self.reuse_scope: ReuseScope | None = reuse_scope
|
|
783
|
+
self.shared_module_name: str = shared_module_name
|
|
318
784
|
self.encoding: str = encoding
|
|
319
|
-
self.enum_field_as_literal:
|
|
785
|
+
self.enum_field_as_literal: LiteralType | None = enum_field_as_literal
|
|
320
786
|
self.set_default_enum_member: bool = set_default_enum_member
|
|
787
|
+
self.use_subclass_enum: bool = use_subclass_enum
|
|
788
|
+
self.use_specialized_enum: bool = use_specialized_enum
|
|
321
789
|
self.strict_nullable: bool = strict_nullable
|
|
322
790
|
self.use_generic_container_types: bool = use_generic_container_types
|
|
791
|
+
self.use_union_operator: bool = use_union_operator
|
|
323
792
|
self.enable_faux_immutability: bool = enable_faux_immutability
|
|
324
|
-
self.custom_class_name_generator:
|
|
325
|
-
|
|
326
|
-
] =
|
|
327
|
-
self.field_extra_keys: Set[str] = field_extra_keys or set()
|
|
793
|
+
self.custom_class_name_generator: Callable[[str], str] | None = custom_class_name_generator
|
|
794
|
+
self.field_extra_keys: set[str] = field_extra_keys or set()
|
|
795
|
+
self.field_extra_keys_without_x_prefix: set[str] = field_extra_keys_without_x_prefix or set()
|
|
328
796
|
self.field_include_all_keys: bool = field_include_all_keys
|
|
329
797
|
|
|
330
|
-
self.remote_text_cache: DefaultPutDict[str, str] = (
|
|
331
|
-
|
|
332
|
-
)
|
|
333
|
-
self.current_source_path: Optional[Path] = None
|
|
798
|
+
self.remote_text_cache: DefaultPutDict[str, str] = remote_text_cache or DefaultPutDict()
|
|
799
|
+
self.current_source_path: Path | None = None
|
|
334
800
|
self.use_title_as_name: bool = use_title_as_name
|
|
801
|
+
self.use_operation_id_as_name: bool = use_operation_id_as_name
|
|
802
|
+
self.use_unique_items_as_set: bool = use_unique_items_as_set
|
|
803
|
+
self.allof_merge_mode: AllOfMergeMode = allof_merge_mode
|
|
804
|
+
self.dataclass_arguments = dataclass_arguments
|
|
335
805
|
|
|
336
806
|
if base_path:
|
|
337
807
|
self.base_path = base_path
|
|
338
808
|
elif isinstance(source, Path):
|
|
339
|
-
self.base_path = (
|
|
340
|
-
source.absolute() if source.is_dir() else source.absolute().parent
|
|
341
|
-
)
|
|
809
|
+
self.base_path = source.absolute() if source.is_dir() else source.absolute().parent
|
|
342
810
|
else:
|
|
343
811
|
self.base_path = Path.cwd()
|
|
344
812
|
|
|
345
|
-
self.source:
|
|
813
|
+
self.source: str | Path | list[Path] | ParseResult = source
|
|
346
814
|
self.custom_template_dir = custom_template_dir
|
|
347
|
-
self.extra_template_data:
|
|
348
|
-
str, Any
|
|
349
|
-
] = extra_template_data or defaultdict(dict)
|
|
815
|
+
self.extra_template_data: defaultdict[str, Any] = extra_template_data or defaultdict(dict)
|
|
350
816
|
|
|
351
817
|
if allow_population_by_field_name:
|
|
352
|
-
self.extra_template_data[ALL_MODEL][
|
|
818
|
+
self.extra_template_data[ALL_MODEL]["allow_population_by_field_name"] = True
|
|
819
|
+
|
|
820
|
+
if allow_extra_fields:
|
|
821
|
+
self.extra_template_data[ALL_MODEL]["allow_extra_fields"] = True
|
|
822
|
+
|
|
823
|
+
if extra_fields:
|
|
824
|
+
self.extra_template_data[ALL_MODEL]["extra_fields"] = extra_fields
|
|
353
825
|
|
|
354
826
|
if enable_faux_immutability:
|
|
355
|
-
self.extra_template_data[ALL_MODEL][
|
|
827
|
+
self.extra_template_data[ALL_MODEL]["allow_mutation"] = False
|
|
828
|
+
|
|
829
|
+
if use_attribute_docstrings:
|
|
830
|
+
self.extra_template_data[ALL_MODEL]["use_attribute_docstrings"] = True
|
|
356
831
|
|
|
357
832
|
self.model_resolver = ModelResolver(
|
|
358
833
|
base_url=source.geturl() if isinstance(source, ParseResult) else None,
|
|
359
|
-
singular_name_suffix=
|
|
834
|
+
singular_name_suffix="" if disable_appending_item_suffix else None,
|
|
360
835
|
aliases=aliases,
|
|
361
836
|
empty_field_name=empty_enum_field_name,
|
|
362
837
|
snake_case_field=snake_case_field,
|
|
363
838
|
custom_class_name_generator=custom_class_name_generator,
|
|
364
839
|
base_path=self.base_path,
|
|
840
|
+
original_field_name_delimiter=original_field_name_delimiter,
|
|
841
|
+
special_field_name_prefix=special_field_name_prefix,
|
|
842
|
+
remove_special_field_name_prefix=remove_special_field_name_prefix,
|
|
843
|
+
capitalise_enum_members=capitalise_enum_members,
|
|
844
|
+
no_alias=no_alias,
|
|
845
|
+
parent_scoped_naming=parent_scoped_naming,
|
|
846
|
+
treat_dot_as_module=treat_dot_as_module,
|
|
365
847
|
)
|
|
366
|
-
self.class_name:
|
|
367
|
-
self.wrap_string_literal:
|
|
368
|
-
self.http_headers:
|
|
848
|
+
self.class_name: str | None = class_name
|
|
849
|
+
self.wrap_string_literal: bool | None = wrap_string_literal
|
|
850
|
+
self.http_headers: Sequence[tuple[str, str]] | None = http_headers
|
|
851
|
+
self.http_query_parameters: Sequence[tuple[str, str]] | None = http_query_parameters
|
|
852
|
+
self.http_ignore_tls: bool = http_ignore_tls
|
|
369
853
|
self.use_annotated: bool = use_annotated
|
|
370
854
|
if self.use_annotated and not self.field_constraints: # pragma: no cover
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
855
|
+
msg = "`use_annotated=True` has to be used with `field_constraints=True`"
|
|
856
|
+
raise Exception(msg) # noqa: TRY002
|
|
857
|
+
self.use_serialize_as_any: bool = use_serialize_as_any
|
|
858
|
+
self.use_non_positive_negative_number_constrained_types = use_non_positive_negative_number_constrained_types
|
|
859
|
+
self.use_double_quotes = use_double_quotes
|
|
860
|
+
self.allow_responses_without_content = allow_responses_without_content
|
|
861
|
+
self.collapse_root_models = collapse_root_models
|
|
862
|
+
self.skip_root_model = skip_root_model
|
|
863
|
+
self.use_type_alias = use_type_alias
|
|
864
|
+
self.capitalise_enum_members = capitalise_enum_members
|
|
865
|
+
self.keep_model_order = keep_model_order
|
|
866
|
+
self.use_one_literal_as_default = use_one_literal_as_default
|
|
867
|
+
self.use_enum_values_in_discriminator = use_enum_values_in_discriminator
|
|
868
|
+
self.known_third_party = known_third_party
|
|
869
|
+
self.custom_formatter = custom_formatters
|
|
870
|
+
self.custom_formatters_kwargs = custom_formatters_kwargs
|
|
871
|
+
self.treat_dot_as_module = treat_dot_as_module
|
|
872
|
+
self.default_field_extras: dict[str, Any] | None = default_field_extras
|
|
873
|
+
self.formatters: list[Formatter] = formatters
|
|
874
|
+
self.type_mappings: dict[tuple[str, str], str] = Parser._parse_type_mappings(type_mappings)
|
|
875
|
+
self.read_only_write_only_model_type: ReadOnlyWriteOnlyModelType | None = read_only_write_only_model_type
|
|
876
|
+
self.use_frozen_field: bool = use_frozen_field
|
|
877
|
+
|
|
878
|
+
@property
|
|
879
|
+
def field_name_model_type(self) -> ModelType:
|
|
880
|
+
"""Get the ModelType for field name validation based on data_model_type.
|
|
881
|
+
|
|
882
|
+
Returns ModelType.PYDANTIC for Pydantic models (which have reserved attributes
|
|
883
|
+
like 'schema', 'model_fields', etc.), and ModelType.CLASS for other model types
|
|
884
|
+
(TypedDict, dataclass, msgspec) which don't have such constraints.
|
|
885
|
+
"""
|
|
886
|
+
if issubclass(
|
|
887
|
+
self.data_model_type,
|
|
888
|
+
(pydantic_model.BaseModel, pydantic_model_v2.BaseModel),
|
|
889
|
+
):
|
|
890
|
+
return ModelType.PYDANTIC
|
|
891
|
+
return ModelType.CLASS
|
|
892
|
+
|
|
893
|
+
@staticmethod
|
|
894
|
+
def _parse_type_mappings(type_mappings: list[str] | None) -> dict[tuple[str, str], str]:
|
|
895
|
+
"""Parse type mappings from CLI format to internal format.
|
|
896
|
+
|
|
897
|
+
Supports two formats:
|
|
898
|
+
- "type+format=target" (e.g., "string+binary=string")
|
|
899
|
+
- "format=target" (e.g., "binary=string", assumes type="string")
|
|
900
|
+
|
|
901
|
+
Returns a dict mapping (type, format) tuples to target type names.
|
|
902
|
+
"""
|
|
903
|
+
if not type_mappings:
|
|
904
|
+
return {}
|
|
905
|
+
|
|
906
|
+
result: dict[tuple[str, str], str] = {}
|
|
907
|
+
for mapping in type_mappings:
|
|
908
|
+
if "=" not in mapping:
|
|
909
|
+
msg = f"Invalid type mapping format: {mapping!r}. Expected 'type+format=target' or 'format=target'."
|
|
910
|
+
raise ValueError(msg)
|
|
911
|
+
|
|
912
|
+
source, target = mapping.split("=", 1)
|
|
913
|
+
if "+" in source:
|
|
914
|
+
type_, format_ = source.split("+", 1)
|
|
915
|
+
else:
|
|
916
|
+
# Default to "string" type if only format is specified
|
|
917
|
+
type_ = "string"
|
|
918
|
+
format_ = source
|
|
919
|
+
|
|
920
|
+
result[type_, format_] = target
|
|
921
|
+
|
|
922
|
+
return result
|
|
374
923
|
|
|
375
924
|
@property
|
|
376
925
|
def iter_source(self) -> Iterator[Source]:
|
|
926
|
+
"""Iterate over all source files to be parsed."""
|
|
377
927
|
if isinstance(self.source, str):
|
|
378
928
|
yield Source(path=Path(), text=self.source)
|
|
379
929
|
elif isinstance(self.source, Path): # pragma: no cover
|
|
380
930
|
if self.source.is_dir():
|
|
381
|
-
for path in self.source.rglob(
|
|
931
|
+
for path in sorted(self.source.rglob("*"), key=lambda p: p.name):
|
|
382
932
|
if path.is_file():
|
|
383
933
|
yield Source.from_path(path, self.base_path, self.encoding)
|
|
384
934
|
else:
|
|
@@ -389,266 +939,1762 @@ class Parser(ABC):
|
|
|
389
939
|
else:
|
|
390
940
|
yield Source(
|
|
391
941
|
path=Path(self.source.path),
|
|
392
|
-
text=self.remote_text_cache.get_or_put(
|
|
393
|
-
self.source.geturl(), default_factory=self._get_text_from_url
|
|
394
|
-
),
|
|
942
|
+
text=self.remote_text_cache.get_or_put(self.source.geturl(), default_factory=self._get_text_from_url),
|
|
395
943
|
)
|
|
396
944
|
|
|
945
|
+
def _append_additional_imports(self, additional_imports: list[str] | None) -> None:
|
|
946
|
+
if additional_imports is None:
|
|
947
|
+
additional_imports = []
|
|
948
|
+
|
|
949
|
+
for additional_import_string in additional_imports:
|
|
950
|
+
if additional_import_string is None: # pragma: no cover
|
|
951
|
+
continue
|
|
952
|
+
new_import = Import.from_full_path(additional_import_string)
|
|
953
|
+
self.imports.append(new_import)
|
|
954
|
+
|
|
397
955
|
def _get_text_from_url(self, url: str) -> str:
|
|
398
|
-
from datamodel_code_generator.http import get_body
|
|
956
|
+
from datamodel_code_generator.http import get_body # noqa: PLC0415
|
|
399
957
|
|
|
400
958
|
return self.remote_text_cache.get_or_put(
|
|
401
|
-
url,
|
|
959
|
+
url,
|
|
960
|
+
default_factory=lambda _url: get_body(
|
|
961
|
+
url, self.http_headers, self.http_ignore_tls, self.http_query_parameters
|
|
962
|
+
),
|
|
402
963
|
)
|
|
403
964
|
|
|
404
965
|
@classmethod
|
|
405
|
-
def get_url_path_parts(cls, url: ParseResult) ->
|
|
966
|
+
def get_url_path_parts(cls, url: ParseResult) -> list[str]:
|
|
967
|
+
"""Split URL into scheme/host and path components."""
|
|
406
968
|
return [
|
|
407
|
-
f
|
|
408
|
-
*url.path.split(
|
|
969
|
+
f"{url.scheme}://{url.hostname}",
|
|
970
|
+
*url.path.split("/")[1:],
|
|
409
971
|
]
|
|
410
972
|
|
|
411
973
|
@property
|
|
412
|
-
def data_type(self) ->
|
|
974
|
+
def data_type(self) -> type[DataType]:
|
|
975
|
+
"""Get the DataType class from the type manager."""
|
|
413
976
|
return self.data_type_manager.data_type
|
|
414
977
|
|
|
415
978
|
@abstractmethod
|
|
416
979
|
def parse_raw(self) -> None:
|
|
980
|
+
"""Parse the raw schema source. Must be implemented by subclasses."""
|
|
417
981
|
raise NotImplementedError
|
|
418
982
|
|
|
419
|
-
|
|
983
|
+
@classmethod
|
|
984
|
+
def _replace_model_in_list(
|
|
985
|
+
cls,
|
|
986
|
+
models: list[DataModel],
|
|
987
|
+
original: DataModel,
|
|
988
|
+
replacement: DataModel,
|
|
989
|
+
) -> None:
|
|
990
|
+
"""Replace model at its position in list."""
|
|
991
|
+
models.insert(models.index(original), replacement)
|
|
992
|
+
models.remove(original)
|
|
993
|
+
|
|
994
|
+
def __delete_duplicate_models(self, models: list[DataModel]) -> None:
|
|
995
|
+
model_class_names: dict[str, DataModel] = {}
|
|
996
|
+
model_to_duplicate_models: defaultdict[DataModel, list[DataModel]] = defaultdict(list)
|
|
997
|
+
for model in models.copy():
|
|
998
|
+
if isinstance(model, self.data_model_root_type):
|
|
999
|
+
root_data_type = model.fields[0].data_type
|
|
1000
|
+
|
|
1001
|
+
# backward compatible
|
|
1002
|
+
# Remove duplicated root model
|
|
1003
|
+
if (
|
|
1004
|
+
root_data_type.reference
|
|
1005
|
+
and not root_data_type.is_dict
|
|
1006
|
+
and not root_data_type.is_list
|
|
1007
|
+
and root_data_type.reference.source in models
|
|
1008
|
+
and root_data_type.reference.name
|
|
1009
|
+
== self.model_resolver.get_class_name(model.reference.original_name, unique=False).name
|
|
1010
|
+
):
|
|
1011
|
+
model.reference.replace_children_references(root_data_type.reference)
|
|
1012
|
+
models.remove(model)
|
|
1013
|
+
for data_type in model.all_data_types:
|
|
1014
|
+
if data_type.reference:
|
|
1015
|
+
data_type.remove_reference()
|
|
1016
|
+
continue
|
|
1017
|
+
|
|
1018
|
+
# Remove self from all DataModel children's base_classes
|
|
1019
|
+
for child in model.reference.iter_data_model_children():
|
|
1020
|
+
child.base_classes = [bc for bc in child.base_classes if bc.reference != model.reference]
|
|
1021
|
+
if not child.base_classes: # pragma: no cover
|
|
1022
|
+
child.set_base_class()
|
|
1023
|
+
|
|
1024
|
+
class_name = model.duplicate_class_name or model.class_name
|
|
1025
|
+
if class_name in model_class_names:
|
|
1026
|
+
original_model = model_class_names[class_name]
|
|
1027
|
+
if model.get_dedup_key(model.duplicate_class_name, use_default=False) == original_model.get_dedup_key(
|
|
1028
|
+
original_model.duplicate_class_name, use_default=False
|
|
1029
|
+
):
|
|
1030
|
+
model_to_duplicate_models[original_model].append(model)
|
|
1031
|
+
continue
|
|
1032
|
+
model_class_names[class_name] = model
|
|
1033
|
+
for model, duplicate_models in model_to_duplicate_models.items():
|
|
1034
|
+
for duplicate_model in duplicate_models:
|
|
1035
|
+
duplicate_model.reference.replace_children_references(model.reference)
|
|
1036
|
+
# Deduplicate base_classes in all DataModel children
|
|
1037
|
+
for child in duplicate_model.reference.iter_data_model_children():
|
|
1038
|
+
child.base_classes = list(
|
|
1039
|
+
{f"{c.module_name}.{c.type_hint}": c for c in child.base_classes}.values()
|
|
1040
|
+
)
|
|
1041
|
+
models.remove(duplicate_model)
|
|
1042
|
+
|
|
1043
|
+
@classmethod
|
|
1044
|
+
def __replace_duplicate_name_in_module(cls, models: list[DataModel]) -> None:
|
|
1045
|
+
scoped_model_resolver = ModelResolver(
|
|
1046
|
+
exclude_names={i.alias or i.import_ for m in models for i in m.imports},
|
|
1047
|
+
duplicate_name_suffix="Model",
|
|
1048
|
+
)
|
|
1049
|
+
|
|
1050
|
+
model_names: dict[str, DataModel] = {}
|
|
1051
|
+
for model in models:
|
|
1052
|
+
class_name: str = model.class_name
|
|
1053
|
+
generated_name: str = scoped_model_resolver.add([model.path], class_name, unique=True, class_name=True).name
|
|
1054
|
+
if class_name != generated_name:
|
|
1055
|
+
model.class_name = generated_name
|
|
1056
|
+
model_names[model.class_name] = model
|
|
1057
|
+
|
|
1058
|
+
for model in models:
|
|
1059
|
+
duplicate_name = model.duplicate_class_name
|
|
1060
|
+
# check only first desired name
|
|
1061
|
+
if duplicate_name and duplicate_name not in model_names:
|
|
1062
|
+
del model_names[model.class_name]
|
|
1063
|
+
model.class_name = duplicate_name
|
|
1064
|
+
model_names[duplicate_name] = model
|
|
1065
|
+
|
|
1066
|
+
def __change_from_import( # noqa: PLR0913, PLR0914
|
|
420
1067
|
self,
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
1068
|
+
models: list[DataModel],
|
|
1069
|
+
imports: Imports,
|
|
1070
|
+
scoped_model_resolver: ModelResolver,
|
|
1071
|
+
*,
|
|
1072
|
+
init: bool,
|
|
1073
|
+
internal_modules: set[tuple[str, ...]] | None = None,
|
|
1074
|
+
model_path_to_module_name: dict[str, str] | None = None,
|
|
1075
|
+
) -> None:
|
|
1076
|
+
model_paths = {model.path for model in models}
|
|
1077
|
+
internal_modules = internal_modules or set()
|
|
1078
|
+
model_path_to_module_name = model_path_to_module_name or {}
|
|
1079
|
+
|
|
1080
|
+
for model in models:
|
|
1081
|
+
scoped_model_resolver.add([model.path], model.class_name)
|
|
1082
|
+
for model in models:
|
|
1083
|
+
before_import = model.imports
|
|
1084
|
+
imports.append(before_import)
|
|
1085
|
+
current_module_name = model_path_to_module_name.get(model.path, model.module_name)
|
|
1086
|
+
for data_type in model.all_data_types:
|
|
1087
|
+
if not data_type.reference or data_type.reference.path in model_paths:
|
|
1088
|
+
continue
|
|
1089
|
+
|
|
1090
|
+
ref_module_name = model_path_to_module_name.get(
|
|
1091
|
+
data_type.reference.path,
|
|
1092
|
+
data_type.full_name.rsplit(".", 1)[0] if "." in data_type.full_name else "",
|
|
1093
|
+
)
|
|
1094
|
+
target_full_name = (
|
|
1095
|
+
f"{ref_module_name}.{data_type.reference.short_name}"
|
|
1096
|
+
if ref_module_name
|
|
1097
|
+
else data_type.reference.short_name
|
|
1098
|
+
)
|
|
425
1099
|
|
|
426
|
-
|
|
1100
|
+
if isinstance(data_type, BaseClassDataType):
|
|
1101
|
+
left, right = relative(current_module_name, target_full_name)
|
|
1102
|
+
is_ancestor = is_ancestor_package_reference(current_module_name, target_full_name)
|
|
1103
|
+
from_ = left if is_ancestor else (f"{left}{right}" if left.endswith(".") else f"{left}.{right}")
|
|
1104
|
+
import_ = data_type.reference.short_name
|
|
1105
|
+
full_path = from_, import_
|
|
1106
|
+
else:
|
|
1107
|
+
from_, import_ = full_path = relative(current_module_name, target_full_name)
|
|
1108
|
+
if imports.use_exact:
|
|
1109
|
+
from_, import_ = exact_import(from_, import_, data_type.reference.short_name)
|
|
1110
|
+
import_ = import_.replace("-", "_")
|
|
1111
|
+
current_module_path = tuple(current_module_name.split(".")) if current_module_name else ()
|
|
1112
|
+
if ( # pragma: no cover
|
|
1113
|
+
len(current_module_path) > 1
|
|
1114
|
+
and current_module_path[-1].count(".") > 0
|
|
1115
|
+
and not self.treat_dot_as_module
|
|
1116
|
+
):
|
|
1117
|
+
rel_path_depth = current_module_path[-1].count(".")
|
|
1118
|
+
from_ = from_[rel_path_depth:]
|
|
427
1119
|
|
|
428
|
-
|
|
429
|
-
if self.target_python_version != PythonVersion.PY_36:
|
|
430
|
-
self.imports.append(IMPORT_ANNOTATIONS)
|
|
1120
|
+
ref_module = tuple(target_full_name.split(".")[:-1])
|
|
431
1121
|
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
1122
|
+
is_module_class_collision = (
|
|
1123
|
+
ref_module and import_ == data_type.reference.short_name and ref_module[-1] == import_
|
|
1124
|
+
)
|
|
1125
|
+
|
|
1126
|
+
if from_ and (ref_module in internal_modules or is_module_class_collision):
|
|
1127
|
+
from_ = f"{from_}{import_}" if from_.endswith(".") else f"{from_}.{import_}"
|
|
1128
|
+
import_ = data_type.reference.short_name
|
|
1129
|
+
full_path = from_, import_
|
|
1130
|
+
|
|
1131
|
+
alias = scoped_model_resolver.add(full_path, import_).name
|
|
1132
|
+
|
|
1133
|
+
name = data_type.reference.short_name
|
|
1134
|
+
if from_ and import_ and alias != name:
|
|
1135
|
+
data_type.alias = alias if data_type.reference.short_name == import_ else f"{alias}.{name}"
|
|
1136
|
+
|
|
1137
|
+
if init and not target_full_name.startswith(current_module_name + "."):
|
|
1138
|
+
from_ = "." + from_
|
|
1139
|
+
imports.append(
|
|
1140
|
+
Import(
|
|
1141
|
+
from_=from_,
|
|
1142
|
+
import_=import_,
|
|
1143
|
+
alias=alias,
|
|
1144
|
+
reference_path=data_type.reference.path,
|
|
1145
|
+
),
|
|
1146
|
+
)
|
|
1147
|
+
after_import = model.imports
|
|
1148
|
+
if before_import != after_import:
|
|
1149
|
+
imports.append(after_import)
|
|
1150
|
+
|
|
1151
|
+
@classmethod
|
|
1152
|
+
def __extract_inherited_enum(cls, models: list[DataModel]) -> None:
|
|
1153
|
+
for model in models.copy():
|
|
1154
|
+
if model.fields:
|
|
1155
|
+
continue
|
|
1156
|
+
enums: list[Enum] = []
|
|
1157
|
+
for base_model in model.base_classes:
|
|
1158
|
+
if not base_model.reference:
|
|
1159
|
+
continue
|
|
1160
|
+
source_model = base_model.reference.source
|
|
1161
|
+
if isinstance(source_model, Enum):
|
|
1162
|
+
enums.append(source_model)
|
|
1163
|
+
if enums:
|
|
1164
|
+
merged_enum = enums[0].__class__(
|
|
1165
|
+
fields=[f for e in enums for f in e.fields],
|
|
1166
|
+
description=model.description,
|
|
1167
|
+
reference=model.reference,
|
|
1168
|
+
)
|
|
1169
|
+
cls._replace_model_in_list(models, model, merged_enum)
|
|
1170
|
+
|
|
1171
|
+
def _create_discriminator_data_type(
|
|
1172
|
+
self,
|
|
1173
|
+
enum_source: Enum | None,
|
|
1174
|
+
type_names: list[str],
|
|
1175
|
+
discriminator_model: DataModel,
|
|
1176
|
+
imports: Imports,
|
|
1177
|
+
) -> DataType:
|
|
1178
|
+
"""Create a data type for discriminator field, using enum literals if available."""
|
|
1179
|
+
if enum_source:
|
|
1180
|
+
enum_class_name = enum_source.reference.short_name
|
|
1181
|
+
enum_member_literals: list[tuple[str, str]] = []
|
|
1182
|
+
for value in type_names:
|
|
1183
|
+
member = enum_source.find_member(value)
|
|
1184
|
+
if member and member.field.name:
|
|
1185
|
+
enum_member_literals.append((enum_class_name, member.field.name))
|
|
1186
|
+
else: # pragma: no cover
|
|
1187
|
+
enum_member_literals.append((enum_class_name, value))
|
|
1188
|
+
data_type = self.data_type(enum_member_literals=enum_member_literals)
|
|
1189
|
+
if enum_source.module_path != discriminator_model.module_path: # pragma: no cover
|
|
1190
|
+
imports.append(Import.from_full_path(enum_source.name))
|
|
436
1191
|
else:
|
|
437
|
-
|
|
1192
|
+
data_type = self.data_type(literals=type_names)
|
|
1193
|
+
return data_type
|
|
438
1194
|
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
1195
|
+
def __apply_discriminator_type( # noqa: PLR0912, PLR0914, PLR0915
|
|
1196
|
+
self,
|
|
1197
|
+
models: list[DataModel],
|
|
1198
|
+
imports: Imports,
|
|
1199
|
+
) -> None:
|
|
1200
|
+
for model in models: # noqa: PLR1702
|
|
1201
|
+
for field in model.fields:
|
|
1202
|
+
discriminator = field.extras.get("discriminator")
|
|
1203
|
+
if not discriminator or not isinstance(discriminator, dict):
|
|
1204
|
+
continue
|
|
1205
|
+
property_name = discriminator.get("propertyName")
|
|
1206
|
+
if not property_name: # pragma: no cover
|
|
1207
|
+
continue
|
|
1208
|
+
field_name, alias = self.model_resolver.get_valid_field_name_and_alias(
|
|
1209
|
+
field_name=property_name, model_type=self.field_name_model_type
|
|
1210
|
+
)
|
|
1211
|
+
discriminator["propertyName"] = field_name
|
|
1212
|
+
mapping = discriminator.get("mapping", {})
|
|
1213
|
+
for data_type in field.data_type.data_types:
|
|
1214
|
+
if not data_type.reference: # pragma: no cover
|
|
1215
|
+
continue
|
|
1216
|
+
discriminator_model = data_type.reference.source
|
|
1217
|
+
|
|
1218
|
+
if not isinstance( # pragma: no cover
|
|
1219
|
+
discriminator_model,
|
|
1220
|
+
(
|
|
1221
|
+
pydantic_model.BaseModel,
|
|
1222
|
+
pydantic_model_v2.BaseModel,
|
|
1223
|
+
dataclass_model.DataClass,
|
|
1224
|
+
msgspec_model.Struct,
|
|
1225
|
+
),
|
|
1226
|
+
):
|
|
1227
|
+
continue # pragma: no cover
|
|
1228
|
+
|
|
1229
|
+
type_names: list[str] = []
|
|
1230
|
+
|
|
1231
|
+
def check_paths(
|
|
1232
|
+
model: pydantic_model.BaseModel | pydantic_model_v2.BaseModel | Reference,
|
|
1233
|
+
mapping: dict[str, str],
|
|
1234
|
+
type_names: list[str] = type_names,
|
|
1235
|
+
) -> None:
|
|
1236
|
+
"""Validate discriminator mapping paths for a model."""
|
|
1237
|
+
for name, path in mapping.items():
|
|
1238
|
+
if (model.path.split("#/")[-1] != path.split("#/")[-1]) and (
|
|
1239
|
+
path.startswith("#/") or model.path[:-1] != path.split("/")[-1]
|
|
1240
|
+
):
|
|
1241
|
+
t_path = path[str(path).find("/") + 1 :]
|
|
1242
|
+
t_disc = model.path[: str(model.path).find("#")].lstrip("../") # noqa: B005
|
|
1243
|
+
t_disc_2 = "/".join(t_disc.split("/")[1:])
|
|
1244
|
+
if t_path not in {t_disc, t_disc_2}:
|
|
1245
|
+
continue
|
|
1246
|
+
type_names.append(name)
|
|
1247
|
+
|
|
1248
|
+
# First try to get the discriminator value from the const field
|
|
1249
|
+
for discriminator_field in discriminator_model.fields:
|
|
1250
|
+
if field_name not in {discriminator_field.original_name, discriminator_field.name}:
|
|
1251
|
+
continue
|
|
1252
|
+
if discriminator_field.extras.get("const"):
|
|
1253
|
+
type_names = [discriminator_field.extras["const"]]
|
|
1254
|
+
break
|
|
1255
|
+
|
|
1256
|
+
# If no const value found, try to get it from the mapping
|
|
1257
|
+
if not type_names:
|
|
1258
|
+
# Check the main discriminator model path
|
|
1259
|
+
if mapping:
|
|
1260
|
+
check_paths(discriminator_model, mapping) # pyright: ignore[reportArgumentType]
|
|
1261
|
+
|
|
1262
|
+
# Check the base_classes if they exist
|
|
1263
|
+
if len(type_names) == 0:
|
|
1264
|
+
for base_class in discriminator_model.base_classes:
|
|
1265
|
+
check_paths(base_class.reference, mapping) # pyright: ignore[reportArgumentType]
|
|
1266
|
+
else:
|
|
1267
|
+
for discriminator_field in discriminator_model.fields:
|
|
1268
|
+
if field_name not in {discriminator_field.original_name, discriminator_field.name}:
|
|
1269
|
+
continue
|
|
1270
|
+
|
|
1271
|
+
literals = discriminator_field.data_type.literals
|
|
1272
|
+
if literals and len(literals) == 1: # pragma: no cover
|
|
1273
|
+
type_names = [str(v) for v in literals]
|
|
1274
|
+
break
|
|
1275
|
+
|
|
1276
|
+
enum_source = discriminator_field.data_type.find_source(Enum)
|
|
1277
|
+
if enum_source and len(enum_source.fields) == 1:
|
|
1278
|
+
first_field = enum_source.fields[0]
|
|
1279
|
+
raw_default = first_field.default
|
|
1280
|
+
if isinstance(raw_default, str):
|
|
1281
|
+
type_names = [raw_default.strip("'\"")]
|
|
1282
|
+
else: # pragma: no cover
|
|
1283
|
+
type_names = [str(raw_default)]
|
|
1284
|
+
break
|
|
1285
|
+
|
|
1286
|
+
if not type_names:
|
|
1287
|
+
type_names = [discriminator_model.path.split("/")[-1]]
|
|
1288
|
+
|
|
1289
|
+
if not type_names: # pragma: no cover
|
|
1290
|
+
msg = f"Discriminator type is not found. {data_type.reference.path}"
|
|
1291
|
+
raise RuntimeError(msg)
|
|
1292
|
+
|
|
1293
|
+
enum_from_base: Enum | None = None
|
|
1294
|
+
if self.use_enum_values_in_discriminator:
|
|
1295
|
+
for base_class in discriminator_model.base_classes:
|
|
1296
|
+
if not base_class.reference or not base_class.reference.source: # pragma: no cover
|
|
1297
|
+
continue
|
|
1298
|
+
base_model = base_class.reference.source
|
|
1299
|
+
if not isinstance( # pragma: no cover
|
|
1300
|
+
base_model,
|
|
1301
|
+
(
|
|
1302
|
+
pydantic_model.BaseModel,
|
|
1303
|
+
pydantic_model_v2.BaseModel,
|
|
1304
|
+
dataclass_model.DataClass,
|
|
1305
|
+
msgspec_model.Struct,
|
|
1306
|
+
),
|
|
1307
|
+
):
|
|
1308
|
+
continue
|
|
1309
|
+
for base_field in base_model.fields: # pragma: no branch
|
|
1310
|
+
if field_name not in {base_field.original_name, base_field.name}: # pragma: no cover
|
|
1311
|
+
continue
|
|
1312
|
+
enum_from_base = base_field.data_type.find_source(Enum)
|
|
1313
|
+
if enum_from_base: # pragma: no branch
|
|
1314
|
+
break
|
|
1315
|
+
if enum_from_base: # pragma: no branch
|
|
1316
|
+
break
|
|
1317
|
+
|
|
1318
|
+
has_one_literal = False
|
|
1319
|
+
for discriminator_field in discriminator_model.fields:
|
|
1320
|
+
if field_name not in {discriminator_field.original_name, discriminator_field.name}:
|
|
1321
|
+
continue
|
|
1322
|
+
literals = discriminator_field.data_type.literals
|
|
1323
|
+
const_value = discriminator_field.extras.get("const")
|
|
1324
|
+
expected_value = type_names[0] if type_names else None
|
|
1325
|
+
|
|
1326
|
+
# Check if literals match (existing behavior)
|
|
1327
|
+
literals_match = len(literals) == 1 and literals[0] == expected_value
|
|
1328
|
+
# Check if const value matches (for msgspec with type: string + const)
|
|
1329
|
+
const_match = const_value is not None and const_value == expected_value
|
|
1330
|
+
|
|
1331
|
+
if literals_match:
|
|
1332
|
+
has_one_literal = True
|
|
1333
|
+
if isinstance(discriminator_model, msgspec_model.Struct): # pragma: no cover
|
|
1334
|
+
discriminator_model.add_base_class_kwarg("tag_field", f"'{field_name}'")
|
|
1335
|
+
discriminator_model.add_base_class_kwarg("tag", discriminator_field.represented_default)
|
|
1336
|
+
discriminator_field.extras["is_classvar"] = True
|
|
1337
|
+
# Found the discriminator field, no need to keep looking
|
|
1338
|
+
break
|
|
1339
|
+
|
|
1340
|
+
# For msgspec with const value but no literal (type: string + const case)
|
|
1341
|
+
if const_match and isinstance(discriminator_model, msgspec_model.Struct): # pragma: no cover
|
|
1342
|
+
has_one_literal = True
|
|
1343
|
+
discriminator_model.add_base_class_kwarg("tag_field", f"'{field_name}'")
|
|
1344
|
+
discriminator_model.add_base_class_kwarg("tag", repr(const_value))
|
|
1345
|
+
discriminator_field.extras["is_classvar"] = True
|
|
1346
|
+
break
|
|
1347
|
+
|
|
1348
|
+
enum_source: Enum | None = None
|
|
1349
|
+
if self.use_enum_values_in_discriminator:
|
|
1350
|
+
enum_source = ( # pragma: no cover
|
|
1351
|
+
discriminator_field.data_type.find_source(Enum) or enum_from_base
|
|
1352
|
+
)
|
|
1353
|
+
|
|
1354
|
+
for field_data_type in discriminator_field.data_type.all_data_types:
|
|
1355
|
+
if field_data_type.reference: # pragma: no cover
|
|
1356
|
+
field_data_type.remove_reference()
|
|
1357
|
+
|
|
1358
|
+
discriminator_field.data_type = self._create_discriminator_data_type(
|
|
1359
|
+
enum_source, type_names, discriminator_model, imports
|
|
1360
|
+
)
|
|
1361
|
+
discriminator_field.data_type.parent = discriminator_field
|
|
1362
|
+
discriminator_field.required = True
|
|
1363
|
+
imports.append(discriminator_field.imports)
|
|
1364
|
+
has_one_literal = True
|
|
1365
|
+
if not has_one_literal:
|
|
1366
|
+
new_data_type = self._create_discriminator_data_type(
|
|
1367
|
+
enum_from_base, type_names, discriminator_model, imports
|
|
1368
|
+
)
|
|
1369
|
+
discriminator_model.fields.append(
|
|
1370
|
+
self.data_model_field_type(
|
|
1371
|
+
name=field_name,
|
|
1372
|
+
data_type=new_data_type,
|
|
1373
|
+
required=True,
|
|
1374
|
+
alias=alias,
|
|
1375
|
+
)
|
|
1376
|
+
)
|
|
1377
|
+
has_imported_literal = any(import_ == IMPORT_LITERAL for import_ in imports)
|
|
1378
|
+
if has_imported_literal: # pragma: no cover
|
|
1379
|
+
imports.append(IMPORT_LITERAL)
|
|
442
1380
|
|
|
443
|
-
|
|
1381
|
+
@classmethod
|
|
1382
|
+
def _create_set_from_list(cls, data_type: DataType) -> DataType | None:
|
|
1383
|
+
if data_type.is_list:
|
|
1384
|
+
new_data_type = data_type.copy()
|
|
1385
|
+
new_data_type.is_list = False
|
|
1386
|
+
new_data_type.is_set = True
|
|
1387
|
+
for data_type_ in new_data_type.data_types:
|
|
1388
|
+
data_type_.parent = new_data_type
|
|
1389
|
+
return new_data_type
|
|
1390
|
+
if data_type.data_types: # pragma: no cover
|
|
1391
|
+
for nested_data_type in data_type.data_types[:]:
|
|
1392
|
+
set_data_type = cls._create_set_from_list(nested_data_type)
|
|
1393
|
+
if set_data_type: # pragma: no cover
|
|
1394
|
+
nested_data_type.swap_with(set_data_type)
|
|
1395
|
+
return data_type
|
|
1396
|
+
return None # pragma: no cover
|
|
1397
|
+
|
|
1398
|
+
def __replace_unique_list_to_set(self, models: list[DataModel]) -> None:
|
|
1399
|
+
for model in models:
|
|
1400
|
+
for model_field in model.fields:
|
|
1401
|
+
if not self.use_unique_items_as_set:
|
|
1402
|
+
continue
|
|
1403
|
+
|
|
1404
|
+
if not (model_field.constraints and model_field.constraints.unique_items):
|
|
1405
|
+
continue
|
|
1406
|
+
set_data_type = self._create_set_from_list(model_field.data_type)
|
|
1407
|
+
if set_data_type: # pragma: no cover
|
|
1408
|
+
# Check if default list elements are hashable before converting type
|
|
1409
|
+
if isinstance(model_field.default, list):
|
|
1410
|
+
try:
|
|
1411
|
+
converted_default = set(model_field.default)
|
|
1412
|
+
except TypeError:
|
|
1413
|
+
# Elements are not hashable (e.g., contains dicts)
|
|
1414
|
+
# Skip both type and default conversion to keep consistency
|
|
1415
|
+
continue
|
|
1416
|
+
model_field.default = converted_default
|
|
1417
|
+
model_field.replace_data_type(set_data_type)
|
|
444
1418
|
|
|
445
|
-
|
|
1419
|
+
@classmethod
|
|
1420
|
+
def __set_reference_default_value_to_field(cls, models: list[DataModel]) -> None:
|
|
1421
|
+
for model in models:
|
|
1422
|
+
for model_field in model.fields:
|
|
1423
|
+
if not model_field.data_type.reference or model_field.has_default:
|
|
1424
|
+
continue
|
|
1425
|
+
if (
|
|
1426
|
+
isinstance(model_field.data_type.reference.source, DataModel)
|
|
1427
|
+
and model_field.data_type.reference.source.default != UNDEFINED
|
|
1428
|
+
):
|
|
1429
|
+
# pragma: no cover
|
|
1430
|
+
model_field.default = model_field.data_type.reference.source.default
|
|
1431
|
+
|
|
1432
|
+
def __reuse_model(self, models: list[DataModel], require_update_action_models: list[str]) -> None:
|
|
1433
|
+
if not self.reuse_model or self.reuse_scope == ReuseScope.Tree:
|
|
1434
|
+
return
|
|
1435
|
+
model_cache: dict[tuple[HashableComparable, ...], Reference] = {}
|
|
1436
|
+
duplicates = []
|
|
1437
|
+
for model in models.copy():
|
|
1438
|
+
model_key = model.get_dedup_key()
|
|
1439
|
+
cached_model_reference = model_cache.get(model_key)
|
|
1440
|
+
if cached_model_reference:
|
|
1441
|
+
if isinstance(model, Enum):
|
|
1442
|
+
model.replace_children_in_models(models, cached_model_reference)
|
|
1443
|
+
duplicates.append(model)
|
|
1444
|
+
else:
|
|
1445
|
+
inherited_model = model.create_reuse_model(cached_model_reference)
|
|
1446
|
+
if cached_model_reference.path in require_update_action_models:
|
|
1447
|
+
add_model_path_to_list(require_update_action_models, inherited_model)
|
|
1448
|
+
self._replace_model_in_list(models, model, inherited_model)
|
|
1449
|
+
else:
|
|
1450
|
+
model_cache[model_key] = model.reference
|
|
446
1451
|
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
1452
|
+
for duplicate in duplicates:
|
|
1453
|
+
models.remove(duplicate)
|
|
1454
|
+
|
|
1455
|
+
def __find_duplicate_models_across_modules( # noqa: PLR6301
|
|
1456
|
+
self,
|
|
1457
|
+
module_models: list[tuple[tuple[str, ...], list[DataModel]]],
|
|
1458
|
+
) -> list[tuple[tuple[str, ...], DataModel, tuple[str, ...], DataModel]]:
|
|
1459
|
+
"""Find duplicate models across all modules by comparing render output and imports."""
|
|
1460
|
+
all_models: list[tuple[tuple[str, ...], DataModel]] = []
|
|
1461
|
+
for module, models in module_models:
|
|
1462
|
+
all_models.extend((module, model) for model in models)
|
|
1463
|
+
|
|
1464
|
+
model_cache: dict[tuple[HashableComparable, ...], tuple[tuple[str, ...], DataModel]] = {}
|
|
1465
|
+
duplicates: list[tuple[tuple[str, ...], DataModel, tuple[str, ...], DataModel]] = []
|
|
1466
|
+
|
|
1467
|
+
for module, model in all_models:
|
|
1468
|
+
model_key = model.get_dedup_key()
|
|
1469
|
+
cached = model_cache.get(model_key)
|
|
1470
|
+
if cached:
|
|
1471
|
+
canonical_module, canonical_model = cached
|
|
1472
|
+
duplicates.append((module, model, canonical_module, canonical_model))
|
|
1473
|
+
else:
|
|
1474
|
+
model_cache[model_key] = (module, model)
|
|
1475
|
+
|
|
1476
|
+
return duplicates
|
|
1477
|
+
|
|
1478
|
+
def __validate_shared_module_name(
|
|
1479
|
+
self,
|
|
1480
|
+
module_models: list[tuple[tuple[str, ...], list[DataModel]]],
|
|
1481
|
+
) -> None:
|
|
1482
|
+
"""Validate that the shared module name doesn't conflict with existing modules."""
|
|
1483
|
+
shared_module = self.shared_module_name
|
|
1484
|
+
existing_module_names = {module[0] for module, _ in module_models}
|
|
1485
|
+
if shared_module in existing_module_names:
|
|
1486
|
+
msg = (
|
|
1487
|
+
f"Schema file or directory '{shared_module}' conflicts with the shared module name. "
|
|
1488
|
+
f"Use --shared-module-name to specify a different name."
|
|
1489
|
+
)
|
|
1490
|
+
raise Error(msg)
|
|
1491
|
+
|
|
1492
|
+
def __create_shared_module_from_duplicates( # noqa: PLR0912
|
|
1493
|
+
self,
|
|
1494
|
+
module_models: list[tuple[tuple[str, ...], list[DataModel]]],
|
|
1495
|
+
duplicates: list[tuple[tuple[str, ...], DataModel, tuple[str, ...], DataModel]],
|
|
1496
|
+
require_update_action_models: list[str],
|
|
1497
|
+
) -> tuple[tuple[str, ...], list[DataModel]]:
|
|
1498
|
+
"""Create shared module with canonical models and replace duplicates with inherited models."""
|
|
1499
|
+
shared_module = self.shared_module_name
|
|
1500
|
+
|
|
1501
|
+
shared_models: list[DataModel] = []
|
|
1502
|
+
canonical_to_shared_ref: dict[DataModel, Reference] = {}
|
|
1503
|
+
canonical_models_seen: set[DataModel] = set()
|
|
1504
|
+
|
|
1505
|
+
# Process in order of first appearance in duplicates to ensure stable ordering
|
|
1506
|
+
for _, _, _, canonical in duplicates:
|
|
1507
|
+
if canonical in canonical_models_seen:
|
|
1508
|
+
continue
|
|
1509
|
+
canonical_models_seen.add(canonical)
|
|
1510
|
+
canonical.file_path = Path(f"{shared_module}.py")
|
|
1511
|
+
canonical_to_shared_ref[canonical] = canonical.reference
|
|
1512
|
+
shared_models.append(canonical)
|
|
1513
|
+
|
|
1514
|
+
supports_inheritance = issubclass(
|
|
1515
|
+
self.data_model_type,
|
|
1516
|
+
(
|
|
1517
|
+
pydantic_model.BaseModel,
|
|
1518
|
+
pydantic_model_v2.BaseModel,
|
|
1519
|
+
dataclass_model.DataClass,
|
|
1520
|
+
),
|
|
451
1521
|
)
|
|
452
1522
|
|
|
453
|
-
|
|
1523
|
+
for duplicate_module, duplicate_model, _, canonical_model in duplicates:
|
|
1524
|
+
shared_ref = canonical_to_shared_ref[canonical_model]
|
|
1525
|
+
for module, models in module_models:
|
|
1526
|
+
if module != duplicate_module or duplicate_model not in models:
|
|
1527
|
+
continue
|
|
1528
|
+
if isinstance(duplicate_model, Enum) or not supports_inheritance:
|
|
1529
|
+
duplicate_model.replace_children_in_models(models, shared_ref)
|
|
1530
|
+
models.remove(duplicate_model)
|
|
1531
|
+
else:
|
|
1532
|
+
inherited_model = duplicate_model.create_reuse_model(shared_ref)
|
|
1533
|
+
if shared_ref.path in require_update_action_models:
|
|
1534
|
+
add_model_path_to_list(require_update_action_models, inherited_model)
|
|
1535
|
+
self._replace_model_in_list(models, duplicate_model, inherited_model)
|
|
1536
|
+
break
|
|
1537
|
+
else: # pragma: no cover
|
|
1538
|
+
msg = f"Duplicate model {duplicate_model.name} not found in module {duplicate_module}"
|
|
1539
|
+
raise RuntimeError(msg)
|
|
1540
|
+
|
|
1541
|
+
for canonical in canonical_models_seen:
|
|
1542
|
+
for _module, models in module_models:
|
|
1543
|
+
if canonical in models:
|
|
1544
|
+
models.remove(canonical)
|
|
1545
|
+
break
|
|
1546
|
+
else: # pragma: no cover
|
|
1547
|
+
msg = f"Canonical model {canonical.name} not found in any module"
|
|
1548
|
+
raise RuntimeError(msg)
|
|
1549
|
+
|
|
1550
|
+
return (shared_module,), shared_models
|
|
1551
|
+
|
|
1552
|
+
def __reuse_model_tree_scope(
|
|
1553
|
+
self,
|
|
1554
|
+
module_models: list[tuple[tuple[str, ...], list[DataModel]]],
|
|
1555
|
+
require_update_action_models: list[str],
|
|
1556
|
+
) -> tuple[tuple[str, ...], list[DataModel]] | None:
|
|
1557
|
+
"""Deduplicate models across all modules, placing shared models in shared.py."""
|
|
1558
|
+
if not self.reuse_model or self.reuse_scope != ReuseScope.Tree:
|
|
1559
|
+
return None
|
|
454
1560
|
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
1561
|
+
duplicates = self.__find_duplicate_models_across_modules(module_models)
|
|
1562
|
+
if not duplicates:
|
|
1563
|
+
return None
|
|
458
1564
|
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
1565
|
+
self.__validate_shared_module_name(module_models)
|
|
1566
|
+
return self.__create_shared_module_from_duplicates(module_models, duplicates, require_update_action_models)
|
|
1567
|
+
|
|
1568
|
+
def __collapse_root_models( # noqa: PLR0912
|
|
1569
|
+
self,
|
|
1570
|
+
models: list[DataModel],
|
|
1571
|
+
unused_models: list[DataModel],
|
|
1572
|
+
imports: Imports,
|
|
1573
|
+
scoped_model_resolver: ModelResolver,
|
|
1574
|
+
) -> None:
|
|
1575
|
+
if not self.collapse_root_models:
|
|
1576
|
+
return
|
|
1577
|
+
|
|
1578
|
+
for model in models: # noqa: PLR1702
|
|
1579
|
+
for model_field in model.fields:
|
|
1580
|
+
for data_type in model_field.data_type.all_data_types:
|
|
1581
|
+
reference = data_type.reference
|
|
1582
|
+
if not reference or not isinstance(reference.source, self.data_model_root_type):
|
|
1583
|
+
# If the data type is not a reference, we can't collapse it.
|
|
1584
|
+
# If it's a reference to a root model type, we don't do anything.
|
|
1585
|
+
continue
|
|
1586
|
+
|
|
1587
|
+
# Use root-type as model_field type
|
|
1588
|
+
root_type_model = reference.source
|
|
1589
|
+
root_type_field = root_type_model.fields[0]
|
|
462
1590
|
|
|
463
|
-
# backward compatible
|
|
464
|
-
# Remove duplicated root model
|
|
465
1591
|
if (
|
|
466
|
-
|
|
467
|
-
and
|
|
468
|
-
and
|
|
469
|
-
and
|
|
470
|
-
and root_data_type.reference.name
|
|
471
|
-
== self.model_resolver.get_class_name(
|
|
472
|
-
model.reference.original_name, unique=False
|
|
473
|
-
)
|
|
1592
|
+
self.field_constraints
|
|
1593
|
+
and isinstance(root_type_field.constraints, ConstraintsBase)
|
|
1594
|
+
and root_type_field.constraints.has_constraints
|
|
1595
|
+
and any(d for d in model_field.data_type.all_data_types if d.is_dict or d.is_union or d.is_list)
|
|
474
1596
|
):
|
|
475
|
-
#
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
1597
|
+
continue # pragma: no cover
|
|
1598
|
+
|
|
1599
|
+
if root_type_field.data_type.reference:
|
|
1600
|
+
# If the root type field is a reference, we aren't able to collapse it yet.
|
|
479
1601
|
continue
|
|
480
1602
|
|
|
481
|
-
#
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
)
|
|
1603
|
+
# set copied data_type
|
|
1604
|
+
copied_data_type = root_type_field.data_type.copy()
|
|
1605
|
+
if isinstance(data_type.parent, self.data_model_field_type):
|
|
1606
|
+
# for field
|
|
1607
|
+
# override empty field by root-type field
|
|
1608
|
+
model_field.extras = {
|
|
1609
|
+
**root_type_field.extras,
|
|
1610
|
+
**model_field.extras,
|
|
1611
|
+
}
|
|
1612
|
+
model_field.process_const()
|
|
1613
|
+
|
|
1614
|
+
if self.field_constraints:
|
|
1615
|
+
model_field.constraints = ConstraintsBase.merge_constraints(
|
|
1616
|
+
root_type_field.constraints, model_field.constraints
|
|
1617
|
+
)
|
|
497
1618
|
|
|
498
|
-
|
|
499
|
-
exclude_names={i.alias or i.import_ for m in models for i in m.imports},
|
|
500
|
-
duplicate_name_suffix='Model',
|
|
501
|
-
)
|
|
1619
|
+
data_type.parent.data_type = copied_data_type
|
|
502
1620
|
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
|
|
1621
|
+
elif isinstance(data_type.parent, DataType) and data_type.parent.is_list:
|
|
1622
|
+
if self.field_constraints:
|
|
1623
|
+
model_field.constraints = ConstraintsBase.merge_constraints(
|
|
1624
|
+
root_type_field.constraints, model_field.constraints
|
|
1625
|
+
)
|
|
1626
|
+
if ( # pragma: no cover
|
|
1627
|
+
isinstance(
|
|
1628
|
+
root_type_field,
|
|
1629
|
+
pydantic_model.DataModelField,
|
|
1630
|
+
)
|
|
1631
|
+
and not model_field.extras.get("discriminator")
|
|
1632
|
+
and not any(t.is_list for t in model_field.data_type.data_types)
|
|
1633
|
+
):
|
|
1634
|
+
discriminator = root_type_field.extras.get("discriminator")
|
|
1635
|
+
if discriminator:
|
|
1636
|
+
model_field.extras["discriminator"] = discriminator
|
|
1637
|
+
assert isinstance(data_type.parent, DataType)
|
|
1638
|
+
data_type.parent.data_types.remove(data_type) # pragma: no cover
|
|
1639
|
+
data_type.parent.data_types.append(copied_data_type)
|
|
1640
|
+
|
|
1641
|
+
elif isinstance(data_type.parent, DataType):
|
|
1642
|
+
# for data_type
|
|
1643
|
+
data_type_id = id(data_type)
|
|
1644
|
+
data_type.parent.data_types = [
|
|
1645
|
+
d for d in (*data_type.parent.data_types, copied_data_type) if id(d) != data_type_id
|
|
1646
|
+
]
|
|
1647
|
+
else: # pragma: no cover
|
|
1648
|
+
continue
|
|
1649
|
+
|
|
1650
|
+
for d in copied_data_type.all_data_types:
|
|
1651
|
+
if d.reference is None:
|
|
1652
|
+
continue
|
|
1653
|
+
from_, import_ = full_path = relative(model.module_name, d.full_name)
|
|
1654
|
+
if from_ and import_:
|
|
1655
|
+
alias = scoped_model_resolver.add(full_path, import_)
|
|
1656
|
+
d.alias = (
|
|
1657
|
+
alias.name
|
|
1658
|
+
if d.reference.short_name == import_
|
|
1659
|
+
else f"{alias.name}.{d.reference.short_name}"
|
|
1660
|
+
)
|
|
1661
|
+
imports.append([
|
|
1662
|
+
Import(
|
|
1663
|
+
from_=from_,
|
|
1664
|
+
import_=import_,
|
|
1665
|
+
alias=alias.name,
|
|
1666
|
+
reference_path=d.reference.path,
|
|
1667
|
+
)
|
|
1668
|
+
])
|
|
1669
|
+
|
|
1670
|
+
original_field = get_most_of_parent(data_type, DataModelFieldBase)
|
|
1671
|
+
if original_field: # pragma: no cover
|
|
1672
|
+
# TODO: Improve detection of reference type
|
|
1673
|
+
# Use list instead of set because Import is not hashable
|
|
1674
|
+
excluded_imports = [IMPORT_OPTIONAL, IMPORT_UNION]
|
|
1675
|
+
field_imports = [i for i in original_field.imports if i not in excluded_imports]
|
|
1676
|
+
imports.append(field_imports)
|
|
1677
|
+
|
|
1678
|
+
data_type.remove_reference()
|
|
1679
|
+
|
|
1680
|
+
assert isinstance(root_type_model, DataModel)
|
|
1681
|
+
root_type_model.reference.children = [
|
|
1682
|
+
c for c in root_type_model.reference.children if getattr(c, "parent", None)
|
|
1683
|
+
]
|
|
1684
|
+
|
|
1685
|
+
imports.remove_referenced_imports(root_type_model.path)
|
|
1686
|
+
if not root_type_model.reference.children:
|
|
1687
|
+
unused_models.append(root_type_model)
|
|
1688
|
+
|
|
1689
|
+
def __set_default_enum_member(
|
|
1690
|
+
self,
|
|
1691
|
+
models: list[DataModel],
|
|
1692
|
+
) -> None:
|
|
1693
|
+
if not self.set_default_enum_member:
|
|
1694
|
+
return
|
|
1695
|
+
for _, model_field, data_type in iter_models_field_data_types(models):
|
|
1696
|
+
if not model_field.default:
|
|
1697
|
+
continue
|
|
1698
|
+
if data_type.reference and isinstance(data_type.reference.source, Enum): # pragma: no cover
|
|
1699
|
+
if isinstance(model_field.default, list):
|
|
1700
|
+
enum_member: list[Member] | (Member | None) = [
|
|
1701
|
+
e for e in (data_type.reference.source.find_member(d) for d in model_field.default) if e
|
|
1702
|
+
]
|
|
1703
|
+
else:
|
|
1704
|
+
enum_member = data_type.reference.source.find_member(model_field.default)
|
|
1705
|
+
if not enum_member:
|
|
1706
|
+
continue
|
|
1707
|
+
model_field.default = enum_member
|
|
1708
|
+
if data_type.alias:
|
|
1709
|
+
if isinstance(enum_member, list):
|
|
1710
|
+
for enum_member_ in enum_member:
|
|
1711
|
+
enum_member_.alias = data_type.alias
|
|
513
1712
|
else:
|
|
514
|
-
|
|
1713
|
+
enum_member.alias = data_type.alias
|
|
515
1714
|
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
|
|
1715
|
+
def __wrap_root_model_default_values(
|
|
1716
|
+
self,
|
|
1717
|
+
models: list[DataModel],
|
|
1718
|
+
) -> None:
|
|
1719
|
+
"""Wrap RootModel reference default values with their type constructors."""
|
|
1720
|
+
if not self.use_annotated:
|
|
1721
|
+
return
|
|
1722
|
+
for model, model_field, data_type in iter_models_field_data_types(models):
|
|
1723
|
+
if isinstance(model, (Enum, self.data_model_root_type)):
|
|
1724
|
+
continue
|
|
1725
|
+
if model_field.default is None:
|
|
1726
|
+
continue
|
|
1727
|
+
if isinstance(model_field.default, (WrappedDefault, Member)):
|
|
1728
|
+
continue
|
|
1729
|
+
if isinstance(model_field.default, list):
|
|
1730
|
+
continue
|
|
1731
|
+
if data_type.reference and isinstance(data_type.reference.source, pydantic_model_v2.RootModel):
|
|
1732
|
+
# Use alias if available (handles import collisions)
|
|
1733
|
+
type_name = data_type.alias or data_type.reference.short_name
|
|
1734
|
+
model_field.default = WrappedDefault(
|
|
1735
|
+
value=model_field.default,
|
|
1736
|
+
type_name=type_name,
|
|
1737
|
+
)
|
|
1738
|
+
|
|
1739
|
+
def __override_required_field(
|
|
1740
|
+
self,
|
|
1741
|
+
models: list[DataModel],
|
|
1742
|
+
) -> None:
|
|
1743
|
+
for model in models:
|
|
1744
|
+
if isinstance(model, (Enum, self.data_model_root_type)):
|
|
1745
|
+
continue
|
|
1746
|
+
for index, model_field in enumerate(model.fields[:]):
|
|
1747
|
+
data_type = model_field.data_type
|
|
1748
|
+
if (
|
|
1749
|
+
not model_field.original_name # noqa: PLR0916
|
|
1750
|
+
or data_type.data_types
|
|
1751
|
+
or data_type.reference
|
|
1752
|
+
or data_type.type
|
|
1753
|
+
or data_type.literals
|
|
1754
|
+
or data_type.dict_key
|
|
1755
|
+
):
|
|
1756
|
+
continue
|
|
1757
|
+
|
|
1758
|
+
original_field = _find_field(model_field.original_name, _find_base_classes(model))
|
|
1759
|
+
if not original_field: # pragma: no cover
|
|
1760
|
+
model.fields.remove(model_field)
|
|
1761
|
+
continue
|
|
1762
|
+
copied_original_field = original_field.copy()
|
|
1763
|
+
if original_field.data_type.reference:
|
|
1764
|
+
data_type = self.data_type_manager.data_type(
|
|
1765
|
+
reference=original_field.data_type.reference,
|
|
1766
|
+
)
|
|
1767
|
+
elif original_field.data_type.data_types:
|
|
1768
|
+
data_type = original_field.data_type.copy()
|
|
1769
|
+
data_type.data_types = _copy_data_types(original_field.data_type.data_types)
|
|
1770
|
+
for data_type_ in data_type.data_types:
|
|
1771
|
+
data_type_.parent = data_type
|
|
525
1772
|
else:
|
|
526
|
-
|
|
1773
|
+
data_type = original_field.data_type.copy()
|
|
1774
|
+
data_type.parent = copied_original_field
|
|
1775
|
+
copied_original_field.data_type = data_type
|
|
1776
|
+
copied_original_field.parent = model
|
|
1777
|
+
copied_original_field.required = True
|
|
1778
|
+
model.fields.insert(index, copied_original_field)
|
|
1779
|
+
model.fields.remove(model_field)
|
|
1780
|
+
|
|
1781
|
+
def __sort_models(
|
|
1782
|
+
self,
|
|
1783
|
+
models: list[DataModel],
|
|
1784
|
+
imports: Imports,
|
|
1785
|
+
*,
|
|
1786
|
+
use_deferred_annotations: bool,
|
|
1787
|
+
) -> None:
|
|
1788
|
+
if not self.keep_model_order:
|
|
1789
|
+
return
|
|
1790
|
+
|
|
1791
|
+
_reorder_models_keep_model_order(models, imports, use_deferred_annotations=use_deferred_annotations)
|
|
1792
|
+
|
|
1793
|
+
def __change_field_name(
|
|
1794
|
+
self,
|
|
1795
|
+
models: list[DataModel],
|
|
1796
|
+
) -> None:
|
|
1797
|
+
if not issubclass(self.data_model_type, pydantic_model_v2.BaseModel):
|
|
1798
|
+
return
|
|
1799
|
+
for model in models:
|
|
1800
|
+
if "Enum" in model.base_class:
|
|
1801
|
+
continue
|
|
1802
|
+
|
|
1803
|
+
for field in model.fields:
|
|
1804
|
+
filed_name = field.name
|
|
1805
|
+
filed_name_resolver = ModelResolver(snake_case_field=self.snake_case_field, remove_suffix_number=True)
|
|
1806
|
+
for data_type in field.data_type.all_data_types:
|
|
1807
|
+
if data_type.reference:
|
|
1808
|
+
filed_name_resolver.exclude_names.add(data_type.reference.short_name)
|
|
1809
|
+
new_filed_name = filed_name_resolver.add(["field"], cast("str", filed_name)).name
|
|
1810
|
+
if filed_name != new_filed_name:
|
|
1811
|
+
field.alias = filed_name
|
|
1812
|
+
field.name = new_filed_name
|
|
1813
|
+
|
|
1814
|
+
def __set_one_literal_on_default(self, models: list[DataModel]) -> None:
|
|
1815
|
+
if not self.use_one_literal_as_default:
|
|
1816
|
+
return
|
|
1817
|
+
for model in models:
|
|
1818
|
+
for model_field in model.fields:
|
|
1819
|
+
if not model_field.required or len(model_field.data_type.literals) != 1:
|
|
1820
|
+
continue
|
|
1821
|
+
model_field.default = model_field.data_type.literals[0]
|
|
1822
|
+
model_field.required = False
|
|
1823
|
+
if model_field.nullable is not True: # pragma: no cover
|
|
1824
|
+
model_field.nullable = False
|
|
1825
|
+
|
|
1826
|
+
def __fix_dataclass_field_ordering(self, models: list[DataModel]) -> None:
|
|
1827
|
+
"""Fix field ordering for dataclasses with inheritance after defaults are set."""
|
|
1828
|
+
for model in models:
|
|
1829
|
+
if (inherited := self.__get_dataclass_inherited_info(model)) is None:
|
|
1830
|
+
continue
|
|
1831
|
+
inherited_names, has_default = inherited
|
|
1832
|
+
if not has_default or not any(self.__is_new_required_field(f, inherited_names) for f in model.fields):
|
|
1833
|
+
continue
|
|
1834
|
+
|
|
1835
|
+
if self.target_python_version.has_kw_only_dataclass:
|
|
1836
|
+
for field in model.fields:
|
|
1837
|
+
if self.__is_new_required_field(field, inherited_names):
|
|
1838
|
+
field.extras["kw_only"] = True
|
|
527
1839
|
else:
|
|
528
|
-
|
|
1840
|
+
warn(
|
|
1841
|
+
f"Dataclass '{model.class_name}' has a field ordering conflict due to inheritance. "
|
|
1842
|
+
f"An inherited field has a default value, but new required fields are added. "
|
|
1843
|
+
f"This will cause a TypeError at runtime. Consider using --target-python-version 3.10 "
|
|
1844
|
+
f"or higher to enable automatic field(kw_only=True) fix.",
|
|
1845
|
+
category=UserWarning,
|
|
1846
|
+
stacklevel=2,
|
|
1847
|
+
)
|
|
1848
|
+
model.fields = sorted(model.fields, key=dataclass_model.has_field_assignment)
|
|
529
1849
|
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
|
|
1850
|
+
@classmethod
|
|
1851
|
+
def __get_dataclass_inherited_info(cls, model: DataModel) -> tuple[set[str], bool] | None:
|
|
1852
|
+
"""Get inherited field names and whether any has default. Returns None if not applicable."""
|
|
1853
|
+
if not isinstance(model, dataclass_model.DataClass):
|
|
1854
|
+
return None
|
|
1855
|
+
if not model.base_classes or model.dataclass_arguments.get("kw_only"):
|
|
1856
|
+
return None
|
|
1857
|
+
|
|
1858
|
+
inherited_names: set[str] = set()
|
|
1859
|
+
has_default = False
|
|
1860
|
+
for base in model.base_classes:
|
|
1861
|
+
if not base.reference or not isinstance(base.reference.source, DataModel):
|
|
1862
|
+
continue # pragma: no cover
|
|
1863
|
+
for f in base.reference.source.iter_all_fields():
|
|
1864
|
+
if not f.name or f.extras.get("init") is False:
|
|
1865
|
+
continue # pragma: no cover
|
|
1866
|
+
inherited_names.add(f.name)
|
|
1867
|
+
if dataclass_model.has_field_assignment(f):
|
|
1868
|
+
has_default = True
|
|
1869
|
+
|
|
1870
|
+
for f in model.fields:
|
|
1871
|
+
if f.name not in inherited_names or f.extras.get("init") is False:
|
|
1872
|
+
continue
|
|
1873
|
+
if dataclass_model.has_field_assignment(f): # pragma: no branch
|
|
1874
|
+
has_default = True
|
|
1875
|
+
return (inherited_names, has_default) if inherited_names else None
|
|
1876
|
+
|
|
1877
|
+
def __is_new_required_field(self, field: DataModelFieldBase, inherited: set[str]) -> bool: # noqa: PLR6301
|
|
1878
|
+
"""Check if field is a new required init field."""
|
|
1879
|
+
return (
|
|
1880
|
+
field.name not in inherited
|
|
1881
|
+
and field.extras.get("init") is not False
|
|
1882
|
+
and not dataclass_model.has_field_assignment(field)
|
|
1883
|
+
)
|
|
533
1884
|
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
1885
|
+
@classmethod
|
|
1886
|
+
def __update_type_aliases(cls, models: list[DataModel]) -> None:
|
|
1887
|
+
"""Update type aliases to properly handle forward references per PEP 484."""
|
|
1888
|
+
model_index: dict[str, int] = {m.class_name: i for i, m in enumerate(models)}
|
|
1889
|
+
|
|
1890
|
+
for i, model in enumerate(models):
|
|
1891
|
+
if not isinstance(model, TypeAliasBase):
|
|
1892
|
+
continue
|
|
1893
|
+
if isinstance(model, TypeStatement):
|
|
1894
|
+
continue
|
|
538
1895
|
|
|
539
|
-
|
|
540
|
-
|
|
541
|
-
|
|
1896
|
+
for field in model.fields:
|
|
1897
|
+
for data_type in field.data_type.all_data_types:
|
|
1898
|
+
if not data_type.reference:
|
|
1899
|
+
continue
|
|
1900
|
+
source = data_type.reference.source
|
|
1901
|
+
if not isinstance(source, DataModel):
|
|
1902
|
+
continue # pragma: no cover
|
|
1903
|
+
if isinstance(source, TypeStatement):
|
|
1904
|
+
continue # pragma: no cover
|
|
1905
|
+
if source.module_path != model.module_path:
|
|
542
1906
|
continue
|
|
1907
|
+
name = data_type.reference.short_name
|
|
1908
|
+
source_index = model_index.get(name)
|
|
1909
|
+
if source_index is not None and source_index >= i:
|
|
1910
|
+
data_type.alias = f'"{name}"'
|
|
543
1911
|
|
|
544
|
-
|
|
545
|
-
|
|
546
|
-
|
|
547
|
-
|
|
548
|
-
|
|
549
|
-
|
|
550
|
-
|
|
551
|
-
|
|
552
|
-
|
|
553
|
-
|
|
1912
|
+
@classmethod
|
|
1913
|
+
def __postprocess_result_modules(cls, results: dict[tuple[str, ...], Result]) -> dict[tuple[str, ...], Result]:
|
|
1914
|
+
def process(input_tuple: tuple[str, ...]) -> tuple[str, ...]:
|
|
1915
|
+
r = []
|
|
1916
|
+
for item in input_tuple:
|
|
1917
|
+
p = item.split(".")
|
|
1918
|
+
if len(p) > 1:
|
|
1919
|
+
r.extend(p[:-1])
|
|
1920
|
+
r.append(p[-1])
|
|
1921
|
+
else:
|
|
1922
|
+
r.append(item)
|
|
554
1923
|
|
|
555
|
-
|
|
1924
|
+
if len(r) >= 2: # noqa: PLR2004
|
|
1925
|
+
r = [*r[:-2], f"{r[-2]}.{r[-1]}"]
|
|
1926
|
+
return tuple(r)
|
|
556
1927
|
|
|
557
|
-
|
|
558
|
-
if from_ and import_ and alias != name:
|
|
559
|
-
data_type.alias = f'{alias}.{name}'
|
|
1928
|
+
results = {process(k): v for k, v in results.items()}
|
|
560
1929
|
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
|
|
1930
|
+
init_result = next(v for k, v in results.items() if k[-1] == "__init__.py")
|
|
1931
|
+
folders = {t[:-1] if t[-1].endswith(".py") else t for t in results}
|
|
1932
|
+
for folder in folders:
|
|
1933
|
+
for i in range(len(folder)):
|
|
1934
|
+
subfolder = folder[: i + 1]
|
|
1935
|
+
init_file = (*subfolder, "__init__.py")
|
|
1936
|
+
results.update({init_file: init_result})
|
|
1937
|
+
return results
|
|
564
1938
|
|
|
565
|
-
|
|
566
|
-
|
|
567
|
-
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
|
|
572
|
-
|
|
573
|
-
|
|
574
|
-
|
|
575
|
-
|
|
1939
|
+
def __change_imported_model_name( # noqa: PLR6301
|
|
1940
|
+
self,
|
|
1941
|
+
models: list[DataModel],
|
|
1942
|
+
imports: Imports,
|
|
1943
|
+
scoped_model_resolver: ModelResolver,
|
|
1944
|
+
) -> None:
|
|
1945
|
+
imported_names = {
|
|
1946
|
+
imports.alias[from_][i] if i in imports.alias[from_] and i != imports.alias[from_][i] else i
|
|
1947
|
+
for from_, import_ in imports.items()
|
|
1948
|
+
for i in import_
|
|
1949
|
+
}
|
|
1950
|
+
for model in models:
|
|
1951
|
+
if model.class_name not in imported_names: # pragma: no cover
|
|
1952
|
+
continue
|
|
1953
|
+
|
|
1954
|
+
model.reference.name = scoped_model_resolver.add( # pragma: no cover
|
|
1955
|
+
path=get_special_path("imported_name", model.path.split("/")),
|
|
1956
|
+
original_name=model.reference.name,
|
|
1957
|
+
unique=True,
|
|
1958
|
+
class_name=True,
|
|
1959
|
+
).name
|
|
1960
|
+
|
|
1961
|
+
def __alias_shadowed_imports( # noqa: PLR6301
|
|
1962
|
+
self,
|
|
1963
|
+
models: list[DataModel],
|
|
1964
|
+
all_model_field_names: set[str],
|
|
1965
|
+
) -> None:
|
|
1966
|
+
for _, model_field, data_type in iter_models_field_data_types(models):
|
|
1967
|
+
if data_type and data_type.type in all_model_field_names and data_type.type == model_field.name:
|
|
1968
|
+
alias = data_type.type + "_aliased"
|
|
1969
|
+
data_type.type = alias
|
|
1970
|
+
if data_type.import_: # pragma: no cover
|
|
1971
|
+
data_type.import_ = Import(
|
|
1972
|
+
from_=data_type.import_.from_,
|
|
1973
|
+
import_=data_type.import_.import_,
|
|
1974
|
+
alias=alias,
|
|
1975
|
+
reference_path=data_type.import_.reference_path,
|
|
576
1976
|
)
|
|
577
|
-
cached_model_reference = model_cache.get(model_key)
|
|
578
|
-
if cached_model_reference:
|
|
579
|
-
if isinstance(model, Enum):
|
|
580
|
-
for child in model.reference.children[:]:
|
|
581
|
-
# child is resolved data_type by reference
|
|
582
|
-
data_model = get_most_of_parent(child)
|
|
583
|
-
# TODO: replace reference in all modules
|
|
584
|
-
if data_model in models: # pragma: no cover
|
|
585
|
-
child.replace_reference(cached_model_reference)
|
|
586
|
-
duplicates.append(model)
|
|
587
|
-
else:
|
|
588
|
-
index = models.index(model)
|
|
589
|
-
inherited_model = model.__class__(
|
|
590
|
-
fields=[],
|
|
591
|
-
base_classes=[cached_model_reference],
|
|
592
|
-
description=model.description,
|
|
593
|
-
reference=Reference(
|
|
594
|
-
name=model.name,
|
|
595
|
-
path=model.reference.path + '/reuse',
|
|
596
|
-
),
|
|
597
|
-
)
|
|
598
|
-
if (
|
|
599
|
-
cached_model_reference.path
|
|
600
|
-
in require_update_action_models
|
|
601
|
-
):
|
|
602
|
-
require_update_action_models.append(
|
|
603
|
-
inherited_model.path
|
|
604
|
-
)
|
|
605
|
-
models.insert(index, inherited_model)
|
|
606
|
-
models.remove(model)
|
|
607
1977
|
|
|
608
|
-
|
|
609
|
-
|
|
1978
|
+
@classmethod
|
|
1979
|
+
def _collect_exports_for_init(
|
|
1980
|
+
cls,
|
|
1981
|
+
module: tuple[str, ...],
|
|
1982
|
+
processed_models: Sequence[
|
|
1983
|
+
tuple[tuple[str, ...], tuple[str, ...], Sequence[DataModel], bool, Imports, ModelResolver]
|
|
1984
|
+
],
|
|
1985
|
+
scope: AllExportsScope,
|
|
1986
|
+
) -> list[tuple[str, tuple[str, ...], str]]:
|
|
1987
|
+
"""Collect exports for __init__.py based on scope."""
|
|
1988
|
+
exports: list[tuple[str, tuple[str, ...], str]] = []
|
|
1989
|
+
base = module[:-1] if module[-1] == "__init__.py" else module
|
|
1990
|
+
base_len = len(base)
|
|
1991
|
+
|
|
1992
|
+
for proc_module, _, proc_models, _, _, _ in processed_models:
|
|
1993
|
+
if not proc_models or proc_module == module:
|
|
1994
|
+
continue
|
|
1995
|
+
last = proc_module[-1]
|
|
1996
|
+
prefix = proc_module[:-1] if last == "__init__.py" else (*proc_module[:-1], last[:-3])
|
|
1997
|
+
if prefix[:base_len] != base or (depth := len(prefix) - base_len) < 1:
|
|
1998
|
+
continue
|
|
1999
|
+
if scope == AllExportsScope.Children and depth != 1:
|
|
2000
|
+
continue
|
|
2001
|
+
rel = prefix[base_len:]
|
|
2002
|
+
exports.extend(
|
|
2003
|
+
(ref.short_name, rel, ".".join(rel))
|
|
2004
|
+
for m in proc_models
|
|
2005
|
+
if (ref := m.reference) and not ref.short_name.startswith("_")
|
|
2006
|
+
)
|
|
2007
|
+
return exports
|
|
2008
|
+
|
|
2009
|
+
@classmethod
|
|
2010
|
+
def _resolve_export_collisions(
|
|
2011
|
+
cls,
|
|
2012
|
+
exports: list[tuple[str, tuple[str, ...], str]],
|
|
2013
|
+
strategy: AllExportsCollisionStrategy | None,
|
|
2014
|
+
reserved: set[str] | None = None,
|
|
2015
|
+
) -> dict[str, list[tuple[str, tuple[str, ...], str]]]:
|
|
2016
|
+
"""Resolve name collisions in exports based on strategy."""
|
|
2017
|
+
reserved = reserved or set()
|
|
2018
|
+
by_name: dict[str, list[tuple[str, tuple[str, ...], str]]] = {}
|
|
2019
|
+
for item in exports:
|
|
2020
|
+
by_name.setdefault(item[0], []).append(item)
|
|
2021
|
+
|
|
2022
|
+
if not (colliding := {n for n, items in by_name.items() if len(items) > 1 or n in reserved}):
|
|
2023
|
+
return dict(by_name)
|
|
2024
|
+
if (effective := strategy or AllExportsCollisionStrategy.Error) == AllExportsCollisionStrategy.Error:
|
|
2025
|
+
cls._raise_collision_error(by_name, colliding)
|
|
2026
|
+
|
|
2027
|
+
used: set[str] = {n for n in by_name if n not in colliding} | reserved
|
|
2028
|
+
result = {n: items for n, items in by_name.items() if n not in colliding}
|
|
2029
|
+
|
|
2030
|
+
for name in sorted(colliding):
|
|
2031
|
+
for item in sorted(by_name[name], key=lambda x: len(x[1])):
|
|
2032
|
+
new_name = cls._make_prefixed_name(
|
|
2033
|
+
item[0], item[1], used, minimal=effective == AllExportsCollisionStrategy.MinimalPrefix
|
|
2034
|
+
)
|
|
2035
|
+
if new_name in reserved:
|
|
2036
|
+
msg = (
|
|
2037
|
+
f"Cannot resolve collision: '{new_name}' conflicts with __init__.py model. "
|
|
2038
|
+
"Please rename one of the models."
|
|
2039
|
+
)
|
|
2040
|
+
raise Error(msg)
|
|
2041
|
+
result[new_name] = [item]
|
|
2042
|
+
used.add(new_name)
|
|
2043
|
+
return result
|
|
2044
|
+
|
|
2045
|
+
@classmethod
|
|
2046
|
+
def _raise_collision_error(
|
|
2047
|
+
cls,
|
|
2048
|
+
by_name: dict[str, list[tuple[str, tuple[str, ...], str]]],
|
|
2049
|
+
colliding: set[str],
|
|
2050
|
+
) -> None:
|
|
2051
|
+
"""Raise an error with collision details."""
|
|
2052
|
+
details = []
|
|
2053
|
+
for n in colliding:
|
|
2054
|
+
if len(items := by_name[n]) > 1:
|
|
2055
|
+
details.append(f" '{n}' is defined in: {', '.join(f'.{s}' for _, _, s in items)}")
|
|
2056
|
+
else:
|
|
2057
|
+
details.append(f" '{n}' conflicts with a model in __init__.py")
|
|
2058
|
+
raise Error(
|
|
2059
|
+
"Name collision detected with --all-exports-scope:\n"
|
|
2060
|
+
+ "\n".join(details)
|
|
2061
|
+
+ "\n\nUse --all-exports-collision-strategy to specify how to handle collisions."
|
|
2062
|
+
)
|
|
2063
|
+
|
|
2064
|
+
@staticmethod
|
|
2065
|
+
def _make_prefixed_name(name: str, path: tuple[str, ...], used: set[str], *, minimal: bool) -> str:
|
|
2066
|
+
"""Generate a prefixed name, using minimal or full prefix."""
|
|
2067
|
+
if minimal:
|
|
2068
|
+
for depth in range(1, len(path) + 1):
|
|
2069
|
+
if (candidate := "".join(p.title().replace("_", "") for p in path[-depth:]) + name) not in used:
|
|
2070
|
+
return candidate
|
|
2071
|
+
return "".join(p.title().replace("_", "") for p in path) + name
|
|
2072
|
+
|
|
2073
|
+
@classmethod
|
|
2074
|
+
def _build_all_exports_code(
|
|
2075
|
+
cls,
|
|
2076
|
+
resolved: dict[str, list[tuple[str, tuple[str, ...], str]]],
|
|
2077
|
+
) -> Imports:
|
|
2078
|
+
"""Build import statements from resolved exports."""
|
|
2079
|
+
export_imports = Imports()
|
|
2080
|
+
for export_name, items in resolved.items():
|
|
2081
|
+
for orig, _, short in items:
|
|
2082
|
+
export_imports.append(
|
|
2083
|
+
Import(from_=f".{short}", import_=orig, alias=export_name if export_name != orig else None)
|
|
2084
|
+
)
|
|
2085
|
+
return export_imports
|
|
2086
|
+
|
|
2087
|
+
@classmethod
|
|
2088
|
+
def _collect_used_names_from_models(cls, models: list[DataModel]) -> set[str]:
|
|
2089
|
+
"""Collect identifiers referenced by models before rendering."""
|
|
2090
|
+
names: set[str] = set()
|
|
2091
|
+
|
|
2092
|
+
def add(name: str | None) -> None:
|
|
2093
|
+
if not name:
|
|
2094
|
+
return
|
|
2095
|
+
# first segment is sufficient to match import target or alias
|
|
2096
|
+
names.add(name.split(".")[0])
|
|
2097
|
+
|
|
2098
|
+
def walk_data_type(data_type: DataType) -> None:
|
|
2099
|
+
add(data_type.alias or data_type.type)
|
|
2100
|
+
if data_type.reference:
|
|
2101
|
+
add(data_type.reference.short_name)
|
|
2102
|
+
for child in data_type.data_types:
|
|
2103
|
+
walk_data_type(child)
|
|
2104
|
+
if data_type.dict_key:
|
|
2105
|
+
walk_data_type(data_type.dict_key)
|
|
2106
|
+
|
|
2107
|
+
for model in models:
|
|
2108
|
+
add(model.class_name)
|
|
2109
|
+
add(model.duplicate_class_name)
|
|
2110
|
+
for base in model.base_classes:
|
|
2111
|
+
add(base.type_hint)
|
|
2112
|
+
for import_ in model.imports:
|
|
2113
|
+
add(import_.alias or import_.import_.split(".")[-1])
|
|
2114
|
+
for field in model.fields:
|
|
2115
|
+
if field.extras.get("is_classvar"):
|
|
2116
|
+
continue
|
|
2117
|
+
add(field.name)
|
|
2118
|
+
add(field.alias)
|
|
2119
|
+
walk_data_type(field.data_type)
|
|
2120
|
+
return names
|
|
2121
|
+
|
|
2122
|
+
def __generate_forwarder_content( # noqa: PLR6301
|
|
2123
|
+
self,
|
|
2124
|
+
original_module: tuple[str, ...],
|
|
2125
|
+
internal_module: tuple[str, ...],
|
|
2126
|
+
class_mappings: list[tuple[str, str]],
|
|
2127
|
+
*,
|
|
2128
|
+
is_init: bool = False,
|
|
2129
|
+
) -> str:
|
|
2130
|
+
"""Generate forwarder module content that re-exports classes from _internal.
|
|
2131
|
+
|
|
2132
|
+
Args:
|
|
2133
|
+
original_module: The original module tuple (e.g., ("issuing",) or ())
|
|
2134
|
+
internal_module: The _internal module tuple (e.g., ("_internal",))
|
|
2135
|
+
class_mappings: List of (original_name, new_name) tuples, sorted by original_name
|
|
2136
|
+
is_init: True if this is a package __init__.py, False for regular .py files
|
|
2137
|
+
|
|
2138
|
+
Returns:
|
|
2139
|
+
The forwarder module content as a string
|
|
2140
|
+
"""
|
|
2141
|
+
original_str = ".".join(original_module)
|
|
2142
|
+
internal_str = ".".join(internal_module)
|
|
2143
|
+
from_dots, module_name = relative(original_str, internal_str, reference_is_module=True, current_is_init=is_init)
|
|
2144
|
+
relative_import = f"{from_dots}{module_name}"
|
|
2145
|
+
|
|
2146
|
+
imports = Imports()
|
|
2147
|
+
for original_name, new_name in class_mappings:
|
|
2148
|
+
if original_name == new_name:
|
|
2149
|
+
imports.append(Import(from_=relative_import, import_=new_name))
|
|
2150
|
+
else:
|
|
2151
|
+
imports.append(Import(from_=relative_import, import_=new_name, alias=original_name))
|
|
2152
|
+
|
|
2153
|
+
return f"{imports.dump()}\n\n{imports.dump_all()}\n"
|
|
2154
|
+
|
|
2155
|
+
def __compute_internal_module_path( # noqa: PLR6301
|
|
2156
|
+
self,
|
|
2157
|
+
scc_modules: set[tuple[str, ...]],
|
|
2158
|
+
existing_modules: set[tuple[str, ...]],
|
|
2159
|
+
*,
|
|
2160
|
+
base_name: str = "_internal",
|
|
2161
|
+
) -> tuple[str, ...]:
|
|
2162
|
+
"""Compute the internal module path for an SCC."""
|
|
2163
|
+
directories = [get_module_directory(m) for m in sorted(scc_modules)]
|
|
2164
|
+
|
|
2165
|
+
if not directories or any(not d for d in directories):
|
|
2166
|
+
prefix: tuple[str, ...] = ()
|
|
2167
|
+
else:
|
|
2168
|
+
path_strings = ["/".join(d) for d in directories]
|
|
2169
|
+
common = os.path.commonpath(path_strings)
|
|
2170
|
+
prefix = tuple(common.split("/")) if common else ()
|
|
2171
|
+
|
|
2172
|
+
base_module = (base_name,) if not prefix else (*prefix, base_name)
|
|
2173
|
+
|
|
2174
|
+
if base_module in existing_modules:
|
|
2175
|
+
counter = 1
|
|
2176
|
+
while True:
|
|
2177
|
+
candidate = (*prefix, f"{base_name}_{counter}") if prefix else (f"{base_name}_{counter}",)
|
|
2178
|
+
if candidate not in existing_modules:
|
|
2179
|
+
return candidate
|
|
2180
|
+
counter += 1
|
|
610
2181
|
|
|
611
|
-
|
|
612
|
-
models.remove(duplicate)
|
|
2182
|
+
return base_module
|
|
613
2183
|
|
|
614
|
-
|
|
2184
|
+
def __collect_scc_models( # noqa: PLR6301
|
|
2185
|
+
self,
|
|
2186
|
+
scc: set[tuple[str, ...]],
|
|
2187
|
+
result_modules: dict[tuple[str, ...], list[DataModel]],
|
|
2188
|
+
) -> tuple[list[DataModel], dict[int, tuple[str, ...]]]:
|
|
2189
|
+
"""Collect all models from SCC modules.
|
|
2190
|
+
|
|
2191
|
+
Returns:
|
|
2192
|
+
- List of all models in the SCC
|
|
2193
|
+
- Mapping from model id to its original module
|
|
2194
|
+
"""
|
|
2195
|
+
all_models: list[DataModel] = []
|
|
2196
|
+
model_to_module: dict[int, tuple[str, ...]] = {}
|
|
2197
|
+
for scc_module in sorted(scc):
|
|
2198
|
+
for model in result_modules[scc_module]:
|
|
2199
|
+
all_models.append(model)
|
|
2200
|
+
model_to_module[id(model)] = scc_module
|
|
2201
|
+
return all_models, model_to_module
|
|
2202
|
+
|
|
2203
|
+
def __rename_and_relocate_scc_models( # noqa: PLR6301
|
|
2204
|
+
self,
|
|
2205
|
+
all_scc_models: list[DataModel],
|
|
2206
|
+
model_to_original_module: dict[int, tuple[str, ...]],
|
|
2207
|
+
internal_module: tuple[str, ...],
|
|
2208
|
+
internal_path: Path,
|
|
2209
|
+
) -> tuple[defaultdict[tuple[str, ...], list[tuple[str, str]]], dict[str, str]]:
|
|
2210
|
+
"""Rename duplicate classes and relocate models to internal module.
|
|
2211
|
+
|
|
2212
|
+
Returns:
|
|
2213
|
+
Tuple of:
|
|
2214
|
+
- Mapping from original module to list of (original_name, new_name) tuples.
|
|
2215
|
+
- Mapping from old reference paths to new reference paths.
|
|
2216
|
+
"""
|
|
2217
|
+
class_name_counts = Counter(model.class_name for model in all_scc_models)
|
|
2218
|
+
class_name_seen: dict[str, int] = {}
|
|
2219
|
+
internal_module_str = ".".join(internal_module)
|
|
2220
|
+
module_class_mappings: defaultdict[tuple[str, ...], list[tuple[str, str]]] = defaultdict(list)
|
|
2221
|
+
path_mapping: dict[str, str] = {}
|
|
2222
|
+
|
|
2223
|
+
for model in all_scc_models:
|
|
2224
|
+
original_class_name = model.class_name
|
|
2225
|
+
original_module = model_to_original_module[id(model)]
|
|
2226
|
+
old_path = model.path # Save old path before updating
|
|
2227
|
+
|
|
2228
|
+
if class_name_counts[original_class_name] > 1:
|
|
2229
|
+
seen_count = class_name_seen.get(original_class_name, 0)
|
|
2230
|
+
new_class_name = f"{original_class_name}_{seen_count}" if seen_count > 0 else original_class_name
|
|
2231
|
+
class_name_seen[original_class_name] = seen_count + 1
|
|
2232
|
+
else:
|
|
2233
|
+
new_class_name = original_class_name
|
|
2234
|
+
|
|
2235
|
+
model.reference.name = new_class_name
|
|
2236
|
+
new_path = f"{internal_module_str}.{new_class_name}"
|
|
2237
|
+
model.set_reference_path(new_path)
|
|
2238
|
+
model.file_path = internal_path
|
|
2239
|
+
|
|
2240
|
+
module_class_mappings[original_module].append((original_class_name, new_class_name))
|
|
2241
|
+
path_mapping[old_path] = new_path
|
|
2242
|
+
|
|
2243
|
+
return module_class_mappings, path_mapping
|
|
2244
|
+
|
|
2245
|
+
def __build_module_dependency_graph( # noqa: PLR6301
|
|
2246
|
+
self,
|
|
2247
|
+
module_models_list: list[tuple[tuple[str, ...], list[DataModel]]],
|
|
2248
|
+
) -> dict[tuple[str, ...], set[tuple[str, ...]]]:
|
|
2249
|
+
"""Build a directed graph of module dependencies."""
|
|
2250
|
+
path_to_module: dict[str, tuple[str, ...]] = {}
|
|
2251
|
+
for module, models in module_models_list:
|
|
2252
|
+
for model in models:
|
|
2253
|
+
path_to_module[model.path] = module
|
|
2254
|
+
|
|
2255
|
+
graph: dict[tuple[str, ...], set[tuple[str, ...]]] = {}
|
|
2256
|
+
|
|
2257
|
+
def add_cross_module_edge(ref_path: str, source_module: tuple[str, ...]) -> None:
|
|
2258
|
+
"""Add edge if ref_path points to a different module."""
|
|
2259
|
+
if ref_path in path_to_module:
|
|
2260
|
+
target_module = path_to_module[ref_path]
|
|
2261
|
+
if target_module != source_module:
|
|
2262
|
+
graph[source_module].add(target_module)
|
|
2263
|
+
|
|
2264
|
+
for module, models in module_models_list:
|
|
2265
|
+
graph[module] = set()
|
|
2266
|
+
|
|
2267
|
+
for model in models:
|
|
2268
|
+
for data_type in model.all_data_types:
|
|
2269
|
+
if data_type.reference and data_type.reference.source:
|
|
2270
|
+
add_cross_module_edge(data_type.reference.path, module)
|
|
2271
|
+
|
|
2272
|
+
for base_class in model.base_classes:
|
|
2273
|
+
if base_class.reference and base_class.reference.source:
|
|
2274
|
+
add_cross_module_edge(base_class.reference.path, module)
|
|
2275
|
+
|
|
2276
|
+
return graph
|
|
2277
|
+
|
|
2278
|
+
def __resolve_circular_imports( # noqa: PLR0914
|
|
2279
|
+
self,
|
|
2280
|
+
module_models_list: list[tuple[tuple[str, ...], list[DataModel]]],
|
|
2281
|
+
) -> tuple[
|
|
2282
|
+
list[tuple[tuple[str, ...], list[DataModel]]],
|
|
2283
|
+
set[tuple[str, ...]],
|
|
2284
|
+
dict[tuple[str, ...], tuple[tuple[str, ...], list[tuple[str, str]]]],
|
|
2285
|
+
dict[str, str],
|
|
2286
|
+
]:
|
|
2287
|
+
"""Resolve circular imports by merging all SCCs into _internal.py modules.
|
|
2288
|
+
|
|
2289
|
+
Uses Tarjan's algorithm to find strongly connected components (SCCs) in the
|
|
2290
|
+
module dependency graph. All modules in each SCC are merged into a single
|
|
2291
|
+
_internal.py module to break import cycles. Original modules become thin
|
|
2292
|
+
forwarders that re-export their classes from _internal.
|
|
2293
|
+
|
|
2294
|
+
Returns:
|
|
2295
|
+
- Updated module_models_list with models moved to _internal modules
|
|
2296
|
+
- Set of _internal modules created
|
|
2297
|
+
- Forwarder map: original_module -> (internal_module, [(original_name, new_name)])
|
|
2298
|
+
- Path mapping: old_reference_path -> new_reference_path
|
|
2299
|
+
"""
|
|
2300
|
+
graph = self.__build_module_dependency_graph(module_models_list)
|
|
2301
|
+
|
|
2302
|
+
circular_sccs = find_circular_sccs(graph)
|
|
2303
|
+
|
|
2304
|
+
forwarder_map: dict[tuple[str, ...], tuple[tuple[str, ...], list[tuple[str, str]]]] = {}
|
|
2305
|
+
all_path_mappings: dict[str, str] = {}
|
|
2306
|
+
|
|
2307
|
+
if not circular_sccs:
|
|
2308
|
+
return module_models_list, set(), forwarder_map, all_path_mappings
|
|
2309
|
+
|
|
2310
|
+
# All circular SCCs are problematic and should be merged into _internal.py
|
|
2311
|
+
# to break the import cycles.
|
|
2312
|
+
problematic_sccs = circular_sccs
|
|
2313
|
+
|
|
2314
|
+
existing_modules = {module for module, _ in module_models_list}
|
|
2315
|
+
internal_modules_created: set[tuple[str, ...]] = set()
|
|
2316
|
+
|
|
2317
|
+
result_modules: dict[tuple[str, ...], list[DataModel]] = {
|
|
2318
|
+
module: list(models) for module, models in module_models_list
|
|
2319
|
+
}
|
|
2320
|
+
|
|
2321
|
+
for scc in problematic_sccs:
|
|
2322
|
+
internal_module = self.__compute_internal_module_path(scc, existing_modules | internal_modules_created)
|
|
2323
|
+
internal_modules_created.add(internal_module)
|
|
2324
|
+
internal_path = Path("/".join(internal_module))
|
|
2325
|
+
|
|
2326
|
+
all_scc_models, model_to_original_module = self.__collect_scc_models(scc, result_modules)
|
|
2327
|
+
module_class_mappings, path_mapping = self.__rename_and_relocate_scc_models(
|
|
2328
|
+
all_scc_models, model_to_original_module, internal_module, internal_path
|
|
2329
|
+
)
|
|
2330
|
+
all_path_mappings.update(path_mapping)
|
|
2331
|
+
|
|
2332
|
+
for scc_module in scc:
|
|
2333
|
+
if scc_module in result_modules: # pragma: no branch
|
|
2334
|
+
result_modules[scc_module] = []
|
|
2335
|
+
if scc_module in module_class_mappings: # pragma: no branch
|
|
2336
|
+
sorted_mappings = sorted(module_class_mappings[scc_module], key=operator.itemgetter(0))
|
|
2337
|
+
forwarder_map[scc_module] = (internal_module, sorted_mappings)
|
|
2338
|
+
result_modules[internal_module] = all_scc_models
|
|
2339
|
+
|
|
2340
|
+
new_module_models: list[tuple[tuple[str, ...], list[DataModel]]] = [
|
|
2341
|
+
(internal_module, result_modules[internal_module])
|
|
2342
|
+
for internal_module in sorted(internal_modules_created)
|
|
2343
|
+
if internal_module in result_modules # pragma: no branch
|
|
2344
|
+
]
|
|
2345
|
+
|
|
2346
|
+
for module, _ in module_models_list:
|
|
2347
|
+
if module not in internal_modules_created: # pragma: no branch
|
|
2348
|
+
new_module_models.append((module, result_modules.get(module, [])))
|
|
2349
|
+
|
|
2350
|
+
return new_module_models, internal_modules_created, forwarder_map, all_path_mappings
|
|
2351
|
+
|
|
2352
|
+
def __get_resolve_reference_action_parts(
|
|
2353
|
+
self,
|
|
2354
|
+
models: list[DataModel],
|
|
2355
|
+
require_update_action_models: list[str],
|
|
2356
|
+
*,
|
|
2357
|
+
use_deferred_annotations: bool,
|
|
2358
|
+
) -> list[str]:
|
|
2359
|
+
"""Return the trailing rebuild/update calls for the given module's models."""
|
|
2360
|
+
if self.dump_resolve_reference_action is None:
|
|
2361
|
+
return []
|
|
2362
|
+
|
|
2363
|
+
require_update_action_model_paths = set(require_update_action_models)
|
|
2364
|
+
required_paths_in_module = {m.path for m in models if m.path in require_update_action_model_paths}
|
|
2365
|
+
|
|
2366
|
+
if (
|
|
2367
|
+
use_deferred_annotations
|
|
2368
|
+
and required_paths_in_module
|
|
2369
|
+
and self.dump_resolve_reference_action is pydantic_model_v2.dump_resolve_reference_action
|
|
2370
|
+
):
|
|
2371
|
+
module_positions = {m.reference.short_name: i for i, m in enumerate(models) if m.reference}
|
|
2372
|
+
module_model_names = set(module_positions)
|
|
2373
|
+
|
|
2374
|
+
forward_needed: set[str] = set()
|
|
2375
|
+
for model in models:
|
|
2376
|
+
if model.path not in required_paths_in_module or not model.reference:
|
|
2377
|
+
continue
|
|
2378
|
+
name = model.reference.short_name
|
|
2379
|
+
pos = module_positions[name]
|
|
2380
|
+
refs = {
|
|
2381
|
+
t.reference.short_name
|
|
2382
|
+
for f in model.fields
|
|
2383
|
+
for t in f.data_type.all_data_types
|
|
2384
|
+
if t.reference and t.reference.short_name in module_model_names
|
|
2385
|
+
}
|
|
2386
|
+
if name in refs or any(module_positions.get(r, -1) > pos for r in refs):
|
|
2387
|
+
forward_needed.add(model.path)
|
|
2388
|
+
|
|
2389
|
+
# Propagate requirement through inheritance.
|
|
2390
|
+
changed = True
|
|
2391
|
+
required_filtered = set(forward_needed)
|
|
2392
|
+
while changed:
|
|
2393
|
+
changed = False
|
|
615
2394
|
for model in models:
|
|
616
|
-
|
|
617
|
-
|
|
618
|
-
|
|
619
|
-
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
|
|
633
|
-
|
|
634
|
-
|
|
635
|
-
|
|
636
|
-
|
|
637
|
-
|
|
2395
|
+
if not model.reference or model.path in required_filtered:
|
|
2396
|
+
continue
|
|
2397
|
+
base_paths = {b.reference.path for b in model.base_classes if b.reference}
|
|
2398
|
+
if base_paths & required_filtered:
|
|
2399
|
+
required_filtered.add(model.path)
|
|
2400
|
+
changed = True
|
|
2401
|
+
|
|
2402
|
+
required_paths_in_module = required_filtered
|
|
2403
|
+
|
|
2404
|
+
return [
|
|
2405
|
+
"\n",
|
|
2406
|
+
self.dump_resolve_reference_action(
|
|
2407
|
+
m.reference.short_name for m in models if m.reference and m.path in required_paths_in_module
|
|
2408
|
+
),
|
|
2409
|
+
]
|
|
2410
|
+
|
|
2411
|
+
def parse( # noqa: PLR0912, PLR0913, PLR0914, PLR0915, PLR0917
|
|
2412
|
+
self,
|
|
2413
|
+
with_import: bool | None = True, # noqa: FBT001, FBT002
|
|
2414
|
+
format_: bool | None = True, # noqa: FBT001, FBT002
|
|
2415
|
+
settings_path: Path | None = None,
|
|
2416
|
+
disable_future_imports: bool = False, # noqa: FBT001, FBT002
|
|
2417
|
+
all_exports_scope: AllExportsScope | None = None,
|
|
2418
|
+
all_exports_collision_strategy: AllExportsCollisionStrategy | None = None,
|
|
2419
|
+
module_split_mode: ModuleSplitMode | None = None,
|
|
2420
|
+
) -> str | dict[tuple[str, ...], Result]:
|
|
2421
|
+
"""Parse schema and generate code, returning single file or module dict."""
|
|
2422
|
+
self.parse_raw()
|
|
2423
|
+
|
|
2424
|
+
use_deferred_annotations = bool(
|
|
2425
|
+
self.target_python_version.has_native_deferred_annotations or (with_import and not disable_future_imports)
|
|
2426
|
+
)
|
|
2427
|
+
|
|
2428
|
+
if (
|
|
2429
|
+
with_import
|
|
2430
|
+
and not disable_future_imports
|
|
2431
|
+
and not self.target_python_version.has_native_deferred_annotations
|
|
2432
|
+
):
|
|
2433
|
+
self.imports.append(IMPORT_ANNOTATIONS)
|
|
2434
|
+
|
|
2435
|
+
if format_:
|
|
2436
|
+
code_formatter: CodeFormatter | None = CodeFormatter(
|
|
2437
|
+
self.target_python_version,
|
|
2438
|
+
settings_path,
|
|
2439
|
+
self.wrap_string_literal,
|
|
2440
|
+
skip_string_normalization=not self.use_double_quotes,
|
|
2441
|
+
known_third_party=self.known_third_party,
|
|
2442
|
+
custom_formatters=self.custom_formatter,
|
|
2443
|
+
custom_formatters_kwargs=self.custom_formatters_kwargs,
|
|
2444
|
+
encoding=self.encoding,
|
|
2445
|
+
formatters=self.formatters,
|
|
2446
|
+
)
|
|
2447
|
+
else:
|
|
2448
|
+
code_formatter = None
|
|
2449
|
+
|
|
2450
|
+
_, sorted_data_models, require_update_action_models = sort_data_models(self.results)
|
|
2451
|
+
|
|
2452
|
+
results: dict[tuple[str, ...], Result] = {}
|
|
2453
|
+
|
|
2454
|
+
def module_key(data_model: DataModel) -> tuple[str, ...]:
|
|
2455
|
+
if module_split_mode == ModuleSplitMode.Single:
|
|
2456
|
+
file_name = camel_to_snake(data_model.class_name)
|
|
2457
|
+
return (*data_model.module_path, file_name)
|
|
2458
|
+
return tuple(data_model.module_path)
|
|
2459
|
+
|
|
2460
|
+
def sort_key(data_model: DataModel) -> tuple[int, tuple[str, ...]]:
|
|
2461
|
+
key = module_key(data_model)
|
|
2462
|
+
return (len(key), key)
|
|
2463
|
+
|
|
2464
|
+
# process in reverse order to correctly establish module levels
|
|
2465
|
+
grouped_models = groupby(
|
|
2466
|
+
sorted(sorted_data_models.values(), key=sort_key, reverse=True),
|
|
2467
|
+
key=module_key,
|
|
2468
|
+
)
|
|
2469
|
+
|
|
2470
|
+
module_models: list[tuple[tuple[str, ...], list[DataModel]]] = []
|
|
2471
|
+
unused_models: list[DataModel] = []
|
|
2472
|
+
model_to_module_models: dict[DataModel, tuple[tuple[str, ...], list[DataModel]]] = {}
|
|
2473
|
+
module_to_import: dict[tuple[str, ...], Imports] = {}
|
|
2474
|
+
model_path_to_module_name: dict[str, str] = {}
|
|
2475
|
+
|
|
2476
|
+
previous_module: tuple[str, ...] = ()
|
|
2477
|
+
for module, models in ((k, [*v]) for k, v in grouped_models):
|
|
2478
|
+
for model in models:
|
|
2479
|
+
model_to_module_models[model] = module, models
|
|
2480
|
+
if module_split_mode == ModuleSplitMode.Single:
|
|
2481
|
+
model_path_to_module_name[model.path] = ".".join(module)
|
|
2482
|
+
self.__delete_duplicate_models(models)
|
|
2483
|
+
self.__replace_duplicate_name_in_module(models)
|
|
2484
|
+
if len(previous_module) - len(module) > 1:
|
|
2485
|
+
module_models.extend(
|
|
2486
|
+
(
|
|
2487
|
+
previous_module[:parts],
|
|
2488
|
+
[],
|
|
2489
|
+
)
|
|
2490
|
+
for parts in range(len(previous_module) - 1, len(module), -1)
|
|
2491
|
+
)
|
|
2492
|
+
module_models.append((
|
|
2493
|
+
module,
|
|
2494
|
+
models,
|
|
2495
|
+
))
|
|
2496
|
+
previous_module = module
|
|
2497
|
+
|
|
2498
|
+
shared_module_entry = self.__reuse_model_tree_scope(module_models, require_update_action_models)
|
|
2499
|
+
if shared_module_entry:
|
|
2500
|
+
module_models.insert(0, shared_module_entry)
|
|
2501
|
+
|
|
2502
|
+
# Resolve circular imports by moving models to _internal.py modules
|
|
2503
|
+
module_models, internal_modules, forwarder_map, path_mapping = self.__resolve_circular_imports(module_models)
|
|
2504
|
+
|
|
2505
|
+
# Update require_update_action_models with new paths for relocated models
|
|
2506
|
+
if path_mapping:
|
|
2507
|
+
require_update_action_models[:] = [path_mapping.get(path, path) for path in require_update_action_models]
|
|
2508
|
+
|
|
2509
|
+
class Processed(NamedTuple):
|
|
2510
|
+
module: tuple[str, ...]
|
|
2511
|
+
module_key: tuple[str, ...] # Original module tuple (without file extension)
|
|
2512
|
+
models: list[DataModel]
|
|
2513
|
+
init: bool
|
|
2514
|
+
imports: Imports
|
|
2515
|
+
scoped_model_resolver: ModelResolver
|
|
2516
|
+
|
|
2517
|
+
processed_models: list[Processed] = []
|
|
2518
|
+
|
|
2519
|
+
for module_, models in module_models:
|
|
2520
|
+
imports = module_to_import[module_] = Imports(self.use_exact_imports)
|
|
2521
|
+
init = False
|
|
2522
|
+
if module_:
|
|
2523
|
+
if len(module_) == 1:
|
|
2524
|
+
parent = ("__init__.py",)
|
|
2525
|
+
if parent not in results:
|
|
2526
|
+
results[parent] = Result(body="")
|
|
2527
|
+
else:
|
|
2528
|
+
for i in range(1, len(module_)):
|
|
2529
|
+
parent = (*module_[:i], "__init__.py")
|
|
2530
|
+
if parent not in results:
|
|
2531
|
+
results[parent] = Result(body="")
|
|
2532
|
+
if (*module_, "__init__.py") in results:
|
|
2533
|
+
module = (*module_, "__init__.py")
|
|
2534
|
+
init = True
|
|
2535
|
+
else:
|
|
2536
|
+
module = tuple(part.replace("-", "_") for part in (*module_[:-1], f"{module_[-1]}.py"))
|
|
2537
|
+
else:
|
|
2538
|
+
module = ("__init__.py",)
|
|
2539
|
+
|
|
2540
|
+
all_module_fields = {field.name for model in models for field in model.fields if field.name is not None}
|
|
2541
|
+
scoped_model_resolver = ModelResolver(exclude_names=all_module_fields)
|
|
2542
|
+
|
|
2543
|
+
self.__alias_shadowed_imports(models, all_module_fields)
|
|
2544
|
+
self.__override_required_field(models)
|
|
2545
|
+
self.__replace_unique_list_to_set(models)
|
|
2546
|
+
self.__change_from_import(
|
|
2547
|
+
models,
|
|
2548
|
+
imports,
|
|
2549
|
+
scoped_model_resolver,
|
|
2550
|
+
init=init,
|
|
2551
|
+
internal_modules=internal_modules,
|
|
2552
|
+
model_path_to_module_name=model_path_to_module_name,
|
|
2553
|
+
)
|
|
2554
|
+
self.__extract_inherited_enum(models)
|
|
2555
|
+
self.__set_reference_default_value_to_field(models)
|
|
2556
|
+
self.__reuse_model(models, require_update_action_models)
|
|
2557
|
+
self.__collapse_root_models(models, unused_models, imports, scoped_model_resolver)
|
|
2558
|
+
self.__set_default_enum_member(models)
|
|
2559
|
+
self.__wrap_root_model_default_values(models)
|
|
2560
|
+
self.__sort_models(
|
|
2561
|
+
models,
|
|
2562
|
+
imports,
|
|
2563
|
+
use_deferred_annotations=bool(
|
|
2564
|
+
self.target_python_version.has_native_deferred_annotations
|
|
2565
|
+
or (with_import and not disable_future_imports)
|
|
2566
|
+
),
|
|
2567
|
+
)
|
|
2568
|
+
self.__change_field_name(models)
|
|
2569
|
+
self.__apply_discriminator_type(models, imports)
|
|
2570
|
+
self.__set_one_literal_on_default(models)
|
|
2571
|
+
self.__fix_dataclass_field_ordering(models)
|
|
2572
|
+
|
|
2573
|
+
processed_models.append(Processed(module, module_, models, init, imports, scoped_model_resolver))
|
|
2574
|
+
|
|
2575
|
+
for processed_model in processed_models:
|
|
2576
|
+
for model in processed_model.models:
|
|
2577
|
+
processed_model.imports.append(model.imports)
|
|
2578
|
+
|
|
2579
|
+
for unused_model in unused_models:
|
|
2580
|
+
module, models = model_to_module_models[unused_model]
|
|
2581
|
+
if unused_model in models: # pragma: no cover
|
|
2582
|
+
imports = module_to_import[module]
|
|
2583
|
+
imports.remove(unused_model.imports)
|
|
2584
|
+
models.remove(unused_model)
|
|
2585
|
+
|
|
2586
|
+
for processed_model in processed_models:
|
|
2587
|
+
# postprocess imports to remove unused imports.
|
|
2588
|
+
used_names = self._collect_used_names_from_models(processed_model.models)
|
|
2589
|
+
unused_imports = [
|
|
2590
|
+
(from_, import_)
|
|
2591
|
+
for from_, imports_ in processed_model.imports.items()
|
|
2592
|
+
for import_ in imports_
|
|
2593
|
+
if not {processed_model.imports.alias.get(from_, {}).get(import_, import_), import_}.intersection(
|
|
2594
|
+
used_names
|
|
2595
|
+
)
|
|
2596
|
+
]
|
|
2597
|
+
for from_, import_ in unused_imports:
|
|
2598
|
+
import_obj = Import(from_=from_, import_=import_)
|
|
2599
|
+
while processed_model.imports.counter.get((from_, import_), 0) > 0:
|
|
2600
|
+
processed_model.imports.remove(import_obj)
|
|
2601
|
+
|
|
2602
|
+
for module, mod_key, models, init, imports, scoped_model_resolver in processed_models: # noqa: B007
|
|
2603
|
+
# process after removing unused models
|
|
2604
|
+
self.__change_imported_model_name(models, imports, scoped_model_resolver)
|
|
2605
|
+
|
|
2606
|
+
future_imports = self.imports.extract_future()
|
|
2607
|
+
future_imports_str = str(future_imports)
|
|
2608
|
+
|
|
2609
|
+
for module, mod_key, models, init, imports, scoped_model_resolver in processed_models: # noqa: B007
|
|
2610
|
+
result: list[str] = []
|
|
2611
|
+
export_imports: Imports | None = None
|
|
2612
|
+
|
|
2613
|
+
if all_exports_scope is not None and module[-1] == "__init__.py":
|
|
2614
|
+
child_exports = self._collect_exports_for_init(module, processed_models, all_exports_scope)
|
|
2615
|
+
if child_exports:
|
|
2616
|
+
local_model_names = {
|
|
638
2617
|
m.reference.short_name
|
|
639
2618
|
for m in models
|
|
640
|
-
if m.
|
|
641
|
-
|
|
642
|
-
|
|
2619
|
+
if m.reference and not m.reference.short_name.startswith("_")
|
|
2620
|
+
}
|
|
2621
|
+
resolved_exports = self._resolve_export_collisions(
|
|
2622
|
+
child_exports, all_exports_collision_strategy, local_model_names
|
|
2623
|
+
)
|
|
2624
|
+
export_imports = self._build_all_exports_code(resolved_exports)
|
|
2625
|
+
|
|
2626
|
+
if models:
|
|
2627
|
+
if with_import:
|
|
2628
|
+
import_parts = [s for s in [future_imports_str, str(self.imports), str(imports)] if s]
|
|
2629
|
+
result += [*import_parts, "\n"]
|
|
643
2630
|
|
|
644
|
-
|
|
2631
|
+
if export_imports:
|
|
2632
|
+
result += [str(export_imports), ""]
|
|
2633
|
+
for m in models:
|
|
2634
|
+
if m.reference and not m.reference.short_name.startswith("_"): # pragma: no branch
|
|
2635
|
+
export_imports.add_export(m.reference.short_name)
|
|
2636
|
+
result += [export_imports.dump_all(multiline=True) + "\n"]
|
|
2637
|
+
|
|
2638
|
+
self.__update_type_aliases(models)
|
|
2639
|
+
code = dump_templates(models)
|
|
2640
|
+
result += [code]
|
|
2641
|
+
|
|
2642
|
+
result += self.__get_resolve_reference_action_parts(
|
|
2643
|
+
models,
|
|
2644
|
+
require_update_action_models,
|
|
2645
|
+
use_deferred_annotations=use_deferred_annotations,
|
|
2646
|
+
)
|
|
2647
|
+
|
|
2648
|
+
# Generate forwarder content for modules that had models moved to _internal
|
|
2649
|
+
if not result and mod_key in forwarder_map:
|
|
2650
|
+
internal_module, class_mappings = forwarder_map[mod_key]
|
|
2651
|
+
forwarder_content = self.__generate_forwarder_content(
|
|
2652
|
+
mod_key, internal_module, class_mappings, is_init=init
|
|
2653
|
+
)
|
|
2654
|
+
result = [forwarder_content]
|
|
2655
|
+
|
|
2656
|
+
if not result and not init:
|
|
2657
|
+
continue
|
|
2658
|
+
body = "\n".join(result)
|
|
645
2659
|
if code_formatter:
|
|
646
2660
|
body = code_formatter.format_code(body)
|
|
647
2661
|
|
|
648
|
-
results[module] = Result(
|
|
2662
|
+
results[module] = Result(
|
|
2663
|
+
body=body,
|
|
2664
|
+
future_imports=future_imports_str,
|
|
2665
|
+
source=models[0].file_path if models else None,
|
|
2666
|
+
)
|
|
2667
|
+
|
|
2668
|
+
if all_exports_scope is not None:
|
|
2669
|
+
processed_init_modules = {m for m, _, _, _, _, _ in processed_models if m[-1] == "__init__.py"}
|
|
2670
|
+
for init_module, init_result in list(results.items()):
|
|
2671
|
+
if init_module[-1] != "__init__.py" or init_module in processed_init_modules or init_result.body:
|
|
2672
|
+
continue
|
|
2673
|
+
if child_exports := self._collect_exports_for_init(
|
|
2674
|
+
init_module, processed_models, all_exports_scope
|
|
2675
|
+
): # pragma: no branch
|
|
2676
|
+
resolved = self._resolve_export_collisions(child_exports, all_exports_collision_strategy, set())
|
|
2677
|
+
export_imports = self._build_all_exports_code(resolved)
|
|
2678
|
+
import_parts = [s for s in [future_imports_str, str(self.imports)] if s] if with_import else []
|
|
2679
|
+
parts = import_parts + (["\n"] if import_parts else [])
|
|
2680
|
+
parts += [str(export_imports), "", export_imports.dump_all(multiline=True)]
|
|
2681
|
+
body = "\n".join(parts)
|
|
2682
|
+
results[init_module] = Result(
|
|
2683
|
+
body=code_formatter.format_code(body) if code_formatter else body,
|
|
2684
|
+
future_imports=future_imports_str,
|
|
2685
|
+
)
|
|
649
2686
|
|
|
650
2687
|
# retain existing behaviour
|
|
651
|
-
if [*results] == [(
|
|
652
|
-
|
|
653
|
-
|
|
654
|
-
|
|
2688
|
+
if [*results] == [("__init__.py",)]:
|
|
2689
|
+
single_result = results["__init__.py",]
|
|
2690
|
+
return single_result.body
|
|
2691
|
+
|
|
2692
|
+
results = {tuple(i.replace("-", "_") for i in k): v for k, v in results.items()}
|
|
2693
|
+
return (
|
|
2694
|
+
self.__postprocess_result_modules(results)
|
|
2695
|
+
if self.treat_dot_as_module
|
|
2696
|
+
else {
|
|
2697
|
+
tuple((part[: part.rfind(".")].replace(".", "_") + part[part.rfind(".") :]) for part in k): v
|
|
2698
|
+
for k, v in results.items()
|
|
2699
|
+
}
|
|
2700
|
+
)
|