atlas-init 0.1.0__py3-none-any.whl → 0.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. atlas_init/__init__.py +3 -3
  2. atlas_init/atlas_init.yaml +51 -34
  3. atlas_init/cli.py +76 -72
  4. atlas_init/cli_cfn/app.py +40 -117
  5. atlas_init/cli_cfn/{cfn.py → aws.py} +129 -14
  6. atlas_init/cli_cfn/cfn_parameter_finder.py +89 -6
  7. atlas_init/cli_cfn/example.py +203 -0
  8. atlas_init/cli_cfn/files.py +63 -0
  9. atlas_init/cli_helper/go.py +6 -3
  10. atlas_init/cli_helper/run.py +18 -2
  11. atlas_init/cli_helper/tf_runner.py +12 -21
  12. atlas_init/cli_root/__init__.py +0 -0
  13. atlas_init/cli_root/trigger.py +153 -0
  14. atlas_init/cli_tf/app.py +211 -4
  15. atlas_init/cli_tf/changelog.py +103 -0
  16. atlas_init/cli_tf/debug_logs.py +221 -0
  17. atlas_init/cli_tf/debug_logs_test_data.py +253 -0
  18. atlas_init/cli_tf/github_logs.py +229 -0
  19. atlas_init/cli_tf/go_test_run.py +194 -0
  20. atlas_init/cli_tf/go_test_run_format.py +31 -0
  21. atlas_init/cli_tf/go_test_summary.py +144 -0
  22. atlas_init/cli_tf/hcl/__init__.py +0 -0
  23. atlas_init/cli_tf/hcl/cli.py +161 -0
  24. atlas_init/cli_tf/hcl/cluster_mig.py +348 -0
  25. atlas_init/cli_tf/hcl/parser.py +140 -0
  26. atlas_init/cli_tf/schema.py +222 -18
  27. atlas_init/cli_tf/schema_go_parser.py +236 -0
  28. atlas_init/cli_tf/schema_table.py +150 -0
  29. atlas_init/cli_tf/schema_table_models.py +155 -0
  30. atlas_init/cli_tf/schema_v2.py +599 -0
  31. atlas_init/cli_tf/schema_v2_api_parsing.py +298 -0
  32. atlas_init/cli_tf/schema_v2_sdk.py +361 -0
  33. atlas_init/cli_tf/schema_v3.py +222 -0
  34. atlas_init/cli_tf/schema_v3_sdk.py +279 -0
  35. atlas_init/cli_tf/schema_v3_sdk_base.py +68 -0
  36. atlas_init/cli_tf/schema_v3_sdk_create.py +216 -0
  37. atlas_init/humps.py +253 -0
  38. atlas_init/repos/cfn.py +6 -1
  39. atlas_init/repos/path.py +3 -3
  40. atlas_init/settings/config.py +30 -11
  41. atlas_init/settings/env_vars.py +29 -3
  42. atlas_init/settings/path.py +12 -1
  43. atlas_init/settings/rich_utils.py +39 -2
  44. atlas_init/terraform.yaml +77 -1
  45. atlas_init/tf/.terraform.lock.hcl +125 -0
  46. atlas_init/tf/always.tf +11 -2
  47. atlas_init/tf/main.tf +3 -0
  48. atlas_init/tf/modules/aws_s3/provider.tf +1 -1
  49. atlas_init/tf/modules/aws_vars/aws_vars.tf +2 -0
  50. atlas_init/tf/modules/aws_vpc/provider.tf +4 -1
  51. atlas_init/tf/modules/cfn/cfn.tf +47 -33
  52. atlas_init/tf/modules/cfn/kms.tf +54 -0
  53. atlas_init/tf/modules/cfn/resource_actions.yaml +1 -0
  54. atlas_init/tf/modules/cfn/variables.tf +31 -0
  55. atlas_init/tf/modules/cloud_provider/cloud_provider.tf +1 -0
  56. atlas_init/tf/modules/cloud_provider/provider.tf +1 -1
  57. atlas_init/tf/modules/cluster/cluster.tf +34 -24
  58. atlas_init/tf/modules/cluster/provider.tf +1 -1
  59. atlas_init/tf/modules/federated_vars/federated_vars.tf +3 -0
  60. atlas_init/tf/modules/federated_vars/provider.tf +1 -1
  61. atlas_init/tf/modules/project_extra/project_extra.tf +15 -1
  62. atlas_init/tf/modules/stream_instance/stream_instance.tf +1 -1
  63. atlas_init/tf/modules/vpc_peering/vpc_peering.tf +1 -1
  64. atlas_init/tf/modules/vpc_privatelink/versions.tf +1 -1
  65. atlas_init/tf/outputs.tf +11 -3
  66. atlas_init/tf/providers.tf +2 -1
  67. atlas_init/tf/variables.tf +17 -0
  68. atlas_init/typer_app.py +76 -0
  69. {atlas_init-0.1.0.dist-info → atlas_init-0.1.4.dist-info}/METADATA +58 -21
  70. atlas_init-0.1.4.dist-info/RECORD +91 -0
  71. {atlas_init-0.1.0.dist-info → atlas_init-0.1.4.dist-info}/WHEEL +1 -1
  72. atlas_init-0.1.0.dist-info/RECORD +0 -61
  73. /atlas_init/tf/modules/aws_vpc/{aws-vpc.tf → aws_vpc.tf} +0 -0
  74. {atlas_init-0.1.0.dist-info → atlas_init-0.1.4.dist-info}/entry_points.txt +0 -0
