atlas-init 0.4.4__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. atlas_init/__init__.py +1 -1
  2. atlas_init/cli.py +2 -0
  3. atlas_init/cli_cfn/app.py +3 -4
  4. atlas_init/cli_cfn/cfn_parameter_finder.py +61 -53
  5. atlas_init/cli_cfn/contract.py +4 -7
  6. atlas_init/cli_cfn/example.py +8 -18
  7. atlas_init/cli_helper/go.py +7 -11
  8. atlas_init/cli_root/mms_released.py +46 -0
  9. atlas_init/cli_root/trigger.py +6 -6
  10. atlas_init/cli_tf/app.py +3 -84
  11. atlas_init/cli_tf/ci_tests.py +493 -0
  12. atlas_init/cli_tf/codegen/__init__.py +0 -0
  13. atlas_init/cli_tf/codegen/models.py +97 -0
  14. atlas_init/cli_tf/codegen/openapi_minimal.py +74 -0
  15. atlas_init/cli_tf/github_logs.py +7 -94
  16. atlas_init/cli_tf/go_test_run.py +385 -132
  17. atlas_init/cli_tf/go_test_summary.py +331 -4
  18. atlas_init/cli_tf/go_test_tf_error.py +380 -0
  19. atlas_init/cli_tf/hcl/modifier.py +14 -12
  20. atlas_init/cli_tf/hcl/modifier2.py +87 -0
  21. atlas_init/cli_tf/mock_tf_log.py +1 -1
  22. atlas_init/cli_tf/{schema_v2_api_parsing.py → openapi.py} +95 -17
  23. atlas_init/cli_tf/schema_v2.py +43 -1
  24. atlas_init/crud/__init__.py +0 -0
  25. atlas_init/crud/mongo_client.py +115 -0
  26. atlas_init/crud/mongo_dao.py +296 -0
  27. atlas_init/crud/mongo_utils.py +239 -0
  28. atlas_init/repos/go_sdk.py +12 -3
  29. atlas_init/repos/path.py +110 -7
  30. atlas_init/settings/config.py +3 -6
  31. atlas_init/settings/env_vars.py +22 -31
  32. atlas_init/settings/interactive2.py +134 -0
  33. atlas_init/tf/.terraform.lock.hcl +59 -59
  34. atlas_init/tf/always.tf +5 -5
  35. atlas_init/tf/main.tf +3 -3
  36. atlas_init/tf/modules/aws_kms/aws_kms.tf +1 -1
  37. atlas_init/tf/modules/aws_s3/provider.tf +2 -1
  38. atlas_init/tf/modules/aws_vpc/provider.tf +2 -1
  39. atlas_init/tf/modules/cfn/cfn.tf +0 -8
  40. atlas_init/tf/modules/cfn/kms.tf +5 -5
  41. atlas_init/tf/modules/cfn/provider.tf +7 -0
  42. atlas_init/tf/modules/cfn/variables.tf +1 -1
  43. atlas_init/tf/modules/cloud_provider/cloud_provider.tf +1 -1
  44. atlas_init/tf/modules/cloud_provider/provider.tf +2 -1
  45. atlas_init/tf/modules/cluster/cluster.tf +31 -31
  46. atlas_init/tf/modules/cluster/provider.tf +2 -1
  47. atlas_init/tf/modules/encryption_at_rest/provider.tf +2 -1
  48. atlas_init/tf/modules/federated_vars/federated_vars.tf +1 -1
  49. atlas_init/tf/modules/federated_vars/provider.tf +2 -1
  50. atlas_init/tf/modules/project_extra/project_extra.tf +1 -10
  51. atlas_init/tf/modules/project_extra/provider.tf +8 -0
  52. atlas_init/tf/modules/stream_instance/provider.tf +8 -0
  53. atlas_init/tf/modules/stream_instance/stream_instance.tf +0 -9
  54. atlas_init/tf/modules/vpc_peering/provider.tf +10 -0
  55. atlas_init/tf/modules/vpc_peering/vpc_peering.tf +0 -10
  56. atlas_init/tf/modules/vpc_privatelink/versions.tf +2 -1
  57. atlas_init/tf/outputs.tf +1 -0
  58. atlas_init/tf/providers.tf +1 -1
  59. atlas_init/tf/variables.tf +7 -7
  60. atlas_init/typer_app.py +4 -8
  61. {atlas_init-0.4.4.dist-info → atlas_init-0.6.0.dist-info}/METADATA +7 -4
  62. atlas_init-0.6.0.dist-info/RECORD +121 -0
  63. atlas_init-0.4.4.dist-info/RECORD +0 -105
  64. {atlas_init-0.4.4.dist-info → atlas_init-0.6.0.dist-info}/WHEEL +0 -0
  65. {atlas_init-0.4.4.dist-info → atlas_init-0.6.0.dist-info}/entry_points.txt +0 -0
  66. {atlas_init-0.4.4.dist-info → atlas_init-0.6.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,380 @@
1
+ from __future__ import annotations
2
+
3
+ import re
4
+ from dataclasses import dataclass
5
+ from enum import StrEnum
6
+ from functools import total_ordering
7
+ from typing import ClassVar, Literal, NamedTuple, Self, TypeAlias
8
+
9
+ import humanize
10
+ from model_lib import Entity, utc_datetime_ms
11
+ from pydantic import Field, model_validator
12
+ from zero_3rdparty import iter_utils
13
+ from zero_3rdparty.datetime_utils import utc_now
14
+ from zero_3rdparty.str_utils import instance_repr
15
+
16
+ from atlas_init.cli_tf.go_test_run import GoTestRun
17
+ from atlas_init.repos.go_sdk import ApiSpecPaths
18
+
19
+
20
+ class GoTestErrorClass(StrEnum):
21
+ """Goal of each error class to be actionable."""
22
+
23
+ FLAKY_400 = "flaky_400"
24
+ FLAKY_500 = "flaky_500"
25
+ FLAKY_CHECK = "flaky_check"
26
+ OUT_OF_CAPACITY = "out_of_capacity"
27
+ PROJECT_LIMIT_EXCEEDED = "project_limit_exceeded"
28
+ DANGLING_RESOURCE = "dangling_resource"
29
+ REAL_TEST_FAILURE = "real_test_failure"
30
+ TIMEOUT = "timeout"
31
+ UNKNOWN = "unknown"
32
+ PROVIDER_DOWNLOAD = "provider_download"
33
+ UNCLASSIFIED = "unclassified"
34
+
35
+ __ACTIONS__ = {
36
+ FLAKY_400: "retry",
37
+ FLAKY_500: "retry",
38
+ FLAKY_CHECK: "retry",
39
+ PROVIDER_DOWNLOAD: "retry",
40
+ OUT_OF_CAPACITY: "retry_later",
41
+ PROJECT_LIMIT_EXCEEDED: "clean_project",
42
+ DANGLING_RESOURCE: "update_cleanup_script",
43
+ REAL_TEST_FAILURE: "investigate",
44
+ TIMEOUT: "investigate",
45
+ UNKNOWN: "investigate",
46
+ }
47
+ __CONTAINS_MAPPING__ = {
48
+ OUT_OF_CAPACITY: ("OUT_OF_CAPACITY",),
49
+ FLAKY_500: ("HTTP 500", "UNEXPECTED_ERROR"),
50
+ PROVIDER_DOWNLOAD: [
51
+ "mongodbatlas: failed to retrieve authentication checksums for provider",
52
+ "Error: Failed to install provider github.com: bad response",
53
+ ],
54
+ TIMEOUT: ("timeout while waiting for",),
55
+ }
56
+
57
+ @classmethod
58
+ def auto_classification(cls, output: str) -> GoTestErrorClass | None:
59
+ def contains(output: str, contains_part: str) -> bool:
60
+ if " " in contains_part:
61
+ return all(part in output for part in contains_part.split())
62
+ return contains_part in output
63
+
64
+ return next(
65
+ (
66
+ error_class
67
+ for error_class, contains_list in cls.__CONTAINS_MAPPING__.items()
68
+ if all(contains(output, contains_part) for contains_part in contains_list)
69
+ ),
70
+ None,
71
+ ) # type: ignore
72
+
73
+
74
+ API_METHODS = ["GET", "POST", "PUT", "DELETE", "PATCH"]
75
+
76
+
77
+ class GoTestAPIError(Entity):
78
+ type: Literal["api_error"] = "api_error"
79
+ api_error_code_str: str
80
+ api_path: str
81
+ api_method: Literal["GET", "POST", "PUT", "DELETE", "PATCH"]
82
+ api_response_code: int
83
+ tf_resource_name: str = ""
84
+ tf_resource_type: str = ""
85
+ step_nr: int = -1
86
+
87
+ api_path_normalized: str = Field(init=False, default="")
88
+
89
+ @model_validator(mode="after")
90
+ def strip_path_chars(self) -> GoTestAPIError:
91
+ self.api_path = self.api_path.rstrip(":/")
92
+ return self
93
+
94
+ def add_info_fields(self, info: DetailsInfo) -> None:
95
+ if api_paths := info.paths:
96
+ self.api_path_normalized = api_paths.normalize_path(self.api_method, self.api_path)
97
+
98
+ def __str__(self) -> str:
99
+ resource_part = f"{self.tf_resource_type} " if self.tf_resource_type else ""
100
+ if self.api_path_normalized:
101
+ return f"{resource_part}{self.api_error_code_str} {self.api_method} {self.api_path_normalized} {self.api_response_code}"
102
+ return f"{resource_part}{self.api_error_code_str} {self.api_method} {self.api_path} {self.api_response_code}"
103
+
104
+
105
+ @total_ordering
106
+ class CheckError(Entity):
107
+ attribute: str = ""
108
+ expected: str = ""
109
+ got: str = ""
110
+ check_nr: int = -1
111
+
112
+ def __lt__(self, other) -> bool:
113
+ if not isinstance(other, CheckError):
114
+ raise TypeError
115
+ return (self.check_nr, self.attribute) < (other.check_nr, other.attribute)
116
+
117
+ def __str__(self) -> str:
118
+ if self.attribute and self.expected and self.got:
119
+ return f"{self.check_nr}({self.attribute}:expected:{self.expected}, got: {self.got})"
120
+ return f"{self.check_nr}"
121
+
122
+ @classmethod
123
+ def parse_from_output(cls, output: str) -> list[Self]:
124
+ return [
125
+ cls(**check_match.groupdict()) # type: ignore
126
+ for check_match in check_pattern.finditer(output)
127
+ ]
128
+
129
+
130
+ class GoTestResourceCheckError(Entity):
131
+ type: Literal["check_error"] = "check_error"
132
+ tf_resource_name: str
133
+ tf_resource_type: str
134
+ step_nr: int = -1
135
+ check_errors: list[CheckError] = Field(default_factory=list)
136
+ test_name: str = ""
137
+
138
+ def add_info_fields(self, info: DetailsInfo) -> None:
139
+ self.test_name = info.run.name
140
+
141
+ def __str__(self) -> str:
142
+ return f"{self.tf_resource_type} {self.tf_resource_name} {self.step_nr} {self.check_errors}"
143
+
144
+ @property
145
+ def check_numbers_str(self) -> str:
146
+ return ",".join(str(check.check_nr) for check in sorted(self.check_errors))
147
+
148
+ def check_errors_match(self, other_check_errors: list[CheckError]) -> bool:
149
+ if len(self.check_errors) != len(other_check_errors):
150
+ return False
151
+ return all(
152
+ any(
153
+ check.check_nr == other_check.check_nr and check.attribute == other_check.attribute
154
+ for other_check in other_check_errors
155
+ )
156
+ for check in self.check_errors
157
+ )
158
+
159
+
160
+ class GoTestGeneralCheckError(Entity):
161
+ type: Literal["general_check_error"] = "general_check_error"
162
+ step_nr: int = -1
163
+ check_errors: list[CheckError] = Field(default_factory=list)
164
+ error_check_str: str
165
+ test_name: str = ""
166
+
167
+ def add_info_fields(self, info: DetailsInfo) -> None:
168
+ self.test_name = info.run.name
169
+
170
+ def check_errors_str(self) -> str:
171
+ return ",".join(str(check) for check in sorted(self.check_errors))
172
+
173
+ def __str__(self) -> str:
174
+ return f"Step {self.step_nr} {self.check_errors_str()}"
175
+
176
+
177
+ @dataclass
178
+ class DetailsInfo:
179
+ run: GoTestRun
180
+ paths: ApiSpecPaths | None = None
181
+
182
+
183
+ class GoTestDefaultError(Entity):
184
+ type: Literal["default_error"] = "default_error"
185
+ error_str: str
186
+
187
+ def add_info_fields(self, _: DetailsInfo) -> None:
188
+ pass
189
+
190
+
191
+ ErrorDetailsT: TypeAlias = GoTestAPIError | GoTestResourceCheckError | GoTestDefaultError | GoTestGeneralCheckError
192
+
193
+
194
+ class ErrorClassified(NamedTuple):
195
+ classified: dict[GoTestErrorClass, list[GoTestError]]
196
+ unclassified: list[GoTestError]
197
+
198
+
199
+ class ErrorClassAuthor(StrEnum):
200
+ AUTO = "auto"
201
+ HUMAN = "human"
202
+ LLM = "llm"
203
+ SIMILAR = "similar"
204
+
205
+
206
+ class GoTestErrorClassification(Entity):
207
+ error_class: GoTestErrorClass = GoTestErrorClass.UNCLASSIFIED
208
+ ts: utc_datetime_ms = Field(default_factory=utc_now)
209
+ author: ErrorClassAuthor
210
+ confidence: float = 0.0
211
+ test_output: str = ""
212
+ details: ErrorDetailsT
213
+ run_id: str
214
+ test_name: str
215
+
216
+ STR_COLUMNS: ClassVar[list[str]] = ["error_class", "author", "run_id", "confidence", "ts_when"]
217
+
218
+ def needs_classification(self, confidence_threshold: float = 1.0) -> bool:
219
+ return (
220
+ self.error_class in {GoTestErrorClass.UNCLASSIFIED, GoTestErrorClass.UNKNOWN}
221
+ or self.confidence < confidence_threshold
222
+ )
223
+
224
+ @property
225
+ def ts_when(self) -> str:
226
+ return humanize.naturaltime(self.ts)
227
+
228
+ def __str__(self) -> str:
229
+ return instance_repr(self, self.STR_COLUMNS)
230
+
231
+
232
+ @total_ordering
233
+ class GoTestError(Entity):
234
+ details: ErrorDetailsT
235
+ run: GoTestRun
236
+ bot_error_class: GoTestErrorClass = GoTestErrorClass.UNCLASSIFIED
237
+ human_error_class: GoTestErrorClass = GoTestErrorClass.UNCLASSIFIED
238
+
239
+ def __lt__(self, other) -> bool:
240
+ if not isinstance(other, GoTestError):
241
+ raise TypeError
242
+ return self.run < other.run
243
+
244
+ @property
245
+ def run_id(self) -> str:
246
+ return self.run.id
247
+
248
+ @property
249
+ def run_name(self) -> str:
250
+ return self.run.name
251
+
252
+ @property
253
+ def classifications(self) -> tuple[GoTestErrorClass, GoTestErrorClass] | None:
254
+ if (
255
+ self.bot_error_class != GoTestErrorClass.UNCLASSIFIED
256
+ and self.human_error_class != GoTestErrorClass.UNCLASSIFIED
257
+ ):
258
+ return self.bot_error_class, self.human_error_class
259
+ return None
260
+
261
+ def set_human_and_bot_classification(self, chosen_class: GoTestErrorClass) -> None:
262
+ self.human_error_class = chosen_class
263
+ self.bot_error_class = chosen_class
264
+
265
+ def match(self, other: GoTestError) -> bool:
266
+ if self.run.id == other.run.id:
267
+ return True
268
+ details = self.details
269
+ other_details = other.details
270
+ if type(self.details) is not type(other_details):
271
+ return False
272
+ if isinstance(details, GoTestAPIError):
273
+ assert isinstance(other_details, GoTestAPIError)
274
+ return (
275
+ details.api_path_normalized == other_details.api_path_normalized
276
+ and details.api_response_code == other_details.api_response_code
277
+ and details.api_method == other_details.api_method
278
+ and details.api_response_code == other_details.api_response_code
279
+ )
280
+ if isinstance(details, GoTestResourceCheckError):
281
+ assert isinstance(other_details, GoTestResourceCheckError)
282
+ return (
283
+ details.tf_resource_name == other_details.tf_resource_name
284
+ and details.tf_resource_type == other_details.tf_resource_type
285
+ and details.step_nr == other_details.step_nr
286
+ and details.check_numbers_str == other_details.check_numbers_str
287
+ )
288
+ return False
289
+
290
+ @classmethod
291
+ def group_by_classification(
292
+ cls, errors: list[GoTestError], *, classifier: Literal["bot", "human"] = "human"
293
+ ) -> ErrorClassified:
294
+ def get_classification(error: GoTestError) -> GoTestErrorClass:
295
+ if classifier == "bot":
296
+ return error.bot_error_class
297
+ return error.human_error_class
298
+
299
+ grouped_errors: dict[GoTestErrorClass, list[GoTestError]] = iter_utils.group_by_once(
300
+ errors, key=get_classification
301
+ )
302
+ unclassified = grouped_errors.pop(GoTestErrorClass.UNCLASSIFIED, [])
303
+ return ErrorClassified(grouped_errors, unclassified)
304
+
305
+ @classmethod
306
+ def group_by_name_with_package(cls, errors: list[GoTestError]) -> dict[str, list[GoTestError]]:
307
+ def by_name(error: GoTestError) -> str:
308
+ return error.run.name_with_package
309
+
310
+ return iter_utils.group_by_once(errors, key=by_name)
311
+
312
+ @property
313
+ def short_description(self) -> str:
314
+ match self.details:
315
+ case GoTestGeneralCheckError():
316
+ return str(self.details)
317
+ case GoTestResourceCheckError():
318
+ return f"CheckFailure for {self.details.tf_resource_type}.{self.details.tf_resource_name} at Step: {self.details.step_nr} Checks: {self.details.check_numbers_str}"
319
+ case GoTestAPIError(api_path_normalized=api_path_normalized) if api_path_normalized:
320
+ return f"API Error {self.details.api_error_code_str} {api_path_normalized}"
321
+ case GoTestAPIError(api_path=api_path):
322
+ return f"{self.details.api_error_code_str} {api_path}"
323
+ return ""
324
+
325
+ def header(self, use_ticks: bool = False) -> str:
326
+ name_with_ticks = f"`{self.run.name_with_package}`" if use_ticks else self.run.name_with_package
327
+ if details := self.short_description:
328
+ return f"{name_with_ticks} {details}"
329
+ return f"{name_with_ticks}"
330
+
331
+
332
+ one_of_methods = "|".join(API_METHODS)
333
+
334
+
335
+ check_pattern_str = r"Check (?P<check_nr>\d+)/\d+"
336
+ check_pattern = re.compile(check_pattern_str)
337
+ url_pattern = r"https://cloud(-dev|-qa)?\.mongodb\.com(?P<api_path>\S+)"
338
+ error_check_pattern = re.compile(check_pattern_str + r"\s+error:\s(?P<error_check_str>.+)$", re.MULTILINE)
339
+ detail_patterns: list[re.Pattern] = [
340
+ re.compile(r"Step (?P<step_nr>\d+)/\d+"),
341
+ check_pattern,
342
+ re.compile(r"mongodbatlas_(?P<tf_resource_type>[^\.]+)\.(?P<tf_resource_name>[\w_-]+)"),
343
+ re.compile(rf"(?P<api_method>{one_of_methods})" + r": HTTP (?P<api_response_code>\d+)"),
344
+ re.compile(r'Error code: "(?P<api_error_code_str>[^"]+)"'),
345
+ re.compile(url_pattern),
346
+ ]
347
+
348
+ # Error: error creating MongoDB Cluster: POST https://cloud-dev.mongodb.com/api/atlas/v1.0/groups/680ecbc7122f5b15cc627ba5/clusters: 409 (request "OUT_OF_CAPACITY") The requested region is currently out of capacity for the requested instance size.
349
+ api_error_pattern_missing_details = re.compile(
350
+ rf"(?P<api_method>{one_of_methods})\s+"
351
+ + url_pattern
352
+ + r'\s+(?P<api_response_code>\d+)\s\(request\s"(?P<api_error_code_str>[^"]+)"\)'
353
+ )
354
+
355
+
356
+ def parse_error_details(run: GoTestRun) -> ErrorDetailsT:
357
+ kwargs = {}
358
+ output = run.output_lines_str
359
+ for pattern in detail_patterns:
360
+ if pattern_match := pattern.search(output):
361
+ kwargs |= pattern_match.groupdict()
362
+ match kwargs:
363
+ case {"api_path": _, "api_error_code_str": _}:
364
+ return GoTestAPIError(**kwargs)
365
+ case {"api_path": _} if pattern_match := api_error_pattern_missing_details.search(output):
366
+ kwargs |= pattern_match.groupdict()
367
+ return GoTestAPIError(**kwargs)
368
+ case {"check_nr": _} if all(name in kwargs for name in ("tf_resource_name", "tf_resource_type")):
369
+ kwargs.pop("check_nr")
370
+ check_errors = CheckError.parse_from_output(output)
371
+ return GoTestResourceCheckError(**kwargs, check_errors=check_errors)
372
+ case {"check_nr": _}:
373
+ if error_check_match := error_check_pattern.search(output):
374
+ kwargs.pop("check_nr")
375
+ check_errors = CheckError.parse_from_output(output)
376
+ return GoTestGeneralCheckError(
377
+ **kwargs, error_check_str=error_check_match.group("error_check_str"), check_errors=check_errors
378
+ )
379
+ kwargs.pop("error_check_str", None) # Remove if it was not matched
380
+ return GoTestDefaultError(error_str=run.output_lines_str)
@@ -5,7 +5,9 @@ from pathlib import Path
5
5
  from typing import Callable
6
6
 
7
7
  import hcl2
8
- from lark import Token, Tree, UnexpectedToken
8
+ from lark import Token, Tree
9
+
10
+ from atlas_init.cli_tf.hcl.modifier2 import safe_parse
9
11
 
10
12
  logger = logging.getLogger(__name__)
11
13
 
@@ -14,10 +16,14 @@ BLOCK_TYPE_OUTPUT = "output"
14
16
 
15
17
 
16
18
  def process_token(node: Token, indent=0):
17
- logger.debug(f"[{indent}] (token)\t|", " " * indent, node.type, node.value)
19
+ debug_log(f"token:{node.type}:{node.value}", indent)
18
20
  return deepcopy(node)
19
21
 
20
22
 
23
+ def debug_log(message: str, depth=0):
24
+ logger.debug(" " * depth + message.rstrip("\n"))
25
+
26
+
21
27
  def is_identifier_block_type(tree: Tree | Token, block_type: str) -> bool:
22
28
  if not isinstance(tree, Tree):
23
29
  return False
@@ -43,7 +49,7 @@ def update_description(tree: Tree, new_descriptions: dict[str, str], existing_na
43
49
  existing_names[name].append(old_description)
44
50
  new_description = new_descriptions.get(name, "")
45
51
  if not new_description:
46
- logger.debug(f"no description found for variable {name}")
52
+ debug_log(f"no description found for variable {name}", 0)
47
53
  return tree
48
54
  new_children[2] = update_body_with_description(variable_body, new_description)
49
55
  return Tree(tree.data, new_children)
@@ -112,7 +118,7 @@ def process_generic(
112
118
  depth=0,
113
119
  ):
114
120
  new_children = []
115
- logger.debug(f"[{depth}] (tree)\t|", " " * depth, node.data)
121
+ debug_log(f"tree:{node.data}", depth)
116
122
  for child in node.children:
117
123
  if isinstance(child, Tree):
118
124
  if tree_match(child):
@@ -146,10 +152,8 @@ def process_descriptions(
146
152
 
147
153
 
148
154
  def update_descriptions(tf_path: Path, new_names: dict[str, str], block_type: str) -> tuple[str, dict[str, list[str]]]:
149
- try:
150
- tree = hcl2.parses(tf_path.read_text()) # type: ignore
151
- except UnexpectedToken as e:
152
- logger.warning(f"failed to parse {tf_path}: {e}")
155
+ tree = safe_parse(tf_path)
156
+ if tree is None:
153
157
  return "", {}
154
158
  existing_descriptions = defaultdict(list)
155
159
  new_tree = process_descriptions(
@@ -210,10 +214,8 @@ def _read_object_elem_key(tree_body: Tree) -> str:
210
214
 
211
215
 
212
216
  def read_block_attribute_object_keys(tf_path: Path, block_type: str, block_name: str, block_key: str) -> list[str]:
213
- try:
214
- tree = hcl2.parses(tf_path.read_text()) # type: ignore
215
- except UnexpectedToken as e:
216
- logger.warning(f"failed to parse {tf_path}: {e}")
217
+ tree = safe_parse(tf_path)
218
+ if tree is None:
217
219
  return []
218
220
  env_vars = []
219
221
 
@@ -0,0 +1,87 @@
1
+ import logging
2
+ from contextlib import suppress
3
+ from pathlib import Path
4
+ from typing import NamedTuple
5
+ from lark import Token, Transformer, Tree, UnexpectedToken, v_args
6
+ from hcl2.transformer import Attribute, DictTransformer
7
+ from hcl2.api import reverse_transform, writes, parses
8
+ import rich
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ def update_attribute_object_str_value_for_block(
14
+ tree: Tree, block_name: str, block_transformer: DictTransformer
15
+ ) -> Tree:
16
+ class BlockUpdater(Transformer):
17
+ @v_args(tree=True)
18
+ def block(self, block_tree: Tree) -> Tree:
19
+ current_block_name = _identifier_name(block_tree)
20
+ if current_block_name == block_name:
21
+ tree_dict = block_transformer.transform(tree)
22
+ tree_modified = reverse_transform(tree_dict)
23
+ assert isinstance(tree_modified, Tree)
24
+ body_tree = tree_modified.children[0]
25
+ assert isinstance(body_tree, Tree)
26
+ block_tree = body_tree.children[0]
27
+ assert isinstance(block_tree, Tree)
28
+ return block_tree
29
+ return block_tree
30
+
31
+ return BlockUpdater().transform(tree)
32
+
33
+
34
+ class AttributeChange(NamedTuple):
35
+ attribute_name: str
36
+ old_value: str | None
37
+ new_value: str
38
+
39
+
40
+ def attribute_transfomer(attr_name: str, obj_key: str, new_value: str) -> tuple[DictTransformer, list[AttributeChange]]:
41
+ changes: list[AttributeChange] = []
42
+
43
+ class AttributeTransformer(DictTransformer):
44
+ def attribute(self, args: list) -> Attribute:
45
+ found_attribute = super().attribute(args)
46
+ if found_attribute.key == attr_name:
47
+ attribute_value = found_attribute.value
48
+ if not isinstance(attribute_value, dict):
49
+ raise ValueError(f"Expected a dict for attribute {attr_name}, but got {type(attribute_value)}")
50
+ old_value = attribute_value.get(obj_key)
51
+ if old_value == new_value:
52
+ return found_attribute
53
+ changes.append(AttributeChange(attr_name, old_value, new_value))
54
+ return Attribute(attr_name, found_attribute.value | {obj_key: new_value})
55
+ return found_attribute
56
+
57
+ return AttributeTransformer(with_meta=True), changes
58
+
59
+
60
+ def _identifier_name(tree: Tree) -> str | None:
61
+ with suppress(Exception):
62
+ identifier_tree = tree.children[0]
63
+ assert identifier_tree.data == "identifier"
64
+ name_token = identifier_tree.children[0]
65
+ assert isinstance(name_token, Token)
66
+ if name_token.type == "NAME":
67
+ return name_token.value
68
+
69
+
70
+ def write_tree(tree: Tree) -> str:
71
+ return writes(tree)
72
+
73
+
74
+ def print_tree(path: Path) -> None:
75
+ tree = safe_parse(path)
76
+ if tree is None:
77
+ return
78
+ logger.info("=" * 10 + f"tree START of {path.parent.name}/{path.name}" + "=" * 10)
79
+ rich.print(tree)
80
+ logger.info("=" * 10 + f"tree END of {path.parent.name}/{path.name}" + "=" * 10)
81
+
82
+
83
+ def safe_parse(path: Path) -> Tree | None:
84
+ try:
85
+ return parses(path.read_text()) # type: ignore
86
+ except UnexpectedToken as e:
87
+ logger.warning(f"failed to parse {path}: {e}")
@@ -154,7 +154,7 @@ def is_cache_up_to_date(cache_path: Path, cache_ttl: int) -> bool:
154
154
  return False
155
155
 
156
156
 
157
- def resolve_admin_api_path(sdk_repo_path_str: str, sdk_branch: str, admin_api_path: str) -> Path:
157
+ def resolve_admin_api_path(sdk_repo_path_str: str = "", sdk_branch: str = "main", admin_api_path: str = "") -> Path:
158
158
  if admin_api_path:
159
159
  resolved_admin_api_path = Path(admin_api_path)
160
160
  if not resolved_admin_api_path.exists():