atlas-init 0.7.0__py3-none-any.whl → 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- atlas_init/__init__.py +1 -1
- atlas_init/atlas_init.yaml +1 -0
- atlas_init/cli_tf/example_update.py +20 -8
- atlas_init/cli_tf/hcl/modifier.py +22 -8
- atlas_init/settings/env_vars.py +12 -2
- atlas_init/tf_ext/api_call.py +9 -9
- atlas_init/tf_ext/args.py +16 -1
- atlas_init/tf_ext/gen_examples.py +141 -0
- atlas_init/tf_ext/gen_module_readme.py +131 -0
- atlas_init/tf_ext/gen_resource_main.py +195 -0
- atlas_init/tf_ext/gen_resource_output.py +71 -0
- atlas_init/tf_ext/gen_resource_variables.py +159 -0
- atlas_init/tf_ext/gen_versions.py +10 -0
- atlas_init/tf_ext/models_module.py +454 -0
- atlas_init/tf_ext/newres.py +90 -0
- atlas_init/tf_ext/plan_diffs.py +140 -0
- atlas_init/tf_ext/provider_schema.py +199 -0
- atlas_init/tf_ext/py_gen.py +294 -0
- atlas_init/tf_ext/schema_to_dataclass.py +522 -0
- atlas_init/tf_ext/settings.py +151 -2
- atlas_init/tf_ext/tf_dep.py +5 -5
- atlas_init/tf_ext/tf_desc_gen.py +53 -0
- atlas_init/tf_ext/tf_desc_update.py +0 -0
- atlas_init/tf_ext/tf_mod_gen.py +263 -0
- atlas_init/tf_ext/tf_mod_gen_provider.py +124 -0
- atlas_init/tf_ext/tf_modules.py +5 -4
- atlas_init/tf_ext/tf_vars.py +13 -28
- atlas_init/tf_ext/typer_app.py +6 -2
- {atlas_init-0.7.0.dist-info → atlas_init-0.8.0.dist-info}/METADATA +4 -3
- {atlas_init-0.7.0.dist-info → atlas_init-0.8.0.dist-info}/RECORD +33 -17
- {atlas_init-0.7.0.dist-info → atlas_init-0.8.0.dist-info}/WHEEL +0 -0
- {atlas_init-0.7.0.dist-info → atlas_init-0.8.0.dist-info}/entry_points.txt +0 -0
- {atlas_init-0.7.0.dist-info → atlas_init-0.8.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,199 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
import logging
|
3
|
+
import os
|
4
|
+
from pathlib import Path
|
5
|
+
from tempfile import TemporaryDirectory
|
6
|
+
|
7
|
+
from ask_shell import run_and_wait
|
8
|
+
from model_lib import Entity, dump, parse_dict
|
9
|
+
from pydantic import BaseModel
|
10
|
+
from zero_3rdparty.file_utils import ensure_parents_write_text
|
11
|
+
|
12
|
+
from atlas_init.tf_ext.args import TF_CLI_CONFIG_FILE_ENV_NAME
|
13
|
+
from atlas_init.tf_ext.constants import ATLAS_PROVIDER_NAME
|
14
|
+
from atlas_init.tf_ext.models_module import ProviderGenConfig
|
15
|
+
from atlas_init.tf_ext.settings import TfExtSettings
|
16
|
+
|
17
|
+
|
18
|
+
logger = logging.getLogger(__name__)
|
19
|
+
|
20
|
+
|
21
|
+
def parse_provider_resource_schema(schema: dict, provider_name: str) -> dict:
|
22
|
+
schemas = schema.get("provider_schemas", {})
|
23
|
+
for provider_url, provider_schema in schemas.items():
|
24
|
+
if provider_url.endswith(provider_name):
|
25
|
+
return provider_schema.get("resource_schemas", {})
|
26
|
+
raise ValueError(f"Provider '{provider_name}' not found in schema.")
|
27
|
+
|
28
|
+
|
29
|
+
_providers_tf_with_external = """
|
30
|
+
terraform {
|
31
|
+
required_providers {
|
32
|
+
mongodbatlas = {
|
33
|
+
source = "mongodb/mongodbatlas"
|
34
|
+
version = "~> 1.26"
|
35
|
+
}
|
36
|
+
external = {
|
37
|
+
source = "hashicorp/external"
|
38
|
+
version = "~>2.0"
|
39
|
+
}
|
40
|
+
}
|
41
|
+
required_version = ">= 1.8"
|
42
|
+
}
|
43
|
+
"""
|
44
|
+
|
45
|
+
_providers_tf = """
|
46
|
+
terraform {
|
47
|
+
required_providers {
|
48
|
+
mongodbatlas = {
|
49
|
+
source = "mongodb/mongodbatlas"
|
50
|
+
version = "~> 1.26"
|
51
|
+
}
|
52
|
+
}
|
53
|
+
required_version = ">= 1.8"
|
54
|
+
}
|
55
|
+
"""
|
56
|
+
_providers_tf_minimal = """
|
57
|
+
terraform {
|
58
|
+
required_version = ">= 1.8"
|
59
|
+
}
|
60
|
+
"""
|
61
|
+
|
62
|
+
|
63
|
+
def get_providers_tf(skip_python: bool = True, minimal: bool = False) -> str:
|
64
|
+
if minimal:
|
65
|
+
return _providers_tf_minimal
|
66
|
+
return _providers_tf if skip_python else _providers_tf_with_external
|
67
|
+
|
68
|
+
|
69
|
+
class AtlasSchemaInfo(Entity):
|
70
|
+
resource_types: list[str]
|
71
|
+
deprecated_resource_types: list[str]
|
72
|
+
raw_resource_schema: dict[str, dict]
|
73
|
+
providers_tf: str = _providers_tf
|
74
|
+
|
75
|
+
|
76
|
+
class SchemaAttribute(BaseModel):
|
77
|
+
type: str | list | dict | None = None
|
78
|
+
description: str | None = None
|
79
|
+
description_kind: str | None = None
|
80
|
+
optional: bool | None = None
|
81
|
+
required: bool | None = None
|
82
|
+
computed: bool | None = None
|
83
|
+
deprecated: bool | None = None
|
84
|
+
sensitive: bool | None = None
|
85
|
+
nested_type: SchemaBlock | None = None
|
86
|
+
default: object | None = None
|
87
|
+
enum: list[object] | None = None
|
88
|
+
allowed_values: list[object] | None = None
|
89
|
+
force_new: bool | None = None
|
90
|
+
conflicts_with: list[str] | None = None
|
91
|
+
exactly_one_of: list[str] | None = None
|
92
|
+
at_least_one_of: list[str] | None = None
|
93
|
+
required_with: list[str] | None = None
|
94
|
+
deprecated_message: str | None = None
|
95
|
+
validators: list[dict] | None = None
|
96
|
+
element_type: str | dict | None = None
|
97
|
+
|
98
|
+
|
99
|
+
class SchemaBlockType(BaseModel):
|
100
|
+
block: SchemaBlock
|
101
|
+
nesting_mode: str
|
102
|
+
min_items: int | None = None
|
103
|
+
max_items: int | None = None
|
104
|
+
required: bool | None = None
|
105
|
+
optional: bool | None = None
|
106
|
+
description_kind: str | None = None
|
107
|
+
deprecated: bool | None = None
|
108
|
+
description: str | None = None
|
109
|
+
default: object | None = None
|
110
|
+
validators: list[dict] | None = None
|
111
|
+
|
112
|
+
@property
|
113
|
+
def block_with_nesting_mode(self) -> SchemaBlock:
|
114
|
+
return self.block.model_copy(update={"nesting_mode": self.nesting_mode})
|
115
|
+
|
116
|
+
|
117
|
+
class SchemaBlock(BaseModel):
|
118
|
+
attributes: dict[str, SchemaAttribute] | None = None
|
119
|
+
block_types: dict[str, SchemaBlockType] | None = None
|
120
|
+
description_kind: str | None = None
|
121
|
+
description: str | None = None
|
122
|
+
deprecated: bool | None = None
|
123
|
+
nesting_mode: str | None = None
|
124
|
+
|
125
|
+
|
126
|
+
class ResourceSchema(BaseModel):
|
127
|
+
block: SchemaBlock
|
128
|
+
version: int | None = None
|
129
|
+
description_kind: str | None = None
|
130
|
+
|
131
|
+
def required_attributes(self) -> dict[str, SchemaAttribute]:
|
132
|
+
return {name: attr for name, attr in (self.block.attributes or {}).items() if attr.required}
|
133
|
+
|
134
|
+
|
135
|
+
SchemaAttribute.model_rebuild()
|
136
|
+
SchemaBlockType.model_rebuild()
|
137
|
+
SchemaBlock.model_rebuild()
|
138
|
+
|
139
|
+
|
140
|
+
def parse_atlas_schema_from_settings(settings: TfExtSettings, provider_config: ProviderGenConfig) -> AtlasSchemaInfo:
|
141
|
+
repo_path = settings.repo_path_atlas_provider
|
142
|
+
assert repo_path, "repo_path_atlas_provider is required"
|
143
|
+
current_sha = run_and_wait("git rev-parse HEAD", cwd=repo_path).stdout_one_line
|
144
|
+
cache_dir = settings.provider_cache_dir(provider_config.provider_name)
|
145
|
+
if provider_config.last_gen_sha == current_sha:
|
146
|
+
return read_cached_atlas_schema(cache_dir, current_sha, settings.tf_cli_config_file)
|
147
|
+
schema = parse_atlas_schema()
|
148
|
+
provider_config.last_gen_sha = current_sha
|
149
|
+
provider_yaml = dump(provider_config.config_dump(), "yaml")
|
150
|
+
settings.repo_out.provider_settings_path(provider_config.provider_name).write_text(provider_yaml)
|
151
|
+
return schema
|
152
|
+
|
153
|
+
|
154
|
+
def read_cached_atlas_schema(cache_dir: Path, sha: str, tf_cli_config_file: Path | None = None) -> AtlasSchemaInfo:
|
155
|
+
json_response_path = cache_dir / f"{sha}.json"
|
156
|
+
if not json_response_path.exists():
|
157
|
+
logger.info(f"Cache miss for sha = {sha}, parsing atlas schema")
|
158
|
+
return parse_atlas_schema(store_path=json_response_path, tf_cli_config_file=tf_cli_config_file)
|
159
|
+
parsed_dict = parse_dict(json_response_path)
|
160
|
+
return _parse_dict_schema(parsed_dict)
|
161
|
+
|
162
|
+
|
163
|
+
def parse_atlas_schema(store_path: Path | None = None, tf_cli_config_file: Path | None = None) -> AtlasSchemaInfo:
|
164
|
+
tf_cli_config_file_str = (
|
165
|
+
str(tf_cli_config_file) if tf_cli_config_file else os.environ.get(TF_CLI_CONFIG_FILE_ENV_NAME)
|
166
|
+
)
|
167
|
+
assert tf_cli_config_file_str, f"{TF_CLI_CONFIG_FILE_ENV_NAME} is required"
|
168
|
+
with TemporaryDirectory() as example_dir:
|
169
|
+
tmp_path = Path(example_dir)
|
170
|
+
providers_tf = tmp_path / "providers.tf"
|
171
|
+
providers_tf.write_text(_providers_tf)
|
172
|
+
run_and_wait("terraform init", cwd=example_dir)
|
173
|
+
schema_run = run_and_wait(
|
174
|
+
"terraform providers schema -json",
|
175
|
+
cwd=example_dir,
|
176
|
+
ansi_content=False,
|
177
|
+
env={
|
178
|
+
TF_CLI_CONFIG_FILE_ENV_NAME: tf_cli_config_file_str,
|
179
|
+
"MONGODB_ATLAS_PREVIEW_PROVIDER_V2_ADVANCED_CLUSTER": "true",
|
180
|
+
},
|
181
|
+
)
|
182
|
+
parsed_dict = schema_run.parse_output(dict, output_format="json")
|
183
|
+
if store_path:
|
184
|
+
ensure_parents_write_text(store_path, schema_run.stdout_one_line)
|
185
|
+
return _parse_dict_schema(parsed_dict)
|
186
|
+
|
187
|
+
|
188
|
+
def _parse_dict_schema(parsed: dict) -> AtlasSchemaInfo:
|
189
|
+
resource_schema = parse_provider_resource_schema(parsed, ATLAS_PROVIDER_NAME)
|
190
|
+
|
191
|
+
def is_deprecated(resource_details: dict) -> bool:
|
192
|
+
return resource_details["block"].get("deprecated", False)
|
193
|
+
|
194
|
+
deprecated_resource_types = [name for name, details in resource_schema.items() if is_deprecated(details)]
|
195
|
+
return AtlasSchemaInfo(
|
196
|
+
resource_types=sorted(resource_schema.keys()),
|
197
|
+
deprecated_resource_types=sorted(deprecated_resource_types),
|
198
|
+
raw_resource_schema=resource_schema,
|
199
|
+
)
|
@@ -0,0 +1,294 @@
|
|
1
|
+
import importlib.util
|
2
|
+
import inspect
|
3
|
+
import logging
|
4
|
+
import re
|
5
|
+
import sys
|
6
|
+
from dataclasses import Field, fields, is_dataclass
|
7
|
+
from pathlib import Path
|
8
|
+
from tempfile import TemporaryDirectory
|
9
|
+
from types import ModuleType
|
10
|
+
from typing import Any, Dict, Generic, Iterable, List, Literal, NamedTuple, Set, TypeVar, Union, get_args, get_origin
|
11
|
+
|
12
|
+
from zero_3rdparty import humps
|
13
|
+
from zero_3rdparty.file_utils import copy
|
14
|
+
|
15
|
+
logger = logging.getLogger(__name__)
|
16
|
+
|
17
|
+
|
18
|
+
def import_module_by_using_parents(file_path: Path) -> ModuleType:
|
19
|
+
with TemporaryDirectory() as tmp_dir:
|
20
|
+
tmp_dir_path = Path(tmp_dir)
|
21
|
+
module_path = tmp_dir_path / "tmp_module"
|
22
|
+
copy(file_path.parent, module_path)
|
23
|
+
init_py = module_path / "__init__.py"
|
24
|
+
if not init_py.exists():
|
25
|
+
init_py.write_text("")
|
26
|
+
# old_path = sys.path todo: reset after
|
27
|
+
sys.path.insert(0, tmp_dir)
|
28
|
+
logger.info("files in tmp_module: " + ", ".join((py.name for py in module_path.glob("*.py"))))
|
29
|
+
module = importlib.import_module(f"tmp_module.{file_path.stem}")
|
30
|
+
assert module
|
31
|
+
if inspect.ismodule(module):
|
32
|
+
return module
|
33
|
+
raise ImportError(f"Could not import module {file_path.stem} from {file_path}")
|
34
|
+
|
35
|
+
# sys.path.insert(0, str(module_path))
|
36
|
+
# try:
|
37
|
+
# module_spec = importlib.util.spec_from_file_location(module_path.name, init_py)
|
38
|
+
# assert module_spec
|
39
|
+
# assert module_spec.loader
|
40
|
+
# parent_module = importlib.util.module_from_spec(module_spec)
|
41
|
+
# module_spec.loader.exec_module(parent_module)
|
42
|
+
|
43
|
+
# spec_file = importlib.util.spec_from_file_location(dest_path.stem, dest_path)
|
44
|
+
# assert spec_file
|
45
|
+
# assert spec_file.loader
|
46
|
+
# module = importlib.util.module_from_spec(spec_file)
|
47
|
+
# module.__package__ = "tmp_module"
|
48
|
+
# spec_file.loader.exec_module(module)
|
49
|
+
# if inspect.ismodule(module):
|
50
|
+
# return module
|
51
|
+
# raise ImportError(f"Could not import module {file_path.stem} from {file_path}")
|
52
|
+
# except Exception as e:
|
53
|
+
# raise e
|
54
|
+
# finally:
|
55
|
+
# sys.path = old_path
|
56
|
+
|
57
|
+
|
58
|
+
def import_from_path(module_name: str, file_path: Path) -> ModuleType:
|
59
|
+
spec = importlib.util.spec_from_file_location(module_name, file_path)
|
60
|
+
assert spec
|
61
|
+
assert spec.loader
|
62
|
+
module = importlib.util.module_from_spec(spec)
|
63
|
+
spec.loader.exec_module(module)
|
64
|
+
return module
|
65
|
+
|
66
|
+
|
67
|
+
primitive_types = (str, float, bool, int)
|
68
|
+
|
69
|
+
|
70
|
+
def as_set(values: list[str]) -> str:
|
71
|
+
return f"{{{', '.join(repr(v) for v in values)}}}" if values else "set()"
|
72
|
+
|
73
|
+
|
74
|
+
def make_post_init_line_optional(field_name: str, elem_type: str, is_map: bool = False, is_list: bool = False) -> str:
|
75
|
+
if is_map:
|
76
|
+
return (
|
77
|
+
f" if self.{field_name} is not None:\n"
|
78
|
+
f" self.{field_name} = {{k:v if isinstance(v, {elem_type}) else {elem_type}(**v) for k, v in self.{field_name}.items()}}"
|
79
|
+
)
|
80
|
+
elif is_list:
|
81
|
+
return (
|
82
|
+
f" if self.{field_name} is not None:\n"
|
83
|
+
f" self.{field_name} = [x if isinstance(x, {elem_type}) else {elem_type}(**x) for x in self.{field_name}]"
|
84
|
+
)
|
85
|
+
else:
|
86
|
+
return (
|
87
|
+
f" if self.{field_name} is not None and not isinstance(self.{field_name}, {elem_type}):\n"
|
88
|
+
f' assert isinstance(self.{field_name}, dict), f"Expected {field_name} to be a {elem_type} or a dict, got {{type(self.{field_name})}}"\n'
|
89
|
+
f" self.{field_name} = {elem_type}(**self.{field_name})"
|
90
|
+
)
|
91
|
+
|
92
|
+
|
93
|
+
def make_post_init_line(field_name: str, elem_type: str, is_map: bool = False, is_list: bool = False) -> str:
|
94
|
+
if is_map:
|
95
|
+
return (
|
96
|
+
f' assert isinstance(self.{field_name}, dict), f"Expected {field_name} to be a dict, got {{type(self.{field_name})}}"\n'
|
97
|
+
f" self.{field_name} = {{k:v if isinstance(v, {elem_type}) else {elem_type}(**v) for k, v in self.{field_name}.items()}}"
|
98
|
+
)
|
99
|
+
elif is_list:
|
100
|
+
return (
|
101
|
+
f' assert isinstance(self.{field_name}, list), f"Expected {field_name} to be a list, got {{type(self.{field_name})}}"\n'
|
102
|
+
f" self.{field_name} = [x if isinstance(x, {elem_type}) else {elem_type}(**x) for x in self.{field_name}]"
|
103
|
+
)
|
104
|
+
else:
|
105
|
+
return (
|
106
|
+
f" if not isinstance(self.{field_name}, {elem_type}):\n"
|
107
|
+
f' assert isinstance(self.{field_name}, dict), f"Expected {field_name} to be a {elem_type} or a dict, got {{type(self.{field_name})}}"\n'
|
108
|
+
f" self.{field_name} = {elem_type}(**self.{field_name})"
|
109
|
+
)
|
110
|
+
|
111
|
+
|
112
|
+
class PrimitiveTypeError(Exception):
|
113
|
+
def __init__(self, type_: type):
|
114
|
+
self.type_ = type_
|
115
|
+
|
116
|
+
|
117
|
+
def make_post_init_line_from_field(field: Field) -> str:
|
118
|
+
try:
|
119
|
+
container_type = unwrap_type(field)
|
120
|
+
except PrimitiveTypeError:
|
121
|
+
return ""
|
122
|
+
make_func = make_post_init_line_optional if container_type.is_optional else make_post_init_line
|
123
|
+
return make_func(
|
124
|
+
field.name, container_type.type.__name__, is_map=container_type.is_dict, is_list=container_type.is_list
|
125
|
+
)
|
126
|
+
|
127
|
+
|
128
|
+
T = TypeVar("T")
|
129
|
+
|
130
|
+
|
131
|
+
class ContainerType(NamedTuple, Generic[T]):
|
132
|
+
type: type[T]
|
133
|
+
container_type: Literal[
|
134
|
+
"list", "set", "dict", "optional", "optional_list", "optional_set", "optional_dict", "cls_direct"
|
135
|
+
]
|
136
|
+
|
137
|
+
@property
|
138
|
+
def is_cls_direct(self) -> bool:
|
139
|
+
return self.container_type == "cls_direct"
|
140
|
+
|
141
|
+
@property
|
142
|
+
def is_list(self) -> bool:
|
143
|
+
return self.container_type in {"list", "optional_list"}
|
144
|
+
|
145
|
+
@property
|
146
|
+
def is_set(self) -> bool:
|
147
|
+
return self.container_type in {"set", "optional_set"}
|
148
|
+
|
149
|
+
@property
|
150
|
+
def is_dict(self) -> bool:
|
151
|
+
return self.container_type in {"dict", "optional_dict"}
|
152
|
+
|
153
|
+
@property
|
154
|
+
def is_optional(self) -> bool:
|
155
|
+
return self.container_type in {"optional", "optional_list", "optional_set", "optional_dict"}
|
156
|
+
|
157
|
+
@property
|
158
|
+
def is_any(self) -> bool:
|
159
|
+
return self.type is Any
|
160
|
+
|
161
|
+
|
162
|
+
def unwrap_type(field: Field) -> ContainerType:
|
163
|
+
field_type = field.type
|
164
|
+
origin = get_origin(field_type)
|
165
|
+
args = get_args(field_type)
|
166
|
+
return _unwrap_type(field_type, origin, args) # type: ignore
|
167
|
+
|
168
|
+
|
169
|
+
def _unwrap_type(field_type: type, origin: type, args: list[type]) -> ContainerType:
|
170
|
+
if origin is Union and type(None) in args:
|
171
|
+
non_none_args = [arg for arg in args if arg is not type(None)]
|
172
|
+
assert len(non_none_args) == 1, f"Expected one non-None type in Union, got {non_none_args}"
|
173
|
+
inner_type = non_none_args[0]
|
174
|
+
response = _unwrap_type(inner_type, get_origin(inner_type), get_args(inner_type)) # type: ignore
|
175
|
+
if response.is_cls_direct:
|
176
|
+
return ContainerType(response.type, "optional")
|
177
|
+
if response.is_list:
|
178
|
+
return ContainerType(response.type, "optional_list")
|
179
|
+
if response.is_set:
|
180
|
+
return ContainerType(response.type, "optional_set")
|
181
|
+
if response.is_dict:
|
182
|
+
return ContainerType(response.type, "optional_dict")
|
183
|
+
raise ValueError(f"Unsupported optional type: {inner_type}")
|
184
|
+
if origin in (list, List) and args:
|
185
|
+
item_type = args[0]
|
186
|
+
if item_type in primitive_types:
|
187
|
+
raise PrimitiveTypeError(item_type)
|
188
|
+
return ContainerType(item_type, "list")
|
189
|
+
if origin in (set, Set) and args:
|
190
|
+
item_type = args[0]
|
191
|
+
if item_type in primitive_types:
|
192
|
+
raise PrimitiveTypeError(item_type)
|
193
|
+
return ContainerType(item_type, "set")
|
194
|
+
if origin in (dict, Dict) and args:
|
195
|
+
_, value_type = args
|
196
|
+
if value_type in primitive_types:
|
197
|
+
raise PrimitiveTypeError(value_type)
|
198
|
+
return ContainerType(value_type, "dict")
|
199
|
+
if field_type in primitive_types:
|
200
|
+
raise PrimitiveTypeError(field_type)
|
201
|
+
assert not isinstance(field_type, str), f"Expected type, got {field_type!r}"
|
202
|
+
return ContainerType(field_type, "cls_direct")
|
203
|
+
|
204
|
+
|
205
|
+
def longest_common_substring_among_all(strings: list[str]) -> str:
|
206
|
+
from functools import reduce
|
207
|
+
|
208
|
+
strings = [s.lower() for s in strings]
|
209
|
+
|
210
|
+
def lcs(a, b):
|
211
|
+
m = [[0] * (1 + len(b)) for _ in range(1 + len(a))]
|
212
|
+
longest, x_longest = 0, 0
|
213
|
+
for x in range(1, 1 + len(a)):
|
214
|
+
for y in range(1, 1 + len(b)):
|
215
|
+
if a[x - 1] == b[y - 1]:
|
216
|
+
m[x][y] = m[x - 1][y - 1] + 1
|
217
|
+
if m[x][y] > longest:
|
218
|
+
longest = m[x][y]
|
219
|
+
x_longest = x
|
220
|
+
else:
|
221
|
+
m[x][y] = 0
|
222
|
+
return a[x_longest - longest : x_longest]
|
223
|
+
|
224
|
+
return humps.pascalize(reduce(lcs, strings).strip("_"))
|
225
|
+
|
226
|
+
|
227
|
+
_main_call = """
|
228
|
+
if __name__ == "__main__":
|
229
|
+
main()
|
230
|
+
"""
|
231
|
+
|
232
|
+
|
233
|
+
def move_main_call_to_end(file_path: Path) -> None:
|
234
|
+
text = file_path.read_text()
|
235
|
+
text = text.replace(_main_call, "")
|
236
|
+
file_path.write_text(text + _main_call)
|
237
|
+
|
238
|
+
|
239
|
+
class DataclassMatch(NamedTuple):
|
240
|
+
cls_name: str
|
241
|
+
match_context: str
|
242
|
+
index_start: int
|
243
|
+
index_end: int
|
244
|
+
|
245
|
+
|
246
|
+
def dataclass_matches(code: str, cls_name: str) -> Iterable[DataclassMatch]:
|
247
|
+
for match in dataclass_pattern(cls_name).finditer(code):
|
248
|
+
start = match.start()
|
249
|
+
end = code[start:].find("\n\n\n")
|
250
|
+
assert end > 0, f"unable to find end of dataclass: {cls_name}"
|
251
|
+
yield DataclassMatch(cls_name, code[start - 20 : start + 20], start, start + end + 3)
|
252
|
+
|
253
|
+
|
254
|
+
def dataclass_pattern(cls_name: str) -> re.Pattern:
|
255
|
+
return re.compile(rf"@dataclass\nclass {cls_name}(?P<base>\(\w+\))?:")
|
256
|
+
|
257
|
+
|
258
|
+
def dataclass_indexes(code: str, cls_name: str) -> tuple[int, int]:
|
259
|
+
matches = list(dataclass_matches(code, cls_name))
|
260
|
+
assert len(matches) == 1, f"expected exactly one dataclass match for {cls_name}, got {len(matches)}"
|
261
|
+
return matches[0].index_start, matches[0].index_end
|
262
|
+
|
263
|
+
|
264
|
+
def make_post_init(lines: list[str]) -> str:
|
265
|
+
return " def __post_init__(self):\n" + "\n".join(lines)
|
266
|
+
|
267
|
+
|
268
|
+
def module_dataclasses(module: ModuleType) -> dict[str, type]:
|
269
|
+
return {
|
270
|
+
name: maybe_dc
|
271
|
+
for name, maybe_dc in vars(module).items()
|
272
|
+
if is_dataclass(maybe_dc) and inspect.isclass(maybe_dc)
|
273
|
+
}
|
274
|
+
|
275
|
+
|
276
|
+
def ensure_dataclass_use_conversion(dataclasses: dict[str, type], file_path: Path, skip_filter: set[str]) -> None:
|
277
|
+
py_code = file_path.read_text()
|
278
|
+
for name, cls in dataclasses.items():
|
279
|
+
if name in skip_filter:
|
280
|
+
continue
|
281
|
+
post_init_lines = [extra_line for field in fields(cls) if (extra_line := make_post_init_line_from_field(field))]
|
282
|
+
if not post_init_lines:
|
283
|
+
continue
|
284
|
+
index_start, cls_def_end = dataclass_indexes(py_code, name)
|
285
|
+
old_dc_code = py_code[index_start:cls_def_end]
|
286
|
+
if "def __post_init__" in old_dc_code:
|
287
|
+
continue # already exists, don't touch it
|
288
|
+
insert_location = old_dc_code.find("\n\n")
|
289
|
+
assert insert_location > 0
|
290
|
+
new_dc_code = (
|
291
|
+
old_dc_code[:insert_location] + "\n\n" + make_post_init(post_init_lines) + old_dc_code[insert_location:]
|
292
|
+
)
|
293
|
+
py_code = py_code.replace(old_dc_code, new_dc_code)
|
294
|
+
file_path.write_text(py_code)
|