@@ -1,10 +1,14 @@
1
1
  import logging
2
+ from collections.abc import Iterable
3
+ from functools import singledispatch
2
4
  from pathlib import Path
3
- from typing import Literal
5
+ from typing import Annotated, Literal, NamedTuple
4
6
 
5
7
  import pydantic
6
8
  import requests
7
9
  from model_lib import Entity, dump, field_names, parse_model
10
+ from zero_3rdparty import dict_nested
11
+ from zero_3rdparty.enum_utils import StrEnum
8
12
 
9
13
  logger = logging.getLogger(__name__)
10
14
 
@@ -23,10 +27,53 @@ class ProviderSpecAttribute(Entity):
23
27
  return self.model_dump(exclude_none=True)
24
28
 
25
29
 
30
+ class IgnoreNested(Entity):
31
+ type: Literal["ignore_nested"] = "ignore_nested"
32
+ path: str
33
+
34
+ @property
35
+ def use_wildcard(self) -> bool:
36
+ return "*" in self.path
37
+
38
+
39
+ class RenameAttribute(Entity):
40
+ type: Literal["rename_attribute"] = "rename_attribute"
41
+ from_name: str
42
+ to_name: str
43
+
44
+
45
+ class ComputedOptionalRequired(StrEnum):
46
+ COMPUTED_OPTIONAL = "computed_optional"
47
+ REQUIRED = "required"
48
+ COMPUTED = "computed"
49
+ OPTIONAL = "optional"
50
+
51
+
52
+ class ChangeAttributeType(Entity):
53
+ type: Literal["change_attribute_type"] = "change_attribute_type"
54
+ path: str
55
+ new_value: ComputedOptionalRequired
56
+
57
+ @classmethod
58
+ def read_value(cls, attribute_dict: dict) -> str:
59
+ return attribute_dict["string"]["computed_optional_required"]
60
+
61
+ def update_value(self, attribute_dict: dict) -> None:
62
+ attribute_dict["string"]["computed_optional_required"] = self.new_value
63
+
64
+
65
+ class SkipValidators(Entity):
66
+ type: Literal["skip_validators"] = "skip_validators"
67
+
68
+
69
+ Extension = Annotated[IgnoreNested | RenameAttribute | ChangeAttributeType | SkipValidators, pydantic.Field("type")]
70
+
71
+
26
72
  class TFResource(Entity):
27
73
  model_config = pydantic.ConfigDict(extra="allow")
28
74
  name: str
29
- provider_spec_attributes: list[ProviderSpecAttribute]
75
+ extensions: list[Extension] = pydantic.Field(default_factory=list)
76
+ provider_spec_attributes: list[ProviderSpecAttribute] = pydantic.Field(default_factory=list)
30
77
 
31
78
  def dump_generator_config(self) -> dict:
32
79
  names = field_names(self)
@@ -35,6 +82,7 @@ class TFResource(Entity):
35
82
 
36
83
  class PyTerraformSchema(Entity):
37
84
  resources: list[TFResource]
85
+ data_sources: list[TFResource] = pydantic.Field(default_factory=list)
38
86
 
39
87
  def resource(self, resource: str) -> TFResource:
40
88
  return next(r for r in self.resources if r.name == resource)
@@ -48,27 +96,89 @@ def dump_generator_config(schema: PyTerraformSchema) -> str:
48
96
  resources = {}
49
97
  for resource in schema.resources:
50
98
  resources[resource.name] = resource.dump_generator_config()
99
+ data_sources = {ds.name: ds.dump_generator_config() for ds in schema.data_sources}
51
100
  generator_config = {
52
101
  "provider": {"name": "mongodbatlas"},
53
102
  "resources": resources,
103
+ "data_sources": data_sources,
54
104
  }
55
105
  return dump(generator_config, "yaml")
