atlas-init 0.7.0__py3-none-any.whl → 0.8.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. atlas_init/__init__.py +1 -1
  2. atlas_init/atlas_init.yaml +1 -0
  3. atlas_init/cli_tf/example_update.py +20 -8
  4. atlas_init/cli_tf/hcl/modifier.py +22 -8
  5. atlas_init/settings/env_vars.py +12 -2
  6. atlas_init/tf_ext/api_call.py +9 -9
  7. atlas_init/tf_ext/args.py +16 -1
  8. atlas_init/tf_ext/gen_examples.py +141 -0
  9. atlas_init/tf_ext/gen_module_readme.py +131 -0
  10. atlas_init/tf_ext/gen_resource_main.py +195 -0
  11. atlas_init/tf_ext/gen_resource_output.py +71 -0
  12. atlas_init/tf_ext/gen_resource_variables.py +162 -0
  13. atlas_init/tf_ext/gen_versions.py +10 -0
  14. atlas_init/tf_ext/models_module.py +455 -0
  15. atlas_init/tf_ext/newres.py +90 -0
  16. atlas_init/tf_ext/plan_diffs.py +140 -0
  17. atlas_init/tf_ext/provider_schema.py +199 -0
  18. atlas_init/tf_ext/py_gen.py +294 -0
  19. atlas_init/tf_ext/schema_to_dataclass.py +522 -0
  20. atlas_init/tf_ext/settings.py +151 -2
  21. atlas_init/tf_ext/tf_dep.py +5 -5
  22. atlas_init/tf_ext/tf_desc_gen.py +53 -0
  23. atlas_init/tf_ext/tf_desc_update.py +0 -0
  24. atlas_init/tf_ext/tf_mod_gen.py +263 -0
  25. atlas_init/tf_ext/tf_mod_gen_provider.py +124 -0
  26. atlas_init/tf_ext/tf_modules.py +5 -4
  27. atlas_init/tf_ext/tf_vars.py +13 -28
  28. atlas_init/tf_ext/typer_app.py +6 -2
  29. {atlas_init-0.7.0.dist-info → atlas_init-0.8.1.dist-info}/METADATA +4 -3
  30. {atlas_init-0.7.0.dist-info → atlas_init-0.8.1.dist-info}/RECORD +33 -17
  31. {atlas_init-0.7.0.dist-info → atlas_init-0.8.1.dist-info}/WHEEL +0 -0
  32. {atlas_init-0.7.0.dist-info → atlas_init-0.8.1.dist-info}/entry_points.txt +0 -0
  33. {atlas_init-0.7.0.dist-info → atlas_init-0.8.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,199 @@
1
+ from __future__ import annotations
2
+ import logging
3
+ import os
4
+ from pathlib import Path
5
+ from tempfile import TemporaryDirectory
6
+
7
+ from ask_shell import run_and_wait
8
+ from model_lib import Entity, dump, parse_dict
9
+ from pydantic import BaseModel
10
+ from zero_3rdparty.file_utils import ensure_parents_write_text
11
+
12
+ from atlas_init.tf_ext.args import TF_CLI_CONFIG_FILE_ENV_NAME
13
+ from atlas_init.tf_ext.constants import ATLAS_PROVIDER_NAME
14
+ from atlas_init.tf_ext.models_module import ProviderGenConfig
15
+ from atlas_init.tf_ext.settings import TfExtSettings
16
+
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ def parse_provider_resource_schema(schema: dict, provider_name: str) -> dict:
22
+ schemas = schema.get("provider_schemas", {})
23
+ for provider_url, provider_schema in schemas.items():
24
+ if provider_url.endswith(provider_name):
25
+ return provider_schema.get("resource_schemas", {})
26
+ raise ValueError(f"Provider '{provider_name}' not found in schema.")
27
+
28
+
29
+ _providers_tf_with_external = """
30
+ terraform {
31
+ required_providers {
32
+ mongodbatlas = {
33
+ source = "mongodb/mongodbatlas"
34
+ version = "~> 1.26"
35
+ }
36
+ external = {
37
+ source = "hashicorp/external"
38
+ version = "~>2.0"
39
+ }
40
+ }
41
+ required_version = ">= 1.8"
42
+ }
43
+ """
44
+
45
+ _providers_tf = """
46
+ terraform {
47
+ required_providers {
48
+ mongodbatlas = {
49
+ source = "mongodb/mongodbatlas"
50
+ version = "~> 1.26"
51
+ }
52
+ }
53
+ required_version = ">= 1.8"
54
+ }
55
+ """
56
+ _providers_tf_minimal = """
57
+ terraform {
58
+ required_version = ">= 1.8"
59
+ }
60
+ """
61
+
62
+
63
+ def get_providers_tf(skip_python: bool = True, minimal: bool = False) -> str:
64
+ if minimal:
65
+ return _providers_tf_minimal
66
+ return _providers_tf if skip_python else _providers_tf_with_external
67
+
68
+
69
+ class AtlasSchemaInfo(Entity):
70
+ resource_types: list[str]
71
+ deprecated_resource_types: list[str]
72
+ raw_resource_schema: dict[str, dict]
73
+ providers_tf: str = _providers_tf
74
+
75
+
76
+ class SchemaAttribute(BaseModel):
77
+ type: str | list | dict | None = None
78
+ description: str | None = None
79
+ description_kind: str | None = None
80
+ optional: bool | None = None
81
+ required: bool | None = None
82
+ computed: bool | None = None
83
+ deprecated: bool | None = None
84
+ sensitive: bool | None = None
85
+ nested_type: SchemaBlock | None = None
86
+ default: object | None = None
87
+ enum: list[object] | None = None
88
+ allowed_values: list[object] | None = None
89
+ force_new: bool | None = None
90
+ conflicts_with: list[str] | None = None
91
+ exactly_one_of: list[str] | None = None
92
+ at_least_one_of: list[str] | None = None
93
+ required_with: list[str] | None = None
94
+ deprecated_message: str | None = None
95
+ validators: list[dict] | None = None
96
+ element_type: str | dict | None = None
97
+
98
+
99
+ class SchemaBlockType(BaseModel):
100
+ block: SchemaBlock
101
+ nesting_mode: str
102
+ min_items: int | None = None
103
+ max_items: int | None = None
104
+ required: bool | None = None
105
+ optional: bool | None = None
106
+ description_kind: str | None = None
107
+ deprecated: bool | None = None
108
+ description: str | None = None
109
+ default: object | None = None
110
+ validators: list[dict] | None = None
111
+
112
+ @property
113
+ def block_with_nesting_mode(self) -> SchemaBlock:
114
+ return self.block.model_copy(update={"nesting_mode": self.nesting_mode})
115
+
116
+
117
+ class SchemaBlock(BaseModel):
118
+ attributes: dict[str, SchemaAttribute] | None = None
119
+ block_types: dict[str, SchemaBlockType] | None = None
120
+ description_kind: str | None = None
121
+ description: str | None = None
122
+ deprecated: bool | None = None
123
+ nesting_mode: str | None = None
124
+
125
+
126
+ class ResourceSchema(BaseModel):
127
+ block: SchemaBlock
128
+ version: int | None = None
129
+ description_kind: str | None = None
130
+
131
+ def required_attributes(self) -> dict[str, SchemaAttribute]:
132
+ return {name: attr for name, attr in (self.block.attributes or {}).items() if attr.required}
133
+
134
+
135
+ SchemaAttribute.model_rebuild()
136
+ SchemaBlockType.model_rebuild()
137
+ SchemaBlock.model_rebuild()
138
+
139
+
140
+ def parse_atlas_schema_from_settings(settings: TfExtSettings, provider_config: ProviderGenConfig) -> AtlasSchemaInfo:
141
+ repo_path = settings.repo_path_atlas_provider
142
+ assert repo_path, "repo_path_atlas_provider is required"
143
+ current_sha = run_and_wait("git rev-parse HEAD", cwd=repo_path).stdout_one_line
144
+ cache_dir = settings.provider_cache_dir(provider_config.provider_name)
145
+ if provider_config.last_gen_sha == current_sha:
146
+ return read_cached_atlas_schema(cache_dir, current_sha, settings.tf_cli_config_file)
147
+ schema = parse_atlas_schema()
148
+ provider_config.last_gen_sha = current_sha
149
+ provider_yaml = dump(provider_config.config_dump(), "yaml")
150
+ settings.repo_out.provider_settings_path(provider_config.provider_name).write_text(provider_yaml)
151
+ return schema
152
+
153
+
154
+ def read_cached_atlas_schema(cache_dir: Path, sha: str, tf_cli_config_file: Path | None = None) -> AtlasSchemaInfo:
155
+ json_response_path = cache_dir / f"{sha}.json"
156
+ if not json_response_path.exists():
157
+ logger.info(f"Cache miss for sha = {sha}, parsing atlas schema")
158
+ return parse_atlas_schema(store_path=json_response_path, tf_cli_config_file=tf_cli_config_file)
159
+ parsed_dict = parse_dict(json_response_path)
160
+ return _parse_dict_schema(parsed_dict)
161
+
162
+
163
+ def parse_atlas_schema(store_path: Path | None = None, tf_cli_config_file: Path | None = None) -> AtlasSchemaInfo:
164
+ tf_cli_config_file_str = (
165
+ str(tf_cli_config_file) if tf_cli_config_file else os.environ.get(TF_CLI_CONFIG_FILE_ENV_NAME)
166
+ )
167
+ assert tf_cli_config_file_str, f"{TF_CLI_CONFIG_FILE_ENV_NAME} is required"
168
+ with TemporaryDirectory() as example_dir:
169
+ tmp_path = Path(example_dir)
170
+ providers_tf = tmp_path / "providers.tf"
171
+ providers_tf.write_text(_providers_tf)
172
+ run_and_wait("terraform init", cwd=example_dir)
173
+ schema_run = run_and_wait(
174
+ "terraform providers schema -json",
175
+ cwd=example_dir,
176
+ ansi_content=False,
177
+ env={
178
+ TF_CLI_CONFIG_FILE_ENV_NAME: tf_cli_config_file_str,
179
+ "MONGODB_ATLAS_PREVIEW_PROVIDER_V2_ADVANCED_CLUSTER": "true",
180
+ },
181
+ )
182
+ parsed_dict = schema_run.parse_output(dict, output_format="json")
183
+ if store_path:
184
+ ensure_parents_write_text(store_path, schema_run.stdout_one_line)
185
+ return _parse_dict_schema(parsed_dict)
186
+
187
+
188
+ def _parse_dict_schema(parsed: dict) -> AtlasSchemaInfo:
189
+ resource_schema = parse_provider_resource_schema(parsed, ATLAS_PROVIDER_NAME)
190
+
191
+ def is_deprecated(resource_details: dict) -> bool:
192
+ return resource_details["block"].get("deprecated", False)
193
+
194
+ deprecated_resource_types = [name for name, details in resource_schema.items() if is_deprecated(details)]
195
+ return AtlasSchemaInfo(
196
+ resource_types=sorted(resource_schema.keys()),
197
+ deprecated_resource_types=sorted(deprecated_resource_types),
198
+ raw_resource_schema=resource_schema,
199
+ )
@@ -0,0 +1,294 @@
1
+ import importlib.util
2
+ import inspect
3
+ import logging
4
+ import re
5
+ import sys
6
+ from dataclasses import Field, fields, is_dataclass
7
+ from pathlib import Path
8
+ from tempfile import TemporaryDirectory
9
+ from types import ModuleType
10
+ from typing import Any, Dict, Generic, Iterable, List, Literal, NamedTuple, Set, TypeVar, Union, get_args, get_origin
11
+
12
+ from zero_3rdparty import humps
13
+ from zero_3rdparty.file_utils import copy
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ def import_module_by_using_parents(file_path: Path) -> ModuleType:
19
+ with TemporaryDirectory() as tmp_dir:
20
+ tmp_dir_path = Path(tmp_dir)
21
+ module_path = tmp_dir_path / "tmp_module"
22
+ copy(file_path.parent, module_path)
23
+ init_py = module_path / "__init__.py"
24
+ if not init_py.exists():
25
+ init_py.write_text("")
26
+ # old_path = sys.path todo: reset after
27
+ sys.path.insert(0, tmp_dir)
28
+ logger.info("files in tmp_module: " + ", ".join((py.name for py in module_path.glob("*.py"))))
29
+ module = importlib.import_module(f"tmp_module.{file_path.stem}")
30
+ assert module
31
+ if inspect.ismodule(module):
32
+ return module
33
+ raise ImportError(f"Could not import module {file_path.stem} from {file_path}")
34
+
35
+ # sys.path.insert(0, str(module_path))
36
+ # try:
37
+ # module_spec = importlib.util.spec_from_file_location(module_path.name, init_py)
38
+ # assert module_spec
39
+ # assert module_spec.loader
40
+ # parent_module = importlib.util.module_from_spec(module_spec)
41
+ # module_spec.loader.exec_module(parent_module)
42
+
43
+ # spec_file = importlib.util.spec_from_file_location(dest_path.stem, dest_path)
44
+ # assert spec_file
45
+ # assert spec_file.loader
46
+ # module = importlib.util.module_from_spec(spec_file)
47
+ # module.__package__ = "tmp_module"
48
+ # spec_file.loader.exec_module(module)
49
+ # if inspect.ismodule(module):
50
+ # return module
51
+ # raise ImportError(f"Could not import module {file_path.stem} from {file_path}")
52
+ # except Exception as e:
53
+ # raise e
54
+ # finally:
55
+ # sys.path = old_path
56
+
57
+
58
+ def import_from_path(module_name: str, file_path: Path) -> ModuleType:
59
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
60
+ assert spec
61
+ assert spec.loader
62
+ module = importlib.util.module_from_spec(spec)
63
+ spec.loader.exec_module(module)
64
+ return module
65
+
66
+
67
+ primitive_types = (str, float, bool, int)
68
+
69
+
70
+ def as_set(values: list[str]) -> str:
71
+ return f"{{{', '.join(repr(v) for v in values)}}}" if values else "set()"
72
+
73
+
74
+ def make_post_init_line_optional(field_name: str, elem_type: str, is_map: bool = False, is_list: bool = False) -> str:
75
+ if is_map:
76
+ return (
77
+ f" if self.{field_name} is not None:\n"
78
+ f" self.{field_name} = {{k:v if isinstance(v, {elem_type}) else {elem_type}(**v) for k, v in self.{field_name}.items()}}"
79
+ )
80
+ elif is_list:
81
+ return (
82
+ f" if self.{field_name} is not None:\n"
83
+ f" self.{field_name} = [x if isinstance(x, {elem_type}) else {elem_type}(**x) for x in self.{field_name}]"
84
+ )
85
+ else:
86
+ return (
87
+ f" if self.{field_name} is not None and not isinstance(self.{field_name}, {elem_type}):\n"
88
+ f' assert isinstance(self.{field_name}, dict), f"Expected {field_name} to be a {elem_type} or a dict, got {{type(self.{field_name})}}"\n'
89
+ f" self.{field_name} = {elem_type}(**self.{field_name})"
90
+ )
91
+
92
+
93
+ def make_post_init_line(field_name: str, elem_type: str, is_map: bool = False, is_list: bool = False) -> str:
94
+ if is_map:
95
+ return (
96
+ f' assert isinstance(self.{field_name}, dict), f"Expected {field_name} to be a dict, got {{type(self.{field_name})}}"\n'
97
+ f" self.{field_name} = {{k:v if isinstance(v, {elem_type}) else {elem_type}(**v) for k, v in self.{field_name}.items()}}"
98
+ )
99
+ elif is_list:
100
+ return (
101
+ f' assert isinstance(self.{field_name}, list), f"Expected {field_name} to be a list, got {{type(self.{field_name})}}"\n'
102
+ f" self.{field_name} = [x if isinstance(x, {elem_type}) else {elem_type}(**x) for x in self.{field_name}]"
103
+ )
104
+ else:
105
+ return (
106
+ f" if not isinstance(self.{field_name}, {elem_type}):\n"
107
+ f' assert isinstance(self.{field_name}, dict), f"Expected {field_name} to be a {elem_type} or a dict, got {{type(self.{field_name})}}"\n'
108
+ f" self.{field_name} = {elem_type}(**self.{field_name})"
109
+ )
110
+
111
+
112
+ class PrimitiveTypeError(Exception):
113
+ def __init__(self, type_: type):
114
+ self.type_ = type_
115
+
116
+
117
+ def make_post_init_line_from_field(field: Field) -> str:
118
+ try:
119
+ container_type = unwrap_type(field)
120
+ except PrimitiveTypeError:
121
+ return ""
122
+ make_func = make_post_init_line_optional if container_type.is_optional else make_post_init_line
123
+ return make_func(
124
+ field.name, container_type.type.__name__, is_map=container_type.is_dict, is_list=container_type.is_list
125
+ )
126
+
127
+
128
+ T = TypeVar("T")
129
+
130
+
131
+ class ContainerType(NamedTuple, Generic[T]):
132
+ type: type[T]
133
+ container_type: Literal[
134
+ "list", "set", "dict", "optional", "optional_list", "optional_set", "optional_dict", "cls_direct"
135
+ ]
136
+
137
+ @property
138
+ def is_cls_direct(self) -> bool:
139
+ return self.container_type == "cls_direct"
140
+
141
+ @property
142
+ def is_list(self) -> bool:
143
+ return self.container_type in {"list", "optional_list"}
144
+
145
+ @property
146
+ def is_set(self) -> bool:
147
+ return self.container_type in {"set", "optional_set"}
148
+
149
+ @property
150
+ def is_dict(self) -> bool:
151
+ return self.container_type in {"dict", "optional_dict"}
152
+
153
+ @property
154
+ def is_optional(self) -> bool:
155
+ return self.container_type in {"optional", "optional_list", "optional_set", "optional_dict"}
156
+
157
+ @property
158
+ def is_any(self) -> bool:
159
+ return self.type is Any
160
+
161
+
162
+ def unwrap_type(field: Field) -> ContainerType:
163
+ field_type = field.type
164
+ origin = get_origin(field_type)
165
+ args = get_args(field_type)
166
+ return _unwrap_type(field_type, origin, args) # type: ignore
167
+
168
+
169
+ def _unwrap_type(field_type: type, origin: type, args: list[type]) -> ContainerType:
170
+ if origin is Union and type(None) in args:
171
+ non_none_args = [arg for arg in args if arg is not type(None)]
172
+ assert len(non_none_args) == 1, f"Expected one non-None type in Union, got {non_none_args}"
173
+ inner_type = non_none_args[0]
174
+ response = _unwrap_type(inner_type, get_origin(inner_type), get_args(inner_type)) # type: ignore
175
+ if response.is_cls_direct:
176
+ return ContainerType(response.type, "optional")
177
+ if response.is_list:
178
+ return ContainerType(response.type, "optional_list")
179
+ if response.is_set:
180
+ return ContainerType(response.type, "optional_set")
181
+ if response.is_dict:
182
+ return ContainerType(response.type, "optional_dict")
183
+ raise ValueError(f"Unsupported optional type: {inner_type}")
184
+ if origin in (list, List) and args:
185
+ item_type = args[0]
186
+ if item_type in primitive_types:
187
+ raise PrimitiveTypeError(item_type)
188
+ return ContainerType(item_type, "list")
189
+ if origin in (set, Set) and args:
190
+ item_type = args[0]
191
+ if item_type in primitive_types:
192
+ raise PrimitiveTypeError(item_type)
193
+ return ContainerType(item_type, "set")
194
+ if origin in (dict, Dict) and args:
195
+ _, value_type = args
196
+ if value_type in primitive_types:
197
+ raise PrimitiveTypeError(value_type)
198
+ return ContainerType(value_type, "dict")
199
+ if field_type in primitive_types:
200
+ raise PrimitiveTypeError(field_type)
201
+ assert not isinstance(field_type, str), f"Expected type, got {field_type!r}"
202
+ return ContainerType(field_type, "cls_direct")
203
+
204
+
205
+ def longest_common_substring_among_all(strings: list[str]) -> str:
206
+ from functools import reduce
207
+
208
+ strings = [s.lower() for s in strings]
209
+
210
+ def lcs(a, b):
211
+ m = [[0] * (1 + len(b)) for _ in range(1 + len(a))]
212
+ longest, x_longest = 0, 0
213
+ for x in range(1, 1 + len(a)):
214
+ for y in range(1, 1 + len(b)):
215
+ if a[x - 1] == b[y - 1]:
216
+ m[x][y] = m[x - 1][y - 1] + 1
217
+ if m[x][y] > longest:
218
+ longest = m[x][y]
219
+ x_longest = x
220
+ else:
221
+ m[x][y] = 0
222
+ return a[x_longest - longest : x_longest]
223
+
224
+ return humps.pascalize(reduce(lcs, strings).strip("_"))
225
+
226
+
227
+ _main_call = """
228
+ if __name__ == "__main__":
229
+ main()
230
+ """
231
+
232
+
233
+ def move_main_call_to_end(file_path: Path) -> None:
234
+ text = file_path.read_text()
235
+ text = text.replace(_main_call, "")
236
+ file_path.write_text(text + _main_call)
237
+
238
+
239
+ class DataclassMatch(NamedTuple):
240
+ cls_name: str
241
+ match_context: str
242
+ index_start: int
243
+ index_end: int
244
+
245
+
246
+ def dataclass_matches(code: str, cls_name: str) -> Iterable[DataclassMatch]:
247
+ for match in dataclass_pattern(cls_name).finditer(code):
248
+ start = match.start()
249
+ end = code[start:].find("\n\n\n")
250
+ assert end > 0, f"unable to find end of dataclass: {cls_name}"
251
+ yield DataclassMatch(cls_name, code[start - 20 : start + 20], start, start + end + 3)
252
+
253
+
254
+ def dataclass_pattern(cls_name: str) -> re.Pattern:
255
+ return re.compile(rf"@dataclass\nclass {cls_name}(?P<base>\(\w+\))?:")
256
+
257
+
258
+ def dataclass_indexes(code: str, cls_name: str) -> tuple[int, int]:
259
+ matches = list(dataclass_matches(code, cls_name))
260
+ assert len(matches) == 1, f"expected exactly one dataclass match for {cls_name}, got {len(matches)}"
261
+ return matches[0].index_start, matches[0].index_end
262
+
263
+
264
+ def make_post_init(lines: list[str]) -> str:
265
+ return " def __post_init__(self):\n" + "\n".join(lines)
266
+
267
+
268
+ def module_dataclasses(module: ModuleType) -> dict[str, type]:
269
+ return {
270
+ name: maybe_dc
271
+ for name, maybe_dc in vars(module).items()
272
+ if is_dataclass(maybe_dc) and inspect.isclass(maybe_dc)
273
+ }
274
+
275
+
276
+ def ensure_dataclass_use_conversion(dataclasses: dict[str, type], file_path: Path, skip_filter: set[str]) -> None:
277
+ py_code = file_path.read_text()
278
+ for name, cls in dataclasses.items():
279
+ if name in skip_filter:
280
+ continue
281
+ post_init_lines = [extra_line for field in fields(cls) if (extra_line := make_post_init_line_from_field(field))]
282
+ if not post_init_lines:
283
+ continue
284
+ index_start, cls_def_end = dataclass_indexes(py_code, name)
285
+ old_dc_code = py_code[index_start:cls_def_end]
286
+ if "def __post_init__" in old_dc_code:
287
+ continue # already exists, don't touch it
288
+ insert_location = old_dc_code.find("\n\n")
289
+ assert insert_location > 0
290
+ new_dc_code = (
291
+ old_dc_code[:insert_location] + "\n\n" + make_post_init(post_init_lines) + old_dc_code[insert_location:]
292
+ )
293
+ py_code = py_code.replace(old_dc_code, new_dc_code)
294
+ file_path.write_text(py_code)