datamodel-code-generator 0.11.12__py3-none-any.whl → 0.45.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- datamodel_code_generator/__init__.py +654 -185
- datamodel_code_generator/__main__.py +872 -388
- datamodel_code_generator/arguments.py +798 -0
- datamodel_code_generator/cli_options.py +295 -0
- datamodel_code_generator/format.py +292 -54
- datamodel_code_generator/http.py +85 -10
- datamodel_code_generator/imports.py +152 -43
- datamodel_code_generator/model/__init__.py +138 -1
- datamodel_code_generator/model/base.py +531 -120
- datamodel_code_generator/model/dataclass.py +211 -0
- datamodel_code_generator/model/enum.py +133 -12
- datamodel_code_generator/model/imports.py +22 -0
- datamodel_code_generator/model/msgspec.py +462 -0
- datamodel_code_generator/model/pydantic/__init__.py +30 -25
- datamodel_code_generator/model/pydantic/base_model.py +304 -100
- datamodel_code_generator/model/pydantic/custom_root_type.py +11 -2
- datamodel_code_generator/model/pydantic/dataclass.py +15 -4
- datamodel_code_generator/model/pydantic/imports.py +40 -27
- datamodel_code_generator/model/pydantic/types.py +188 -96
- datamodel_code_generator/model/pydantic_v2/__init__.py +51 -0
- datamodel_code_generator/model/pydantic_v2/base_model.py +268 -0
- datamodel_code_generator/model/pydantic_v2/imports.py +15 -0
- datamodel_code_generator/model/pydantic_v2/root_model.py +35 -0
- datamodel_code_generator/model/pydantic_v2/types.py +143 -0
- datamodel_code_generator/model/scalar.py +124 -0
- datamodel_code_generator/model/template/Enum.jinja2 +15 -2
- datamodel_code_generator/model/template/ScalarTypeAliasAnnotation.jinja2 +6 -0
- datamodel_code_generator/model/template/ScalarTypeAliasType.jinja2 +6 -0
- datamodel_code_generator/model/template/ScalarTypeStatement.jinja2 +6 -0
- datamodel_code_generator/model/template/TypeAliasAnnotation.jinja2 +20 -0
- datamodel_code_generator/model/template/TypeAliasType.jinja2 +20 -0
- datamodel_code_generator/model/template/TypeStatement.jinja2 +20 -0
- datamodel_code_generator/model/template/TypedDict.jinja2 +5 -0
- datamodel_code_generator/model/template/TypedDictClass.jinja2 +25 -0
- datamodel_code_generator/model/template/TypedDictFunction.jinja2 +24 -0
- datamodel_code_generator/model/template/UnionTypeAliasAnnotation.jinja2 +10 -0
- datamodel_code_generator/model/template/UnionTypeAliasType.jinja2 +10 -0
- datamodel_code_generator/model/template/UnionTypeStatement.jinja2 +10 -0
- datamodel_code_generator/model/template/dataclass.jinja2 +50 -0
- datamodel_code_generator/model/template/msgspec.jinja2 +55 -0
- datamodel_code_generator/model/template/pydantic/BaseModel.jinja2 +17 -4
- datamodel_code_generator/model/template/pydantic/BaseModel_root.jinja2 +12 -4
- datamodel_code_generator/model/template/pydantic/Config.jinja2 +1 -1
- datamodel_code_generator/model/template/pydantic/dataclass.jinja2 +15 -2
- datamodel_code_generator/model/template/pydantic_v2/BaseModel.jinja2 +57 -0
- datamodel_code_generator/model/template/pydantic_v2/ConfigDict.jinja2 +5 -0
- datamodel_code_generator/model/template/pydantic_v2/RootModel.jinja2 +48 -0
- datamodel_code_generator/model/type_alias.py +70 -0
- datamodel_code_generator/model/typed_dict.py +161 -0
- datamodel_code_generator/model/types.py +106 -0
- datamodel_code_generator/model/union.py +105 -0
- datamodel_code_generator/parser/__init__.py +30 -12
- datamodel_code_generator/parser/_graph.py +67 -0
- datamodel_code_generator/parser/_scc.py +171 -0
- datamodel_code_generator/parser/base.py +2426 -380
- datamodel_code_generator/parser/graphql.py +652 -0
- datamodel_code_generator/parser/jsonschema.py +2518 -647
- datamodel_code_generator/parser/openapi.py +631 -222
- datamodel_code_generator/py.typed +0 -0
- datamodel_code_generator/pydantic_patch.py +28 -0
- datamodel_code_generator/reference.py +672 -290
- datamodel_code_generator/types.py +521 -145
- datamodel_code_generator/util.py +155 -0
- datamodel_code_generator/watch.py +65 -0
- datamodel_code_generator-0.45.0.dist-info/METADATA +301 -0
- datamodel_code_generator-0.45.0.dist-info/RECORD +69 -0
- {datamodel_code_generator-0.11.12.dist-info → datamodel_code_generator-0.45.0.dist-info}/WHEEL +1 -1
- datamodel_code_generator-0.45.0.dist-info/entry_points.txt +2 -0
- datamodel_code_generator/version.py +0 -1
- datamodel_code_generator-0.11.12.dist-info/METADATA +0 -440
- datamodel_code_generator-0.11.12.dist-info/RECORD +0 -31
- datamodel_code_generator-0.11.12.dist-info/entry_points.txt +0 -3
- {datamodel_code_generator-0.11.12.dist-info → datamodel_code_generator-0.45.0.dist-info/licenses}/LICENSE +0 -0
|
@@ -0,0 +1,106 @@
|
|
|
1
|
+
"""Base type manager for model modules.
|
|
2
|
+
|
|
3
|
+
Provides DataTypeManager implementation with type mapping factory.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
from typing import TYPE_CHECKING, Any
|
|
9
|
+
|
|
10
|
+
from datamodel_code_generator import DatetimeClassType, PythonVersion, PythonVersionMin
|
|
11
|
+
from datamodel_code_generator.imports import (
|
|
12
|
+
IMPORT_ANY,
|
|
13
|
+
IMPORT_DECIMAL,
|
|
14
|
+
IMPORT_TIMEDELTA,
|
|
15
|
+
)
|
|
16
|
+
from datamodel_code_generator.types import DataType, StrictTypes, Types
|
|
17
|
+
from datamodel_code_generator.types import DataTypeManager as _DataTypeManager
|
|
18
|
+
|
|
19
|
+
if TYPE_CHECKING:
|
|
20
|
+
from collections.abc import Sequence
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def type_map_factory(data_type: type[DataType]) -> dict[Types, DataType]:
|
|
24
|
+
"""Create type mapping for common schema types to Python types."""
|
|
25
|
+
data_type_int = data_type(type="int")
|
|
26
|
+
data_type_float = data_type(type="float")
|
|
27
|
+
data_type_str = data_type(type="str")
|
|
28
|
+
return {
|
|
29
|
+
# TODO: Should we support a special type such UUID?
|
|
30
|
+
Types.integer: data_type_int,
|
|
31
|
+
Types.int32: data_type_int,
|
|
32
|
+
Types.int64: data_type_int,
|
|
33
|
+
Types.number: data_type_float,
|
|
34
|
+
Types.float: data_type_float,
|
|
35
|
+
Types.double: data_type_float,
|
|
36
|
+
Types.decimal: data_type.from_import(IMPORT_DECIMAL),
|
|
37
|
+
Types.time: data_type_str,
|
|
38
|
+
Types.string: data_type_str,
|
|
39
|
+
Types.byte: data_type_str, # base64 encoded string
|
|
40
|
+
Types.binary: data_type(type="bytes"),
|
|
41
|
+
Types.date: data_type_str,
|
|
42
|
+
Types.date_time: data_type_str,
|
|
43
|
+
Types.timedelta: data_type.from_import(IMPORT_TIMEDELTA),
|
|
44
|
+
Types.password: data_type_str,
|
|
45
|
+
Types.email: data_type_str,
|
|
46
|
+
Types.uuid: data_type_str,
|
|
47
|
+
Types.uuid1: data_type_str,
|
|
48
|
+
Types.uuid2: data_type_str,
|
|
49
|
+
Types.uuid3: data_type_str,
|
|
50
|
+
Types.uuid4: data_type_str,
|
|
51
|
+
Types.uuid5: data_type_str,
|
|
52
|
+
Types.uri: data_type_str,
|
|
53
|
+
Types.hostname: data_type_str,
|
|
54
|
+
Types.ipv4: data_type_str,
|
|
55
|
+
Types.ipv6: data_type_str,
|
|
56
|
+
Types.ipv4_network: data_type_str,
|
|
57
|
+
Types.ipv6_network: data_type_str,
|
|
58
|
+
Types.boolean: data_type(type="bool"),
|
|
59
|
+
Types.object: data_type.from_import(IMPORT_ANY, is_dict=True),
|
|
60
|
+
Types.null: data_type(type="None"),
|
|
61
|
+
Types.array: data_type.from_import(IMPORT_ANY, is_list=True),
|
|
62
|
+
Types.any: data_type.from_import(IMPORT_ANY),
|
|
63
|
+
}
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class DataTypeManager(_DataTypeManager):
|
|
67
|
+
"""Base type manager for model modules."""
|
|
68
|
+
|
|
69
|
+
def __init__( # noqa: PLR0913, PLR0917
|
|
70
|
+
self,
|
|
71
|
+
python_version: PythonVersion = PythonVersionMin,
|
|
72
|
+
use_standard_collections: bool = False, # noqa: FBT001, FBT002
|
|
73
|
+
use_generic_container_types: bool = False, # noqa: FBT001, FBT002
|
|
74
|
+
strict_types: Sequence[StrictTypes] | None = None,
|
|
75
|
+
use_non_positive_negative_number_constrained_types: bool = False, # noqa: FBT001, FBT002
|
|
76
|
+
use_decimal_for_multiple_of: bool = False, # noqa: FBT001, FBT002
|
|
77
|
+
use_union_operator: bool = False, # noqa: FBT001, FBT002
|
|
78
|
+
use_pendulum: bool = False, # noqa: FBT001, FBT002
|
|
79
|
+
target_datetime_class: DatetimeClassType | None = None,
|
|
80
|
+
treat_dot_as_module: bool = False, # noqa: FBT001, FBT002
|
|
81
|
+
use_serialize_as_any: bool = False, # noqa: FBT001, FBT002
|
|
82
|
+
) -> None:
|
|
83
|
+
"""Initialize type manager with basic type mapping."""
|
|
84
|
+
super().__init__(
|
|
85
|
+
python_version,
|
|
86
|
+
use_standard_collections,
|
|
87
|
+
use_generic_container_types,
|
|
88
|
+
strict_types,
|
|
89
|
+
use_non_positive_negative_number_constrained_types,
|
|
90
|
+
use_decimal_for_multiple_of,
|
|
91
|
+
use_union_operator,
|
|
92
|
+
use_pendulum,
|
|
93
|
+
target_datetime_class,
|
|
94
|
+
treat_dot_as_module,
|
|
95
|
+
use_serialize_as_any,
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
self.type_map: dict[Types, DataType] = type_map_factory(self.data_type)
|
|
99
|
+
|
|
100
|
+
def get_data_type(
|
|
101
|
+
self,
|
|
102
|
+
types: Types,
|
|
103
|
+
**_: Any,
|
|
104
|
+
) -> DataType:
|
|
105
|
+
"""Get data type for schema type."""
|
|
106
|
+
return self.type_map[types]
|
|
@@ -0,0 +1,105 @@
|
|
|
1
|
+
"""Union type model generators.
|
|
2
|
+
|
|
3
|
+
Provides classes for generating union type aliases for GraphQL union types.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
from typing import TYPE_CHECKING, Any, ClassVar
|
|
9
|
+
|
|
10
|
+
from datamodel_code_generator.imports import (
|
|
11
|
+
IMPORT_TYPE_ALIAS,
|
|
12
|
+
IMPORT_TYPE_ALIAS_BACKPORT,
|
|
13
|
+
IMPORT_TYPE_ALIAS_TYPE,
|
|
14
|
+
IMPORT_UNION,
|
|
15
|
+
Import,
|
|
16
|
+
)
|
|
17
|
+
from datamodel_code_generator.model import DataModel, DataModelFieldBase
|
|
18
|
+
from datamodel_code_generator.model.base import UNDEFINED
|
|
19
|
+
|
|
20
|
+
if TYPE_CHECKING:
|
|
21
|
+
from collections import defaultdict
|
|
22
|
+
from pathlib import Path
|
|
23
|
+
|
|
24
|
+
from datamodel_code_generator.reference import Reference
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class _DataTypeUnionBase(DataModel):
|
|
28
|
+
"""Base class for GraphQL union types with shared __init__ logic."""
|
|
29
|
+
|
|
30
|
+
def __init__( # noqa: PLR0913
|
|
31
|
+
self,
|
|
32
|
+
*,
|
|
33
|
+
reference: Reference,
|
|
34
|
+
fields: list[DataModelFieldBase],
|
|
35
|
+
decorators: list[str] | None = None,
|
|
36
|
+
base_classes: list[Reference] | None = None,
|
|
37
|
+
custom_base_class: str | None = None,
|
|
38
|
+
custom_template_dir: Path | None = None,
|
|
39
|
+
extra_template_data: defaultdict[str, dict[str, Any]] | None = None,
|
|
40
|
+
methods: list[str] | None = None,
|
|
41
|
+
path: Path | None = None,
|
|
42
|
+
description: str | None = None,
|
|
43
|
+
default: Any = UNDEFINED,
|
|
44
|
+
nullable: bool = False,
|
|
45
|
+
keyword_only: bool = False,
|
|
46
|
+
treat_dot_as_module: bool = False,
|
|
47
|
+
) -> None:
|
|
48
|
+
"""Initialize GraphQL union type."""
|
|
49
|
+
super().__init__(
|
|
50
|
+
reference=reference,
|
|
51
|
+
fields=fields,
|
|
52
|
+
decorators=decorators,
|
|
53
|
+
base_classes=base_classes,
|
|
54
|
+
custom_base_class=custom_base_class,
|
|
55
|
+
custom_template_dir=custom_template_dir,
|
|
56
|
+
extra_template_data=extra_template_data,
|
|
57
|
+
methods=methods,
|
|
58
|
+
path=path,
|
|
59
|
+
description=description,
|
|
60
|
+
default=default,
|
|
61
|
+
nullable=nullable,
|
|
62
|
+
keyword_only=keyword_only,
|
|
63
|
+
treat_dot_as_module=treat_dot_as_module,
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
class DataTypeUnion(_DataTypeUnionBase):
|
|
68
|
+
"""GraphQL union using TypeAlias annotation for Python 3.10+ (Name: TypeAlias = Union[...])."""
|
|
69
|
+
|
|
70
|
+
TEMPLATE_FILE_PATH: ClassVar[str] = "UnionTypeAliasAnnotation.jinja2"
|
|
71
|
+
BASE_CLASS: ClassVar[str] = ""
|
|
72
|
+
DEFAULT_IMPORTS: ClassVar[tuple[Import, ...]] = (
|
|
73
|
+
IMPORT_TYPE_ALIAS,
|
|
74
|
+
IMPORT_UNION,
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
class DataTypeUnionBackport(_DataTypeUnionBase):
|
|
79
|
+
"""GraphQL union using TypeAlias annotation for Python 3.9 (Name: TypeAlias = Union[...])."""
|
|
80
|
+
|
|
81
|
+
TEMPLATE_FILE_PATH: ClassVar[str] = "UnionTypeAliasAnnotation.jinja2"
|
|
82
|
+
BASE_CLASS: ClassVar[str] = ""
|
|
83
|
+
DEFAULT_IMPORTS: ClassVar[tuple[Import, ...]] = (
|
|
84
|
+
IMPORT_TYPE_ALIAS_BACKPORT,
|
|
85
|
+
IMPORT_UNION,
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
class DataTypeUnionTypeBackport(_DataTypeUnionBase):
|
|
90
|
+
"""GraphQL union using TypeAliasType for Python 3.9-3.11 (Name = TypeAliasType("Name", Union[...]))."""
|
|
91
|
+
|
|
92
|
+
TEMPLATE_FILE_PATH: ClassVar[str] = "UnionTypeAliasType.jinja2"
|
|
93
|
+
BASE_CLASS: ClassVar[str] = ""
|
|
94
|
+
DEFAULT_IMPORTS: ClassVar[tuple[Import, ...]] = (
|
|
95
|
+
IMPORT_TYPE_ALIAS_TYPE,
|
|
96
|
+
IMPORT_UNION,
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
class DataTypeUnionTypeStatement(_DataTypeUnionBase):
|
|
101
|
+
"""GraphQL union using type statement for Python 3.12+ (type Name = Union[...])."""
|
|
102
|
+
|
|
103
|
+
TEMPLATE_FILE_PATH: ClassVar[str] = "UnionTypeStatement.jinja2"
|
|
104
|
+
BASE_CLASS: ClassVar[str] = ""
|
|
105
|
+
DEFAULT_IMPORTS: ClassVar[tuple[Import, ...]] = (IMPORT_UNION,)
|
|
@@ -1,31 +1,49 @@
|
|
|
1
|
+
"""Parser utilities and base types for schema parsing.
|
|
2
|
+
|
|
3
|
+
Provides LiteralType enum for literal parsing options and DefaultPutDict
|
|
4
|
+
for caching remote schema content.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from collections import UserDict
|
|
1
10
|
from enum import Enum
|
|
2
|
-
from typing import Callable,
|
|
11
|
+
from typing import Callable, TypeVar
|
|
3
12
|
|
|
4
|
-
TK = TypeVar(
|
|
5
|
-
TV = TypeVar(
|
|
13
|
+
TK = TypeVar("TK")
|
|
14
|
+
TV = TypeVar("TV")
|
|
6
15
|
|
|
7
16
|
|
|
8
17
|
class LiteralType(Enum):
|
|
9
|
-
|
|
10
|
-
|
|
18
|
+
"""Options for handling enum fields as literals."""
|
|
19
|
+
|
|
20
|
+
All = "all"
|
|
21
|
+
One = "one"
|
|
22
|
+
|
|
11
23
|
|
|
24
|
+
class DefaultPutDict(UserDict[TK, TV]):
|
|
25
|
+
"""Dict that can lazily compute and cache missing values."""
|
|
12
26
|
|
|
13
|
-
class DefaultPutDict(Dict[TK, TV]):
|
|
14
27
|
def get_or_put(
|
|
15
28
|
self,
|
|
16
29
|
key: TK,
|
|
17
|
-
default:
|
|
18
|
-
default_factory:
|
|
30
|
+
default: TV | None = None,
|
|
31
|
+
default_factory: Callable[[TK], TV] | None = None,
|
|
19
32
|
) -> TV:
|
|
33
|
+
"""Get value for key, or compute and store it if missing."""
|
|
20
34
|
if key in self:
|
|
21
35
|
return self[key]
|
|
22
|
-
|
|
36
|
+
if default: # pragma: no cover
|
|
23
37
|
value = self[key] = default
|
|
24
38
|
return value
|
|
25
|
-
|
|
39
|
+
if default_factory:
|
|
26
40
|
value = self[key] = default_factory(key)
|
|
27
41
|
return value
|
|
28
|
-
|
|
42
|
+
msg = "Not found default and default_factory" # pragma: no cover
|
|
43
|
+
raise ValueError(msg) # pragma: no cover
|
|
29
44
|
|
|
30
45
|
|
|
31
|
-
__all__ = [
|
|
46
|
+
__all__ = [
|
|
47
|
+
"DefaultPutDict",
|
|
48
|
+
"LiteralType",
|
|
49
|
+
]
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
"""Graph utilities used by parsers.
|
|
2
|
+
|
|
3
|
+
This module intentionally contains only generic graph algorithms (no DataModel
|
|
4
|
+
or schema-specific logic), so it can be reused across parsers without creating
|
|
5
|
+
dependency cycles.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
from collections.abc import Callable, Hashable
|
|
11
|
+
from heapq import heappop, heappush
|
|
12
|
+
from typing import TypeVar
|
|
13
|
+
|
|
14
|
+
TNode = TypeVar("TNode", bound=Hashable)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def stable_toposort(
|
|
18
|
+
nodes: list[TNode],
|
|
19
|
+
edges: dict[TNode, set[TNode]],
|
|
20
|
+
*,
|
|
21
|
+
key: Callable[[TNode], int],
|
|
22
|
+
) -> list[TNode]:
|
|
23
|
+
"""Stable topological sort; breaks ties by `key`.
|
|
24
|
+
|
|
25
|
+
The `edges` mapping is an adjacency list where `edges[u]` contains all `v`
|
|
26
|
+
such that `u -> v` (i.e., `u` must come before `v`).
|
|
27
|
+
|
|
28
|
+
If a cycle is detected, any remaining nodes are appended in `key` order for
|
|
29
|
+
determinism.
|
|
30
|
+
"""
|
|
31
|
+
node_set = set(nodes)
|
|
32
|
+
order_index = {node: index for index, node in enumerate(nodes)}
|
|
33
|
+
indegree: dict[TNode, int] = dict.fromkeys(nodes, 0)
|
|
34
|
+
outgoing: dict[TNode, set[TNode]] = {n: set() for n in nodes}
|
|
35
|
+
|
|
36
|
+
for source in node_set & edges.keys():
|
|
37
|
+
destinations = edges[source]
|
|
38
|
+
new_destinations = destinations & node_set - outgoing[source]
|
|
39
|
+
outgoing[source].update(new_destinations)
|
|
40
|
+
for destination in new_destinations:
|
|
41
|
+
indegree[destination] += 1
|
|
42
|
+
|
|
43
|
+
outgoing_sorted = {
|
|
44
|
+
node: sorted(neighbors, key=lambda neighbor: (key(neighbor), order_index[neighbor]))
|
|
45
|
+
for node, neighbors in outgoing.items()
|
|
46
|
+
}
|
|
47
|
+
|
|
48
|
+
ready: list[tuple[int, int, TNode]] = []
|
|
49
|
+
for node in nodes:
|
|
50
|
+
if indegree[node] == 0:
|
|
51
|
+
heappush(ready, (key(node), order_index[node], node))
|
|
52
|
+
|
|
53
|
+
result: list[TNode] = []
|
|
54
|
+
while ready:
|
|
55
|
+
_, _, node = heappop(ready)
|
|
56
|
+
result.append(node)
|
|
57
|
+
for neighbor in outgoing_sorted[node]:
|
|
58
|
+
indegree[neighbor] -= 1
|
|
59
|
+
if indegree[neighbor] == 0:
|
|
60
|
+
heappush(ready, (key(neighbor), order_index[neighbor], neighbor))
|
|
61
|
+
|
|
62
|
+
remaining = sorted(
|
|
63
|
+
[node for node in nodes if node not in result],
|
|
64
|
+
key=lambda node: (key(node), order_index[node]),
|
|
65
|
+
)
|
|
66
|
+
result.extend(remaining)
|
|
67
|
+
return result
|
|
@@ -0,0 +1,171 @@
|
|
|
1
|
+
"""Strongly Connected Components detection using Tarjan's algorithm.
|
|
2
|
+
|
|
3
|
+
Provides SCC detection for module dependency graphs to identify
|
|
4
|
+
circular import patterns in generated code.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from enum import IntEnum
|
|
10
|
+
from typing import NamedTuple
|
|
11
|
+
|
|
12
|
+
from typing_extensions import TypeAlias
|
|
13
|
+
|
|
14
|
+
ModulePath: TypeAlias = tuple[str, ...]
|
|
15
|
+
ModuleGraph: TypeAlias = dict[ModulePath, set[ModulePath]]
|
|
16
|
+
SCC: TypeAlias = set[ModulePath]
|
|
17
|
+
SCCList: TypeAlias = list[SCC]
|
|
18
|
+
|
|
19
|
+
_EMPTY_SET: frozenset[ModulePath] = frozenset()
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class _Phase(IntEnum):
|
|
23
|
+
"""DFS traversal phase for iterative Tarjan's algorithm."""
|
|
24
|
+
|
|
25
|
+
VISIT = 0
|
|
26
|
+
POSTVISIT = 1
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class _Frame(NamedTuple):
|
|
30
|
+
"""Call stack frame for iterative DFS."""
|
|
31
|
+
|
|
32
|
+
node: ModulePath
|
|
33
|
+
neighbor_idx: int
|
|
34
|
+
phase: _Phase
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class _TarjanState:
|
|
38
|
+
"""Mutable state for Tarjan's SCC algorithm."""
|
|
39
|
+
|
|
40
|
+
__slots__ = ("graph", "index", "index_counter", "lowlinks", "on_stack", "result", "sorted_cache", "stack")
|
|
41
|
+
|
|
42
|
+
def __init__(self, graph: ModuleGraph) -> None:
|
|
43
|
+
self.graph = graph
|
|
44
|
+
self.index_counter: int = 0
|
|
45
|
+
self.stack: list[ModulePath] = []
|
|
46
|
+
self.lowlinks: dict[ModulePath, int] = {}
|
|
47
|
+
self.index: dict[ModulePath, int] = {}
|
|
48
|
+
self.on_stack: set[ModulePath] = set()
|
|
49
|
+
self.result: SCCList = []
|
|
50
|
+
self.sorted_cache: dict[ModulePath, list[ModulePath]] = {}
|
|
51
|
+
|
|
52
|
+
def get_sorted_neighbors(self, node: ModulePath) -> list[ModulePath]:
|
|
53
|
+
"""Get sorted neighbors with lazy memoization."""
|
|
54
|
+
cached: list[ModulePath] | None = self.sorted_cache.get(node)
|
|
55
|
+
if cached is None:
|
|
56
|
+
cached = sorted(self.graph.get(node, _EMPTY_SET))
|
|
57
|
+
self.sorted_cache[node] = cached
|
|
58
|
+
return cached
|
|
59
|
+
|
|
60
|
+
def extract_scc(self, root: ModulePath) -> None:
|
|
61
|
+
"""Pop nodes from stack to form an SCC rooted at the given node."""
|
|
62
|
+
scc: SCC = set()
|
|
63
|
+
while True:
|
|
64
|
+
w: ModulePath = self.stack.pop()
|
|
65
|
+
self.on_stack.remove(w)
|
|
66
|
+
scc.add(w)
|
|
67
|
+
if w == root: # pragma: no branch
|
|
68
|
+
break
|
|
69
|
+
self.result.append(scc)
|
|
70
|
+
|
|
71
|
+
def initialize_node(self, node: ModulePath) -> None:
|
|
72
|
+
"""Initialize a node for DFS traversal."""
|
|
73
|
+
self.index[node] = self.lowlinks[node] = self.index_counter
|
|
74
|
+
self.index_counter += 1
|
|
75
|
+
self.stack.append(node)
|
|
76
|
+
self.on_stack.add(node)
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def _strongconnect(state: _TarjanState, start: ModulePath) -> None:
|
|
80
|
+
"""Execute Tarjan's strongconnect algorithm iteratively."""
|
|
81
|
+
state.initialize_node(start)
|
|
82
|
+
call_stack: list[_Frame] = [_Frame(start, 0, _Phase.VISIT)]
|
|
83
|
+
|
|
84
|
+
while call_stack:
|
|
85
|
+
frame: _Frame = call_stack.pop()
|
|
86
|
+
node: ModulePath = frame.node
|
|
87
|
+
neighbors: list[ModulePath] = state.get_sorted_neighbors(node)
|
|
88
|
+
neighbor_idx: int = frame.neighbor_idx
|
|
89
|
+
|
|
90
|
+
# Handle post-visit: update lowlink from child
|
|
91
|
+
if frame.phase == _Phase.POSTVISIT:
|
|
92
|
+
child: ModulePath = neighbors[neighbor_idx]
|
|
93
|
+
state.lowlinks[node] = min(state.lowlinks[node], state.lowlinks[child])
|
|
94
|
+
neighbor_idx += 1
|
|
95
|
+
|
|
96
|
+
# Process remaining neighbors
|
|
97
|
+
while neighbor_idx < len(neighbors):
|
|
98
|
+
w: ModulePath = neighbors[neighbor_idx]
|
|
99
|
+
|
|
100
|
+
if w not in state.index:
|
|
101
|
+
# Save state for post-visit
|
|
102
|
+
call_stack.append(_Frame(node, neighbor_idx, _Phase.POSTVISIT))
|
|
103
|
+
# Initialize and push unvisited neighbor
|
|
104
|
+
state.initialize_node(w)
|
|
105
|
+
call_stack.append(_Frame(w, 0, _Phase.VISIT))
|
|
106
|
+
break
|
|
107
|
+
if w in state.on_stack:
|
|
108
|
+
state.lowlinks[node] = min(state.lowlinks[node], state.index[w])
|
|
109
|
+
|
|
110
|
+
neighbor_idx += 1
|
|
111
|
+
else:
|
|
112
|
+
# All neighbors processed: check if node is SCC root
|
|
113
|
+
if state.lowlinks[node] == state.index[node]:
|
|
114
|
+
state.extract_scc(node)
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def strongly_connected_components(graph: ModuleGraph) -> SCCList:
|
|
118
|
+
"""Find all strongly connected components using Tarjan's algorithm.
|
|
119
|
+
|
|
120
|
+
Uses an iterative approach to avoid Python recursion limits on large graphs.
|
|
121
|
+
Neighbors are lazily sorted and memoized for determinism with O(E log V) cost.
|
|
122
|
+
|
|
123
|
+
Args:
|
|
124
|
+
graph: Adjacency list mapping module tuple to set of dependency module tuples.
|
|
125
|
+
Each node is a tuple like ("pkg", "__init__.py") or ("pkg", "module.py").
|
|
126
|
+
|
|
127
|
+
Returns:
|
|
128
|
+
List of all SCCs, each being a set of module tuples.
|
|
129
|
+
SCCs are returned in reverse topological order (leaves first).
|
|
130
|
+
Includes all SCCs, including singleton nodes without self-loops.
|
|
131
|
+
"""
|
|
132
|
+
# Collect all nodes (including those only referenced as edges)
|
|
133
|
+
all_nodes: set[ModulePath] = set(graph.keys())
|
|
134
|
+
for neighbors in graph.values():
|
|
135
|
+
all_nodes.update(neighbors)
|
|
136
|
+
|
|
137
|
+
state = _TarjanState(graph)
|
|
138
|
+
|
|
139
|
+
# Run algorithm on all unvisited nodes (sorted for determinism)
|
|
140
|
+
for node in sorted(all_nodes):
|
|
141
|
+
if node not in state.index:
|
|
142
|
+
_strongconnect(state, node)
|
|
143
|
+
|
|
144
|
+
return state.result
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def find_circular_sccs(graph: ModuleGraph) -> SCCList:
|
|
148
|
+
"""Find SCCs that represent circular dependencies.
|
|
149
|
+
|
|
150
|
+
A circular SCC is one with:
|
|
151
|
+
- More than one node, OR
|
|
152
|
+
- Exactly one node with a self-loop (edge to itself)
|
|
153
|
+
|
|
154
|
+
Args:
|
|
155
|
+
graph: Module dependency graph
|
|
156
|
+
|
|
157
|
+
Returns:
|
|
158
|
+
List of circular SCCs, sorted by their minimum element for determinism
|
|
159
|
+
"""
|
|
160
|
+
all_sccs: SCCList = strongly_connected_components(graph)
|
|
161
|
+
circular: SCCList = []
|
|
162
|
+
|
|
163
|
+
for scc in all_sccs:
|
|
164
|
+
if len(scc) > 1:
|
|
165
|
+
circular.append(scc)
|
|
166
|
+
elif len(scc) == 1: # pragma: no branch
|
|
167
|
+
node: ModulePath = next(iter(scc))
|
|
168
|
+
if node in graph and node in graph[node]:
|
|
169
|
+
circular.append(scc)
|
|
170
|
+
|
|
171
|
+
return sorted(circular, key=min)
|