56
106
 
57
107
 
108
+ class AttributeTuple(NamedTuple):
109
+ name: str
110
+ path: str
111
+ attribute_dict: dict
112
+
113
+ @property
114
+ def attribute_path(self) -> str:
115
+ return f"{self.path}.{self.name}" if self.path else self.name
116
+
117
+
58
118
  class ProviderCodeSpec(Entity):
59
119
  model_config = pydantic.ConfigDict(extra="allow")
60
120
  provider: dict
61
121
  resources: list[dict]
122
+ datasources: list[dict] = pydantic.Field(default_factory=list)
62
123
  version: str
63
124
 
64
- def resource_attributes(self, name: str) -> list:
65
- for r in self.resources:
66
- if r["name"] == name:
67
- return r["schema"]["attributes"]
68
- raise ValueError(f"resource: {name} not found!")
125
+ def root_dict(self, name: str, is_datasource: bool = False) -> dict: # noqa: FBT002
126
+ resources = self.datasources if is_datasource else self.resources
127
+ root_value = next((r for r in resources if r["name"] == name), None)
128
+ if root_value is None:
129
+ raise ValueError(f"{self.root_name(name, is_datasource)} not found!")
130
+ return root_value
131
+
132
+ def schema_attributes(self, name: str, is_datasource: bool = False) -> list: # noqa: FBT002
133
+ root_dict = self.root_dict(name, is_datasource)
134
+ return root_dict["schema"]["attributes"]
135
+
136
+ def _type_name(self, is_datasource: bool):
137
+ return "datasource" if is_datasource else "resource"
69
138
 
70
- def resource_attribute_names(self, name: str) -> list[str]:
71
- return [a["name"] for a in self.resource_attributes(name)]
139
+ def root_name(self, name: str, is_datasource: bool):
140
+ return f"{self._type_name(is_datasource)}.{name}"
141
+
142
+ def attribute_names(self, name: str, is_datasource: bool = False) -> list[str]: # noqa: FBT002
143
+ return [a["name"] for a in self.schema_attributes(name, is_datasource=is_datasource)]
144
+
145
+ def iter_all_attributes(self, name: str, is_datasource: bool = False) -> Iterable[AttributeTuple]: # noqa: FBT002
146
+ for attribute in self.schema_attributes(name=name, is_datasource=is_datasource):
147
+ yield AttributeTuple(attribute["name"], "", attribute)
148
+ yield from self.iter_nested_attributes(name, is_datasource=is_datasource)
149
+
150
+ def iter_nested_attributes(self, name: str, is_datasource: bool = False) -> Iterable[AttributeTuple]: # noqa: FBT002
151
+ for i, attribute in enumerate(self.schema_attributes(name=name, is_datasource=is_datasource)):
152
+ for path, attr_dict in dict_nested.iter_nested_key_values(
153
+ attribute, type_filter=dict, include_list_indexes=True
154
+ ):
155
+ full_path = f"[{i}].{path}"
156
+ if name := attr_dict.get("name", ""):
157
+ yield AttributeTuple(name, full_path, attr_dict)
158
+
159
+ def remove_nested_attribute(self, name: str, path: str, is_datasource: bool = False) -> None: # noqa: FBT002
160
+ root_name = self.root_name(name, is_datasource)
161
+ logger.info(f"will remove attribute from {root_name} with path: {path}")
162
+ root_attributes = self.root_dict(name, is_datasource)
163
+ full_path = f"schema.attributes.{path}"
164
+ popped = dict_nested.pop_nested(root_attributes, full_path, "")
165
+ if popped == "":
166
+ raise ValueError(f"failed to remove attribute from resource {name} with path: {full_path}")
167
+ assert isinstance(popped, dict), f"expected removed attribute to be a dict, got: {popped}"
168
+ logger.info(f"removal ok, attribute_name: '{root_name}.{popped.get('name')}'")
169
+
170
+ def read_attribute(self, name: str, path: str, *, is_datasource: bool = False) -> dict:
171
+ if "." not in path:
172
+ attribute_dict = next((a for a in self.schema_attributes(name, is_datasource) if a["name"] == path), None)
173
+ else:
174
+ root_dict = self.root_dict(name, is_datasource)
175
+ attribute_dict = dict_nested.read_nested_or_none(root_dict, f"schema.attributes.{path}")
176
+ if attribute_dict is None:
177
+ raise ValueError(f"attribute {path} not found in {self.root_name(name, is_datasource)}")
178
+ assert isinstance(
179
+ attribute_dict, dict
180
+ ), f"expected attribute to be a dict, got: {attribute_dict} @ {path} for resource={name}"
181
+ return attribute_dict
72
182
 
