atlas-init 0.1.1__py3-none-any.whl → 0.1.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- atlas_init/__init__.py +3 -3
- atlas_init/atlas_init.yaml +18 -1
- atlas_init/cli.py +62 -70
- atlas_init/cli_cfn/app.py +40 -117
- atlas_init/cli_cfn/{cfn.py → aws.py} +129 -14
- atlas_init/cli_cfn/cfn_parameter_finder.py +89 -6
- atlas_init/cli_cfn/example.py +203 -0
- atlas_init/cli_cfn/files.py +63 -0
- atlas_init/cli_helper/run.py +18 -2
- atlas_init/cli_helper/tf_runner.py +4 -6
- atlas_init/cli_root/__init__.py +0 -0
- atlas_init/cli_root/trigger.py +153 -0
- atlas_init/cli_tf/app.py +211 -4
- atlas_init/cli_tf/changelog.py +103 -0
- atlas_init/cli_tf/debug_logs.py +221 -0
- atlas_init/cli_tf/debug_logs_test_data.py +253 -0
- atlas_init/cli_tf/github_logs.py +229 -0
- atlas_init/cli_tf/go_test_run.py +194 -0
- atlas_init/cli_tf/go_test_run_format.py +31 -0
- atlas_init/cli_tf/go_test_summary.py +144 -0
- atlas_init/cli_tf/hcl/__init__.py +0 -0
- atlas_init/cli_tf/hcl/cli.py +161 -0
- atlas_init/cli_tf/hcl/cluster_mig.py +348 -0
- atlas_init/cli_tf/hcl/parser.py +140 -0
- atlas_init/cli_tf/schema.py +222 -18
- atlas_init/cli_tf/schema_go_parser.py +236 -0
- atlas_init/cli_tf/schema_table.py +150 -0
- atlas_init/cli_tf/schema_table_models.py +155 -0
- atlas_init/cli_tf/schema_v2.py +599 -0
- atlas_init/cli_tf/schema_v2_api_parsing.py +298 -0
- atlas_init/cli_tf/schema_v2_sdk.py +361 -0
- atlas_init/cli_tf/schema_v3.py +222 -0
- atlas_init/cli_tf/schema_v3_sdk.py +279 -0
- atlas_init/cli_tf/schema_v3_sdk_base.py +68 -0
- atlas_init/cli_tf/schema_v3_sdk_create.py +216 -0
- atlas_init/humps.py +253 -0
- atlas_init/repos/cfn.py +6 -1
- atlas_init/repos/path.py +3 -3
- atlas_init/settings/config.py +14 -4
- atlas_init/settings/env_vars.py +16 -1
- atlas_init/settings/path.py +12 -1
- atlas_init/settings/rich_utils.py +2 -0
- atlas_init/terraform.yaml +77 -1
- atlas_init/tf/.terraform.lock.hcl +59 -83
- atlas_init/tf/always.tf +7 -0
- atlas_init/tf/main.tf +3 -0
- atlas_init/tf/modules/aws_s3/provider.tf +1 -1
- atlas_init/tf/modules/aws_vars/aws_vars.tf +2 -0
- atlas_init/tf/modules/aws_vpc/provider.tf +4 -1
- atlas_init/tf/modules/cfn/cfn.tf +47 -33
- atlas_init/tf/modules/cfn/kms.tf +54 -0
- atlas_init/tf/modules/cfn/resource_actions.yaml +1 -0
- atlas_init/tf/modules/cfn/variables.tf +31 -0
- atlas_init/tf/modules/cloud_provider/cloud_provider.tf +1 -0
- atlas_init/tf/modules/cloud_provider/provider.tf +1 -1
- atlas_init/tf/modules/cluster/cluster.tf +34 -24
- atlas_init/tf/modules/cluster/provider.tf +1 -1
- atlas_init/tf/modules/federated_vars/federated_vars.tf +3 -0
- atlas_init/tf/modules/federated_vars/provider.tf +1 -1
- atlas_init/tf/modules/project_extra/project_extra.tf +15 -1
- atlas_init/tf/modules/stream_instance/stream_instance.tf +1 -1
- atlas_init/tf/modules/vpc_peering/vpc_peering.tf +1 -1
- atlas_init/tf/modules/vpc_privatelink/versions.tf +1 -1
- atlas_init/tf/outputs.tf +11 -3
- atlas_init/tf/providers.tf +2 -1
- atlas_init/tf/variables.tf +12 -0
- atlas_init/typer_app.py +76 -0
- {atlas_init-0.1.1.dist-info → atlas_init-0.1.8.dist-info}/METADATA +36 -18
- atlas_init-0.1.8.dist-info/RECORD +91 -0
- {atlas_init-0.1.1.dist-info → atlas_init-0.1.8.dist-info}/WHEEL +1 -1
- atlas_init-0.1.1.dist-info/RECORD +0 -62
- /atlas_init/tf/modules/aws_vpc/{aws-vpc.tf → aws_vpc.tf} +0 -0
- {atlas_init-0.1.1.dist-info → atlas_init-0.1.8.dist-info}/entry_points.txt +0 -0
atlas_init/cli_tf/schema.py
CHANGED
@@ -1,10 +1,14 @@
|
|
1
1
|
import logging
|
2
|
+
from collections.abc import Iterable
|
3
|
+
from functools import singledispatch
|
2
4
|
from pathlib import Path
|
3
|
-
from typing import Literal
|
5
|
+
from typing import Annotated, Literal, NamedTuple
|
4
6
|
|
5
7
|
import pydantic
|
6
8
|
import requests
|
7
9
|
from model_lib import Entity, dump, field_names, parse_model
|
10
|
+
from zero_3rdparty import dict_nested
|
11
|
+
from zero_3rdparty.enum_utils import StrEnum
|
8
12
|
|
9
13
|
logger = logging.getLogger(__name__)
|
10
14
|
|
@@ -23,10 +27,53 @@ class ProviderSpecAttribute(Entity):
|
|
23
27
|
return self.model_dump(exclude_none=True)
|
24
28
|
|
25
29
|
|
30
|
+
class IgnoreNested(Entity):
|
31
|
+
type: Literal["ignore_nested"] = "ignore_nested"
|
32
|
+
path: str
|
33
|
+
|
34
|
+
@property
|
35
|
+
def use_wildcard(self) -> bool:
|
36
|
+
return "*" in self.path
|
37
|
+
|
38
|
+
|
39
|
+
class RenameAttribute(Entity):
|
40
|
+
type: Literal["rename_attribute"] = "rename_attribute"
|
41
|
+
from_name: str
|
42
|
+
to_name: str
|
43
|
+
|
44
|
+
|
45
|
+
class ComputedOptionalRequired(StrEnum):
|
46
|
+
COMPUTED_OPTIONAL = "computed_optional"
|
47
|
+
REQUIRED = "required"
|
48
|
+
COMPUTED = "computed"
|
49
|
+
OPTIONAL = "optional"
|
50
|
+
|
51
|
+
|
52
|
+
class ChangeAttributeType(Entity):
|
53
|
+
type: Literal["change_attribute_type"] = "change_attribute_type"
|
54
|
+
path: str
|
55
|
+
new_value: ComputedOptionalRequired
|
56
|
+
|
57
|
+
@classmethod
|
58
|
+
def read_value(cls, attribute_dict: dict) -> str:
|
59
|
+
return attribute_dict["string"]["computed_optional_required"]
|
60
|
+
|
61
|
+
def update_value(self, attribute_dict: dict) -> None:
|
62
|
+
attribute_dict["string"]["computed_optional_required"] = self.new_value
|
63
|
+
|
64
|
+
|
65
|
+
class SkipValidators(Entity):
|
66
|
+
type: Literal["skip_validators"] = "skip_validators"
|
67
|
+
|
68
|
+
|
69
|
+
Extension = Annotated[IgnoreNested | RenameAttribute | ChangeAttributeType | SkipValidators, pydantic.Field("type")]
|
70
|
+
|
71
|
+
|
26
72
|
class TFResource(Entity):
|
27
73
|
model_config = pydantic.ConfigDict(extra="allow")
|
28
74
|
name: str
|
29
|
-
|
75
|
+
extensions: list[Extension] = pydantic.Field(default_factory=list)
|
76
|
+
provider_spec_attributes: list[ProviderSpecAttribute] = pydantic.Field(default_factory=list)
|
30
77
|
|
31
78
|
def dump_generator_config(self) -> dict:
|
32
79
|
names = field_names(self)
|
@@ -35,6 +82,7 @@ class TFResource(Entity):
|
|
35
82
|
|
36
83
|
class PyTerraformSchema(Entity):
|
37
84
|
resources: list[TFResource]
|
85
|
+
data_sources: list[TFResource] = pydantic.Field(default_factory=list)
|
38
86
|
|
39
87
|
def resource(self, resource: str) -> TFResource:
|
40
88
|
return next(r for r in self.resources if r.name == resource)
|
@@ -48,27 +96,89 @@ def dump_generator_config(schema: PyTerraformSchema) -> str:
|
|
48
96
|
resources = {}
|
49
97
|
for resource in schema.resources:
|
50
98
|
resources[resource.name] = resource.dump_generator_config()
|
99
|
+
data_sources = {ds.name: ds.dump_generator_config() for ds in schema.data_sources}
|
51
100
|
generator_config = {
|
52
101
|
"provider": {"name": "mongodbatlas"},
|
53
102
|
"resources": resources,
|
103
|
+
"data_sources": data_sources,
|
54
104
|
}
|
55
105
|
return dump(generator_config, "yaml")
|
56
106
|
|
57
107
|
|
108
|
+
class AttributeTuple(NamedTuple):
|
109
|
+
name: str
|
110
|
+
path: str
|
111
|
+
attribute_dict: dict
|
112
|
+
|
113
|
+
@property
|
114
|
+
def attribute_path(self) -> str:
|
115
|
+
return f"{self.path}.{self.name}" if self.path else self.name
|
116
|
+
|
117
|
+
|
58
118
|
class ProviderCodeSpec(Entity):
|
59
119
|
model_config = pydantic.ConfigDict(extra="allow")
|
60
120
|
provider: dict
|
61
121
|
resources: list[dict]
|
122
|
+
datasources: list[dict] = pydantic.Field(default_factory=list)
|
62
123
|
version: str
|
63
124
|
|
64
|
-
def
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
125
|
+
def root_dict(self, name: str, is_datasource: bool = False) -> dict: # noqa: FBT002
|
126
|
+
resources = self.datasources if is_datasource else self.resources
|
127
|
+
root_value = next((r for r in resources if r["name"] == name), None)
|
128
|
+
if root_value is None:
|
129
|
+
raise ValueError(f"{self.root_name(name, is_datasource)} not found!")
|
130
|
+
return root_value
|
131
|
+
|
132
|
+
def schema_attributes(self, name: str, is_datasource: bool = False) -> list: # noqa: FBT002
|
133
|
+
root_dict = self.root_dict(name, is_datasource)
|
134
|
+
return root_dict["schema"]["attributes"]
|
135
|
+
|
136
|
+
def _type_name(self, is_datasource: bool):
|
137
|
+
return "datasource" if is_datasource else "resource"
|
69
138
|
|
70
|
-
def
|
71
|
-
return
|
139
|
+
def root_name(self, name: str, is_datasource: bool):
|
140
|
+
return f"{self._type_name(is_datasource)}.{name}"
|
141
|
+
|
142
|
+
def attribute_names(self, name: str, is_datasource: bool = False) -> list[str]: # noqa: FBT002
|
143
|
+
return [a["name"] for a in self.schema_attributes(name, is_datasource=is_datasource)]
|
144
|
+
|
145
|
+
def iter_all_attributes(self, name: str, is_datasource: bool = False) -> Iterable[AttributeTuple]: # noqa: FBT002
|
146
|
+
for attribute in self.schema_attributes(name=name, is_datasource=is_datasource):
|
147
|
+
yield AttributeTuple(attribute["name"], "", attribute)
|
148
|
+
yield from self.iter_nested_attributes(name, is_datasource=is_datasource)
|
149
|
+
|
150
|
+
def iter_nested_attributes(self, name: str, is_datasource: bool = False) -> Iterable[AttributeTuple]: # noqa: FBT002
|
151
|
+
for i, attribute in enumerate(self.schema_attributes(name=name, is_datasource=is_datasource)):
|
152
|
+
for path, attr_dict in dict_nested.iter_nested_key_values(
|
153
|
+
attribute, type_filter=dict, include_list_indexes=True
|
154
|
+
):
|
155
|
+
full_path = f"[{i}].{path}"
|
156
|
+
if name := attr_dict.get("name", ""):
|
157
|
+
yield AttributeTuple(name, full_path, attr_dict)
|
158
|
+
|
159
|
+
def remove_nested_attribute(self, name: str, path: str, is_datasource: bool = False) -> None: # noqa: FBT002
|
160
|
+
root_name = self.root_name(name, is_datasource)
|
161
|
+
logger.info(f"will remove attribute from {root_name} with path: {path}")
|
162
|
+
root_attributes = self.root_dict(name, is_datasource)
|
163
|
+
full_path = f"schema.attributes.{path}"
|
164
|
+
popped = dict_nested.pop_nested(root_attributes, full_path, "")
|
165
|
+
if popped == "":
|
166
|
+
raise ValueError(f"failed to remove attribute from resource {name} with path: {full_path}")
|
167
|
+
assert isinstance(popped, dict), f"expected removed attribute to be a dict, got: {popped}"
|
168
|
+
logger.info(f"removal ok, attribute_name: '{root_name}.{popped.get('name')}'")
|
169
|
+
|
170
|
+
def read_attribute(self, name: str, path: str, *, is_datasource: bool = False) -> dict:
|
171
|
+
if "." not in path:
|
172
|
+
attribute_dict = next((a for a in self.schema_attributes(name, is_datasource) if a["name"] == path), None)
|
173
|
+
else:
|
174
|
+
root_dict = self.root_dict(name, is_datasource)
|
175
|
+
attribute_dict = dict_nested.read_nested_or_none(root_dict, f"schema.attributes.{path}")
|
176
|
+
if attribute_dict is None:
|
177
|
+
raise ValueError(f"attribute {path} not found in {self.root_name(name, is_datasource)}")
|
178
|
+
assert isinstance(
|
179
|
+
attribute_dict, dict
|
180
|
+
), f"expected attribute to be a dict, got: {attribute_dict} @ {path} for resource={name}"
|
181
|
+
return attribute_dict
|
72
182
|
|
73
183
|
|
74
184
|
def update_provider_code_spec(schema: PyTerraformSchema, provider_code_spec_path: Path) -> str:
|
@@ -76,21 +186,115 @@ def update_provider_code_spec(schema: PyTerraformSchema, provider_code_spec_path
|
|
76
186
|
for resource in schema.resources:
|
77
187
|
resource_name = resource.name
|
78
188
|
if extra_spec_attributes := resource.provider_spec_attributes:
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
189
|
+
add_explicit_attributes(spec, resource_name, extra_spec_attributes)
|
190
|
+
for extension in resource.extensions:
|
191
|
+
apply_extension(extension, spec, resource_name)
|
192
|
+
for data_source in schema.data_sources:
|
193
|
+
data_source_name = data_source.name
|
194
|
+
if extra_spec_attributes := data_source.provider_spec_attributes:
|
195
|
+
add_explicit_attributes(spec, data_source_name, extra_spec_attributes, is_datasource=True)
|
196
|
+
for extension in data_source.extensions:
|
197
|
+
apply_extension(extension, spec, data_source_name, is_datasource=True)
|
85
198
|
return dump(spec, "json")
|
86
199
|
|
87
200
|
|
201
|
+
def add_explicit_attributes(
|
202
|
+
spec: ProviderCodeSpec, name: str, extra_spec_attributes: list[ProviderSpecAttribute], *, is_datasource=False
|
203
|
+
):
|
204
|
+
resource_attributes = spec.schema_attributes(name, is_datasource=is_datasource)
|
205
|
+
existing_names = spec.attribute_names(name, is_datasource=is_datasource)
|
206
|
+
new_names = [extra.name for extra in extra_spec_attributes]
|
207
|
+
if both := set(existing_names) & set(new_names):
|
208
|
+
raise ValueError(f"resource: {name}, has already: {both} attributes")
|
209
|
+
resource_attributes.extend(extra.dump_provider_code_spec() for extra in extra_spec_attributes)
|
210
|
+
|
211
|
+
|
212
|
+
@singledispatch
|
213
|
+
def apply_extension(extension: object, spec: ProviderCodeSpec, resource_name: str, *, is_datasource: bool = False): # noqa: ARG001
|
214
|
+
raise NotImplementedError(f"unsupported extension: {extension!r}")
|
215
|
+
|
216
|
+
|
217
|
+
@apply_extension.register # type: ignore
|
218
|
+
def _ignore_nested(extension: IgnoreNested, spec: ProviderCodeSpec, resource_name: str, *, is_datasource: bool = False):
|
219
|
+
if extension.use_wildcard:
|
220
|
+
name_to_remove = extension.path.removeprefix("*.")
|
221
|
+
assert "*" not in name_to_remove, f"only prefix *. is allowed for wildcard in path {extension.path}"
|
222
|
+
found_paths = [
|
223
|
+
path
|
224
|
+
for name, path, attribute_dict in spec.iter_nested_attributes(resource_name, is_datasource=is_datasource)
|
225
|
+
if name == name_to_remove
|
226
|
+
]
|
227
|
+
while found_paths:
|
228
|
+
next_to_remove = found_paths.pop()
|
229
|
+
spec.remove_nested_attribute(resource_name, next_to_remove, is_datasource=is_datasource)
|
230
|
+
found_paths = [
|
231
|
+
path
|
232
|
+
for name, path, attribute_dict in spec.iter_nested_attributes(
|
233
|
+
resource_name, is_datasource=is_datasource
|
234
|
+
)
|
235
|
+
if name == name_to_remove
|
236
|
+
]
|
237
|
+
else:
|
238
|
+
err_msg = "only wildcard path is supported"
|
239
|
+
raise NotImplementedError(err_msg)
|
240
|
+
|
241
|
+
|
242
|
+
@apply_extension.register # type: ignore
|
243
|
+
def _rename_attribute(
|
244
|
+
extension: RenameAttribute, spec: ProviderCodeSpec, resource_name: str, *, is_datasource: bool = False
|
245
|
+
):
|
246
|
+
for attribute_dict in spec.schema_attributes(resource_name, is_datasource=is_datasource):
|
247
|
+
if attribute_dict.get("name") == extension.from_name:
|
248
|
+
logger.info(
|
249
|
+
f"renaming attribute for {spec.root_name(resource_name, is_datasource)}: {extension.from_name} -> {extension.to_name}"
|
250
|
+
)
|
251
|
+
attribute_dict["name"] = extension.to_name
|
252
|
+
|
253
|
+
|
254
|
+
@apply_extension.register # type: ignore
|
255
|
+
def _change_attribute_type(
|
256
|
+
extension: ChangeAttributeType, spec: ProviderCodeSpec, resource_name: str, *, is_datasource: bool = False
|
257
|
+
):
|
258
|
+
attribute_dict = spec.read_attribute(resource_name, extension.path, is_datasource=is_datasource)
|
259
|
+
old_value = extension.read_value(attribute_dict)
|
260
|
+
if old_value == extension.new_value:
|
261
|
+
logger.info(
|
262
|
+
f"no change for '{spec.root_name(resource_name, is_datasource)}': {extension.path} -> {extension.new_value}"
|
263
|
+
)
|
264
|
+
return
|
265
|
+
|
266
|
+
logger.info(
|
267
|
+
f"changing attribute type for '{spec.root_name(resource_name, is_datasource)}.{extension.path}': {old_value} -> {extension.new_value}"
|
268
|
+
)
|
269
|
+
extension.update_value(attribute_dict)
|
270
|
+
|
271
|
+
|
272
|
+
@apply_extension.register # type: ignore
|
273
|
+
def _skip_validators(_: SkipValidators, spec: ProviderCodeSpec, resource_name: str, *, is_datasource: bool = False):
|
274
|
+
for attr_tuple in spec.iter_all_attributes(resource_name, is_datasource=is_datasource):
|
275
|
+
attribute_dict = attr_tuple.attribute_dict
|
276
|
+
paths_to_pop = [
|
277
|
+
f"{path}.validators"
|
278
|
+
for path, nested_dict in dict_nested.iter_nested_key_values(attribute_dict, type_filter=dict)
|
279
|
+
if "validators" in nested_dict
|
280
|
+
]
|
281
|
+
if paths_to_pop:
|
282
|
+
logger.info(f"popping validators from '{attr_tuple.attribute_path}'")
|
283
|
+
for path in paths_to_pop:
|
284
|
+
dict_nested.pop_nested(attribute_dict, path)
|
285
|
+
|
286
|
+
|
88
287
|
# reusing url from terraform-provider-mongodbatlas/scripts/schema-scaffold.sh
|
89
288
|
ADMIN_API_URL = "https://raw.githubusercontent.com/mongodb/atlas-sdk-go/main/openapi/atlas-api-transformed.yaml"
|
90
289
|
|
91
290
|
|
92
|
-
def
|
93
|
-
|
94
|
-
|
291
|
+
def admin_api_url(branch: str) -> str:
|
292
|
+
return ADMIN_API_URL.replace("/main/", f"/{branch}/")
|
293
|
+
|
294
|
+
|
295
|
+
def download_admin_api(dest: Path, branch: str = "main") -> None:
|
296
|
+
url = admin_api_url(branch)
|
297
|
+
logger.info(f"downloading admin api to {dest} from {url}")
|
298
|
+
response = requests.get(url, timeout=10)
|
95
299
|
response.raise_for_status()
|
96
300
|
dest.write_bytes(response.content)
|
@@ -0,0 +1,236 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import logging
|
4
|
+
import re
|
5
|
+
from collections import defaultdict
|
6
|
+
from typing import NamedTuple
|
7
|
+
|
8
|
+
from model_lib import Entity
|
9
|
+
from pydantic import Field
|
10
|
+
|
11
|
+
from atlas_init.cli_tf.schema_table_models import (
|
12
|
+
AttrRefLine,
|
13
|
+
FuncCallLine,
|
14
|
+
TFSchemaAttribute,
|
15
|
+
)
|
16
|
+
|
17
|
+
logger = logging.getLogger(__name__)
|
18
|
+
|
19
|
+
|
20
|
+
def parse_attribute_ref(
|
21
|
+
name: str, rest: str, go_code: str, code_lines: list[str], ref_line_nr: int
|
22
|
+
) -> TFSchemaAttribute | None:
|
23
|
+
attr_ref = rest.lstrip("&").rstrip(",").strip()
|
24
|
+
if not attr_ref.isidentifier():
|
25
|
+
return None
|
26
|
+
try:
|
27
|
+
_instantiate_regex = re.compile(rf"{attr_ref}\s=\sschema\.\w+\{{$", re.M)
|
28
|
+
except re.error:
|
29
|
+
return None
|
30
|
+
instantiate_match = _instantiate_regex.search(go_code)
|
31
|
+
if not instantiate_match:
|
32
|
+
return None
|
33
|
+
line_start_nr = go_code[: instantiate_match.start()].count("\n") + 1
|
34
|
+
line_start = code_lines[line_start_nr]
|
35
|
+
attribute = parse_attribute_lines(code_lines, line_start_nr, line_start, name, is_attr_ref=True)
|
36
|
+
attribute.attr_ref_line = AttrRefLine(line_nr=ref_line_nr, attr_ref=attr_ref)
|
37
|
+
return attribute
|
38
|
+
|
39
|
+
|
40
|
+
def parse_func_call_line(
|
41
|
+
name: str, rest: str, lines: list[str], go_code: str, call_line_nr: int
|
42
|
+
) -> TFSchemaAttribute | None:
|
43
|
+
func_def_line = _function_line(rest, go_code)
|
44
|
+
if not func_def_line:
|
45
|
+
return None
|
46
|
+
func_name, args = rest.split("(", maxsplit=1)
|
47
|
+
func_start, func_end = _func_lines(name, lines, func_def_line)
|
48
|
+
call = FuncCallLine(
|
49
|
+
call_line_nr=call_line_nr,
|
50
|
+
func_name=func_name.strip(),
|
51
|
+
args=args.removesuffix("),").strip(),
|
52
|
+
func_line_start=func_start,
|
53
|
+
func_line_end=func_end,
|
54
|
+
)
|
55
|
+
return TFSchemaAttribute(
|
56
|
+
name=name,
|
57
|
+
lines=lines[func_start:func_end],
|
58
|
+
line_start=func_start,
|
59
|
+
line_end=func_end,
|
60
|
+
func_call=call,
|
61
|
+
indent="\t",
|
62
|
+
)
|
63
|
+
|
64
|
+
|
65
|
+
def _func_lines(name: str, lines: list[str], func_def_line: str) -> tuple[int, int]:
|
66
|
+
start_line = lines.index(func_def_line)
|
67
|
+
for line_nr, line in enumerate(lines[start_line + 1 :], start=start_line + 1):
|
68
|
+
if line.rstrip() == "}":
|
69
|
+
return start_line, line_nr
|
70
|
+
raise ValueError(f"no end line found for {name} on line {start_line}: {func_def_line}")
|
71
|
+
|
72
|
+
|
73
|
+
def _function_line(rest: str, go_code: str) -> str:
|
74
|
+
function_name = rest.split("(")[0].strip()
|
75
|
+
pattern = re.compile(rf"func {function_name}\(.*\) \*?schema\.\w+ \{{$", re.M)
|
76
|
+
match = pattern.search(go_code)
|
77
|
+
if not match:
|
78
|
+
return ""
|
79
|
+
return go_code[match.start() : match.end()]
|
80
|
+
|
81
|
+
|
82
|
+
def parse_attribute_lines(
|
83
|
+
lines: list[str],
|
84
|
+
line_nr: int,
|
85
|
+
line: str,
|
86
|
+
name: str,
|
87
|
+
*,
|
88
|
+
is_attr_ref: bool = False,
|
89
|
+
) -> TFSchemaAttribute:
|
90
|
+
indents = len(line) - len(line.lstrip())
|
91
|
+
indent = indents * "\t"
|
92
|
+
end_line = f"{indent}}}" if is_attr_ref else f"{indent}}},"
|
93
|
+
for extra_lines, next_line in enumerate(lines[line_nr + 1 :], start=1):
|
94
|
+
if next_line == end_line:
|
95
|
+
return TFSchemaAttribute(
|
96
|
+
name=name,
|
97
|
+
lines=lines[line_nr : line_nr + extra_lines],
|
98
|
+
line_start=line_nr,
|
99
|
+
line_end=line_nr + extra_lines,
|
100
|
+
indent=indent,
|
101
|
+
)
|
102
|
+
raise ValueError(f"no end line found for {name}, starting on line {line_nr}")
|
103
|
+
|
104
|
+
|
105
|
+
_schema_attribute_go_regex = re.compile(
|
106
|
+
r'^\s+"(?P<name>[^"]+)":\s(?P<rest>.+)$',
|
107
|
+
)
|
108
|
+
|
109
|
+
|
110
|
+
def find_attributes(go_code: str) -> list[TFSchemaAttribute]:
|
111
|
+
lines = ["", *go_code.splitlines()] # support line_nr indexing
|
112
|
+
attributes = []
|
113
|
+
for line_nr, line in enumerate(lines):
|
114
|
+
match = _schema_attribute_go_regex.match(line)
|
115
|
+
if not match:
|
116
|
+
continue
|
117
|
+
name = match.group("name")
|
118
|
+
rest = match.group("rest")
|
119
|
+
if rest.endswith("),"):
|
120
|
+
if attr := parse_func_call_line(name, rest, lines, go_code, line_nr):
|
121
|
+
attributes.append(attr)
|
122
|
+
elif attr := parse_attribute_ref(name, rest, go_code, lines, line_nr):
|
123
|
+
attributes.append(attr)
|
124
|
+
else:
|
125
|
+
try:
|
126
|
+
attr = parse_attribute_lines(lines, line_nr, line, name)
|
127
|
+
except ValueError as e:
|
128
|
+
logger.warning(e)
|
129
|
+
continue
|
130
|
+
if not attr.type:
|
131
|
+
continue
|
132
|
+
attributes.append(attr)
|
133
|
+
set_attribute_paths(attributes)
|
134
|
+
return attributes
|
135
|
+
|
136
|
+
|
137
|
+
class StartEnd(NamedTuple):
|
138
|
+
start: int
|
139
|
+
end: int
|
140
|
+
name: str
|
141
|
+
func_call_line: FuncCallLine | None
|
142
|
+
|
143
|
+
def has_parent(self, other: StartEnd) -> bool:
|
144
|
+
if self.name == other.name:
|
145
|
+
return False
|
146
|
+
if func_call := self.func_call_line:
|
147
|
+
func_call_line = func_call.call_line_nr
|
148
|
+
return other.start < func_call_line < other.end
|
149
|
+
return self.start > other.start and self.end < other.end
|
150
|
+
|
151
|
+
|
152
|
+
def set_attribute_paths(attributes: list[TFSchemaAttribute]) -> list[TFSchemaAttribute]:
|
153
|
+
start_stops = [StartEnd(a.line_start, a.line_end, a.name, a.func_call) for a in attributes]
|
154
|
+
overlaps = [
|
155
|
+
(attribute, [other for other in start_stops if start_stop.has_parent(other)])
|
156
|
+
for attribute, start_stop in zip(attributes, start_stops, strict=False)
|
157
|
+
]
|
158
|
+
for attribute, others in overlaps:
|
159
|
+
if not others:
|
160
|
+
attribute.attribute_path = attribute.name
|
161
|
+
continue
|
162
|
+
overlaps = defaultdict(list)
|
163
|
+
for other in others:
|
164
|
+
overlaps[(other.start, other.end)].append(other.name)
|
165
|
+
paths = []
|
166
|
+
for names in overlaps.values():
|
167
|
+
if len(names) == 1:
|
168
|
+
paths.append(names[0])
|
169
|
+
else:
|
170
|
+
paths.append(f"({'|'.join(names)})")
|
171
|
+
paths.append(attribute.name)
|
172
|
+
attribute.attribute_path = ".".join(paths)
|
173
|
+
return attributes
|
174
|
+
|
175
|
+
|
176
|
+
class GoSchemaFunc(Entity):
|
177
|
+
name: str
|
178
|
+
line_start: int
|
179
|
+
line_end: int
|
180
|
+
call_attributes: list[TFSchemaAttribute] = Field(default_factory=list)
|
181
|
+
attributes: list[TFSchemaAttribute] = Field(default_factory=list)
|
182
|
+
|
183
|
+
@property
|
184
|
+
def attribute_names(self) -> set[str]:
|
185
|
+
return {a.name for a in self.call_attributes}
|
186
|
+
|
187
|
+
@property
|
188
|
+
def attribute_paths(self) -> str:
|
189
|
+
paths = set()
|
190
|
+
for a in self.call_attributes:
|
191
|
+
path = ".".join(a.parent_attribute_names())
|
192
|
+
paths.add(path)
|
193
|
+
return f"({'|'.join(paths)})" if len(paths) > 1 else paths.pop()
|
194
|
+
|
195
|
+
def contains_attribute(self, attribute: TFSchemaAttribute) -> bool:
|
196
|
+
names = self.attribute_names
|
197
|
+
return any(parent_attribute in names for parent_attribute in attribute.parent_attribute_names())
|
198
|
+
|
199
|
+
|
200
|
+
def find_schema_functions(attributes: list[TFSchemaAttribute]) -> list[GoSchemaFunc]:
|
201
|
+
function_call_attributes = defaultdict(list)
|
202
|
+
for a in attributes:
|
203
|
+
if a.is_function_call:
|
204
|
+
call = a.func_call
|
205
|
+
assert call
|
206
|
+
function_call_attributes[call.func_name].append(a)
|
207
|
+
root_function = GoSchemaFunc(name="", line_start=0, line_end=0)
|
208
|
+
functions: list[GoSchemaFunc] = [
|
209
|
+
GoSchemaFunc(
|
210
|
+
name=name,
|
211
|
+
line_start=func_attributes[0].line_start,
|
212
|
+
line_end=func_attributes[0].line_end,
|
213
|
+
call_attributes=func_attributes,
|
214
|
+
)
|
215
|
+
for name, func_attributes in function_call_attributes.items()
|
216
|
+
]
|
217
|
+
for attribute in attributes:
|
218
|
+
if match_functions := [func for func in functions if func.contains_attribute(attribute)]:
|
219
|
+
func_names = [func.name for func in match_functions]
|
220
|
+
err_msg = f"multiple functions found for {attribute.name}, {func_names}"
|
221
|
+
assert len(match_functions) == 1, err_msg
|
222
|
+
function = match_functions[0]
|
223
|
+
function.attributes.append(attribute)
|
224
|
+
attribute.absolute_attribute_path = f"{function.attribute_paths}.{attribute.attribute_path}".lstrip(".")
|
225
|
+
else:
|
226
|
+
root_function.attributes.append(attribute)
|
227
|
+
attribute.absolute_attribute_path = attribute.attribute_path
|
228
|
+
return [root_function, *functions]
|
229
|
+
|
230
|
+
|
231
|
+
def parse_schema_functions(
|
232
|
+
go_code: str,
|
233
|
+
) -> tuple[list[TFSchemaAttribute], list[GoSchemaFunc]]:
|
234
|
+
attributes = find_attributes(go_code)
|
235
|
+
functions = find_schema_functions(attributes)
|
236
|
+
return sorted(attributes), functions
|
@@ -0,0 +1,150 @@
|
|
1
|
+
# import typer
|
2
|
+
|
3
|
+
|
4
|
+
from collections import defaultdict
|
5
|
+
from collections.abc import Iterable
|
6
|
+
from functools import total_ordering
|
7
|
+
from pathlib import Path
|
8
|
+
from typing import Literal, TypeAlias
|
9
|
+
|
10
|
+
from model_lib import Entity, Event
|
11
|
+
from pydantic import Field, model_validator
|
12
|
+
from zero_3rdparty import iter_utils
|
13
|
+
|
14
|
+
from atlas_init.cli_tf.schema_go_parser import parse_schema_functions
|
15
|
+
from atlas_init.cli_tf.schema_table_models import TFSchemaAttribute, TFSchemaTableColumn
|
16
|
+
from atlas_init.settings.path import default_factory_cwd
|
17
|
+
|
18
|
+
|
19
|
+
def default_table_columns() -> list[TFSchemaTableColumn]:
|
20
|
+
return [TFSchemaTableColumn.Computability]
|
21
|
+
|
22
|
+
|
23
|
+
def file_name_path(file: str) -> tuple[str, Path]:
|
24
|
+
if ":" in file:
|
25
|
+
file, path = file.split(":", 1)
|
26
|
+
return file, Path(path)
|
27
|
+
path = Path(file)
|
28
|
+
return f"{path.parent.name}/{path.stem}"[:20], path
|
29
|
+
|
30
|
+
|
31
|
+
@total_ordering
|
32
|
+
class TFSchemaSrc(Event):
|
33
|
+
name: str
|
34
|
+
file_path: Path | None = None
|
35
|
+
url: str = ""
|
36
|
+
|
37
|
+
@model_validator(mode="after")
|
38
|
+
def validate(self):
|
39
|
+
assert self.file_path or self.url, "must provide file path or url"
|
40
|
+
if self.file_path:
|
41
|
+
assert self.file_path.exists(), f"file does not exist for {self.name}: {self.file_path}"
|
42
|
+
return self
|
43
|
+
|
44
|
+
def __lt__(self, other) -> bool:
|
45
|
+
if not isinstance(other, TFSchemaSrc):
|
46
|
+
raise TypeError
|
47
|
+
return self.name < other.name
|
48
|
+
|
49
|
+
def go_code(self) -> str:
|
50
|
+
if path := self.file_path:
|
51
|
+
return path.read_text()
|
52
|
+
raise NotImplementedError
|
53
|
+
|
54
|
+
|
55
|
+
TableOutputFormats: TypeAlias = Literal["md"]
|
56
|
+
|
57
|
+
|
58
|
+
class TFSchemaTableInput(Entity):
|
59
|
+
sources: list[TFSchemaSrc] = Field(default_factory=list)
|
60
|
+
output_format: TableOutputFormats = "md"
|
61
|
+
output_path: Path = Field(default_factory=default_factory_cwd("schema_table.md"))
|
62
|
+
columns: list[TFSchemaTableColumn] = Field(default_factory=default_table_columns)
|
63
|
+
explode_rows: bool = False
|
64
|
+
|
65
|
+
@model_validator(mode="after")
|
66
|
+
def validate(self):
|
67
|
+
assert self.columns, "must provide at least 1 column"
|
68
|
+
self.columns = sorted(self.columns)
|
69
|
+
assert self.sources, "must provide at least 1 source"
|
70
|
+
self.sources = sorted(self.sources)
|
71
|
+
assert len(self.sources) == len(set(self.sources)), f"duplicate source names: {self.source_names}"
|
72
|
+
return self
|
73
|
+
|
74
|
+
@property
|
75
|
+
def source_names(self) -> list[str]:
|
76
|
+
return [s.name for s in self.sources]
|
77
|
+
|
78
|
+
def headers(self) -> list[str]:
|
79
|
+
return ["Attribute Name"] + [f"{name}-{col}" for name in self.source_names for col in self.columns]
|
80
|
+
|
81
|
+
|
82
|
+
@total_ordering
|
83
|
+
class TFSchemaTableData(Event):
|
84
|
+
source: TFSchemaSrc
|
85
|
+
schema_path: str = "" # e.g., "" is root, "replication_specs.region_config"
|
86
|
+
attributes: list[TFSchemaAttribute] = Field(default_factory=list)
|
87
|
+
|
88
|
+
@property
|
89
|
+
def id(self) -> str:
|
90
|
+
return f"{self.schema_path}:{self.source.name}"
|
91
|
+
|
92
|
+
def __lt__(self, other) -> bool:
|
93
|
+
if not isinstance(other, TFSchemaTableData):
|
94
|
+
raise TypeError
|
95
|
+
return self.id < other.id
|
96
|
+
|
97
|
+
|
98
|
+
def sorted_schema_paths(schema_paths: Iterable[str]) -> list[str]:
|
99
|
+
return sorted(schema_paths, key=lambda x: (x.count("."), x.split(".")[-1]))
|
100
|
+
|
101
|
+
|
102
|
+
class RawTable(Event):
|
103
|
+
columns: list[str]
|
104
|
+
rows: list[list[str]]
|
105
|
+
|
106
|
+
|
107
|
+
def merge_tables(config: TFSchemaTableInput, schema_path: str, tables: list[TFSchemaTableData]) -> RawTable:
|
108
|
+
if schema_path != "":
|
109
|
+
raise NotImplementedError
|
110
|
+
columns = config.headers()
|
111
|
+
if len(tables) > 1:
|
112
|
+
err_msg = "only 1 table per schema path supported"
|
113
|
+
raise NotImplementedError(err_msg)
|
114
|
+
table = tables[0]
|
115
|
+
rows = [[attr.absolute_attribute_path, *attr.row(config.columns)] for attr in table.attributes]
|
116
|
+
return RawTable(columns=columns, rows=rows)
|
117
|
+
|
118
|
+
|
119
|
+
def format_table(table: RawTable, table_format: TableOutputFormats) -> list[str]:
|
120
|
+
# sourcery skip: merge-list-append
|
121
|
+
assert table_format == "md", "only markdown format supported"
|
122
|
+
lines = []
|
123
|
+
lines.append("|".join(table.columns))
|
124
|
+
lines.append("|".join(["---"] * len(table.columns)))
|
125
|
+
lines.extend("|".join(row) for row in table.rows)
|
126
|
+
return lines
|
127
|
+
|
128
|
+
|
129
|
+
def explode_attributes(attributes: list[TFSchemaAttribute]) -> list[TFSchemaAttribute]:
|
130
|
+
return sorted(iter_utils.flat_map(attr.explode() for attr in attributes))
|
131
|
+
|
132
|
+
|
133
|
+
def schema_table(config: TFSchemaTableInput) -> str:
|
134
|
+
path_tables: dict[str, list[TFSchemaTableData]] = defaultdict(list)
|
135
|
+
for source in config.sources:
|
136
|
+
go_code = source.go_code()
|
137
|
+
attributes, functions = parse_schema_functions(go_code)
|
138
|
+
if config.explode_rows:
|
139
|
+
attributes = explode_attributes(attributes)
|
140
|
+
schema_path = "" # using only root for now
|
141
|
+
path_tables[schema_path].append(
|
142
|
+
TFSchemaTableData(source=source, attributes=attributes, schema_path=schema_path)
|
143
|
+
)
|
144
|
+
output_lines = []
|
145
|
+
for schema_path in sorted_schema_paths(path_tables.keys()):
|
146
|
+
tables = path_tables[schema_path]
|
147
|
+
table = merge_tables(config, schema_path, tables)
|
148
|
+
output_lines.extend(["", f"## {schema_path or 'Root'}", ""])
|
149
|
+
output_lines.extend(format_table(table, table_format=config.output_format))
|
150
|
+
return "\n".join(output_lines)
|