atlas-init 0.4.0__py3-none-any.whl → 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. atlas_init/__init__.py +1 -1
  2. atlas_init/atlas_init.yaml +3 -0
  3. atlas_init/cli.py +1 -1
  4. atlas_init/cli_cfn/aws.py +2 -2
  5. atlas_init/cli_helper/go.py +104 -58
  6. atlas_init/cli_helper/run.py +3 -3
  7. atlas_init/cli_helper/run_manager.py +3 -3
  8. atlas_init/cli_root/go_test.py +13 -10
  9. atlas_init/cli_tf/app.py +4 -0
  10. atlas_init/cli_tf/debug_logs.py +3 -3
  11. atlas_init/cli_tf/example_update.py +142 -0
  12. atlas_init/cli_tf/example_update_test/test_update_example.tf +23 -0
  13. atlas_init/cli_tf/example_update_test.py +96 -0
  14. atlas_init/cli_tf/github_logs.py +6 -3
  15. atlas_init/cli_tf/go_test_run.py +24 -1
  16. atlas_init/cli_tf/go_test_summary.py +7 -1
  17. atlas_init/cli_tf/hcl/modifier.py +144 -0
  18. atlas_init/cli_tf/hcl/modifier_test/test_process_variables_output_.tf +25 -0
  19. atlas_init/cli_tf/hcl/modifier_test/test_process_variables_variable_.tf +24 -0
  20. atlas_init/cli_tf/hcl/modifier_test.py +95 -0
  21. atlas_init/cli_tf/hcl/parser.py +1 -1
  22. atlas_init/cli_tf/log_clean.py +29 -0
  23. atlas_init/cli_tf/schema_table.py +1 -3
  24. atlas_init/cli_tf/schema_v3.py +1 -1
  25. atlas_init/repos/path.py +14 -0
  26. atlas_init/settings/config.py +24 -13
  27. atlas_init/settings/env_vars.py +1 -1
  28. atlas_init/settings/env_vars_generated.py +1 -1
  29. atlas_init/settings/rich_utils.py +1 -1
  30. atlas_init/tf/.terraform.lock.hcl +16 -16
  31. atlas_init/tf/main.tf +25 -1
  32. atlas_init/tf/modules/aws_kms/aws_kms.tf +100 -0
  33. atlas_init/tf/modules/aws_kms/provider.tf +7 -0
  34. atlas_init/tf/modules/cfn/cfn.tf +1 -1
  35. atlas_init/tf/modules/cloud_provider/cloud_provider.tf +9 -2
  36. atlas_init/tf/modules/encryption_at_rest/main.tf +29 -0
  37. atlas_init/tf/modules/encryption_at_rest/provider.tf +9 -0
  38. atlas_init/tf/variables.tf +5 -0
  39. {atlas_init-0.4.0.dist-info → atlas_init-0.4.2.dist-info}/METADATA +13 -11
  40. {atlas_init-0.4.0.dist-info → atlas_init-0.4.2.dist-info}/RECORD +42 -31
  41. atlas_init/cli_tf/go_test_run_format.py +0 -31
  42. {atlas_init-0.4.0.dist-info → atlas_init-0.4.2.dist-info}/WHEEL +0 -0
  43. {atlas_init-0.4.0.dist-info → atlas_init-0.4.2.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,96 @@
1
+ import os
2
+ from pathlib import Path
3
+
4
+ import pytest
5
+
6
+ from atlas_init.cli_tf.example_update import (
7
+ TFConfigDescriptionChange,
8
+ UpdateExamples,
9
+ update_examples,
10
+ )
11
+ from atlas_init.cli_tf.hcl.modifier import BLOCK_TYPE_VARIABLE, update_descriptions
12
+
13
+
14
+ def test_description_change(tmp_path):
15
+ assert TFConfigDescriptionChange(
16
+ block_type=BLOCK_TYPE_VARIABLE,
17
+ path=tmp_path,
18
+ name="cluster_name",
19
+ before="",
20
+ after="description of cluster name",
21
+ ).changed
22
+ assert not TFConfigDescriptionChange(
23
+ block_type=BLOCK_TYPE_VARIABLE,
24
+ path=tmp_path,
25
+ name="cluster_name",
26
+ before="description of cluster name",
27
+ after="description of cluster name",
28
+ ).changed
29
+ assert not TFConfigDescriptionChange(
30
+ block_type=BLOCK_TYPE_VARIABLE,
31
+ path=tmp_path,
32
+ name="cluster_name",
33
+ before="description of cluster name",
34
+ after="",
35
+ ).changed
36
+
37
+
38
+ example_variables_tf = """variable "cluster_name" {
39
+ type = string
40
+ }
41
+ variable "replication_specs" {
42
+ description = "List of replication specifications in legacy mongodbatlas_cluster format"
43
+ default = []
44
+ type = list(object({
45
+ num_shards = number
46
+ zone_name = string
47
+ regions_config = set(object({
48
+ region_name = string
49
+ electable_nodes = number
50
+ priority = number
51
+ read_only_nodes = optional(number, 0)
52
+ }))
53
+ }))
54
+ }
55
+
56
+ variable "provider_name" {
57
+ type = string
58
+ default = "" # optional in v3
59
+ }
60
+ """
61
+
62
+
63
+ def test_update_example(tmp_path, file_regression):
64
+ base_dir = tmp_path / "example_base"
65
+ base_dir.mkdir()
66
+ example_variables_tf_path = base_dir / "example_variables.tf"
67
+ example_variables_tf_path.write_text(example_variables_tf)
68
+ output = update_examples(
69
+ UpdateExamples(
70
+ examples_base_dir=base_dir,
71
+ var_descriptions={
72
+ "cluster_name": "description of cluster name",
73
+ "replication_specs": "Updated description",
74
+ },
75
+ )
76
+ )
77
+ assert output.before_var_descriptions == {
78
+ "cluster_name": "",
79
+ "provider_name": "",
80
+ "replication_specs": "List of replication specifications in legacy mongodbatlas_cluster format",
81
+ }
82
+ assert len(output.changes) == 3 # noqa: PLR2004
83
+ assert [
84
+ ("cluster_name", True),
85
+ ("provider_name", False),
86
+ ("replication_specs", True),
87
+ ] == [(change.name, change.changed) for change in output.changes]
88
+ file_regression.check(example_variables_tf_path.read_text(), extension=".tf")
89
+
90
+
91
+ @pytest.mark.skipif(os.environ.get("TF_FILE", "") == "", reason="needs os.environ[TF_FILE]")
92
+ def test_parsing_tf_file():
93
+ file = Path(os.environ["TF_FILE"])
94
+ assert file.exists()
95
+ response, _ = update_descriptions(file, {}, block_type=BLOCK_TYPE_VARIABLE)
96
+ assert response
@@ -27,7 +27,7 @@ from atlas_init.settings.path import (
27
27
 
28
28
  logger = logging.getLogger(__name__)
29
29
 
30
- GH_TOKEN_ENV_NAME = "GH_TOKEN" # noqa: S105
30
+ GH_TOKEN_ENV_NAME = "GH_TOKEN" # noqa: S105 #nosec
31
31
  GITHUB_CI_RUN_LOGS_ENV_NAME = "GITHUB_CI_RUN_LOGS"
32
32
  GITHUB_CI_SUMMARY_DIR_ENV_NAME = "GITHUB_CI_SUMMARY_DIR_ENV_NAME"
33
33
  REQUIRED_GH_ENV_VARS = [GH_TOKEN_ENV_NAME, GITHUB_CI_RUN_LOGS_ENV_NAME]
@@ -88,7 +88,7 @@ def find_test_runs(
88
88
  repository = tf_repo()
89
89
  for workflow in repository.get_workflow_runs(
90
90
  created=f">{since.strftime('%Y-%m-%d')}",
91
- branch=branch,
91
+ branch=branch, # type: ignore
92
92
  exclude_pull_requests=True, # type: ignore
93
93
  ):
94
94
  if not include_workflow(workflow):
@@ -130,7 +130,10 @@ def parse_job_logs(job: WorkflowJob, logs_path: Path) -> list[GoTestRun]:
130
130
  if job.conclusion in {"skipped", "cancelled", None}:
131
131
  return []
132
132
  step, logs_lines = select_step_and_log_content(job, logs_path)
133
- return list(parse(logs_lines, job, step))
133
+ test_runs = list(parse(logs_lines, job, step))
134
+ for run in test_runs:
135
+ run.log_path = logs_path
136
+ return test_runs
134
137
 
135
138
 
136
139
  def download_job_safely(workflow_dir: Path, job: WorkflowJob) -> Path | None:
@@ -18,7 +18,7 @@ logger = logging.getLogger(__name__)
18
18
 
19
19
  class GoTestStatus(StrEnum):
20
20
  RUN = "RUN"
21
- PASS = "PASS" # noqa: S105
21
+ PASS = "PASS" # noqa: S105 #nosec
22
22
  FAIL = "FAIL"
23
23
  SKIP = "SKIP"
24
24
 
@@ -55,6 +55,25 @@ class GoTestContext(Entity):
55
55
  # return cls(name=name, steps=steps)
56
56
 
57
57
 
58
+ def extract_group_name(log_path: Path | None) -> str:
59
+ """
60
+ >>> extract_group_name(
61
+ ... Path(
62
+ ... "40216340925_tests-1.11.x-latest_tests-1.11.x-latest-false_search_deployment.txt"
63
+ ... )
64
+ ... )
65
+ 'search_deployment'
66
+ >>> extract_group_name(None)
67
+ ''
68
+ """
69
+ if log_path is None:
70
+ return ""
71
+ if "-" not in log_path.name:
72
+ return ""
73
+ last_part = log_path.stem.split("-")[-1]
74
+ return "_".join(last_part.split("_")[1:]) if "_" in last_part else last_part
75
+
76
+
58
77
  @total_ordering
59
78
  class GoTestRun(Entity):
60
79
  name: str
@@ -112,6 +131,10 @@ class GoTestRun(Entity):
112
131
  def is_pass(self) -> bool:
113
132
  return self.status == GoTestStatus.PASS
114
133
 
134
+ @property
135
+ def group_name(self) -> str:
136
+ return extract_group_name(self.log_path)
137
+
115
138
  def add_line_match(self, match: LineMatch, line: str, line_number: int) -> None:
116
139
  self.run_seconds = match.run_seconds or self.run_seconds
117
140
  self.finish_line = LineInfo(number=line_number, text=line)
@@ -43,6 +43,10 @@ class GoTestSummary(Entity):
43
43
  def success_rate_human(self) -> str:
44
44
  return f"{self.success_rate:.2%}"
45
45
 
46
+ @property
47
+ def group_name(self) -> str:
48
+ return next((r.group_name for r in self.results if r.group_name), "unknown-group")
49
+
46
50
  def last_pass_human(self) -> str:
47
51
  return next(
48
52
  (f"Passed {test.when}" for test in reversed(self.results) if test.status == GoTestStatus.PASS),
@@ -124,7 +128,9 @@ def create_detailed_summary(
124
128
  test_summary_path = summary_dir_path / f"{summary.success_rate_human}_{summary.name}.md"
125
129
  test_summary_md = summary_str(summary, start_test_date, end_test_date)
126
130
  file_utils.ensure_parents_write_text(test_summary_path, test_summary_md)
127
- top_level_summary.append(f"- {summary.name} ({summary.success_rate_human}) ({summary.last_pass_human()})")
131
+ top_level_summary.append(
132
+ f"- {summary.name} - {summary.group_name} ({summary.success_rate_human}) ({summary.last_pass_human()}) ('{test_summary_path}')"
133
+ )
128
134
  return top_level_summary
129
135
 
130
136
 
@@ -0,0 +1,144 @@
1
+ import logging
2
+ from collections import defaultdict
3
+ from copy import deepcopy
4
+ from pathlib import Path
5
+
6
+ import hcl2
7
+ from lark import Token, Tree, UnexpectedToken
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+ BLOCK_TYPE_VARIABLE = "variable"
12
+ BLOCK_TYPE_OUTPUT = "output"
13
+
14
+
15
+ def process_token(node: Token, indent=0):
16
+ logger.debug(f"[{indent}] (token)\t|", " " * indent, node.type, node.value)
17
+ return deepcopy(node)
18
+
19
+
20
+ def is_identifier_block_type(tree: Tree | Token, block_type: str) -> bool:
21
+ if not isinstance(tree, Tree):
22
+ return False
23
+ try:
24
+ return tree.children[0].value == block_type # type: ignore
25
+ except (IndexError, AttributeError):
26
+ return False
27
+
28
+
29
+ def is_block_type(tree: Tree, block_type: str) -> bool:
30
+ try:
31
+ return tree.data == "block" and is_identifier_block_type(tree.children[0], block_type)
32
+ except (IndexError, AttributeError):
33
+ return False
34
+
35
+
36
+ def update_description(tree: Tree, new_descriptions: dict[str, str], existing_names: dict[str, list[str]]) -> Tree:
37
+ new_children = tree.children.copy()
38
+ variable_body = new_children[2]
39
+ assert variable_body.data == "body"
40
+ name = token_name(new_children[1])
41
+ old_description = read_description_attribute(variable_body)
42
+ existing_names[name].append(old_description)
43
+ new_description = new_descriptions.get(name, "")
44
+ if not new_description:
45
+ logger.debug(f"no description found for variable {name}")
46
+ return tree
47
+ new_children[2] = update_body_with_description(variable_body, new_description)
48
+ return Tree(tree.data, new_children)
49
+
50
+
51
+ def token_name(token: Token | Tree) -> str:
52
+ if isinstance(token, Token):
53
+ return token.value.strip('"')
54
+ err_msg = f"unexpected token type {type(token)} for token name"
55
+ raise ValueError(err_msg)
56
+
57
+
58
+ def has_attribute_description(maybe_attribute: Token | Tree) -> bool:
59
+ if not isinstance(maybe_attribute, Tree):
60
+ return False
61
+ return maybe_attribute.data == "attribute" and maybe_attribute.children[0].children[0].value == "description" # type: ignore
62
+
63
+
64
+ def update_body_with_description(tree: Tree, new_description: str) -> Tree:
65
+ new_description = new_description.replace('"', '\\"')
66
+ new_children = tree.children.copy()
67
+ found_description = False
68
+ for i, maybe_attribute in enumerate(new_children):
69
+ if has_attribute_description(maybe_attribute):
70
+ found_description = True
71
+ new_children[i] = create_description_attribute(new_description)
72
+ if not found_description:
73
+ new_children.insert(0, new_line())
74
+ new_children.insert(1, create_description_attribute(new_description))
75
+ return Tree(tree.data, new_children)
76
+
77
+
78
+ def new_line() -> Tree:
79
+ return Tree(
80
+ Token("RULE", "new_line_or_comment"),
81
+ [Token("NL_OR_COMMENT", "\n ")],
82
+ )
83
+
84
+
85
+ def read_description_attribute(tree: Tree) -> str:
86
+ return next(
87
+ (
88
+ token_name(maybe_attribute.children[-1].children[0])
89
+ for maybe_attribute in tree.children
90
+ if has_attribute_description(maybe_attribute)
91
+ ),
92
+ "",
93
+ )
94
+
95
+
96
+ def create_description_attribute(description_value: str) -> Tree:
97
+ children = [
98
+ Tree(Token("RULE", "identifier"), [Token("NAME", "description")]),
99
+ Token("EQ", " ="),
100
+ Tree(Token("RULE", "expr_term"), [Token("STRING_LIT", f'"{description_value}"')]),
101
+ ]
102
+ return Tree(Token("RULE", "attribute"), children)
103
+
104
+
105
+ def process_descriptions(
106
+ node: Tree,
107
+ name_updates: dict[str, str],
108
+ existing_names: dict[str, list[str]],
109
+ depth=0,
110
+ *,
111
+ block_type: str,
112
+ ) -> Tree:
113
+ new_children = []
114
+ logger.debug(f"[{depth}] (tree)\t|", " " * depth, node.data)
115
+ for child in node.children:
116
+ if isinstance(child, Tree):
117
+ if is_block_type(child, block_type):
118
+ child = update_description( # noqa: PLW2901
119
+ child, name_updates, existing_names
120
+ )
121
+ new_children.append(
122
+ process_descriptions(child, name_updates, existing_names, depth + 1, block_type=block_type)
123
+ )
124
+ else:
125
+ new_children.append(process_token(child, depth + 1))
126
+
127
+ return Tree(node.data, new_children)
128
+
129
+
130
+ def update_descriptions(tf_path: Path, new_names: dict[str, str], block_type: str) -> tuple[str, dict[str, list[str]]]:
131
+ try:
132
+ tree = hcl2.parses(tf_path.read_text()) # type: ignore
133
+ except UnexpectedToken as e:
134
+ logger.warning(f"failed to parse {tf_path}: {e}")
135
+ return "", {}
136
+ existing_descriptions = defaultdict(list)
137
+ new_tree = process_descriptions(
138
+ tree,
139
+ new_names,
140
+ existing_descriptions,
141
+ block_type=block_type,
142
+ )
143
+ new_tf = hcl2.writes(new_tree) # type: ignore
144
+ return new_tf, existing_descriptions
@@ -0,0 +1,25 @@
1
+ provider "mongodbatlas" {
2
+ public_key = var.public_key
3
+ private_key = var.private_key
4
+ }
5
+
6
+ module "cluster" {
7
+ source = "../../module_maintainer/v3"
8
+
9
+ cluster_name = var.cluster_name
10
+ cluster_type = var.cluster_type
11
+ mongo_db_major_version = var.mongo_db_major_version
12
+ project_id = var.project_id
13
+ replication_specs_new = var.replication_specs_new
14
+ tags = var.tags
15
+ }
16
+
17
+ output "mongodb_connection_strings" {
18
+ description = "new connection strings desc"
19
+ value = module.cluster.mongodb_connection_strings
20
+ }
21
+
22
+ output "with_desc" {
23
+ value = "with_desc"
24
+ description = "description new"
25
+ }
@@ -0,0 +1,24 @@
1
+ variable "cluster_name" {
2
+ description = "description of \"cluster\" name"
3
+ type = string
4
+ }
5
+ variable "replication_specs" {
6
+ description = "List of replication specifications in legacy mongodbatlas_cluster format"
7
+ default = []
8
+ type = list(object({
9
+ num_shards = number
10
+ zone_name = string
11
+ regions_config = set(object({
12
+ region_name = string
13
+ electable_nodes = number
14
+ priority = number
15
+ read_only_nodes = optional(number, 0)
16
+ }))
17
+ }))
18
+ }
19
+
20
+ variable "provider_name" {
21
+ description = "azure/aws/gcp"
22
+ type = string
23
+ default = ""# optional in v3
24
+ }
@@ -0,0 +1,95 @@
1
+ import pytest
2
+
3
+ from atlas_init.cli_tf.hcl.modifier import BLOCK_TYPE_OUTPUT, BLOCK_TYPE_VARIABLE, update_descriptions
4
+
5
+ example_variables_tf = """variable "cluster_name" {
6
+ type = string
7
+ }
8
+ variable "replication_specs" {
9
+ description = "List of replication specifications in legacy mongodbatlas_cluster format"
10
+ default = []
11
+ type = list(object({
12
+ num_shards = number
13
+ zone_name = string
14
+ regions_config = set(object({
15
+ region_name = string
16
+ electable_nodes = number
17
+ priority = number
18
+ read_only_nodes = optional(number, 0)
19
+ }))
20
+ }))
21
+ }
22
+
23
+ variable "provider_name" {
24
+ type = string
25
+ default = "" # optional in v3
26
+ }
27
+ """
28
+
29
+ _existing_descriptions_variables = {
30
+ "cluster_name": [""],
31
+ "provider_name": [""],
32
+ "replication_specs": ["List of replication specifications in legacy "],
33
+ }
34
+
35
+ example_outputs_tf = """provider "mongodbatlas" {
36
+ public_key = var.public_key
37
+ private_key = var.private_key
38
+ }
39
+
40
+ module "cluster" {
41
+ source = "../../module_maintainer/v3"
42
+
43
+ cluster_name = var.cluster_name
44
+ cluster_type = var.cluster_type
45
+ mongo_db_major_version = var.mongo_db_major_version
46
+ project_id = var.project_id
47
+ replication_specs_new = var.replication_specs_new
48
+ tags = var.tags
49
+ }
50
+
51
+ output "mongodb_connection_strings" {
52
+ value = module.cluster.mongodb_connection_strings
53
+ }
54
+
55
+ output "with_desc" {
56
+ value = "with_desc"
57
+ description = "description old"
58
+ }
59
+ """
60
+ _existing_descriptions_outputs = {
61
+ "mongodb_connection_strings": [""],
62
+ "with_desc": ["description old"],
63
+ }
64
+
65
+
66
+ @pytest.mark.parametrize(
67
+ ("block_type", "new_names", "existing_descriptions", "tf_config"),
68
+ [
69
+ (
70
+ BLOCK_TYPE_VARIABLE,
71
+ {
72
+ "cluster_name": 'description of "cluster" name',
73
+ "provider_name": "azure/aws/gcp",
74
+ },
75
+ _existing_descriptions_variables,
76
+ example_variables_tf,
77
+ ),
78
+ (
79
+ BLOCK_TYPE_OUTPUT,
80
+ {
81
+ "with_desc": "description new",
82
+ "mongodb_connection_strings": "new connection strings desc",
83
+ },
84
+ _existing_descriptions_outputs,
85
+ example_outputs_tf,
86
+ ),
87
+ ],
88
+ ids=[BLOCK_TYPE_VARIABLE, BLOCK_TYPE_OUTPUT],
89
+ )
90
+ def test_process_variables(tmp_path, file_regression, block_type, new_names, existing_descriptions, tf_config):
91
+ example_tf_path = tmp_path / "example.tf"
92
+ example_tf_path.write_text(tf_config)
93
+ new_tf, existing_descriptions = update_descriptions(example_tf_path, new_names, block_type=block_type)
94
+ file_regression.check(new_tf, extension=".tf")
95
+ assert dict(existing_descriptions.items()) == existing_descriptions
@@ -108,7 +108,7 @@ def iter_blocks(block: Block, level: int | None = None) -> Iterable[Block]:
108
108
  hcl="\n".join(block_lines),
109
109
  )
110
110
  if line_level_start_names.get(level) is not None:
111
- raise ValueError(f"Unfinished block @ {line_nr} in {block.name} at level {level}")
111
+ raise ValueError(f"Unfinished block @ {line_nr} in {block.name} at level {level}") # pyright: ignore
112
112
 
113
113
 
114
114
  def hcl_attrs(block: Block) -> dict[str, str]:
@@ -0,0 +1,29 @@
1
+ import logging
2
+ from pathlib import Path
3
+
4
+ import typer
5
+
6
+ logger = logging.getLogger(__name__)
7
+ SPLIT_STR = "mongodbatlas: "
8
+
9
+
10
+ def remove_prefix(line: str) -> str:
11
+ """
12
+ >>> remove_prefix(
13
+ ... "2025-02-14T15:47:14.157Z [DEBUG] provider.terraform-provider-mongodbatlas: {"
14
+ ... )
15
+ '{'
16
+ >>> remove_prefix(
17
+ ... '2025-02-14T15:47:14.158Z [DEBUG] provider.terraform-provider-mongodbatlas: "biConnector": {'
18
+ ... )
19
+ ' "biConnector": {'
20
+ """
21
+ return line if SPLIT_STR not in line else line.split(SPLIT_STR, 1)[1]
22
+
23
+
24
+ def log_clean(log_path: str = typer.Argument(..., help="Path to the log file")):
25
+ log_path_parsed = Path(log_path)
26
+ assert log_path_parsed.exists(), f"file not found: {log_path}"
27
+ new_lines = [remove_prefix(line) for line in log_path_parsed.read_text().splitlines()]
28
+ log_path_parsed.write_text("\n".join(new_lines))
29
+ logger.info(f"cleaned log file: {log_path}")
@@ -1,6 +1,4 @@
1
- # import typer
2
-
3
-
1
+ # pyright: reportIncompatibleMethodOverride=none
4
2
  from collections import defaultdict
5
3
  from collections.abc import Iterable
6
4
  from functools import total_ordering
@@ -210,7 +210,7 @@ class Schema(BaseModelLocal):
210
210
 
211
211
 
212
212
  class Resource(BaseModelLocal):
213
- schema: Schema
213
+ schema: Schema # pyright: ignore
214
214
  name: SnakeCaseString
215
215
 
216
216
  @property
atlas_init/repos/path.py CHANGED
@@ -16,6 +16,15 @@ _KNOWN_OWNER_PROJECTS = {
16
16
  }
17
17
 
18
18
 
19
+ def package_glob(package_path: str) -> str:
20
+ return f"{package_path}/*.go"
21
+
22
+
23
+ def go_package_prefix(repo_path: Path) -> str:
24
+ owner_project = owner_project_name(repo_path)
25
+ return f"github.com/{owner_project}"
26
+
27
+
19
28
  def _owner_project_name(repo_path: Path) -> str:
20
29
  owner_project = owner_project_name(repo_path)
21
30
  if owner_project not in _KNOWN_OWNER_PROJECTS:
@@ -61,6 +70,11 @@ class Repo(StrEnum):
61
70
  TF = "tf"
62
71
 
63
72
 
73
+ def as_repo_alias(path: Path) -> Repo:
74
+ owner = owner_project_name(path)
75
+ return _owner_lookup(owner)
76
+
77
+
64
78
  _owner_repos = {
65
79
  GH_OWNER_TERRAFORM_PROVIDER_MONGODBATLAS: Repo.TF,
66
80
  GH_OWNER_MONGODBATLAS_CLOUDFORMATION_RESOURCES: Repo.CFN,
@@ -2,22 +2,22 @@ from __future__ import annotations
2
2
 
3
3
  import fnmatch
4
4
  import logging
5
+ from collections import defaultdict
5
6
  from collections.abc import Iterable
6
7
  from functools import total_ordering
7
8
  from os import getenv
8
9
  from pathlib import Path
9
10
  from typing import Any
10
11
 
11
- from model_lib import Entity, dump_ignore_falsy
12
+ from model_lib import Entity, IgnoreFalsy
12
13
  from pydantic import Field, model_validator
13
14
 
14
- from atlas_init.repos.path import owner_project_name
15
+ from atlas_init.repos.path import as_repo_alias, go_package_prefix, owner_project_name, package_glob
15
16
 
16
17
  logger = logging.getLogger(__name__)
17
18
 
18
19
 
19
- @dump_ignore_falsy
20
- class TerraformVars(Entity):
20
+ class TerraformVars(IgnoreFalsy):
21
21
  cluster_info: bool = False
22
22
  cluster_info_m10: bool = False
23
23
  stream_instance: bool = False
@@ -28,6 +28,7 @@ class TerraformVars(Entity):
28
28
  use_aws_vpc: bool = False
29
29
  use_aws_s3: bool = False
30
30
  use_federated_vars: bool = False
31
+ use_encryption_at_rest: bool = False
31
32
 
32
33
  def __add__(self, other: TerraformVars): # type: ignore
33
34
  assert isinstance(other, TerraformVars) # type: ignore
@@ -59,6 +60,8 @@ class TerraformVars(Entity):
59
60
  config["use_project_extra"] = True
60
61
  if self.use_federated_vars:
61
62
  config["use_federated_vars"] = True
63
+ if self.use_encryption_at_rest:
64
+ config["use_encryption_at_rest"] = True
62
65
  if self.stream_instance:
63
66
  # hack until backend bug with stream instance is fixed
64
67
  config["stream_instance_config"] = {"name": getenv("ATLAS_STREAM_INSTANCE_NAME", "atlas-init")}
@@ -70,15 +73,13 @@ class PyHook(Entity):
70
73
  locate: str
71
74
 
72
75
 
73
- @dump_ignore_falsy
74
76
  @total_ordering
75
- class TestSuite(Entity):
77
+ class TestSuite(IgnoreFalsy):
76
78
  __test__ = False
77
79
 
78
80
  name: str
79
81
  sequential_tests: bool = False
80
82
  repo_go_packages: dict[str, list[str]] = Field(default_factory=dict)
81
- repo_globs: dict[str, list[str]] = Field(default_factory=dict)
82
83
  vars: TerraformVars = Field(default_factory=TerraformVars) # type: ignore
83
84
  post_apply_hooks: list[PyHook] = Field(default_factory=list)
84
85
 
@@ -87,13 +88,23 @@ class TestSuite(Entity):
87
88
  raise TypeError
88
89
  return self.name < other.name
89
90
 
90
- def all_globs(self, repo_alias: str) -> list[str]:
91
- go_packages = self.repo_go_packages.get(repo_alias, [])
92
- return self.repo_globs.get(repo_alias, []) + [f"{pkg}/*.go" for pkg in go_packages] + go_packages
91
+ def package_url_tests(self, repo_path: Path, prefix: str = "") -> dict[str, dict[str, Path]]:
92
+ alias = as_repo_alias(repo_path)
93
+ packages = self.repo_go_packages.get(alias, [])
94
+ names = defaultdict(dict)
95
+ for package in packages:
96
+ pkg_name = f"{go_package_prefix(repo_path)}/{package}"
97
+ for go_file in repo_path.glob(f"{package}/*.go"):
98
+ with go_file.open() as f:
99
+ for line in f:
100
+ if line.startswith(f"func {prefix}"):
101
+ test_name = line.split("(")[0].strip().removeprefix("func ")
102
+ names[pkg_name][test_name] = go_file.parent
103
+ return names
93
104
 
94
105
  def is_active(self, repo_alias: str, change_paths: Iterable[str]) -> bool:
95
106
  """changes paths should be relative to the repo"""
96
- globs = self.all_globs(repo_alias)
107
+ globs = [package_glob(pkg) for pkg in self.repo_go_packages.get(repo_alias, [])]
97
108
  return any(any(fnmatch.fnmatch(path, glob) for glob in globs) for path in change_paths)
98
109
 
99
110
  def cwd_is_repo_go_pkg(self, cwd: Path, repo_alias: str) -> bool:
@@ -145,9 +156,9 @@ class AtlasInitConfig(Entity):
145
156
  @model_validator(mode="after")
146
157
  def ensure_all_repo_aliases_are_found(self):
147
158
  missing_aliases = set()
148
- aliases = set(self.repo_aliases.keys())
159
+ aliases = set(self.repo_aliases.values())
149
160
  for group in self.test_suites:
150
- if more_missing := group.repo_globs.keys() - aliases:
161
+ if more_missing := (group.repo_go_packages.keys() - aliases):
151
162
  logger.warning(f"repo aliases not found for group={group.name}: {more_missing}")
152
163
  missing_aliases |= more_missing
153
164
  if missing_aliases:
@@ -254,7 +254,7 @@ class AtlasInitSettings(AtlasInitPaths, ExternalSettings):
254
254
  return variables
255
255
 
256
256
 
257
- def active_suites(settings: AtlasInitSettings) -> list[TestSuite]:
257
+ def active_suites(settings: AtlasInitSettings) -> list[TestSuite]: # type: ignore
258
258
  repo_path, cwd_rel_path = repo_path_rel_path()
259
259
  return config_active_suites(settings.config, repo_path, cwd_rel_path, settings.test_suites_parsed)
260
260