73
183
 
74
184
  def update_provider_code_spec(schema: PyTerraformSchema, provider_code_spec_path: Path) -> str:
@@ -76,21 +186,115 @@ def update_provider_code_spec(schema: PyTerraformSchema, provider_code_spec_path
76
186
  for resource in schema.resources:
77
187
  resource_name = resource.name
78
188
  if extra_spec_attributes := resource.provider_spec_attributes:
79
- resource_attributes = spec.resource_attributes(resource_name)
80
- existing_names = spec.resource_attribute_names(resource_name)
81
- new_names = [extra.name for extra in extra_spec_attributes]
82
- if both := set(existing_names) & set(new_names):
83
- raise ValueError(f"resource: {resource_name}, has already: {both} attributes")
84
- resource_attributes.extend(extra.dump_provider_code_spec() for extra in extra_spec_attributes)
189
+ add_explicit_attributes(spec, resource_name, extra_spec_attributes)
190
+ for extension in resource.extensions:
191
+ apply_extension(extension, spec, resource_name)
192
+ for data_source in schema.data_sources:
193
+ data_source_name = data_source.name
194
+ if extra_spec_attributes := data_source.provider_spec_attributes:
195
+ add_explicit_attributes(spec, data_source_name, extra_spec_attributes, is_datasource=True)
196
+ for extension in data_source.extensions:
197
+ apply_extension(extension, spec, data_source_name, is_datasource=True)
85
198
  return dump(spec, "json")
86
199
 
87
200
 
201
+ def add_explicit_attributes(
202
+ spec: ProviderCodeSpec, name: str, extra_spec_attributes: list[ProviderSpecAttribute], *, is_datasource=False
203
+ ):
204
+ resource_attributes = spec.schema_attributes(name, is_datasource=is_datasource)
205
+ existing_names = spec.attribute_names(name, is_datasource=is_datasource)
206
+ new_names = [extra.name for extra in extra_spec_attributes]
207
+ if both := set(existing_names) & set(new_names):
208
+ raise ValueError(f"resource: {name}, has already: {both} attributes")
209
+ resource_attributes.extend(extra.dump_provider_code_spec() for extra in extra_spec_attributes)
210
+
211
+
212
+ @singledispatch
213
+ def apply_extension(extension: object, spec: ProviderCodeSpec, resource_name: str, *, is_datasource: bool = False): # noqa: ARG001
214
+ raise NotImplementedError(f"unsupported extension: {extension!r}")
215
+
216
+
217
+ @apply_extension.register # type: ignore
218
+ def _ignore_nested(extension: IgnoreNested, spec: ProviderCodeSpec, resource_name: str, *, is_datasource: bool = False):
219
+ if extension.use_wildcard:
220
+ name_to_remove = extension.path.removeprefix("*.")
221
+ assert "*" not in name_to_remove, f"only prefix *. is allowed for wildcard in path {extension.path}"
222
+ found_paths = [
223
+ path
224
+ for name, path, attribute_dict in spec.iter_nested_attributes(resource_name, is_datasource=is_datasource)
225
+ if name == name_to_remove
226
+ ]
227
+ while found_paths:
228
+ next_to_remove = found_paths.pop()
229
+ spec.remove_nested_attribute(resource_name, next_to_remove, is_datasource=is_datasource)
230
+ found_paths = [
231
+ path
232
+ for name, path, attribute_dict in spec.iter_nested_attributes(
233
+ resource_name, is_datasource=is_datasource
234
+ )
235
+ if name == name_to_remove
236
+ ]
237
+ else:
238
+ err_msg = "only wildcard path is supported"
239
+ raise NotImplementedError(err_msg)
240
+
241
+
242
+ @apply_extension.register # type: ignore
243
+ def _rename_attribute(
244
+ extension: RenameAttribute, spec: ProviderCodeSpec, resource_name: str, *, is_datasource: bool = False
245
+ ):
246
+ for attribute_dict in spec.schema_attributes(resource_name, is_datasource=is_datasource):
247
+ if attribute_dict.get("name") == extension.from_name:
248
+ logger.info(
249
+ f"renaming attribute for {spec.root_name(resource_name, is_datasource)}: {extension.from_name} -> {extension.to_name}"
250
+ )
251
+ attribute_dict["name"] = extension.to_name
252
+
253
+
254
+ @apply_extension.register # type: ignore
255
+ def _change_attribute_type(
256
+ extension: ChangeAttributeType, spec: ProviderCodeSpec, resource_name: str, *, is_datasource: bool = False
257
+ ):
258
+ attribute_dict = spec.read_attribute(resource_name, extension.path, is_datasource=is_datasource)
259
+ old_value = extension.read_value(attribute_dict)
260
+ if old_value == extension.new_value:
261
+ logger.info(
262
+ f"no change for '{spec.root_name(resource_name, is_datasource)}': {extension.path} -> {extension.new_value}"
263
+ )
264
+ return
265
+
266
+ logger.info(
267
+ f"changing attribute type for '{spec.root_name(resource_name, is_datasource)}.{extension.path}': {old_value} -> {extension.new_value}"
268
+ )
269
+ extension.update_value(attribute_dict)
270
+
271
+
272
+ @apply_extension.register # type: ignore
273
+ def _skip_validators(_: SkipValidators, spec: ProviderCodeSpec, resource_name: str, *, is_datasource: bool = False):
274
+ for attr_tuple in spec.iter_all_attributes(resource_name, is_datasource=is_datasource):
275
+ attribute_dict = attr_tuple.attribute_dict
276
+ paths_to_pop = [
277
+ f"{path}.validators"
278
+ for path, nested_dict in dict_nested.iter_nested_key_values(attribute_dict, type_filter=dict)
279
+ if "validators" in nested_dict
280
+ ]
281
+ if paths_to_pop:
282
+ logger.info(f"popping validators from '{attr_tuple.attribute_path}'")
283
+ for path in paths_to_pop:
284
+ dict_nested.pop_nested(attribute_dict, path)
285
+
286
+
88
287
  # reusing url from terraform-provider-mongodbatlas/scripts/schema-scaffold.sh
89
288
  ADMIN_API_URL = "https://raw.githubusercontent.com/mongodb/atlas-sdk-go/main/openapi/atlas-api-transformed.yaml"
90
289
 
91
290
 
92
- def download_admin_api(dest: Path) -> None:
93
- logger.info(f"downloading admin api to {dest} from {ADMIN_API_URL}")
94
- response = requests.get(ADMIN_API_URL, timeout=10)
291
+ def admin_api_url(branch: str) -> str:
292
+ return ADMIN_API_URL.replace("/main/", f"/{branch}/")
293
+
294
+
295
+ def download_admin_api(dest: Path, branch: str = "main") -> None:
296
+ url = admin_api_url(branch)
297
+ logger.info(f"downloading admin api to {dest} from {url}")
298
+ response = requests.get(url, timeout=10)
95
299
  response.raise_for_status()
96
300
  dest.write_bytes(response.content)
@@ -0,0 +1,236 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ import re
5
+ from collections import defaultdict
6
+ from typing import NamedTuple
7
+
8
+ from model_lib import Entity
9
+ from pydantic import Field
10
+
11
+ from atlas_init.cli_tf.schema_table_models import (
12
+ AttrRefLine,
13
+ FuncCallLine,
14
+ TFSchemaAttribute,
15
+ )
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ def parse_attribute_ref(
21
+ name: str, rest: str, go_code: str, code_lines: list[str], ref_line_nr: int
22
+ ) -> TFSchemaAttribute | None:
23
+ attr_ref = rest.lstrip("&").rstrip(",").strip()
24
+ if not attr_ref.isidentifier():
25
+ return None
26
+ try:
27
+ _instantiate_regex = re.compile(rf"{attr_ref}\s=\sschema\.\w+\{{$", re.M)
28
+ except re.error:
29
+ return None
30
+ instantiate_match = _instantiate_regex.search(go_code)
31
+ if not instantiate_match:
32
+ return None
33
+ line_start_nr = go_code[: instantiate_match.start()].count("\n") + 1
34
+ line_start = code_lines[line_start_nr]
35
+ attribute = parse_attribute_lines(code_lines, line_start_nr, line_start, name, is_attr_ref=True)
36
+ attribute.attr_ref_line = AttrRefLine(line_nr=ref_line_nr, attr_ref=attr_ref)
37
+ return attribute
38
+
39
+
40
+ def parse_func_call_line(
41
+ name: str, rest: str, lines: list[str], go_code: str, call_line_nr: int
42
+ ) -> TFSchemaAttribute | None:
43
+ func_def_line = _function_line(rest, go_code)
44
+ if not func_def_line:
45
+ return None
46
+ func_name, args = rest.split("(", maxsplit=1)
47
+ func_start, func_end = _func_lines(name, lines, func_def_line)
48
+ call = FuncCallLine(
49
+ call_line_nr=call_line_nr,
50
+ func_name=func_name.strip(),
51
+ args=args.removesuffix("),").strip(),
52
+ func_line_start=func_start,
53
+ func_line_end=func_end,
54
+ )
55
+ return TFSchemaAttribute(
56
+ name=name,
57
+ lines=lines[func_start:func_end],
58
+ line_start=func_start,
59
+ line_end=func_end,
60
+ func_call=call,
61
+ indent="\t",
62
+ )
63
+
64
+
65
+ def _func_lines(name: str, lines: list[str], func_def_line: str) -> tuple[int, int]:
66
+ start_line = lines.index(func_def_line)
67
+ for line_nr, line in enumerate(lines[start_line + 1 :], start=start_line + 1):
68
+ if line.rstrip() == "}":
69
+ return start_line, line_nr
70
+ raise ValueError(f"no end line found for {name} on line {start_line}: {func_def_line}")
71
+
72
+
73
+ def _function_line(rest: str, go_code: str) -> str:
74
+ function_name = rest.split("(")[0].strip()
75
+ pattern = re.compile(rf"func {function_name}\(.*\) \*?schema\.\w+ \{{$", re.M)
76
+ match = pattern.search(go_code)
77
+ if not match:
78
+ return ""
79
+ return go_code[match.start() : match.end()]
80
+
81
+
82
+ def parse_attribute_lines(
83
+ lines: list[str],
84
+ line_nr: int,
85
+ line: str,
86
+ name: str,
87
+ *,
88
+ is_attr_ref: bool = False,
89
+ ) -> TFSchemaAttribute:
90
+ indents = len(line) - len(line.lstrip())
91
+ indent = indents * "\t"
92
+ end_line = f"{indent}}}" if is_attr_ref else f"{indent}}},"
93
+ for extra_lines, next_line in enumerate(lines[line_nr + 1 :], start=1):
94
+ if next_line == end_line:
95
+ return TFSchemaAttribute(
96
+ name=name,
97
+ lines=lines[line_nr : line_nr + extra_lines],
98
+ line_start=line_nr,
99
+ line_end=line_nr + extra_lines,
100
+ indent=indent,
101
+ )
102
+ raise ValueError(f"no end line found for {name}, starting on line {line_nr}")
103
+
104
+
105
+ _schema_attribute_go_regex = re.compile(
106
+ r'^\s+"(?P<name>[^"]+)":\s(?P<rest>.+)$',
107
+ )
108
+
109
+
110
+ def find_attributes(go_code: str) -> list[TFSchemaAttribute]:
111
+ lines = ["", *go_code.splitlines()] # support line_nr indexing
112
+ attributes = []
113
+ for line_nr, line in enumerate(lines):
114
+ match = _schema_attribute_go_regex.match(line)
115
+ if not match:
116
+ continue
117
+ name = match.group("name")
118
+ rest = match.group("rest")
119
+ if rest.endswith("),"):
120
+ if attr := parse_func_call_line(name, rest, lines, go_code, line_nr):
121
+ attributes.append(attr)
122
+ elif attr := parse_attribute_ref(name, rest, go_code, lines, line_nr):
123
+ attributes.append(attr)
124
+ else:
125
+ try:
126
+ attr = parse_attribute_lines(lines, line_nr, line, name)
127
+ except ValueError as e:
128
+ logger.warning(e)
129
+ continue
130
+ if not attr.type:
131
+ continue
132
+ attributes.append(attr)
133
+ set_attribute_paths(attributes)
134
+ return attributes
135
+
136
+
137
+ class StartEnd(NamedTuple):
138
+ start: int
139
+ end: int
140
+ name: str
141
+ func_call_line: FuncCallLine | None
142
+
143
+ def has_parent(self, other: StartEnd) -> bool:
144
+ if self.name == other.name:
145
+ return False
146
+ if func_call := self.func_call_line:
147
+ func_call_line = func_call.call_line_nr
148
+ return other.start < func_call_line < other.end
149
+ return self.start > other.start and self.end < other.end
150
+
151
+
152
+ def set_attribute_paths(attributes: list[TFSchemaAttribute]) -> list[TFSchemaAttribute]:
153
+ start_stops = [StartEnd(a.line_start, a.line_end, a.name, a.func_call) for a in attributes]
154
+ overlaps = [
155
+ (attribute, [other for other in start_stops if start_stop.has_parent(other)])
156
+ for attribute, start_stop in zip(attributes, start_stops, strict=False)
157
+ ]
158
+ for attribute, others in overlaps:
159
+ if not others:
160
+ attribute.attribute_path = attribute.name
161
+ continue
162
+ overlaps = defaultdict(list)
163
+ for other in others:
164
+ overlaps[(other.start, other.end)].append(other.name)
165
+ paths = []
166
+ for names in overlaps.values():
167
+ if len(names) == 1:
168
+ paths.append(names[0])
169
+ else:
170
+ paths.append(f"({'|'.join(names)})")
171
+ paths.append(attribute.name)
172
+ attribute.attribute_path = ".".join(paths)
173
+ return attributes
174
+
175
+
176
+ class GoSchemaFunc(Entity):
177
+ name: str
178
+ line_start: int
179
+ line_end: int
180
+ call_attributes: list[TFSchemaAttribute] = Field(default_factory=list)
181
+ attributes: list[TFSchemaAttribute] = Field(default_factory=list)
182
+
183
+ @property
184
+ def attribute_names(self) -> set[str]:
185
+ return {a.name for a in self.call_attributes}
186
+
187
+ @property
188
+ def attribute_paths(self) -> str:
189
+ paths = set()
190
+ for a in self.call_attributes:
191
+ path = ".".join(a.parent_attribute_names())
192
+ paths.add(path)
193
+ return f"({'|'.join(paths)})" if len(paths) > 1 else paths.pop()
194
+
195
+ def contains_attribute(self, attribute: TFSchemaAttribute) -> bool:
196
+ names = self.attribute_names
197
+ return any(parent_attribute in names for parent_attribute in attribute.parent_attribute_names())
198
+
199
+
200
+ def find_schema_functions(attributes: list[TFSchemaAttribute]) -> list[GoSchemaFunc]:
201
+ function_call_attributes = defaultdict(list)
202
+ for a in attributes:
203
+ if a.is_function_call:
204
+ call = a.func_call
205
+ assert call
206
+ function_call_attributes[call.func_name].append(a)
207
+ root_function = GoSchemaFunc(name="", line_start=0, line_end=0)
208
+ functions: list[GoSchemaFunc] = [
209
+ GoSchemaFunc(
210
+ name=name,
211
+ line_start=func_attributes[0].line_start,
212
+ line_end=func_attributes[0].line_end,
213
+ call_attributes=func_attributes,
214
+ )
215
+ for name, func_attributes in function_call_attributes.items()
216
+ ]
217
+ for attribute in attributes:
218
+ if match_functions := [func for func in functions if func.contains_attribute(attribute)]:
219
+ func_names = [func.name for func in match_functions]
220
+ err_msg = f"multiple functions found for {attribute.name}, {func_names}"
221
+ assert len(match_functions) == 1, err_msg
222
+ function = match_functions[0]
223
+ function.attributes.append(attribute)
224
+ attribute.absolute_attribute_path = f"{function.attribute_paths}.{attribute.attribute_path}".lstrip(".")
225
+ else:
226
+ root_function.attributes.append(attribute)
227
+ attribute.absolute_attribute_path = attribute.attribute_path
228
+ return [root_function, *functions]
229
+
230
+
231
+ def parse_schema_functions(
232
+ go_code: str,
233
+ ) -> tuple[list[TFSchemaAttribute], list[GoSchemaFunc]]:
234
+ attributes = find_attributes(go_code)
235
+ functions = find_schema_functions(attributes)
236
+ return sorted(attributes), functions
@@ -0,0 +1,150 @@
1
+ # import typer
2
+
3
+
4
+ from collections import defaultdict
5
+ from collections.abc import Iterable
6
+ from functools import total_ordering
7
+ from pathlib import Path
8
+ from typing import Literal, TypeAlias
9
+
10
+ from model_lib import Entity, Event
11
+ from pydantic import Field, model_validator
12
+ from zero_3rdparty import iter_utils
13
+
14
+ from atlas_init.cli_tf.schema_go_parser import parse_schema_functions
15
+ from atlas_init.cli_tf.schema_table_models import TFSchemaAttribute, TFSchemaTableColumn
16
+ from atlas_init.settings.path import default_factory_cwd
17
+
18
+
19
+ def default_table_columns() -> list[TFSchemaTableColumn]:
20
+ return [TFSchemaTableColumn.Computability]
21
+
22
+
23
+ def file_name_path(file: str) -> tuple[str, Path]:
24
+ if ":" in file:
25
+ file, path = file.split(":", 1)
26
+ return file, Path(path)
27
+ path = Path(file)
28
+ return f"{path.parent.name}/{path.stem}"[:20], path
29
+
30
+
31
+ @total_ordering
32
+ class TFSchemaSrc(Event):
33
+ name: str
34
+ file_path: Path | None = None
35
+ url: str = ""
36
+
37
+ @model_validator(mode="after")
38
+ def validate(self):
39
+ assert self.file_path or self.url, "must provide file path or url"
40
+ if self.file_path:
41
+ assert self.file_path.exists(), f"file does not exist for {self.name}: {self.file_path}"
42
+ return self
43
+
44
+ def __lt__(self, other) -> bool:
45
+ if not isinstance(other, TFSchemaSrc):
46
+ raise TypeError
47
+ return self.name < other.name
48
+
49
+ def go_code(self) -> str:
50
+ if path := self.file_path:
51
+ return path.read_text()
52
+ raise NotImplementedError
53
+
54
+
55
+ TableOutputFormats: TypeAlias = Literal["md"]
56
+
57
+
58
+ class TFSchemaTableInput(Entity):
59
+ sources: list[TFSchemaSrc] = Field(default_factory=list)
60
+ output_format: TableOutputFormats = "md"
61
+ output_path: Path = Field(default_factory=default_factory_cwd("schema_table.md"))
62
+ columns: list[TFSchemaTableColumn] = Field(default_factory=default_table_columns)
63
+ explode_rows: bool = False
64
+
65
+ @model_validator(mode="after")
66
+ def validate(self):
67
+ assert self.columns, "must provide at least 1 column"
68
+ self.columns = sorted(self.columns)
69
+ assert self.sources, "must provide at least 1 source"
70
+ self.sources = sorted(self.sources)
71
+ assert len(self.sources) == len(set(self.sources)), f"duplicate source names: {self.source_names}"
72
+ return self
73
+
74
+ @property
75
+ def source_names(self) -> list[str]:
76
+ return [s.name for s in self.sources]
77
+
78
+ def headers(self) -> list[str]:
79
+ return ["Attribute Name"] + [f"{name}-{col}" for name in self.source_names for col in self.columns]
80
+
81
+
82
+ @total_ordering
83
+ class TFSchemaTableData(Event):
84
+ source: TFSchemaSrc
85
+ schema_path: str = "" # e.g., "" is root, "replication_specs.region_config"
86
+ attributes: list[TFSchemaAttribute] = Field(default_factory=list)
87
+
88
+ @property
89
+ def id(self) -> str:
90
+ return f"{self.schema_path}:{self.source.name}"
91
+
92
+ def __lt__(self, other) -> bool:
93
+ if not isinstance(other, TFSchemaTableData):
94
+ raise TypeError
95
+ return self.id < other.id
96
+
97
+
98
+ def sorted_schema_paths(schema_paths: Iterable[str]) -> list[str]:
99
+ return sorted(schema_paths, key=lambda x: (x.count("."), x.split(".")[-1]))
100
+
101
+
102
+ class RawTable(Event):
103
+ columns: list[str]
104
+ rows: list[list[str]]
105
+
106
+
107
+ def merge_tables(config: TFSchemaTableInput, schema_path: str, tables: list[TFSchemaTableData]) -> RawTable:
108
+ if schema_path != "":
109
+ raise NotImplementedError
110
+ columns = config.headers()
111
+ if len(tables) > 1:
112
+ err_msg = "only 1 table per schema path supported"
113
+ raise NotImplementedError(err_msg)
114
+ table = tables[0]
115
+ rows = [[attr.absolute_attribute_path, *attr.row(config.columns)] for attr in table.attributes]
116
+ return RawTable(columns=columns, rows=rows)
117
+
118
+
119
+ def format_table(table: RawTable, table_format: TableOutputFormats) -> list[str]:
120
+ # sourcery skip: merge-list-append
121
+ assert table_format == "md", "only markdown format supported"
122
+ lines = []
123
+ lines.append("|".join(table.columns))
124
+ lines.append("|".join(["---"] * len(table.columns)))
125
+ lines.extend("|".join(row) for row in table.rows)
126
+ return lines
127
+
128
+
129
+ def explode_attributes(attributes: list[TFSchemaAttribute]) -> list[TFSchemaAttribute]:
130
+ return sorted(iter_utils.flat_map(attr.explode() for attr in attributes))
131
+
132
+
133
+ def schema_table(config: TFSchemaTableInput) -> str:
134
+ path_tables: dict[str, list[TFSchemaTableData]] = defaultdict(list)
135
+ for source in config.sources:
136
+ go_code = source.go_code()
137
+ attributes, functions = parse_schema_functions(go_code)
138
+ if config.explode_rows:
139
+ attributes = explode_attributes(attributes)
140
+ schema_path = "" # using only root for now
141
+ path_tables[schema_path].append(
142
+ TFSchemaTableData(source=source, attributes=attributes, schema_path=schema_path)
143
+ )
144
+ output_lines = []
145
+ for schema_path in sorted_schema_paths(path_tables.keys()):
146
+ tables = path_tables[schema_path]
147
+ table = merge_tables(config, schema_path, tables)
148
+ output_lines.extend(["", f"## {schema_path or 'Root'}", ""])
149
+ output_lines.extend(format_table(table, table_format=config.output_format))
150
+ return "\n".join(output_lines)