atlas-init 0.4.5__py3-none-any.whl → 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- atlas_init/__init__.py +1 -1
- atlas_init/cli.py +2 -0
- atlas_init/cli_args.py +19 -1
- atlas_init/cli_cfn/cfn_parameter_finder.py +59 -51
- atlas_init/cli_cfn/example.py +8 -16
- atlas_init/cli_helper/go.py +6 -10
- atlas_init/cli_root/mms_released.py +46 -0
- atlas_init/cli_tf/app.py +3 -84
- atlas_init/cli_tf/ci_tests.py +585 -0
- atlas_init/cli_tf/codegen/__init__.py +0 -0
- atlas_init/cli_tf/codegen/models.py +97 -0
- atlas_init/cli_tf/codegen/openapi_minimal.py +74 -0
- atlas_init/cli_tf/github_logs.py +7 -94
- atlas_init/cli_tf/go_test_run.py +395 -130
- atlas_init/cli_tf/go_test_summary.py +589 -10
- atlas_init/cli_tf/go_test_tf_error.py +388 -0
- atlas_init/cli_tf/hcl/modifier.py +14 -12
- atlas_init/cli_tf/hcl/modifier2.py +207 -0
- atlas_init/cli_tf/mock_tf_log.py +1 -1
- atlas_init/cli_tf/{schema_v2_api_parsing.py → openapi.py} +101 -19
- atlas_init/cli_tf/schema_v2.py +43 -1
- atlas_init/crud/__init__.py +0 -0
- atlas_init/crud/mongo_client.py +115 -0
- atlas_init/crud/mongo_dao.py +296 -0
- atlas_init/crud/mongo_utils.py +239 -0
- atlas_init/html_out/__init__.py +0 -0
- atlas_init/html_out/md_export.py +143 -0
- atlas_init/repos/go_sdk.py +12 -3
- atlas_init/repos/path.py +110 -7
- atlas_init/sdk_ext/__init__.py +0 -0
- atlas_init/sdk_ext/go.py +102 -0
- atlas_init/sdk_ext/typer_app.py +18 -0
- atlas_init/settings/config.py +3 -6
- atlas_init/settings/env_vars.py +18 -2
- atlas_init/settings/env_vars_generated.py +2 -0
- atlas_init/settings/interactive2.py +134 -0
- atlas_init/tf/.terraform.lock.hcl +59 -59
- atlas_init/tf/always.tf +5 -5
- atlas_init/tf/main.tf +3 -3
- atlas_init/tf/modules/aws_kms/aws_kms.tf +1 -1
- atlas_init/tf/modules/aws_s3/provider.tf +2 -1
- atlas_init/tf/modules/aws_vpc/provider.tf +2 -1
- atlas_init/tf/modules/cfn/cfn.tf +0 -8
- atlas_init/tf/modules/cfn/kms.tf +5 -5
- atlas_init/tf/modules/cfn/provider.tf +7 -0
- atlas_init/tf/modules/cfn/variables.tf +1 -1
- atlas_init/tf/modules/cloud_provider/cloud_provider.tf +1 -1
- atlas_init/tf/modules/cloud_provider/provider.tf +2 -1
- atlas_init/tf/modules/cluster/cluster.tf +31 -31
- atlas_init/tf/modules/cluster/provider.tf +2 -1
- atlas_init/tf/modules/encryption_at_rest/provider.tf +2 -1
- atlas_init/tf/modules/federated_vars/federated_vars.tf +2 -3
- atlas_init/tf/modules/federated_vars/provider.tf +2 -1
- atlas_init/tf/modules/project_extra/project_extra.tf +1 -10
- atlas_init/tf/modules/project_extra/provider.tf +8 -0
- atlas_init/tf/modules/stream_instance/provider.tf +8 -0
- atlas_init/tf/modules/stream_instance/stream_instance.tf +0 -9
- atlas_init/tf/modules/vpc_peering/provider.tf +10 -0
- atlas_init/tf/modules/vpc_peering/vpc_peering.tf +0 -10
- atlas_init/tf/modules/vpc_privatelink/versions.tf +2 -1
- atlas_init/tf/outputs.tf +1 -0
- atlas_init/tf/providers.tf +1 -1
- atlas_init/tf/variables.tf +7 -7
- atlas_init/tf_ext/__init__.py +0 -0
- atlas_init/tf_ext/__main__.py +3 -0
- atlas_init/tf_ext/api_call.py +325 -0
- atlas_init/tf_ext/args.py +17 -0
- atlas_init/tf_ext/constants.py +3 -0
- atlas_init/tf_ext/models.py +106 -0
- atlas_init/tf_ext/paths.py +126 -0
- atlas_init/tf_ext/settings.py +39 -0
- atlas_init/tf_ext/tf_dep.py +324 -0
- atlas_init/tf_ext/tf_modules.py +394 -0
- atlas_init/tf_ext/tf_vars.py +173 -0
- atlas_init/tf_ext/typer_app.py +24 -0
- atlas_init/typer_app.py +4 -8
- {atlas_init-0.4.5.dist-info → atlas_init-0.7.0.dist-info}/METADATA +8 -4
- atlas_init-0.7.0.dist-info/RECORD +138 -0
- atlas_init-0.7.0.dist-info/entry_points.txt +5 -0
- atlas_init-0.4.5.dist-info/RECORD +0 -105
- atlas_init-0.4.5.dist-info/entry_points.txt +0 -2
- {atlas_init-0.4.5.dist-info → atlas_init-0.7.0.dist-info}/WHEEL +0 -0
- {atlas_init-0.4.5.dist-info → atlas_init-0.7.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,11 +1,12 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
+
import datetime
|
3
4
|
import logging
|
4
5
|
import re
|
5
6
|
from collections.abc import Iterable
|
6
7
|
from pathlib import Path
|
7
8
|
from queue import Queue
|
8
|
-
from typing import ClassVar
|
9
|
+
from typing import ClassVar, NamedTuple
|
9
10
|
|
10
11
|
from model_lib import Entity, dump
|
11
12
|
from pydantic import Field
|
@@ -48,6 +49,21 @@ def parse_openapi_schema_after_modifications(schema: SchemaV2, api_spec_path: Pa
|
|
48
49
|
return api_spec_text_changes(schema, original)
|
49
50
|
|
50
51
|
|
52
|
+
class PathMethodCode(NamedTuple):
|
53
|
+
path: str
|
54
|
+
method: str
|
55
|
+
code: str
|
56
|
+
|
57
|
+
|
58
|
+
def extract_api_version_content_header(header: str) -> str | None:
|
59
|
+
"""
|
60
|
+
Extracts the API version from the content header.
|
61
|
+
The header should be in the format 'application/vnd.atlas.v1+json'.
|
62
|
+
"""
|
63
|
+
match = re.match(r"application/vnd\.atlas\.v?(?P<version>[\d-]+)\+json", header)
|
64
|
+
return match.group("version") if match else None
|
65
|
+
|
66
|
+
|
51
67
|
class OpenapiSchema(Entity):
|
52
68
|
PARAMETERS_PREFIX: ClassVar[str] = "#/components/parameters/"
|
53
69
|
SCHEMAS_PREFIX: ClassVar[str] = "#/components/schemas/"
|
@@ -61,16 +77,38 @@ class OpenapiSchema(Entity):
|
|
61
77
|
def create_method(self, path: str) -> dict | None:
|
62
78
|
return self.paths.get(path, {}).get("post")
|
63
79
|
|
64
|
-
def
|
80
|
+
def get_method(self, path: str) -> dict | None:
|
65
81
|
return self.paths.get(path, {}).get("get")
|
66
82
|
|
83
|
+
def delete_method(self, path: str) -> dict | None:
|
84
|
+
return self.paths.get(path, {}).get("delete")
|
85
|
+
|
86
|
+
def patch_method(self, path: str) -> dict | None:
|
87
|
+
return self.paths.get(path, {}).get("patch")
|
88
|
+
|
89
|
+
def put_method(self, path: str) -> dict | None:
|
90
|
+
return self.paths.get(path, {}).get("patch")
|
91
|
+
|
92
|
+
def methods_with_name(self, path: str) -> Iterable[tuple[str, dict]]:
|
93
|
+
for method_name in ["post", "get", "delete", "patch", "put"]:
|
94
|
+
if method := self.paths.get(path, {}).get(method_name):
|
95
|
+
yield method_name, method
|
96
|
+
|
97
|
+
def methods(self, path: str) -> Iterable[dict]:
|
98
|
+
yield from (method for _, method in self.methods_with_name(path))
|
99
|
+
|
67
100
|
def method_refs(self, path: str) -> Iterable[str]:
|
68
|
-
for method in
|
101
|
+
for method in self.methods(path):
|
69
102
|
if method:
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
103
|
+
yield from self.method_request_body_ref(method)
|
104
|
+
yield from self.method_response_ref(method)
|
105
|
+
|
106
|
+
def parameter_refs(self, path: str) -> Iterable[str]:
|
107
|
+
for method in self.methods(path):
|
108
|
+
parameters = method.get("parameters", [])
|
109
|
+
for param in parameters:
|
110
|
+
if param_ref := param.get("$ref"):
|
111
|
+
yield param_ref
|
74
112
|
|
75
113
|
def parameter(self, ref: str) -> dict:
|
76
114
|
assert ref.startswith(OpenapiSchema.PARAMETERS_PREFIX)
|
@@ -91,23 +129,50 @@ class OpenapiSchema(Entity):
|
|
91
129
|
prop["name"] = name
|
92
130
|
yield prop
|
93
131
|
|
94
|
-
def method_request_body_ref(self, method: dict) -> str
|
132
|
+
def method_request_body_ref(self, method: dict) -> Iterable[str]:
|
95
133
|
request_body = method.get("requestBody", {})
|
96
|
-
|
134
|
+
yield from self._unpack_schema_ref(request_body)
|
97
135
|
|
98
|
-
def method_response_ref(self, method: dict) -> str
|
136
|
+
def method_response_ref(self, method: dict) -> Iterable[str]:
|
99
137
|
responses = method.get("responses", {})
|
100
138
|
ok_response = responses.get("200", {})
|
101
|
-
|
139
|
+
yield from self._unpack_schema_ref(ok_response)
|
102
140
|
|
103
|
-
def _unpack_schema_ref(self, response: dict) -> str
|
141
|
+
def _unpack_schema_ref(self, response: dict) -> Iterable[str]:
|
104
142
|
content = {**response.get("content", {})} # avoid side effects
|
105
143
|
if not content:
|
106
144
|
return None
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
145
|
+
while content:
|
146
|
+
key, value = content.popitem()
|
147
|
+
if not isinstance(key, str) or not key.endswith("json"):
|
148
|
+
continue
|
149
|
+
if ref := value.get("schema", {}).get("$ref"):
|
150
|
+
yield ref
|
151
|
+
|
152
|
+
def _unpack_schema_versions(self, response: dict) -> list[datetime.date]:
|
153
|
+
content: dict[str, dict] = {**response.get("content", {})}
|
154
|
+
versions = []
|
155
|
+
while content:
|
156
|
+
key, value = content.popitem()
|
157
|
+
if not isinstance(value, dict) or not key.endswith("json"):
|
158
|
+
continue
|
159
|
+
if version := value.get("x-xgen-version"):
|
160
|
+
versions.append(version)
|
161
|
+
continue
|
162
|
+
if version := extract_api_version_content_header(key):
|
163
|
+
versions.append(version)
|
164
|
+
return versions
|
165
|
+
|
166
|
+
def path_method_api_versions(self) -> Iterable[tuple[PathMethodCode, list[datetime.date]]]:
|
167
|
+
for path, methods in self.paths.items():
|
168
|
+
for method_name, method_dict in methods.items():
|
169
|
+
if not isinstance(method_dict, dict):
|
170
|
+
continue
|
171
|
+
responses = method_dict.get("responses", {})
|
172
|
+
for code, response_dict in responses.items():
|
173
|
+
if api_versions := self._unpack_schema_versions(response_dict):
|
174
|
+
key = PathMethodCode(path, method_name, code)
|
175
|
+
yield key, api_versions
|
111
176
|
|
112
177
|
def schema_ref_component(self, ref: str, attributes_skip: set[str]) -> SchemaResource:
|
113
178
|
schemas = self.components.get("schemas", {})
|
@@ -128,6 +193,9 @@ class OpenapiSchema(Entity):
|
|
128
193
|
name=ref,
|
129
194
|
description=schema.get("description", ""),
|
130
195
|
attributes_skip=attributes_skip,
|
196
|
+
discriminator=schema.get("discriminator"),
|
197
|
+
one_of=schema.get("oneOf", []),
|
198
|
+
all_of=schema.get("allOf", []),
|
131
199
|
)
|
132
200
|
required_names = schema.get("required", [])
|
133
201
|
for prop in self.schema_properties(ref):
|
@@ -143,6 +211,16 @@ class OpenapiSchema(Entity):
|
|
143
211
|
elif ref.startswith(self.SCHEMAS_PREFIX):
|
144
212
|
prefix = self.SCHEMAS_PREFIX
|
145
213
|
parent_dict = self.components["schemas"]
|
214
|
+
ref_value.pop("name", None)
|
215
|
+
if properties := ref_value.get("properties"):
|
216
|
+
properties_no_name = {
|
217
|
+
k: {nested_k: nested_v for nested_k, nested_v in v.items() if nested_k != "name"}
|
218
|
+
for k, v in properties.items()
|
219
|
+
}
|
220
|
+
if ref.removeprefix(prefix).endswith("DBRoleToExecute"):
|
221
|
+
logger.warning(f"debug me: {properties_no_name}")
|
222
|
+
ref_value["properties"] = properties_no_name
|
223
|
+
|
146
224
|
else:
|
147
225
|
err_msg = f"Unknown schema_ref {ref}"
|
148
226
|
raise ValueError(err_msg)
|
@@ -168,12 +246,14 @@ def parse_api_spec_param(api_spec: OpenapiSchema, param: dict, resource: SchemaR
|
|
168
246
|
case {"$ref": ref, "name": name} if ref.startswith(OpenapiSchema.SCHEMAS_PREFIX):
|
169
247
|
# nested attribute
|
170
248
|
attribute = SchemaAttribute(
|
249
|
+
additional_properties=param.get("additionalProperties", {}),
|
171
250
|
type="object",
|
172
251
|
name=name,
|
173
252
|
schema_ref=ref,
|
174
253
|
)
|
175
254
|
case {"type": "array", "items": {"$ref": ref}, "name": name}:
|
176
255
|
attribute = SchemaAttribute(
|
256
|
+
additional_properties=param.get("additionalProperties", {}),
|
177
257
|
type="array",
|
178
258
|
name=name,
|
179
259
|
schema_ref=ref,
|
@@ -183,6 +263,7 @@ def parse_api_spec_param(api_spec: OpenapiSchema, param: dict, resource: SchemaR
|
|
183
263
|
)
|
184
264
|
case {"name": name, "schema": schema}:
|
185
265
|
attribute = SchemaAttribute(
|
266
|
+
additional_properties=param.get("additionalProperties", {}),
|
186
267
|
type=schema["type"],
|
187
268
|
name=name,
|
188
269
|
description=param.get("description", ""),
|
@@ -196,6 +277,7 @@ def parse_api_spec_param(api_spec: OpenapiSchema, param: dict, resource: SchemaR
|
|
196
277
|
description=param.get("description", ""),
|
197
278
|
is_computed=param.get("readOnly", False),
|
198
279
|
is_required=param.get("required", False),
|
280
|
+
additional_properties=param.get("additionalProperties", {}),
|
199
281
|
)
|
200
282
|
case _:
|
201
283
|
raise NotImplementedError
|
@@ -220,16 +302,16 @@ def add_api_spec_info(schema: SchemaV2, api_spec_path: Path, *, minimal_refs: bo
|
|
220
302
|
continue
|
221
303
|
for param in create_method.get("parameters", []):
|
222
304
|
parse_api_spec_param(api_spec, param, resource)
|
223
|
-
|
305
|
+
for req_ref in api_spec.method_request_body_ref(create_method):
|
224
306
|
for property_dict in api_spec.schema_properties(req_ref):
|
225
307
|
parse_api_spec_param(api_spec, property_dict, resource)
|
226
308
|
for path in resource.paths:
|
227
|
-
read_method = api_spec.
|
309
|
+
read_method = api_spec.get_method(path)
|
228
310
|
if not read_method:
|
229
311
|
continue
|
230
312
|
for param in read_method.get("parameters", []):
|
231
313
|
parse_api_spec_param(api_spec, param, resource)
|
232
|
-
|
314
|
+
for response_ref in api_spec.method_response_ref(read_method):
|
233
315
|
for property_dict in api_spec.schema_properties(response_ref):
|
234
316
|
parse_api_spec_param(api_spec, property_dict, resource)
|
235
317
|
if minimal_refs:
|
atlas_init/cli_tf/schema_v2.py
CHANGED
@@ -8,11 +8,12 @@ from fnmatch import fnmatch
|
|
8
8
|
from pathlib import Path
|
9
9
|
from queue import Queue
|
10
10
|
from tempfile import TemporaryDirectory
|
11
|
-
from typing import Literal, TypeAlias
|
11
|
+
from typing import Any, Literal, TypeAlias
|
12
12
|
|
13
13
|
from model_lib import Entity, copy_and_validate, parse_model
|
14
14
|
from pydantic import ConfigDict, Field, model_validator
|
15
15
|
from zero_3rdparty.enum_utils import StrEnum
|
16
|
+
from zero_3rdparty.iter_utils import flat_map
|
16
17
|
|
17
18
|
from atlas_init.cli_helper.run import run_binary_command_is_ok
|
18
19
|
from atlas_init.humps import decamelize, pascalize
|
@@ -50,6 +51,13 @@ class SchemaAttribute(Entity):
|
|
50
51
|
validators: list[SchemaAttributeValidator] = Field(default_factory=list)
|
51
52
|
# not used during dumping but backtrace which parameters are used in the api spec
|
52
53
|
parameter_ref: str = ""
|
54
|
+
additional_properties: dict[str, Any] = Field(default_factory=dict)
|
55
|
+
|
56
|
+
@property
|
57
|
+
def additional_properties_ref(self) -> str:
|
58
|
+
if props := self.additional_properties:
|
59
|
+
return props.get("$ref", "")
|
60
|
+
return ""
|
53
61
|
|
54
62
|
@property
|
55
63
|
def tf_name(self) -> str:
|
@@ -88,6 +96,7 @@ class SchemaAttribute(Entity):
|
|
88
96
|
plan_modifiers=self.plan_modifiers + other.plan_modifiers,
|
89
97
|
validators=self.validators + other.validators,
|
90
98
|
parameter_ref=self.parameter_ref or other.parameter_ref,
|
99
|
+
additional_properties=self.additional_properties | other.additional_properties,
|
91
100
|
)
|
92
101
|
|
93
102
|
def set_attribute_type(
|
@@ -175,6 +184,29 @@ class SDKConversion(Entity):
|
|
175
184
|
return bool(self.sdk_start_refs)
|
176
185
|
|
177
186
|
|
187
|
+
class Discriminator(Entity):
|
188
|
+
mapping: dict[str, str] = Field(default_factory=dict)
|
189
|
+
property_name: str = Field(alias="propertyName")
|
190
|
+
|
191
|
+
|
192
|
+
class OneOf(Entity):
|
193
|
+
ref: str = Field(alias="$ref", default="")
|
194
|
+
|
195
|
+
|
196
|
+
class AllOf(Entity):
|
197
|
+
ref: str = Field(alias="$ref", default="")
|
198
|
+
properties: dict[str, Any] = Field(default_factory=dict)
|
199
|
+
|
200
|
+
@property
|
201
|
+
def nested_refs(self) -> set[str]:
|
202
|
+
refs = set()
|
203
|
+
for prop, prop_value in self.properties.items():
|
204
|
+
if isinstance(prop_value, dict):
|
205
|
+
if ref := prop_value.get("$ref"):
|
206
|
+
refs.add(ref)
|
207
|
+
return refs
|
208
|
+
|
209
|
+
|
178
210
|
class SchemaResource(Entity):
|
179
211
|
name: str = "" # populated by the key of the resources dict
|
180
212
|
description: str = ""
|
@@ -183,6 +215,16 @@ class SchemaResource(Entity):
|
|
183
215
|
paths: list[str] = Field(default_factory=list)
|
184
216
|
attribute_type_modifiers: AttributeTypeModifiers = Field(default_factory=AttributeTypeModifiers)
|
185
217
|
conversion: SDKConversion = Field(default_factory=SDKConversion)
|
218
|
+
discriminator: Discriminator | None = None
|
219
|
+
one_of: list[OneOf] = Field(default_factory=list)
|
220
|
+
all_of: list[AllOf] = Field(default_factory=list)
|
221
|
+
|
222
|
+
def extra_refs(self) -> set[str]:
|
223
|
+
return (
|
224
|
+
{one_of.ref for one_of in self.one_of if one_of.ref}
|
225
|
+
| {all_of.ref for all_of in self.all_of if all_of.ref}
|
226
|
+
| {ref for ref in flat_map(all_of.nested_refs for all_of in self.all_of) if ref}
|
227
|
+
)
|
186
228
|
|
187
229
|
@model_validator(mode="after")
|
188
230
|
def set_attribute_names(self):
|
File without changes
|
@@ -0,0 +1,115 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from dataclasses import dataclass, field
|
4
|
+
import logging
|
5
|
+
from typing import TypeAlias
|
6
|
+
|
7
|
+
from motor.motor_asyncio import AsyncIOMotorClient, AsyncIOMotorDatabase, AsyncIOMotorCollection
|
8
|
+
from pymongo import IndexModel
|
9
|
+
from pymongo.errors import DuplicateKeyError
|
10
|
+
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fixed
|
11
|
+
|
12
|
+
from atlas_init.cli_tf.go_test_run import GoTestRun
|
13
|
+
from atlas_init.cli_tf.go_test_tf_error import GoTestErrorClassification
|
14
|
+
from atlas_init.crud.mongo_utils import index_dec
|
15
|
+
|
16
|
+
|
17
|
+
logger = logging.getLogger(__name__)
|
18
|
+
|
19
|
+
|
20
|
+
@dataclass
|
21
|
+
class CollectionConfig:
|
22
|
+
name: str = "" # uses the class name by default
|
23
|
+
indexes: list[IndexModel] = field(default_factory=list)
|
24
|
+
|
25
|
+
|
26
|
+
CollectionConfigsT: TypeAlias = dict[type, CollectionConfig]
|
27
|
+
|
28
|
+
|
29
|
+
def default_document_models() -> CollectionConfigsT:
|
30
|
+
return {
|
31
|
+
GoTestErrorClassification: CollectionConfig(
|
32
|
+
indexes=[index_dec("ts"), IndexModel(["error_class"]), IndexModel(["test_name"])]
|
33
|
+
),
|
34
|
+
GoTestRun: CollectionConfig(indexes=[index_dec("ts"), IndexModel(["branch"]), IndexModel(["status"])]),
|
35
|
+
}
|
36
|
+
|
37
|
+
|
38
|
+
_collections = {}
|
39
|
+
|
40
|
+
|
41
|
+
def get_collection(model: type) -> AsyncIOMotorCollection:
|
42
|
+
col = _collections.get(model)
|
43
|
+
if col is not None:
|
44
|
+
return col
|
45
|
+
raise ValueError(f"Collection for model {model.__name__} is not initialized. Call init_mongo first.")
|
46
|
+
|
47
|
+
|
48
|
+
def get_db(mongo_url: str, db_name: str) -> AsyncIOMotorDatabase:
|
49
|
+
client = AsyncIOMotorClient(mongo_url)
|
50
|
+
return client.get_database(db_name)
|
51
|
+
|
52
|
+
|
53
|
+
async def init_mongo(
|
54
|
+
mongo_url: str, db_name: str, clean_collections: bool = False, document_models: CollectionConfigsT | None = None
|
55
|
+
) -> None:
|
56
|
+
db = get_db(mongo_url, db_name)
|
57
|
+
document_models = document_models or default_document_models()
|
58
|
+
for model, cfg in document_models.items():
|
59
|
+
name = cfg.name or model.__name__
|
60
|
+
col = await ensure_collection_exist(db, name, cfg.indexes, clean_collections)
|
61
|
+
_collections[model] = col
|
62
|
+
|
63
|
+
if clean_collections:
|
64
|
+
logger.info(f"MongoDB collections in '{db_name}' have been cleaned.")
|
65
|
+
|
66
|
+
|
67
|
+
async def ensure_collection_exist(
|
68
|
+
db: AsyncIOMotorDatabase,
|
69
|
+
name: str,
|
70
|
+
indexes: list[IndexModel] | None = None,
|
71
|
+
clean_collection: bool = False,
|
72
|
+
) -> AsyncIOMotorCollection:
|
73
|
+
existing = await db.list_collection_names()
|
74
|
+
if clean_collection and name in existing:
|
75
|
+
await db.drop_collection(name)
|
76
|
+
existing.remove(name)
|
77
|
+
|
78
|
+
if name not in existing:
|
79
|
+
await db.create_collection(name)
|
80
|
+
|
81
|
+
if indexes:
|
82
|
+
# always (re-)create indexes after new creation or drop
|
83
|
+
await db[name].create_indexes(indexes)
|
84
|
+
|
85
|
+
logger.debug(f"mongo collection {name!r} is ready")
|
86
|
+
return db[name]
|
87
|
+
|
88
|
+
|
89
|
+
def duplicate_key_pattern(error: DuplicateKeyError) -> str | None:
|
90
|
+
details: dict = error.details # type: ignore
|
91
|
+
name_violator = details.get("keyPattern", {})
|
92
|
+
if not name_violator:
|
93
|
+
return None
|
94
|
+
name, _ = name_violator.popitem()
|
95
|
+
return name
|
96
|
+
|
97
|
+
|
98
|
+
class CollectionNotEmptyError(Exception):
|
99
|
+
def __init__(self, collection_name: str):
|
100
|
+
super().__init__(f"Collection '{collection_name}' is not empty.")
|
101
|
+
self.collection_name = collection_name
|
102
|
+
|
103
|
+
|
104
|
+
@retry(
|
105
|
+
stop=stop_after_attempt(10),
|
106
|
+
wait=wait_fixed(0.5),
|
107
|
+
retry=retry_if_exception_type(CollectionNotEmptyError),
|
108
|
+
reraise=True,
|
109
|
+
)
|
110
|
+
async def _empty_collections() -> None:
|
111
|
+
col: AsyncIOMotorCollection
|
112
|
+
for col in _collections.values():
|
113
|
+
count = await col.count_documents({})
|
114
|
+
if count > 0:
|
115
|
+
raise CollectionNotEmptyError(col.name)
|