atlas-init 0.1.1__py3-none-any.whl → 0.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. atlas_init/__init__.py +3 -3
  2. atlas_init/atlas_init.yaml +18 -1
  3. atlas_init/cli.py +62 -70
  4. atlas_init/cli_cfn/app.py +40 -117
  5. atlas_init/cli_cfn/{cfn.py → aws.py} +129 -14
  6. atlas_init/cli_cfn/cfn_parameter_finder.py +89 -6
  7. atlas_init/cli_cfn/example.py +203 -0
  8. atlas_init/cli_cfn/files.py +63 -0
  9. atlas_init/cli_helper/run.py +18 -2
  10. atlas_init/cli_helper/tf_runner.py +4 -6
  11. atlas_init/cli_root/__init__.py +0 -0
  12. atlas_init/cli_root/trigger.py +153 -0
  13. atlas_init/cli_tf/app.py +211 -4
  14. atlas_init/cli_tf/changelog.py +103 -0
  15. atlas_init/cli_tf/debug_logs.py +221 -0
  16. atlas_init/cli_tf/debug_logs_test_data.py +253 -0
  17. atlas_init/cli_tf/github_logs.py +229 -0
  18. atlas_init/cli_tf/go_test_run.py +194 -0
  19. atlas_init/cli_tf/go_test_run_format.py +31 -0
  20. atlas_init/cli_tf/go_test_summary.py +144 -0
  21. atlas_init/cli_tf/hcl/__init__.py +0 -0
  22. atlas_init/cli_tf/hcl/cli.py +161 -0
  23. atlas_init/cli_tf/hcl/cluster_mig.py +348 -0
  24. atlas_init/cli_tf/hcl/parser.py +140 -0
  25. atlas_init/cli_tf/schema.py +222 -18
  26. atlas_init/cli_tf/schema_go_parser.py +236 -0
  27. atlas_init/cli_tf/schema_table.py +150 -0
  28. atlas_init/cli_tf/schema_table_models.py +155 -0
  29. atlas_init/cli_tf/schema_v2.py +599 -0
  30. atlas_init/cli_tf/schema_v2_api_parsing.py +298 -0
  31. atlas_init/cli_tf/schema_v2_sdk.py +361 -0
  32. atlas_init/cli_tf/schema_v3.py +222 -0
  33. atlas_init/cli_tf/schema_v3_sdk.py +279 -0
  34. atlas_init/cli_tf/schema_v3_sdk_base.py +68 -0
  35. atlas_init/cli_tf/schema_v3_sdk_create.py +216 -0
  36. atlas_init/humps.py +253 -0
  37. atlas_init/repos/cfn.py +6 -1
  38. atlas_init/repos/path.py +3 -3
  39. atlas_init/settings/config.py +14 -4
  40. atlas_init/settings/env_vars.py +16 -1
  41. atlas_init/settings/path.py +12 -1
  42. atlas_init/settings/rich_utils.py +2 -0
  43. atlas_init/terraform.yaml +77 -1
  44. atlas_init/tf/.terraform.lock.hcl +59 -83
  45. atlas_init/tf/always.tf +7 -0
  46. atlas_init/tf/main.tf +3 -0
  47. atlas_init/tf/modules/aws_s3/provider.tf +1 -1
  48. atlas_init/tf/modules/aws_vars/aws_vars.tf +2 -0
  49. atlas_init/tf/modules/aws_vpc/provider.tf +4 -1
  50. atlas_init/tf/modules/cfn/cfn.tf +47 -33
  51. atlas_init/tf/modules/cfn/kms.tf +54 -0
  52. atlas_init/tf/modules/cfn/resource_actions.yaml +1 -0
  53. atlas_init/tf/modules/cfn/variables.tf +31 -0
  54. atlas_init/tf/modules/cloud_provider/cloud_provider.tf +1 -0
  55. atlas_init/tf/modules/cloud_provider/provider.tf +1 -1
  56. atlas_init/tf/modules/cluster/cluster.tf +34 -24
  57. atlas_init/tf/modules/cluster/provider.tf +1 -1
  58. atlas_init/tf/modules/federated_vars/federated_vars.tf +3 -0
  59. atlas_init/tf/modules/federated_vars/provider.tf +1 -1
  60. atlas_init/tf/modules/project_extra/project_extra.tf +15 -1
  61. atlas_init/tf/modules/stream_instance/stream_instance.tf +1 -1
  62. atlas_init/tf/modules/vpc_peering/vpc_peering.tf +1 -1
  63. atlas_init/tf/modules/vpc_privatelink/versions.tf +1 -1
  64. atlas_init/tf/outputs.tf +11 -3
  65. atlas_init/tf/providers.tf +2 -1
  66. atlas_init/tf/variables.tf +12 -0
  67. atlas_init/typer_app.py +76 -0
  68. {atlas_init-0.1.1.dist-info → atlas_init-0.1.4.dist-info}/METADATA +36 -18
  69. atlas_init-0.1.4.dist-info/RECORD +91 -0
  70. {atlas_init-0.1.1.dist-info → atlas_init-0.1.4.dist-info}/WHEEL +1 -1
  71. atlas_init-0.1.1.dist-info/RECORD +0 -62
  72. /atlas_init/tf/modules/aws_vpc/{aws-vpc.tf → aws_vpc.tf} +0 -0
  73. {atlas_init-0.1.1.dist-info → atlas_init-0.1.4.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,221 @@
1
+ import json
2
+ import logging
3
+ import re
4
+ from contextlib import suppress
5
+ from typing import Any, NamedTuple, Self
6
+
7
+ from model_lib import Entity
8
+ from pydantic import ValidationError, model_validator
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ def parsed(payload: str) -> tuple[dict[str, Any], list]:
14
+ with suppress(ValueError):
15
+ resp = json.loads(payload)
16
+ if isinstance(resp, dict):
17
+ return resp, []
18
+ if isinstance(resp, list):
19
+ return {}, resp
20
+ raise ValueError(f"Could not parse payload: {payload}")
21
+
22
+
23
+ class PathHeadersPayload(Entity):
24
+ method: str
25
+ path: str
26
+ http_protocol: str
27
+ headers: dict[str, str]
28
+ text: str
29
+
30
+ @property
31
+ def expect_list_response(self) -> bool:
32
+ return self.method == "GET" and self.path.endswith("s") and all(not c.isdigit() for c in self.path)
33
+
34
+
35
+ def parse_request(request_lines: list[str]) -> PathHeadersPayload:
36
+ path_line, *header_lines_payload = request_lines
37
+ headers_end = header_lines_payload.index("")
38
+ header_lines = header_lines_payload[:headers_end]
39
+ payload_lines = header_lines_payload[headers_end + 1 :]
40
+ payload_end = payload_lines.index("")
41
+ payload_lines = payload_lines[:payload_end]
42
+ method, path, http_protocol = path_line.split(" ")
43
+ return PathHeadersPayload(
44
+ method=method,
45
+ http_protocol=http_protocol,
46
+ path=path,
47
+ headers=dict(header_line.split(": ", 1) for header_line in header_lines),
48
+ text="\n".join(payload_lines),
49
+ )
50
+
51
+
52
+ class StatusHeadersResponse(Entity):
53
+ http_protocol: str
54
+ status: int
55
+ status_text: str
56
+ headers: dict[str, str]
57
+ text: str
58
+
59
+
60
+ def parse_response(response_lines: list[str]) -> StatusHeadersResponse:
61
+ http_protocol_status, *header_lines_response = response_lines
62
+ http_protocol, status, status_text = http_protocol_status.split(" ", 2)
63
+ headers_end = header_lines_response.index("")
64
+ header_lines = header_lines_response[:headers_end]
65
+ response = header_lines_response[headers_end + 1 :]
66
+ return StatusHeadersResponse(
67
+ http_protocol=http_protocol,
68
+ status=status, # type: ignore
69
+ status_text=status_text,
70
+ headers=dict(header_line.split(": ", 1) for header_line in header_lines),
71
+ text="\n".join(response),
72
+ )
73
+
74
+
75
+ # application/vnd.atlas.2024-08-05+json;charset=utf-8
76
+
77
+ _version_date_pattern = re.compile(r"(\d{4}-\d{2}-\d{2})")
78
+
79
+
80
+ def extract_version(content_type: str) -> str:
81
+ if match := _version_date_pattern.search(content_type):
82
+ return match.group(1)
83
+ raise ValueError(f"Could not extract version from {content_type} header")
84
+
85
+
86
+ class SDKRoundtrip(Entity):
87
+ request: PathHeadersPayload
88
+ response: StatusHeadersResponse
89
+ resp_index: int
90
+ step_number: int
91
+
92
+ @property
93
+ def id(self) -> str:
94
+ return f"{self.request.method}_{self.request.path}_{self.version}"
95
+
96
+ @property
97
+ def version(self) -> str:
98
+ content_type = self.response.headers.get("Content-Type", "v1")
99
+ try:
100
+ return extract_version(content_type)
101
+ except ValueError:
102
+ logger.warning(f"failed to extract version from response header ({content_type}), trying request")
103
+ content_type = self.request.headers.get("Accept", "v1")
104
+ return extract_version(content_type)
105
+
106
+ @model_validator(mode="after")
107
+ def ensure_match(self) -> Self:
108
+ req = self.request
109
+ resp = self.response
110
+ _, resp_payload_list = parsed(resp.text)
111
+ if req.expect_list_response and not resp_payload_list:
112
+ raise ValueError(f"Expected list response but got dict: {resp.text}")
113
+ return self
114
+
115
+
116
+ MARKER_END = "-----------------------------------"
117
+ MARKER_REQUEST_START = "---[ REQUEST ]"
118
+ MARKER_RESPONSE_START = "---[ RESPONSE ]----"
119
+ MARKER_START_STEP = "Starting TestStep: "
120
+ MARKER_TEST = "Starting TestSteps: "
121
+
122
+
123
+ class FileRef(NamedTuple):
124
+ index: int
125
+ line_start: int
126
+ line_end: int
127
+
128
+
129
+ def parse_http_requests(logs: str) -> list[SDKRoundtrip]:
130
+ """
131
+ Problem: With requests that are done in parallel.
132
+ An alternative is to use parallel 1 but it will be slow
133
+ Methods: (rejected)
134
+ 1. Look for match from `path` to the something in the payload
135
+ 2. Use the X-Java-Method header to match the response with the path
136
+ 3. X-Envoy-Upstream-Service-Time to match it
137
+
138
+ Method: (accepted)
139
+ Can say that expected payload is either a list or a dict and if it ends with an identifier it is higher chance for a dict
140
+ """
141
+ test_count = logs.count(MARKER_TEST)
142
+ assert test_count == 1, f"Only one test is supported, found {test_count}"
143
+ requests, responses = parse_raw_req_responses(logs)
144
+ tf_step_starts = [i for i, line in enumerate(logs.splitlines()) if MARKER_START_STEP in line]
145
+ used_responses: set[int] = set()
146
+ responses_list: list[StatusHeadersResponse] = list(responses.values())
147
+ sdk_roundtrips = []
148
+ for ref, request in requests.items():
149
+ roundtrip = match_request(used_responses, responses_list, ref, request, tf_step_starts)
150
+ sdk_roundtrips.append(roundtrip)
151
+ used_responses.add(roundtrip.resp_index)
152
+ return sdk_roundtrips
153
+
154
+
155
+ def find_step_number(ref: FileRef, step_starts: list[int]) -> int:
156
+ for i, step_start in enumerate(reversed(step_starts)):
157
+ if step_start < ref.line_start:
158
+ return len(step_starts) - i
159
+ logger.warning(f"Could not find step start for {ref}")
160
+ return 0
161
+
162
+
163
+ def match_request(
164
+ used_responses: set[int],
165
+ responses_list: list[StatusHeadersResponse],
166
+ ref: FileRef,
167
+ request: PathHeadersPayload,
168
+ step_starts: list[int],
169
+ ) -> SDKRoundtrip:
170
+ for i, response in enumerate(responses_list):
171
+ if i in used_responses:
172
+ continue
173
+ with suppress(ValidationError):
174
+ step_number = find_step_number(ref, step_starts)
175
+ return SDKRoundtrip(request=request, response=response, resp_index=i, step_number=step_number)
176
+ remaining_responses = [resp for i, resp in enumerate(responses_list) if i not in used_responses]
177
+ err_msg = f"Could not match request {ref} with any response\n\n{request}\n\n\nThere are #{len(remaining_responses)} responses left that doesn't match\n{'-'*80}\n{'\n'.join(r.text for r in remaining_responses)}"
178
+ raise ValueError(err_msg)
179
+
180
+
181
+ def parse_raw_req_responses(
182
+ logs: str,
183
+ ) -> tuple[dict[FileRef, PathHeadersPayload], dict[FileRef, StatusHeadersResponse]]:
184
+ # sourcery skip: dict-comprehension
185
+ request_count = 0
186
+ response_count = 0
187
+ in_request = False
188
+ in_response = False
189
+ current_start = 0
190
+ requests: dict[FileRef, list[str]] = {}
191
+ responses: dict[FileRef, list[str]] = {}
192
+ log_lines = logs.splitlines()
193
+ for i, line in enumerate(log_lines):
194
+ if line.startswith(MARKER_REQUEST_START):
195
+ in_request = True
196
+ current_start = i + 1
197
+ elif line.startswith(MARKER_RESPONSE_START):
198
+ in_response = True
199
+ current_start = i + 1
200
+ if in_request and line.startswith(MARKER_END):
201
+ key = FileRef(index=request_count, line_start=current_start, line_end=i)
202
+ requests[key] = log_lines[current_start:i]
203
+ request_count += 1
204
+ in_request = False
205
+ if in_response and line.startswith(MARKER_END):
206
+ key = FileRef(index=request_count, line_start=current_start, line_end=i)
207
+ responses[key] = log_lines[current_start:i]
208
+ response_count += 1
209
+ in_response = False
210
+ assert not in_request, "Request not closed"
211
+ assert not in_response, "Response not closed"
212
+ assert (
213
+ request_count == response_count
214
+ ), f"Mismatch in request and response count: {request_count} != {response_count}"
215
+ parsed_requests = {}
216
+ for ref, request_lines in requests.items():
217
+ parsed_requests[ref] = parse_request(request_lines)
218
+ parsed_responses = {}
219
+ for ref, response_lines in responses.items():
220
+ parsed_responses[ref] = parse_response(response_lines)
221
+ return parsed_requests, parsed_responses
@@ -0,0 +1,253 @@
1
+ import json
2
+ import logging
3
+ from collections.abc import Callable
4
+ from typing import NamedTuple
5
+
6
+ from model_lib import Entity
7
+ from pydantic import Field, model_validator
8
+
9
+ from atlas_init.cli_tf.debug_logs import SDKRoundtrip
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ class StatusText(Entity):
15
+ status: int
16
+ text: str
17
+
18
+ @property
19
+ def id(self):
20
+ return f"{self.status}_{self.text}"
21
+
22
+
23
+ class RequestInfo(Entity):
24
+ version: str
25
+ method: str
26
+ path: str
27
+ text: str
28
+ responses: list[StatusText] = Field(default_factory=list)
29
+
30
+ @property
31
+ def id(self):
32
+ return "__".join( # noqa: FLY002
33
+ [
34
+ self.method,
35
+ self.path,
36
+ self.version,
37
+ self.text,
38
+ ] # need to include text to differentiate between requests
39
+ )
40
+
41
+
42
+ class StepRequests(Entity):
43
+ diff_requests: list[RequestInfo] = Field(default_factory=list)
44
+ request_responses: list[RequestInfo] = Field(default_factory=list)
45
+
46
+ def existing_request(self, info: RequestInfo) -> RequestInfo | None:
47
+ return next((r for r in self.request_responses if r.id == info.id), None)
48
+
49
+ def add_request(
50
+ self,
51
+ path: str,
52
+ method: str,
53
+ version: str,
54
+ status: int,
55
+ text: str,
56
+ text_response: str,
57
+ is_diff: bool,
58
+ ):
59
+ status_text = StatusText(status=status, text=text_response)
60
+ info = RequestInfo(
61
+ path=path,
62
+ method=method,
63
+ version=version,
64
+ text=text,
65
+ responses=[status_text],
66
+ )
67
+ if is_diff:
68
+ self.diff_requests.append(info)
69
+ if existing := self.existing_request(info):
70
+ existing.responses.append(status_text)
71
+ else:
72
+ self.request_responses.append(info)
73
+
74
+
75
+ class RTModifier(Entity):
76
+ version: str
77
+ method: str
78
+ path: str
79
+ modification: Callable[[SDKRoundtrip], None]
80
+
81
+ def match(self, rt: SDKRoundtrip, normalized_path: str) -> bool:
82
+ return rt.request.method == self.method and normalized_path == self.path and rt.version == self.version
83
+
84
+
85
+ class MockRequestData(Entity):
86
+ step_count: int
87
+ steps: list[StepRequests] = Field(default_factory=list, init=False)
88
+ variables: dict[str, str] = Field(default_factory=dict)
89
+
90
+ @model_validator(mode="after")
91
+ def set_steps(self):
92
+ self.steps = [StepRequests() for _ in range(self.step_count)]
93
+ return self
94
+
95
+ def add_roundtrip(
96
+ self,
97
+ rt: SDKRoundtrip,
98
+ normalized_path: str,
99
+ normalized_text: str,
100
+ normalized_response_text: str,
101
+ is_diff: bool,
102
+ ):
103
+ step = self.steps[rt.step_number - 1]
104
+ if rt.request.method == "PATCH":
105
+ logger.info(f"PATCH: {rt.request.path}")
106
+ step.add_request(
107
+ normalized_path,
108
+ rt.request.method,
109
+ rt.version,
110
+ rt.response.status,
111
+ normalized_text,
112
+ normalized_response_text,
113
+ is_diff,
114
+ )
115
+
116
+ def update_variables(self, variables: dict[str, str]) -> None:
117
+ if missing_value := sorted(name for name, value in variables.items() if not value):
118
+ err_msg = f"Missing values for variables: {missing_value}"
119
+ raise ValueError(err_msg)
120
+ changes: list[VariableChange] = []
121
+ for name, value in variables.items():
122
+ old_value = self.variables.get(name)
123
+ if old_value and old_value != value:
124
+ for suffix in range(2, 10):
125
+ new_name = f"{name}{suffix}"
126
+ old_value2 = self.variables.get(new_name, "")
127
+ if old_value2 and old_value2 != value:
128
+ continue
129
+ if not old_value2:
130
+ logger.warning(f"Adding variable {name} to {new_name}={value}")
131
+ change = VariableChange(name, new_name, old_value, value)
132
+ changes.append(change)
133
+ self.variables[new_name] = value
134
+ break
135
+ else:
136
+ raise ValueError(f"Too many variables with the same name and different values: {name}")
137
+ else:
138
+ self.variables[name] = value
139
+ if changes:
140
+ raise VariablesChangedError(changes)
141
+
142
+ def prune_duplicate_responses(self):
143
+ for step in self.steps:
144
+ for request in step.request_responses:
145
+ pruned_responses = []
146
+ seen_response_ids = set()
147
+ before_len = len(request.responses)
148
+ for response in request.responses:
149
+ if response.id in seen_response_ids:
150
+ continue
151
+ seen_response_ids.add(response.id)
152
+ pruned_responses.append(response)
153
+ request.responses = pruned_responses
154
+ after_len = len(request.responses)
155
+ if before_len != after_len:
156
+ logger.info(f"Pruned {before_len - after_len} duplicate responses from {request.id}")
157
+
158
+
159
+ class ApiSpecPath(Entity):
160
+ path: str
161
+
162
+ def variables(self, path: str) -> dict[str, str]:
163
+ return {
164
+ var[1:-1]: default
165
+ for var, default in zip(self.path.split("/"), path.split("/"), strict=False)
166
+ if var.startswith("{") and var.endswith("}")
167
+ }
168
+
169
+ def match(self, path: str) -> bool:
170
+ parts_expected = self.path.split("/")
171
+ parts_actual = path.split("/")
172
+ if len(parts_expected) != len(parts_actual):
173
+ return False
174
+ for expected, actual in zip(parts_expected, parts_actual, strict=False):
175
+ if expected == actual:
176
+ continue
177
+ if expected.startswith("{") and expected.endswith("}"):
178
+ continue
179
+ return False
180
+ return True
181
+
182
+
183
+ def find_normalized_path(path: str, api_spec_paths: list[ApiSpecPath]) -> ApiSpecPath:
184
+ if "?" in path:
185
+ path = path.split("?")[0]
186
+ path = path.rstrip("/") # remove trailing slash
187
+ for api_spec_path in api_spec_paths:
188
+ if api_spec_path.match(path):
189
+ return api_spec_path
190
+ raise ValueError(f"Could not find path: {path}")
191
+
192
+
193
+ def normalize_text(text: str, variables: dict[str, str]) -> str:
194
+ for var, value in variables.items():
195
+ text = text.replace(value, f"{{{var}}}")
196
+ if not text:
197
+ return text
198
+ try:
199
+ parsed_text = json.loads(text)
200
+ return json.dumps(parsed_text, indent=1, sort_keys=True)
201
+ except json.JSONDecodeError:
202
+ logger.warning(f"Could not parse text: {text}")
203
+ return text
204
+
205
+
206
+ def default_is_diff(rt: SDKRoundtrip) -> bool:
207
+ return rt.request.method not in {"DELETE", "GET"}
208
+
209
+
210
+ class VariableChange(NamedTuple):
211
+ var_name: str
212
+ new_var_name: str
213
+ old: str
214
+ new: str
215
+
216
+
217
+ class VariablesChangedError(Exception):
218
+ def __init__(self, changes: list[VariableChange]) -> None:
219
+ super().__init__(f"Variables changed: {changes}")
220
+ self.changes = changes
221
+
222
+
223
+ def create_mock_data(
224
+ roundtrips: list[SDKRoundtrip],
225
+ api_spec_paths: dict[str, list[ApiSpecPath]],
226
+ is_diff: Callable[[SDKRoundtrip], bool] | None = None,
227
+ modifiers: list[RTModifier] | None = None,
228
+ ) -> MockRequestData:
229
+ steps = max(rt.step_number for rt in roundtrips)
230
+ mock_data = MockRequestData(step_count=steps)
231
+ is_diff = is_diff or default_is_diff
232
+ modifiers = modifiers or []
233
+ for rt in roundtrips:
234
+ request_path = rt.request.path
235
+ method = rt.request.method
236
+ spec_path = find_normalized_path(request_path, api_spec_paths[method])
237
+ rt_variables = spec_path.variables(request_path)
238
+ normalized_path = spec_path.path
239
+ try:
240
+ mock_data.update_variables(rt_variables)
241
+ except VariablesChangedError as e:
242
+ for change in e.changes:
243
+ rt_variables.pop(change.var_name)
244
+ rt_variables[change.new_var_name] = change.new
245
+ normalized_path = normalize_text(request_path, rt_variables)
246
+ for modifier in modifiers:
247
+ if modifier.match(rt, normalized_path):
248
+ modifier.modification(rt)
249
+ normalized_text = normalize_text(rt.request.text, rt_variables)
250
+ normalized_response_text = normalize_text(rt.response.text, rt_variables)
251
+ mock_data.add_roundtrip(rt, normalized_path, normalized_text, normalized_response_text, is_diff(rt))
252
+ # requests.prune_duplicate_responses() better to keep duplicates to stay KISS
253
+ return mock_data
@@ -0,0 +1,229 @@
1
+ import logging
2
+ import os
3
+ from collections import defaultdict
4
+ from collections.abc import Callable
5
+ from concurrent.futures import Future, ThreadPoolExecutor, wait
6
+ from datetime import datetime
7
+ from functools import lru_cache
8
+ from pathlib import Path
9
+ from typing import NamedTuple
10
+
11
+ import requests
12
+ from github import Auth, Github
13
+ from github.Repository import Repository
14
+ from github.WorkflowJob import WorkflowJob
15
+ from github.WorkflowRun import WorkflowRun
16
+ from github.WorkflowStep import WorkflowStep
17
+ from zero_3rdparty import datetime_utils, file_utils
18
+
19
+ from atlas_init.cli_tf.go_test_run import GoTestRun, parse
20
+ from atlas_init.repos.path import (
21
+ GH_OWNER_TERRAFORM_PROVIDER_MONGODBATLAS,
22
+ )
23
+ from atlas_init.settings.path import (
24
+ DEFAULT_GITHUB_CI_RUN_LOGS,
25
+ DEFAULT_GITHUB_SUMMARY_DIR,
26
+ )
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+ GH_TOKEN_ENV_NAME = "GH_TOKEN" # noqa: S105
31
+ GITHUB_CI_RUN_LOGS_ENV_NAME = "GITHUB_CI_RUN_LOGS"
32
+ GITHUB_CI_SUMMARY_DIR_ENV_NAME = "GITHUB_CI_SUMMARY_DIR_ENV_NAME"
33
+ REQUIRED_GH_ENV_VARS = [GH_TOKEN_ENV_NAME, GITHUB_CI_RUN_LOGS_ENV_NAME]
34
+ MAX_DOWNLOADS = 5
35
+
36
+
37
+ @lru_cache
38
+ def get_auth() -> Auth.Auth:
39
+ token = os.environ[GH_TOKEN_ENV_NAME]
40
+ return Auth.Token(token)
41
+
42
+
43
+ @lru_cache
44
+ def get_repo(repo_id: str) -> Repository:
45
+ auth = get_auth()
46
+ g = Github(auth=auth)
47
+ logger.info(f"logged in as: {g.get_user().login}")
48
+ return g.get_repo(repo_id)
49
+
50
+
51
+ _DEFAULT_FILESTEMS = {
52
+ "test-suite",
53
+ "terraform-compatibility-matrix",
54
+ # "acceptance-tests",
55
+ }
56
+
57
+
58
+ def include_filestems(stems: set[str]) -> Callable[[WorkflowRun], bool]:
59
+ def inner(run: WorkflowRun) -> bool:
60
+ workflow_stem = stem_name(run.path)
61
+ return workflow_stem in stems
62
+
63
+ return inner
64
+
65
+
66
+ def stem_name(workflow_path: str) -> str:
67
+ return Path(workflow_path).stem
68
+
69
+
70
+ def tf_repo() -> Repository:
71
+ return get_repo(GH_OWNER_TERRAFORM_PROVIDER_MONGODBATLAS)
72
+
73
+
74
+ class WorkflowJobId(NamedTuple):
75
+ workflow_id: int
76
+ job_id: int
77
+
78
+
79
+ def find_test_runs(
80
+ since: datetime,
81
+ include_workflow: Callable[[WorkflowRun], bool] | None = None,
82
+ include_job: Callable[[WorkflowJob], bool] | None = None,
83
+ branch: str = "master",
84
+ ) -> dict[WorkflowJobId, list[GoTestRun]]:
85
+ include_workflow = include_workflow or include_filestems(_DEFAULT_FILESTEMS)
86
+ include_job = include_job or include_test_jobs()
87
+ jobs_found = defaultdict(list)
88
+ repository = tf_repo()
89
+ for workflow in repository.get_workflow_runs(
90
+ created=f">{since.strftime('%Y-%m-%d')}",
91
+ branch=branch,
92
+ exclude_pull_requests=True, # type: ignore
93
+ ):
94
+ if not include_workflow(workflow):
95
+ continue
96
+ workflow_dir = workflow_logs_dir(workflow)
97
+ paginated_jobs = workflow.jobs("all")
98
+ worker_count = min(paginated_jobs.totalCount, 10) or 1
99
+ with ThreadPoolExecutor(max_workers=worker_count) as pool:
100
+ futures: dict[Future[list[GoTestRun]], WorkflowJob] = {}
101
+ for job in paginated_jobs:
102
+ if not include_job(job):
103
+ continue
104
+ future = pool.submit(find_job_test_runs, workflow_dir, job)
105
+ futures[future] = job
106
+ done, not_done = wait(futures.keys(), timeout=300)
107
+ for f in not_done:
108
+ logger.warning(f"timeout to find go tests for job = {futures[f].html_url}")
109
+ workflow_id = workflow.id
110
+ for f in done:
111
+ job = futures[f]
112
+ try:
113
+ go_test_runs: list[GoTestRun] = f.result()
114
+ except Exception:
115
+ job_log_path = logs_file(workflow_dir, job)
116
+ logger.exception(
117
+ f"failed to find go tests for job: {job.html_url}, error 👆, local_path: {job_log_path}"
118
+ )
119
+ continue
120
+ jobs_found[WorkflowJobId(workflow_id, job.id)].extend(go_test_runs)
121
+ return jobs_found
122
+
123
+
124
+ def find_job_test_runs(workflow_dir: Path, job: WorkflowJob) -> list[GoTestRun]:
125
+ jobs_log_path = download_job_safely(workflow_dir, job)
126
+ return [] if jobs_log_path is None else parse_job_logs(job, jobs_log_path)
127
+
128
+
129
+ def parse_job_logs(job: WorkflowJob, logs_path: Path) -> list[GoTestRun]:
130
+ step, logs_lines = select_step_and_log_content(job, logs_path)
131
+ return list(parse(logs_lines, job, step))
132
+
133
+
134
+ def download_job_safely(workflow_dir: Path, job: WorkflowJob) -> Path | None:
135
+ path = logs_file(workflow_dir, job)
136
+ job_summary = f"found test job: {job.name}, attempt {job.run_attempt}, {job.created_at}, url: {job.html_url}"
137
+ if path.exists():
138
+ logger.info(f"{job_summary} exist @ {path}")
139
+ return path
140
+ logger.info(f"{job_summary}\n\t\t downloading to {path}")
141
+ try:
142
+ logs_response = requests.get(job.logs_url(), timeout=60)
143
+ logs_response.raise_for_status()
144
+ except Exception as e: # noqa: BLE001
145
+ logger.warning(f"failed to download logs for {job.html_url}, e={e!r}")
146
+ return None
147
+ file_utils.ensure_parents_write_text(path, logs_response.text)
148
+ return path
149
+
150
+
151
+ def logs_dir() -> Path:
152
+ logs_dir_str = os.environ.get(GITHUB_CI_RUN_LOGS_ENV_NAME)
153
+ if not logs_dir_str:
154
+ logger.warning(f"using {DEFAULT_GITHUB_CI_RUN_LOGS} to store github ci logs!")
155
+ return DEFAULT_GITHUB_CI_RUN_LOGS
156
+ return Path(logs_dir_str)
157
+
158
+
159
+ def summary_dir(summary_name: str) -> Path:
160
+ summary_dir_str = os.environ.get(GITHUB_CI_SUMMARY_DIR_ENV_NAME)
161
+ if not summary_dir_str:
162
+ logger.warning(f"using {DEFAULT_GITHUB_SUMMARY_DIR / summary_name} to store summaries")
163
+ return DEFAULT_GITHUB_SUMMARY_DIR / summary_name
164
+ return Path(summary_dir_str) / summary_name
165
+
166
+
167
+ def workflow_logs_dir(workflow: WorkflowRun) -> Path:
168
+ dt = workflow.created_at
169
+ date_str = datetime_utils.get_date_as_rfc3339_without_time(dt)
170
+ workflow_name = stem_name(workflow.path)
171
+ return logs_dir() / f"{date_str}/{workflow.id}_{workflow_name}"
172
+
173
+
174
+ def logs_file(workflow_dir: Path, job: WorkflowJob) -> Path:
175
+ if job.run_attempt != 1:
176
+ workflow_dir = workflow_dir.with_name(f"{workflow_dir.name}_attempt{job.run_attempt}")
177
+ filename = f"{job.id}_" + job.name.replace(" ", "").replace("/", "_").replace("__", "_") + ".txt"
178
+ return workflow_dir / filename
179
+
180
+
181
+ def as_test_group(job_name: str) -> str:
182
+ """tests-1.8.x-latest / tests-1.8.x-latest-dev / config"""
183
+ return "" if "/" not in job_name else job_name.split("/")[-1].strip()
184
+
185
+
186
+ def include_test_jobs(test_group: str = "") -> Callable[[WorkflowJob], bool]:
187
+ def inner(job: WorkflowJob) -> bool:
188
+ job_name = job.name
189
+ if test_group:
190
+ return is_test_job(job_name) and as_test_group(job_name) == test_group
191
+ return is_test_job(job.name)
192
+
193
+ return inner
194
+
195
+
196
+ def is_test_job(job_name: str) -> bool:
197
+ """
198
+ >>> is_test_job("tests-1.8.x-latest / tests-1.8.x-latest-dev / config")
199
+ True
200
+ """
201
+ if "-before" in job_name or "-after" in job_name:
202
+ return False
203
+ return "tests-" in job_name and not job_name.endswith(("get-provider-version", "change-detection"))
204
+
205
+
206
+ def select_step_and_log_content(job: WorkflowJob, logs_path: Path) -> tuple[int, list[str]]:
207
+ full_text = logs_path.read_text()
208
+ step = test_step(job.steps)
209
+ last_step_start = current_step_start = 1
210
+ # there is always an extra setup job step, so starting at 1
211
+ current_step = 1
212
+ lines = full_text.splitlines()
213
+ for line_index, line in enumerate(lines, 0):
214
+ if "##[group]Run " in line:
215
+ current_step += 1
216
+ last_step_start, current_step_start = current_step_start, line_index
217
+ if current_step == step + 1:
218
+ return step, lines[last_step_start:current_step_start]
219
+ assert step == current_step, f"didn't find enough step in logs for {job.html_url}"
220
+ return step, lines[current_step_start:]
221
+
222
+
223
+ def test_step(steps: list[WorkflowStep]) -> int:
224
+ for i, step in enumerate(steps, 1):
225
+ if "test" in step.name.lower():
226
+ return i
227
+ last_step = len(steps)
228
+ logger.warning(f"using {last_step} as final step, unable to find 'test' in {steps}")
229
+ return last_step