atlas-init 0.4.4__py3-none-any.whl → 0.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- atlas_init/__init__.py +1 -1
- atlas_init/cli.py +2 -0
- atlas_init/cli_cfn/app.py +3 -4
- atlas_init/cli_cfn/cfn_parameter_finder.py +61 -53
- atlas_init/cli_cfn/contract.py +4 -7
- atlas_init/cli_cfn/example.py +8 -18
- atlas_init/cli_helper/go.py +7 -11
- atlas_init/cli_root/mms_released.py +46 -0
- atlas_init/cli_root/trigger.py +6 -6
- atlas_init/cli_tf/app.py +3 -84
- atlas_init/cli_tf/ci_tests.py +493 -0
- atlas_init/cli_tf/codegen/__init__.py +0 -0
- atlas_init/cli_tf/codegen/models.py +97 -0
- atlas_init/cli_tf/codegen/openapi_minimal.py +74 -0
- atlas_init/cli_tf/github_logs.py +7 -94
- atlas_init/cli_tf/go_test_run.py +385 -132
- atlas_init/cli_tf/go_test_summary.py +331 -4
- atlas_init/cli_tf/go_test_tf_error.py +380 -0
- atlas_init/cli_tf/hcl/modifier.py +14 -12
- atlas_init/cli_tf/hcl/modifier2.py +87 -0
- atlas_init/cli_tf/mock_tf_log.py +1 -1
- atlas_init/cli_tf/{schema_v2_api_parsing.py → openapi.py} +95 -17
- atlas_init/cli_tf/schema_v2.py +43 -1
- atlas_init/crud/__init__.py +0 -0
- atlas_init/crud/mongo_client.py +115 -0
- atlas_init/crud/mongo_dao.py +296 -0
- atlas_init/crud/mongo_utils.py +239 -0
- atlas_init/repos/go_sdk.py +12 -3
- atlas_init/repos/path.py +110 -7
- atlas_init/settings/config.py +3 -6
- atlas_init/settings/env_vars.py +22 -31
- atlas_init/settings/interactive2.py +134 -0
- atlas_init/tf/.terraform.lock.hcl +59 -59
- atlas_init/tf/always.tf +5 -5
- atlas_init/tf/main.tf +3 -3
- atlas_init/tf/modules/aws_kms/aws_kms.tf +1 -1
- atlas_init/tf/modules/aws_s3/provider.tf +2 -1
- atlas_init/tf/modules/aws_vpc/provider.tf +2 -1
- atlas_init/tf/modules/cfn/cfn.tf +0 -8
- atlas_init/tf/modules/cfn/kms.tf +5 -5
- atlas_init/tf/modules/cfn/provider.tf +7 -0
- atlas_init/tf/modules/cfn/variables.tf +1 -1
- atlas_init/tf/modules/cloud_provider/cloud_provider.tf +1 -1
- atlas_init/tf/modules/cloud_provider/provider.tf +2 -1
- atlas_init/tf/modules/cluster/cluster.tf +31 -31
- atlas_init/tf/modules/cluster/provider.tf +2 -1
- atlas_init/tf/modules/encryption_at_rest/provider.tf +2 -1
- atlas_init/tf/modules/federated_vars/federated_vars.tf +1 -1
- atlas_init/tf/modules/federated_vars/provider.tf +2 -1
- atlas_init/tf/modules/project_extra/project_extra.tf +1 -10
- atlas_init/tf/modules/project_extra/provider.tf +8 -0
- atlas_init/tf/modules/stream_instance/provider.tf +8 -0
- atlas_init/tf/modules/stream_instance/stream_instance.tf +0 -9
- atlas_init/tf/modules/vpc_peering/provider.tf +10 -0
- atlas_init/tf/modules/vpc_peering/vpc_peering.tf +0 -10
- atlas_init/tf/modules/vpc_privatelink/versions.tf +2 -1
- atlas_init/tf/outputs.tf +1 -0
- atlas_init/tf/providers.tf +1 -1
- atlas_init/tf/variables.tf +7 -7
- atlas_init/typer_app.py +4 -8
- {atlas_init-0.4.4.dist-info → atlas_init-0.6.0.dist-info}/METADATA +7 -4
- atlas_init-0.6.0.dist-info/RECORD +121 -0
- atlas_init-0.4.4.dist-info/RECORD +0 -105
- {atlas_init-0.4.4.dist-info → atlas_init-0.6.0.dist-info}/WHEEL +0 -0
- {atlas_init-0.4.4.dist-info → atlas_init-0.6.0.dist-info}/entry_points.txt +0 -0
- {atlas_init-0.4.4.dist-info → atlas_init-0.6.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,296 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import asyncio
|
4
|
+
import logging
|
5
|
+
from dataclasses import dataclass, field
|
6
|
+
from datetime import datetime
|
7
|
+
from functools import cached_property
|
8
|
+
from pathlib import Path
|
9
|
+
from typing import ClassVar, Self
|
10
|
+
|
11
|
+
from model_lib import Entity, dump, field_names, parse_model
|
12
|
+
from motor.motor_asyncio import AsyncIOMotorCollection
|
13
|
+
from pydantic import model_validator
|
14
|
+
from zero_3rdparty.file_utils import ensure_parents_write_text
|
15
|
+
from zero_3rdparty.iter_utils import ignore_falsy
|
16
|
+
|
17
|
+
from atlas_init.cli_tf.go_test_run import GoTestRun
|
18
|
+
from atlas_init.cli_tf.go_test_tf_error import (
|
19
|
+
ErrorClassAuthor,
|
20
|
+
ErrorDetailsT,
|
21
|
+
GoTestAPIError,
|
22
|
+
GoTestError,
|
23
|
+
GoTestErrorClass,
|
24
|
+
GoTestErrorClassification,
|
25
|
+
GoTestResourceCheckError,
|
26
|
+
)
|
27
|
+
from atlas_init.crud.mongo_client import get_collection, init_mongo
|
28
|
+
from atlas_init.crud.mongo_utils import MongoQueryOperation, create_or_replace, dump_with_id
|
29
|
+
from atlas_init.repos.path import TFResoure, terraform_resources
|
30
|
+
from atlas_init.settings.env_vars import AtlasInitSettings
|
31
|
+
|
32
|
+
logger = logging.getLogger(__name__)
|
33
|
+
|
34
|
+
|
35
|
+
def crud_dir(settings: AtlasInitSettings) -> Path:
|
36
|
+
return settings.static_root / "crud"
|
37
|
+
|
38
|
+
|
39
|
+
@dataclass
|
40
|
+
class TFResources:
|
41
|
+
resources: list[TFResoure] = field(default_factory=list)
|
42
|
+
|
43
|
+
def find_test_resources(self, test: GoTestRun) -> list[str]:
|
44
|
+
found_resources = []
|
45
|
+
for resource in self.resources:
|
46
|
+
url = test.package_url
|
47
|
+
if url and url.endswith(resource.package_rel_path):
|
48
|
+
found_resources.append(resource.name)
|
49
|
+
return found_resources
|
50
|
+
|
51
|
+
|
52
|
+
def read_tf_resources(settings: AtlasInitSettings, repo_path: Path, branch: str) -> TFResources:
|
53
|
+
return TFResources(resources=terraform_resources(repo_path))
|
54
|
+
|
55
|
+
|
56
|
+
class TFErrors(Entity):
|
57
|
+
errors: list[GoTestError] = field(default_factory=list)
|
58
|
+
|
59
|
+
@model_validator(mode="after")
|
60
|
+
def sort_errors(self) -> TFErrors:
|
61
|
+
self.errors.sort()
|
62
|
+
return self
|
63
|
+
|
64
|
+
def look_for_existing_classifications(self, error: GoTestError) -> tuple[GoTestErrorClass, GoTestErrorClass] | None:
|
65
|
+
for candidate in self.errors:
|
66
|
+
if error.match(candidate) and (classifications := candidate.classifications):
|
67
|
+
logger.info(f"found existing classification for {error.run.name}: {classifications}")
|
68
|
+
return classifications
|
69
|
+
|
70
|
+
def classified_errors(self) -> list[GoTestError]:
|
71
|
+
return [error for error in self.errors if error.classifications is not None]
|
72
|
+
|
73
|
+
|
74
|
+
def read_tf_errors(settings: AtlasInitSettings) -> TFErrors:
|
75
|
+
path = crud_dir(settings) / "tf_errors.yaml"
|
76
|
+
return parse_model(path, TFErrors) if path.exists() else TFErrors()
|
77
|
+
|
78
|
+
|
79
|
+
def read_tf_errors_for_day(settings: AtlasInitSettings, branch: str, date: datetime) -> list[GoTestError]:
|
80
|
+
raise NotImplementedError
|
81
|
+
|
82
|
+
|
83
|
+
def store_or_update_tf_errors(settings: AtlasInitSettings, errors: list[GoTestError]) -> None:
|
84
|
+
existing = read_tf_errors(settings)
|
85
|
+
new_error_ids = {error.run.id for error in errors}
|
86
|
+
existing_without_new = [error for error in existing.errors if error.run.id not in new_error_ids]
|
87
|
+
all_errors = existing_without_new + errors
|
88
|
+
yaml_dump = dump(TFErrors(errors=all_errors), "yaml")
|
89
|
+
ensure_parents_write_text(crud_dir(settings) / "tf_errors.yaml", yaml_dump)
|
90
|
+
|
91
|
+
|
92
|
+
def read_tf_error_by_run(settings: AtlasInitSettings, run: GoTestRun) -> GoTestError | None:
|
93
|
+
errors = read_tf_errors(settings)
|
94
|
+
return next((error for error in errors.errors if error.run.id == run.id), None)
|
95
|
+
|
96
|
+
|
97
|
+
class TFTestRuns(Entity):
|
98
|
+
test_runs: list[GoTestRun] = field(default_factory=list)
|
99
|
+
|
100
|
+
@model_validator(mode="after")
|
101
|
+
def sort_test_runs(self) -> TFTestRuns:
|
102
|
+
self.test_runs.sort()
|
103
|
+
return self
|
104
|
+
|
105
|
+
|
106
|
+
def read_tf_test_runs(settings: AtlasInitSettings) -> list[GoTestRun]:
|
107
|
+
path = crud_dir(settings) / "tf_test_runs.yaml"
|
108
|
+
return parse_model(path, TFTestRuns).test_runs if path.exists() else []
|
109
|
+
|
110
|
+
|
111
|
+
def read_tf_tests_for_day(settings: AtlasInitSettings, branch: str, date: datetime) -> list[GoTestRun]:
|
112
|
+
start_date = date.replace(hour=0, minute=0, second=0, microsecond=0)
|
113
|
+
end_date = start_date.replace(hour=23, minute=59, second=59, microsecond=999999)
|
114
|
+
return read_tf_tests(settings, branch, start_date, end_date)
|
115
|
+
|
116
|
+
|
117
|
+
def read_tf_tests(
|
118
|
+
settings: AtlasInitSettings, branch: str, start_date: datetime, end_date: datetime | None = None
|
119
|
+
) -> list[GoTestRun]:
|
120
|
+
raise NotImplementedError
|
121
|
+
|
122
|
+
|
123
|
+
async def init_mongo_dao(settings: AtlasInitSettings) -> MongoDao:
|
124
|
+
dao = MongoDao(settings=settings)
|
125
|
+
return await dao.connect()
|
126
|
+
|
127
|
+
|
128
|
+
class GoTestRunNotFound(Exception):
|
129
|
+
def __init__(self, run_id: str) -> None:
|
130
|
+
self.run_id = run_id
|
131
|
+
super().__init__(run_id)
|
132
|
+
|
133
|
+
|
134
|
+
@dataclass
|
135
|
+
class MongoDao:
|
136
|
+
settings: AtlasInitSettings
|
137
|
+
property_keys_run: ClassVar[list[str]] = ["group_name"]
|
138
|
+
|
139
|
+
@cached_property
|
140
|
+
def runs(self) -> AsyncIOMotorCollection:
|
141
|
+
return get_collection(GoTestRun)
|
142
|
+
|
143
|
+
@cached_property
|
144
|
+
def classifications(self) -> AsyncIOMotorCollection:
|
145
|
+
return get_collection(GoTestErrorClassification)
|
146
|
+
|
147
|
+
@cached_property
|
148
|
+
def _field_names_runs(self) -> set[str]:
|
149
|
+
return set(field_names(GoTestRun)) | set(self.property_keys_run)
|
150
|
+
|
151
|
+
async def connect(self) -> Self:
|
152
|
+
await init_mongo(
|
153
|
+
mongo_url=self.settings.mongo_url,
|
154
|
+
db_name=self.settings.mongo_database,
|
155
|
+
)
|
156
|
+
return self
|
157
|
+
|
158
|
+
async def store_tf_test_runs(self, test_runs: list[GoTestRun]) -> list[GoTestRun]:
|
159
|
+
if not test_runs:
|
160
|
+
return []
|
161
|
+
col = self.runs
|
162
|
+
tasks = []
|
163
|
+
loop = asyncio.get_event_loop()
|
164
|
+
for run in test_runs:
|
165
|
+
dumped = dump_with_id(run, id=run.id, dt_keys=["ts", "finish_ts"], property_keys=self.property_keys_run)
|
166
|
+
tasks.append(loop.create_task(create_or_replace(col, dumped)))
|
167
|
+
await asyncio.gather(*tasks)
|
168
|
+
return test_runs
|
169
|
+
|
170
|
+
async def read_tf_tests_for_day(self, branch: str, date: datetime) -> list[GoTestRun]:
|
171
|
+
start_date = date.replace(hour=0, minute=0, second=0, microsecond=0)
|
172
|
+
end_date = start_date.replace(hour=23, minute=59, second=59, microsecond=999999)
|
173
|
+
query = {
|
174
|
+
"branch": branch,
|
175
|
+
"ts": {MongoQueryOperation.gte: start_date, MongoQueryOperation.lte: end_date},
|
176
|
+
}
|
177
|
+
return await self._find_runs(query)
|
178
|
+
|
179
|
+
async def _find_runs(self, query: dict) -> list[GoTestRun]:
|
180
|
+
runs = []
|
181
|
+
async for raw_run in self.runs.find(query):
|
182
|
+
runs.append(self._parse_run(raw_run))
|
183
|
+
return runs
|
184
|
+
|
185
|
+
async def read_error_classifications(
|
186
|
+
self, run_ids: list[str] | None = None
|
187
|
+
) -> dict[str, GoTestErrorClassification]:
|
188
|
+
run_ids = run_ids or []
|
189
|
+
if not run_ids:
|
190
|
+
return {}
|
191
|
+
query = {"_id": {MongoQueryOperation.in_: run_ids}}
|
192
|
+
return await self._find_classifications(query)
|
193
|
+
|
194
|
+
async def _find_classifications(self, query: dict) -> dict[str, GoTestErrorClassification]:
|
195
|
+
classifications: dict[str, GoTestErrorClassification] = {}
|
196
|
+
async for raw_error in self.classifications.find(query):
|
197
|
+
run_id = raw_error.pop("_id", None)
|
198
|
+
classification = parse_model(raw_error, t=GoTestErrorClassification)
|
199
|
+
classifications[run_id] = classification
|
200
|
+
return classifications
|
201
|
+
|
202
|
+
async def read_similar_error_classifications(
|
203
|
+
self, details: ErrorDetailsT, *, author_filter: ErrorClassAuthor | None = None
|
204
|
+
) -> dict[str, GoTestErrorClassification]:
|
205
|
+
query = {}
|
206
|
+
if author_filter:
|
207
|
+
query["author"] = {MongoQueryOperation.eq: author_filter}
|
208
|
+
match details:
|
209
|
+
case GoTestAPIError(
|
210
|
+
api_error_code_str=api_error_code_str,
|
211
|
+
api_method=api_method,
|
212
|
+
api_response_code=api_response_code,
|
213
|
+
api_path_normalized=api_path_normalized,
|
214
|
+
) if api_path_normalized:
|
215
|
+
query |= {
|
216
|
+
"details.api_error_code_str": {MongoQueryOperation.eq: api_error_code_str},
|
217
|
+
"details.api_method": {MongoQueryOperation.eq: api_method},
|
218
|
+
"details.api_response_code": {MongoQueryOperation.eq: api_response_code},
|
219
|
+
"details.api_path_normalized": {MongoQueryOperation.eq: api_path_normalized},
|
220
|
+
}
|
221
|
+
case GoTestResourceCheckError(
|
222
|
+
tf_resource_name=tf_resource_name,
|
223
|
+
tf_resource_type=tf_resource_type,
|
224
|
+
step_nr=step_nr,
|
225
|
+
check_errors=check_errors,
|
226
|
+
test_name=test_name,
|
227
|
+
):
|
228
|
+
query |= {
|
229
|
+
"details.tf_resource_name": {MongoQueryOperation.eq: tf_resource_name},
|
230
|
+
"details.tf_resource_type": {MongoQueryOperation.eq: tf_resource_type},
|
231
|
+
"details.step_nr": {MongoQueryOperation.eq: step_nr},
|
232
|
+
"test_name": {MongoQueryOperation.eq: test_name},
|
233
|
+
}
|
234
|
+
classifications = await self._find_classifications(query)
|
235
|
+
return {
|
236
|
+
run_id: classification
|
237
|
+
for run_id, classification in classifications.items()
|
238
|
+
if isinstance(classification.details, GoTestResourceCheckError)
|
239
|
+
and classification.details.check_errors_match(check_errors)
|
240
|
+
}
|
241
|
+
case _:
|
242
|
+
return {} # todo: vector search to match on error output
|
243
|
+
return await self._find_classifications(query)
|
244
|
+
|
245
|
+
async def add_classification(self, classification: GoTestErrorClassification) -> bool:
|
246
|
+
"""Returns is_new"""
|
247
|
+
raw = dump_with_id(classification, id=classification.run_id, dt_keys=["ts"])
|
248
|
+
return await create_or_replace(self.classifications, raw)
|
249
|
+
|
250
|
+
async def read_tf_test_run(self, run_id: str) -> GoTestRun:
|
251
|
+
raw = await self.runs.find_one({"_id": run_id})
|
252
|
+
if raw is None:
|
253
|
+
raise GoTestRunNotFound(run_id)
|
254
|
+
return self._parse_run(raw)
|
255
|
+
|
256
|
+
def _parse_run(self, raw: dict) -> GoTestRun:
|
257
|
+
raw.pop("_id")
|
258
|
+
for key in self.property_keys_run:
|
259
|
+
raw.pop(key, None) # Remove properties that are not part of the model
|
260
|
+
return parse_model(raw, t=GoTestRun)
|
261
|
+
|
262
|
+
async def read_run_history(
|
263
|
+
self,
|
264
|
+
test_name: str,
|
265
|
+
branches: list[str] | None = None,
|
266
|
+
package_url: str | None = None,
|
267
|
+
group_name: str | None = None,
|
268
|
+
start_date: datetime | None = None,
|
269
|
+
end_date: datetime | None = None,
|
270
|
+
envs: list[str] | None = None,
|
271
|
+
) -> list[GoTestRun]:
|
272
|
+
eq = MongoQueryOperation.eq
|
273
|
+
query = {
|
274
|
+
"name": {eq: test_name},
|
275
|
+
}
|
276
|
+
eq_parts = {
|
277
|
+
"package_url": {eq: package_url} if package_url else None,
|
278
|
+
"group_name": {eq: group_name} if group_name else None,
|
279
|
+
}
|
280
|
+
in_op = MongoQueryOperation.in_
|
281
|
+
in_parts = {
|
282
|
+
"branch": {in_op: branches} if branches else None,
|
283
|
+
"env": {in_op: envs} if envs else None,
|
284
|
+
}
|
285
|
+
date_parts = {
|
286
|
+
"ts": ignore_falsy(
|
287
|
+
**{
|
288
|
+
MongoQueryOperation.lte: end_date or None,
|
289
|
+
MongoQueryOperation.gte: start_date or None,
|
290
|
+
}
|
291
|
+
)
|
292
|
+
}
|
293
|
+
query |= ignore_falsy(**eq_parts, **in_parts, **date_parts)
|
294
|
+
if invalid_fields := set(query) - self._field_names_runs:
|
295
|
+
raise ValueError(f"Invalid fields in query: {invalid_fields}")
|
296
|
+
return await self._find_runs(query)
|
@@ -0,0 +1,239 @@
|
|
1
|
+
import logging
|
2
|
+
import re
|
3
|
+
from dataclasses import dataclass
|
4
|
+
from datetime import datetime
|
5
|
+
from functools import wraps
|
6
|
+
from typing import Any, AsyncIterable, Iterable, List, Optional, Type, TypeVar
|
7
|
+
|
8
|
+
from model_lib import dump_as_dict
|
9
|
+
from motor.core import AgnosticCollection
|
10
|
+
from pydantic import BaseModel
|
11
|
+
from pymongo import ASCENDING, DESCENDING, IndexModel, ReturnDocument
|
12
|
+
from pymongo.errors import DuplicateKeyError, PyMongoError
|
13
|
+
from pymongo.results import DeleteResult
|
14
|
+
from zero_3rdparty.enum_utils import StrEnum
|
15
|
+
|
16
|
+
logger = logging.getLogger(__name__)
|
17
|
+
ModelT = TypeVar("ModelT", bound=BaseModel)
|
18
|
+
|
19
|
+
|
20
|
+
class MongoUpdateOperation(StrEnum):
|
21
|
+
"""
|
22
|
+
References:
|
23
|
+
https://docs.mongodb.com/manual/reference/operator/update-array/
|
24
|
+
"""
|
25
|
+
|
26
|
+
slice = "$slice"
|
27
|
+
pop = "$pop"
|
28
|
+
pull = "$pull"
|
29
|
+
unset = "$unset"
|
30
|
+
push = "$push"
|
31
|
+
each = "$each"
|
32
|
+
set = "$set"
|
33
|
+
set_on_insert = "$setOnInsert"
|
34
|
+
inc = "$inc"
|
35
|
+
|
36
|
+
|
37
|
+
UPDATE_OPERATIONS = set(MongoUpdateOperation)
|
38
|
+
|
39
|
+
|
40
|
+
def ensure_mongo_operation(updates: dict):
|
41
|
+
"""
|
42
|
+
>>> ensure_mongo_operation({"field1": 2})
|
43
|
+
{'$set': {'field1': 2}}
|
44
|
+
>>> ensure_mongo_operation({MongoUpdateOperation.set: {"field1": 2}})
|
45
|
+
{'$set': {'field1': 2}}
|
46
|
+
>>> ensure_mongo_operation({MongoUpdateOperation.push: {"field1": 2}})
|
47
|
+
{'$push': {'field1': 2}}
|
48
|
+
"""
|
49
|
+
if updates.keys() - UPDATE_OPERATIONS == set():
|
50
|
+
return updates
|
51
|
+
return {MongoUpdateOperation.set: updates}
|
52
|
+
|
53
|
+
|
54
|
+
class MongoQueryOperation(StrEnum):
|
55
|
+
# must be used when checking if a boolean field is false
|
56
|
+
eq = "$eq"
|
57
|
+
# https://stackoverflow.com/questions/18837486/query-for-boolean-field-as-not-true-e-g-either-false-or-non-existent
|
58
|
+
ne = "$ne"
|
59
|
+
in_ = "$in"
|
60
|
+
# https://www.mongodb.com/docs/manual/reference/operator/query/nin/#mongodb-query-op.-nin
|
61
|
+
nin = "$nin"
|
62
|
+
gt = "$gt"
|
63
|
+
gte = "$gte"
|
64
|
+
lt = "$lt"
|
65
|
+
lte = "$lte"
|
66
|
+
slice = "$slice"
|
67
|
+
|
68
|
+
@classmethod
|
69
|
+
def boolean_or_none(cls, bool_value: bool | None) -> dict | None:
|
70
|
+
if bool_value is None:
|
71
|
+
return None
|
72
|
+
return {cls.eq: True} if bool_value else {cls.ne: True}
|
73
|
+
|
74
|
+
@classmethod
|
75
|
+
def in_or_none(cls, options: Iterable[Any] | None) -> dict | None:
|
76
|
+
return None if options is None else {cls.in_: list(options)}
|
77
|
+
|
78
|
+
@classmethod
|
79
|
+
def nin_or_none(cls, options: Iterable[Any] | None) -> dict | None:
|
80
|
+
return None if options is None else {cls.nin: list(options)}
|
81
|
+
|
82
|
+
|
83
|
+
duplicate_key_regex = re.compile(
|
84
|
+
r".*error collection:"
|
85
|
+
r"\s(?P<collection_path>[-\w\d\\.]+)"
|
86
|
+
r"\sindex:\s"
|
87
|
+
r"(?P<index_name>[\w_\\.\d]+)"
|
88
|
+
r"\sdup key.*?"
|
89
|
+
r'(?P<dup_key_value>("?[\\.\w_\d]+"?)|(null))'
|
90
|
+
)
|
91
|
+
|
92
|
+
|
93
|
+
@dataclass
|
94
|
+
class MongoConstraintDetails:
|
95
|
+
collection_path: str
|
96
|
+
index_name: str
|
97
|
+
dup_key_value: Optional[str]
|
98
|
+
|
99
|
+
def __post_init__(self):
|
100
|
+
if self.dup_key_value:
|
101
|
+
self.dup_key_value = self.dup_key_value.strip('"')
|
102
|
+
if self.dup_key_value == "null":
|
103
|
+
self.dup_key_value = None
|
104
|
+
|
105
|
+
|
106
|
+
def parse_error(error: PyMongoError) -> Optional[MongoConstraintDetails]:
|
107
|
+
"""
|
108
|
+
>>> raw = 'E11000 duplicate key error collection: dev_situation.Robot index: _id_ dup key: { : "mw_wheel_id" }'
|
109
|
+
>>> parse_error(raw)
|
110
|
+
MongoConstraintDetails(collection_path='dev_situation.Robot', index_name='_id_', dup_key_value='mw_wheel_id')
|
111
|
+
''
|
112
|
+
"""
|
113
|
+
error_str = str(error)
|
114
|
+
for m in duplicate_key_regex.finditer(error_str):
|
115
|
+
constraints = MongoConstraintDetails(**m.groupdict())
|
116
|
+
if isinstance(error, DuplicateKeyError):
|
117
|
+
_, constraints.dup_key_value = error.details["keyValue"].popitem() # type: ignore
|
118
|
+
return constraints
|
119
|
+
logger.warning(f"unknown pymongo error:{error}")
|
120
|
+
|
121
|
+
|
122
|
+
class MongoConstraintError(Exception):
|
123
|
+
def __init__(self, details: MongoConstraintDetails):
|
124
|
+
self.details: MongoConstraintDetails = details
|
125
|
+
|
126
|
+
|
127
|
+
T = TypeVar("T")
|
128
|
+
|
129
|
+
ConstraintSubT = TypeVar("ConstraintSubT", bound=MongoConstraintError)
|
130
|
+
|
131
|
+
|
132
|
+
def raise_mongo_constraint_error(f: T = None, *, cls: Type[ConstraintSubT] = MongoConstraintError) -> T:
|
133
|
+
def decorator(f: T):
|
134
|
+
@wraps(f) # type: ignore
|
135
|
+
async def inner(*args, **kwargs):
|
136
|
+
try:
|
137
|
+
return await f(*args, **kwargs) # type: ignore
|
138
|
+
except PyMongoError as e:
|
139
|
+
if details := parse_error(e):
|
140
|
+
raise cls(details) from e
|
141
|
+
raise e
|
142
|
+
|
143
|
+
return inner
|
144
|
+
|
145
|
+
return decorator(f) if f else decorator # type: ignore
|
146
|
+
|
147
|
+
|
148
|
+
def dump_with_id(
|
149
|
+
model: BaseModel,
|
150
|
+
id: str = "",
|
151
|
+
dt_keys: Optional[List[str]] = None,
|
152
|
+
property_keys: Optional[List[str]] = None,
|
153
|
+
exclude: Optional[set[str]] = None,
|
154
|
+
) -> dict:
|
155
|
+
"""
|
156
|
+
Warning:
|
157
|
+
If you want to index on datetime, you have to set them afterwards
|
158
|
+
As they will be dumped as strings
|
159
|
+
"""
|
160
|
+
raw = dump_as_dict(model) if exclude is None else dump_as_dict(model.model_dump(exclude=exclude))
|
161
|
+
if id:
|
162
|
+
raw["_id"] = id
|
163
|
+
if dt_keys:
|
164
|
+
for key in dt_keys:
|
165
|
+
raw[key] = getattr(model, key)
|
166
|
+
if property_keys:
|
167
|
+
for key in property_keys:
|
168
|
+
raw[key] = getattr(model, key)
|
169
|
+
return raw
|
170
|
+
|
171
|
+
|
172
|
+
async def create_or_replace(collection: AgnosticCollection, raw: dict) -> bool:
|
173
|
+
"""
|
174
|
+
Returns:
|
175
|
+
is_new: bool
|
176
|
+
"""
|
177
|
+
result = await collection.replace_one({"_id": raw["_id"]}, raw, upsert=True)
|
178
|
+
return bool(result.upserted_id)
|
179
|
+
|
180
|
+
|
181
|
+
async def find_one_and_update(
|
182
|
+
collection: AgnosticCollection,
|
183
|
+
id: str,
|
184
|
+
updates: dict,
|
185
|
+
return_raw_after: bool = True,
|
186
|
+
upsert: bool = False,
|
187
|
+
**query,
|
188
|
+
) -> Optional[dict]:
|
189
|
+
"""
|
190
|
+
Warning:
|
191
|
+
pops the "_id" from serialize_lib
|
192
|
+
"""
|
193
|
+
return_doc = ReturnDocument.AFTER if return_raw_after else ReturnDocument.BEFORE
|
194
|
+
updates = ensure_mongo_operation(updates)
|
195
|
+
raw = await collection.find_one_and_update({"_id": id, **query}, updates, return_document=return_doc, upsert=upsert)
|
196
|
+
if raw:
|
197
|
+
raw.pop("_id", None)
|
198
|
+
return raw
|
199
|
+
|
200
|
+
|
201
|
+
def microsecond_compare(mongo_dt: datetime, dt: datetime) -> bool:
|
202
|
+
"""Mongo only stores milliseconds since epoch
|
203
|
+
https://stackoverflow.com/questions/39963143/why-is-there-a-difference-
|
204
|
+
between-the-stored-and-queried-time-in-mongo-database."""
|
205
|
+
with_microseconds = mongo_dt.replace(microsecond=dt.microsecond)
|
206
|
+
return with_microseconds == dt and (mongo_dt - dt).total_seconds() < 0.001
|
207
|
+
|
208
|
+
|
209
|
+
def safe_key(key: str) -> str:
|
210
|
+
return key.replace(".", "_DOT_")
|
211
|
+
|
212
|
+
|
213
|
+
def replace_dot_keys(values: dict) -> dict:
|
214
|
+
"""avoid InvalidDocument("key 'dev.amironenko' must not contain '.'")"""
|
215
|
+
return {safe_key(key): value for key, value in values.items()}
|
216
|
+
|
217
|
+
|
218
|
+
def decode_delete_count(result: DeleteResult) -> int:
|
219
|
+
return result.deleted_count
|
220
|
+
|
221
|
+
|
222
|
+
def push_and_limit_length_update(field_name: str, new_value: Any, max_size: int) -> dict:
|
223
|
+
return {
|
224
|
+
MongoUpdateOperation.push: {
|
225
|
+
field_name: {
|
226
|
+
MongoUpdateOperation.each: [new_value],
|
227
|
+
MongoUpdateOperation.slice: -max_size,
|
228
|
+
}
|
229
|
+
}
|
230
|
+
}
|
231
|
+
|
232
|
+
|
233
|
+
def index_dec(column: str) -> IndexModel:
|
234
|
+
return IndexModel([(column, DESCENDING)])
|
235
|
+
|
236
|
+
|
237
|
+
def query_and_sort(collection: AgnosticCollection, query: dict, sort_col: str, desc: bool) -> AsyncIterable[dict]:
|
238
|
+
sort_order = DESCENDING if desc else ASCENDING
|
239
|
+
return collection.find(query).sort(sort_col, sort_order)
|
atlas_init/repos/go_sdk.py
CHANGED
@@ -2,11 +2,11 @@ from collections import defaultdict
|
|
2
2
|
from pathlib import Path
|
3
3
|
|
4
4
|
import requests
|
5
|
-
from model_lib import parse_model
|
5
|
+
from model_lib import Entity, parse_model
|
6
6
|
|
7
|
-
from atlas_init.cli_tf.debug_logs_test_data import ApiSpecPath
|
7
|
+
from atlas_init.cli_tf.debug_logs_test_data import ApiSpecPath, find_normalized_path
|
8
8
|
from atlas_init.cli_tf.schema import logger
|
9
|
-
from atlas_init.cli_tf.
|
9
|
+
from atlas_init.cli_tf.openapi import OpenapiSchema
|
10
10
|
|
11
11
|
|
12
12
|
def go_sdk_breaking_changes(repo_path: Path, go_sdk_rel_path: str = "../atlas-sdk-go") -> Path:
|
@@ -21,6 +21,15 @@ def api_spec_path_transformed(sdk_repo_path: Path) -> Path:
|
|
21
21
|
return sdk_repo_path / "openapi/atlas-api-transformed.yaml"
|
22
22
|
|
23
23
|
|
24
|
+
class ApiSpecPaths(Entity):
|
25
|
+
method_paths: dict[str, list[ApiSpecPath]]
|
26
|
+
|
27
|
+
def normalize_path(self, method: str, path: str) -> str:
|
28
|
+
if path.startswith("/api/atlas/v1.0"):
|
29
|
+
return ""
|
30
|
+
return find_normalized_path(path, self.method_paths[method]).path
|
31
|
+
|
32
|
+
|
24
33
|
def parse_api_spec_paths(api_spec_path: Path) -> dict[str, list[ApiSpecPath]]:
|
25
34
|
model = parse_model(api_spec_path, t=OpenapiSchema)
|
26
35
|
paths: dict[str, list[ApiSpecPath]] = defaultdict(list)
|