magic-pocket-cli 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. magic_pocket_cli-0.2.0.dist-info/METADATA +14 -0
  2. magic_pocket_cli-0.2.0.dist-info/RECORD +65 -0
  3. magic_pocket_cli-0.2.0.dist-info/WHEEL +4 -0
  4. magic_pocket_cli-0.2.0.dist-info/entry_points.txt +2 -0
  5. pocket_cli/__init__.py +0 -0
  6. pocket_cli/cli/__init__.py +0 -0
  7. pocket_cli/cli/aws_auth.py +48 -0
  8. pocket_cli/cli/awscontainer_cli.py +328 -0
  9. pocket_cli/cli/cloudfront_cli.py +116 -0
  10. pocket_cli/cli/cloudfront_keys_cli.py +68 -0
  11. pocket_cli/cli/cloudfront_waf_cli.py +68 -0
  12. pocket_cli/cli/deploy_cli.py +274 -0
  13. pocket_cli/cli/destroy_cli.py +358 -0
  14. pocket_cli/cli/dsql_cli.py +60 -0
  15. pocket_cli/cli/main_cli.py +91 -0
  16. pocket_cli/cli/migrate_cli.py +148 -0
  17. pocket_cli/cli/neon_cli.py +97 -0
  18. pocket_cli/cli/permissions_cli.py +46 -0
  19. pocket_cli/cli/rds_cli.py +63 -0
  20. pocket_cli/cli/runtime_config_cli.py +185 -0
  21. pocket_cli/cli/s3_cli.py +69 -0
  22. pocket_cli/cli/status_cli.py +56 -0
  23. pocket_cli/cli/tidb_cli.py +73 -0
  24. pocket_cli/cli/vpc_cli.py +92 -0
  25. pocket_cli/cli/waf_cli.py +182 -0
  26. pocket_cli/django_cli.py +412 -0
  27. pocket_cli/mediator.py +220 -0
  28. pocket_cli/resources/__init__.py +0 -0
  29. pocket_cli/resources/aws/__init__.py +0 -0
  30. pocket_cli/resources/aws/builders/__init__.py +57 -0
  31. pocket_cli/resources/aws/builders/codebuild.py +363 -0
  32. pocket_cli/resources/aws/builders/depot.py +84 -0
  33. pocket_cli/resources/aws/builders/docker.py +34 -0
  34. pocket_cli/resources/aws/builders/dockerignore.py +44 -0
  35. pocket_cli/resources/aws/cloudformation.py +790 -0
  36. pocket_cli/resources/aws/ecr.py +145 -0
  37. pocket_cli/resources/aws/efs.py +138 -0
  38. pocket_cli/resources/aws/lambdahandler.py +182 -0
  39. pocket_cli/resources/aws/s3_utils.py +58 -0
  40. pocket_cli/resources/aws/state.py +74 -0
  41. pocket_cli/resources/awscontainer.py +265 -0
  42. pocket_cli/resources/cloudfront.py +491 -0
  43. pocket_cli/resources/cloudfront_acm.py +55 -0
  44. pocket_cli/resources/cloudfront_keys.py +81 -0
  45. pocket_cli/resources/cloudfront_waf.py +67 -0
  46. pocket_cli/resources/dsql.py +142 -0
  47. pocket_cli/resources/neon.py +353 -0
  48. pocket_cli/resources/rds.py +680 -0
  49. pocket_cli/resources/s3.py +307 -0
  50. pocket_cli/resources/tidb.py +298 -0
  51. pocket_cli/resources/upstash.py +152 -0
  52. pocket_cli/resources/vpc.py +67 -0
  53. pocket_cli/templates/cloudformation/awscontainer.yaml +516 -0
  54. pocket_cli/templates/cloudformation/cf_function_api_host.js +5 -0
  55. pocket_cli/templates/cloudformation/cf_function_spa_auth.js +28 -0
  56. pocket_cli/templates/cloudformation/cf_function_spa_fallback.js +8 -0
  57. pocket_cli/templates/cloudformation/cloudfront.yaml +309 -0
  58. pocket_cli/templates/cloudformation/cloudfront_acm.yaml +43 -0
  59. pocket_cli/templates/cloudformation/cloudfront_keys.yaml +32 -0
  60. pocket_cli/templates/cloudformation/cloudfront_waf.yaml +97 -0
  61. pocket_cli/templates/cloudformation/vpc.yaml +213 -0
  62. pocket_cli/templates/init/django-dotenv.env +3 -0
  63. pocket_cli/templates/init/django-settings.py +140 -0
  64. pocket_cli/templates/init/pocket.Dockerfile +26 -0
  65. pocket_cli/templates/init/pocket_simple.toml +31 -0
@@ -0,0 +1,307 @@
1
+ from __future__ import annotations
2
+
3
+ from functools import cached_property
4
+ from typing import TYPE_CHECKING
5
+
6
+ import boto3
7
+ from botocore.exceptions import ClientError
8
+
9
+ from pocket.resources.base import ResourceStatus
10
+ from pocket.utils import echo
11
+ from pocket_cli.resources.aws.s3_utils import (
12
+ bucket_exists,
13
+ create_bucket,
14
+ delete_bucket_with_contents,
15
+ )
16
+
17
+ if TYPE_CHECKING:
18
+ from pocket.context import CloudFrontContext, S3Context
19
+
20
+
21
+ class S3:
22
+ context: S3Context
23
+ _cloudfront_contexts: dict[str, CloudFrontContext]
24
+
25
+ def __init__(
26
+ self,
27
+ context: S3Context,
28
+ cloudfront_contexts: dict[str, CloudFrontContext] | None = None,
29
+ ) -> None:
30
+ self.context = context
31
+ self._cloudfront_contexts = cloudfront_contexts or {}
32
+ self.client = boto3.client("s3", region_name=context.region)
33
+
34
+ @property
35
+ def description(self):
36
+ return "Create bucket: %s" % self.context.bucket_name
37
+
38
+ def state_info(self):
39
+ return {"s3": {"bucket_name": self.context.bucket_name}}
40
+
41
+ def deploy_init(self):
42
+ pass
43
+
44
+ def create(self):
45
+ create_bucket(self.client, self.context.bucket_name, self.context.region)
46
+ self.ensure_public_access_block()
47
+ self._ensure_cors()
48
+ self._ensure_versioning()
49
+ self._ensure_lifecycle()
50
+
51
+ def ensure_exists(self):
52
+ if self.exists():
53
+ self.ensure_public_access_block()
54
+ self._ensure_cors()
55
+ self._ensure_versioning()
56
+ self._ensure_lifecycle()
57
+ return
58
+ self.create()
59
+
60
+ def delete(self):
61
+ delete_bucket_with_contents(self.client, self.context.bucket_name)
62
+
63
+ def update(self):
64
+ self.ensure_public_access_block()
65
+ self._ensure_cors()
66
+ self._ensure_versioning()
67
+ self._ensure_lifecycle()
68
+
69
+ def exists(self):
70
+ try:
71
+ return bucket_exists(self.client, self.context.bucket_name)
72
+ except ClientError as e:
73
+ raise Exception(
74
+ "Bucket might be already used by other account."
75
+ " Try another bucket_prefix."
76
+ ) from e
77
+
78
+ @property
79
+ def status(self) -> ResourceStatus:
80
+ if not self.exists():
81
+ return "NOEXIST"
82
+ if self.public_access_block_require_update:
83
+ return "REQUIRE_UPDATE"
84
+ if self.cors_require_update:
85
+ return "REQUIRE_UPDATE"
86
+ if self.versioning_require_update:
87
+ return "REQUIRE_UPDATE"
88
+ if self.lifecycle_require_update:
89
+ return "REQUIRE_UPDATE"
90
+ return "COMPLETED"
91
+
92
+ @cached_property
93
+ def public_access_block(self) -> dict | None:
94
+ try:
95
+ res = self.client.get_public_access_block(
96
+ Bucket=self.context.bucket_name,
97
+ )
98
+ except ClientError as e:
99
+ if e.response["Error"]["Code"] == "NoSuchPublicAccessBlockConfiguration":
100
+ return None
101
+ raise
102
+ return res["PublicAccessBlockConfiguration"]
103
+
104
+ @property
105
+ def public_access_block_should_be(self):
106
+ return {
107
+ "BlockPublicAcls": True,
108
+ "IgnorePublicAcls": True,
109
+ "BlockPublicPolicy": True,
110
+ "RestrictPublicBuckets": True,
111
+ }
112
+
113
+ @property
114
+ def public_access_block_require_update(self):
115
+ return self.public_access_block_should_be != self.public_access_block
116
+
117
+ def ensure_public_access_block(self):
118
+ if self.public_access_block_require_update:
119
+ echo.info("Update public access block configuration")
120
+ echo.info("Current configuration: %s" % self.public_access_block)
121
+ self.client.put_public_access_block(
122
+ Bucket=self.context.bucket_name,
123
+ PublicAccessBlockConfiguration=self.public_access_block_should_be,
124
+ )
125
+ del self.public_access_block
126
+ echo.info("Updated to: %s" % self.public_access_block_should_be)
127
+ else:
128
+ echo.info("Public access block is already configured properly.")
129
+
130
+ def _resolve_cors_origins(self) -> list[str]:
131
+ """CORS の AllowedOrigins を CloudFront ドメインから解決する"""
132
+ if not self.context.cors:
133
+ return []
134
+ origins: list[str] = []
135
+ for cf_name in self.context.cors.cloudfront_names:
136
+ cf_ctx = self._cloudfront_contexts.get(cf_name)
137
+ if not cf_ctx:
138
+ echo.warning("CORS: cloudfront '%s' が見つかりません" % cf_name)
139
+ continue
140
+ if cf_ctx.domain:
141
+ origins.append("https://%s" % cf_ctx.domain)
142
+ else:
143
+ origins.append("https://*.cloudfront.net")
144
+ return origins
145
+
146
+ def _desired_cors_rules(self) -> list[dict] | None:
147
+ """期待する CORS ルールを返す。CORS 未設定なら None。"""
148
+ if not self.context.cors:
149
+ return None
150
+ origins = self._resolve_cors_origins()
151
+ if not origins:
152
+ return None
153
+ return [
154
+ {
155
+ "AllowedOrigins": origins,
156
+ "AllowedMethods": self.context.cors.methods,
157
+ "AllowedHeaders": ["*"],
158
+ "MaxAgeSeconds": 3600,
159
+ }
160
+ ]
161
+
162
+ @cached_property
163
+ def current_cors_rules(self) -> list[dict] | None:
164
+ """現在の S3 バケット CORS ルールを返す。未設定なら None。"""
165
+ try:
166
+ res = self.client.get_bucket_cors(Bucket=self.context.bucket_name)
167
+ except ClientError as e:
168
+ if e.response["Error"]["Code"] == "NoSuchCORSConfiguration":
169
+ return None
170
+ raise
171
+ return res.get("CORSRules")
172
+
173
+ @property
174
+ def cors_require_update(self) -> bool:
175
+ # 宣言的: desired と current が一致しなければ drift
176
+ return self._desired_cors_rules() != self.current_cors_rules
177
+
178
+ def _ensure_cors(self):
179
+ """S3 バケットの CORS 設定を宣言的に適用する。
180
+
181
+ - cors 宣言あり: PutBucketCors で置き換え
182
+ - cors 宣言なし & 現状あり: DeleteBucketCors
183
+ - cors 宣言なし & 現状なし: 何もしない
184
+ """
185
+ desired = self._desired_cors_rules()
186
+ current = self.current_cors_rules
187
+ if desired == current:
188
+ return
189
+ if desired is None:
190
+ self.client.delete_bucket_cors(Bucket=self.context.bucket_name)
191
+ self.__dict__.pop("current_cors_rules", None)
192
+ echo.info("CORS 設定を削除しました")
193
+ return
194
+ self.client.put_bucket_cors(
195
+ Bucket=self.context.bucket_name,
196
+ CORSConfiguration={"CORSRules": desired},
197
+ )
198
+ self.__dict__.pop("current_cors_rules", None)
199
+ echo.info("CORS 設定を適用しました: %s" % desired[0]["AllowedOrigins"])
200
+
201
+ @cached_property
202
+ def current_versioning_status(self) -> str | None:
203
+ """現在の S3 バケット versioning 状態 (Enabled / Suspended / None)。"""
204
+ res = self.client.get_bucket_versioning(Bucket=self.context.bucket_name)
205
+ return res.get("Status")
206
+
207
+ @property
208
+ def versioning_require_update(self) -> bool:
209
+ """宣言的判定。
210
+
211
+ - versioning=True: 現状が Enabled でなければ drift
212
+ - versioning=False: 現状が Enabled なら drift
213
+ (None / Suspended は "未有効化" として等価扱い、no-op)
214
+ """
215
+ current = self.current_versioning_status
216
+ if self.context.versioning:
217
+ return current != "Enabled"
218
+ return current == "Enabled"
219
+
220
+ def _ensure_versioning(self):
221
+ """S3 バケットの versioning を宣言的に適用する。
222
+
223
+ - versioning=True かつ現状 != Enabled: PutBucketVersioning(Enabled)
224
+ - versioning=False かつ現状 == Enabled: PutBucketVersioning(Suspended)
225
+ - その他: no-op (None と Suspended は "未有効化" として等価)
226
+
227
+ S3 仕様上、一度 Enabled にしたバケットは Suspended にしか戻せない
228
+ (バージョン情報は保持される)。docs に明記する。
229
+ """
230
+ current = self.current_versioning_status
231
+ if self.context.versioning:
232
+ if current == "Enabled":
233
+ return
234
+ self.client.put_bucket_versioning(
235
+ Bucket=self.context.bucket_name,
236
+ VersioningConfiguration={"Status": "Enabled"},
237
+ )
238
+ self.__dict__.pop("current_versioning_status", None)
239
+ echo.info("Versioning を有効化しました")
240
+ return
241
+ if current != "Enabled":
242
+ return
243
+ self.client.put_bucket_versioning(
244
+ Bucket=self.context.bucket_name,
245
+ VersioningConfiguration={"Status": "Suspended"},
246
+ )
247
+ self.__dict__.pop("current_versioning_status", None)
248
+ echo.info("Versioning を Suspended にしました")
249
+
250
+ def _desired_lifecycle_rules(self) -> list[dict] | None:
251
+ """期待する Lifecycle ルールを返す。空のとき None (= ルールなし)。"""
252
+ if not self.context.lifecycle_rules:
253
+ return None
254
+ rules: list[dict] = []
255
+ for rule in self.context.lifecycle_rules:
256
+ rules.append(
257
+ {
258
+ "ID": rule.id,
259
+ "Status": "Enabled",
260
+ "Filter": {"Prefix": rule.prefix},
261
+ "NoncurrentVersionExpiration": {
262
+ "NoncurrentDays": rule.noncurrent_version_expiration_days,
263
+ },
264
+ }
265
+ )
266
+ return rules
267
+
268
+ @cached_property
269
+ def current_lifecycle_rules(self) -> list[dict] | None:
270
+ """現在の S3 バケット Lifecycle ルール。未設定なら None。"""
271
+ try:
272
+ res = self.client.get_bucket_lifecycle_configuration(
273
+ Bucket=self.context.bucket_name,
274
+ )
275
+ except ClientError as e:
276
+ if e.response["Error"]["Code"] == "NoSuchLifecycleConfiguration":
277
+ return None
278
+ raise
279
+ return res.get("Rules")
280
+
281
+ @property
282
+ def lifecycle_require_update(self) -> bool:
283
+ # 宣言的: desired と current が一致しなければ drift
284
+ return self._desired_lifecycle_rules() != self.current_lifecycle_rules
285
+
286
+ def _ensure_lifecycle(self):
287
+ """S3 バケットの Lifecycle 設定を宣言的に適用する。
288
+
289
+ - lifecycle_rules 宣言あり: PutBucketLifecycleConfiguration で置き換え
290
+ - lifecycle_rules 空 & 現状あり: DeleteBucketLifecycle
291
+ - lifecycle_rules 空 & 現状なし: 何もしない
292
+ """
293
+ desired = self._desired_lifecycle_rules()
294
+ current = self.current_lifecycle_rules
295
+ if desired == current:
296
+ return
297
+ if desired is None:
298
+ self.client.delete_bucket_lifecycle(Bucket=self.context.bucket_name)
299
+ self.__dict__.pop("current_lifecycle_rules", None)
300
+ echo.info("Lifecycle 設定を削除しました")
301
+ return
302
+ self.client.put_bucket_lifecycle_configuration(
303
+ Bucket=self.context.bucket_name,
304
+ LifecycleConfiguration={"Rules": desired},
305
+ )
306
+ self.__dict__.pop("current_lifecycle_rules", None)
307
+ echo.info("Lifecycle ルールを適用しました: %d 件" % len(desired))
@@ -0,0 +1,298 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import logging
5
+ import os
6
+ import secrets
7
+ import ssl
8
+ import string
9
+ import time
10
+ from functools import cached_property
11
+ from typing import TYPE_CHECKING
12
+
13
+ import requests
14
+ from pydantic import BaseModel
15
+ from requests.auth import HTTPDigestAuth
16
+
17
+ from pocket.resources.base import ResourceStatus
18
+
19
+ if TYPE_CHECKING:
20
+ from pocket.context import TiDbContext
21
+
22
+ logging.basicConfig()
23
+ logger = logging.getLogger(__name__)
24
+ logger.setLevel(level=os.getenv("POCKET_LOGGER_LEVEL", "WARNING").upper())
25
+
26
+
27
+ class TiDbResourceIsNotReady(Exception):
28
+ pass
29
+
30
+
31
+ class Project(BaseModel):
32
+ id: str
33
+ name: str
34
+
35
+
36
+ class Cluster(BaseModel):
37
+ id: str
38
+ name: str
39
+ status: str
40
+ host: str
41
+ port: int
42
+ user: str
43
+
44
+
45
+ class TiDbApi:
46
+ serverless_endpoint = "https://serverless.tidbapi.com/v1beta1/"
47
+ iam_endpoint = "https://iam.tidbapi.com/v1beta1/"
48
+
49
+ def __init__(self, public_key: str, private_key: str) -> None:
50
+ self.auth = HTTPDigestAuth(public_key, private_key)
51
+
52
+ def _request(self, method: str, url: str, data=None):
53
+ logger.info("%s %s" % (method, url))
54
+ if data:
55
+ logger.debug(json.dumps(data, indent=2))
56
+ res = requests.request(method, url, auth=self.auth, json=data)
57
+ logger.debug(res.status_code)
58
+ logger.debug(json.dumps(res.json(), indent=2))
59
+ if 200 <= res.status_code < 300:
60
+ if method != "GET":
61
+ time.sleep(2)
62
+ return res
63
+ raise RuntimeError("%s: %s" % (res.status_code, res.text))
64
+
65
+ def iam_get(self, path: str):
66
+ return self._request("GET", self.iam_endpoint + path)
67
+
68
+ def serverless_get(self, path: str):
69
+ return self._request("GET", self.serverless_endpoint + path)
70
+
71
+ def serverless_post(self, path: str, data=None):
72
+ return self._request("POST", self.serverless_endpoint + path, data)
73
+
74
+ def serverless_put(self, path: str, data=None):
75
+ return self._request("PUT", self.serverless_endpoint + path, data)
76
+
77
+ def serverless_delete(self, path: str):
78
+ return self._request("DELETE", self.serverless_endpoint + path)
79
+
80
+ def list_projects(self) -> list[dict]:
81
+ return self.iam_get("projects").json().get("projects", [])
82
+
83
+ def list_clusters(self, project_id: str) -> list[dict]:
84
+ return (
85
+ self.serverless_get("clusters?filter=projectId=%s" % project_id)
86
+ .json()
87
+ .get("clusters", [])
88
+ )
89
+
90
+ def create_cluster(self, data: dict) -> dict:
91
+ return self.serverless_post("clusters", data).json()
92
+
93
+ def get_cluster(self, cluster_id: str) -> dict:
94
+ return self.serverless_get("clusters/%s" % cluster_id).json()
95
+
96
+ def delete_cluster(self, cluster_id: str):
97
+ return self.serverless_delete("clusters/%s" % cluster_id)
98
+
99
+ def change_password(self, cluster_id: str, password: str):
100
+ return self.serverless_put(
101
+ "clusters/%s/password" % cluster_id,
102
+ {"password": password},
103
+ )
104
+
105
+
106
+ class TiDb:
107
+ context: TiDbContext
108
+ _root_password: str | None
109
+
110
+ def __init__(self, context: TiDbContext) -> None:
111
+ self.context = context
112
+ self._root_password = None
113
+
114
+ @cached_property
115
+ def api(self) -> TiDbApi:
116
+ if not self.context.public_key or not self.context.private_key:
117
+ raise TiDbResourceIsNotReady(
118
+ "TiDB API keys not configured. "
119
+ "Set tidb_public_key and tidb_private_key in .env"
120
+ )
121
+ return TiDbApi(self.context.public_key, self.context.private_key)
122
+
123
+ @cached_property
124
+ def project(self) -> Project | None:
125
+ projects = self.api.list_projects()
126
+ if self.context.tidb_project:
127
+ for p in projects:
128
+ if p["name"] == self.context.tidb_project:
129
+ return Project(id=str(p["id"]), name=p["name"])
130
+ return None
131
+ if len(projects) == 1:
132
+ p = projects[0]
133
+ return Project(id=str(p["id"]), name=p["name"])
134
+ if len(projects) > 1:
135
+ names = [p["name"] for p in projects]
136
+ raise TiDbResourceIsNotReady(
137
+ "複数の TiDB Cloud プロジェクトが見つかりました: %s\n"
138
+ "pocket.toml の [tidb] で project を指定してください。\n"
139
+ '例: project = "%s"' % (names, names[0])
140
+ )
141
+ return None
142
+
143
+ @cached_property
144
+ def cluster(self) -> Cluster | None:
145
+ if not self.project:
146
+ return None
147
+ for c in self.api.list_clusters(self.project.id):
148
+ if c.get("displayName") == self.context.cluster_name:
149
+ endpoints = c.get("endpoints", {}).get("public", {})
150
+ user_prefix = c.get("userPrefix", "")
151
+ return Cluster(
152
+ id=str(c["clusterId"]),
153
+ name=c["displayName"],
154
+ status=c.get("state", "UNKNOWN"),
155
+ host=endpoints.get("host", ""),
156
+ port=int(endpoints.get("port", 4000)),
157
+ user="%s.root" % user_prefix,
158
+ )
159
+ return None
160
+
161
+ @property
162
+ def status(self) -> ResourceStatus:
163
+ if self.context.skip_check_existing:
164
+ # 存在確認の TiDB API call を skip し COMPLETED 固定 (settings 参照)。
165
+ return "COMPLETED"
166
+ if not self.context.public_key or not self.context.private_key:
167
+ return "NOEXIST"
168
+ if self.cluster and self.cluster.status == "ACTIVE":
169
+ return "COMPLETED"
170
+ return "NOEXIST"
171
+
172
+ @property
173
+ def description(self) -> str:
174
+ return "Create TiDB cluster and database"
175
+
176
+ @property
177
+ def database_url(self) -> str:
178
+ if not self.cluster:
179
+ raise TiDbResourceIsNotReady("Cluster not found")
180
+ if self.cluster.status != "ACTIVE":
181
+ raise TiDbResourceIsNotReady("Cluster is not available")
182
+ password = self._root_password
183
+ if not password:
184
+ password = self._reset_password()
185
+ return "mysql://%s:%s@%s:%d/%s" % (
186
+ self.cluster.user,
187
+ password,
188
+ self.cluster.host,
189
+ self.cluster.port,
190
+ self.context.database_name,
191
+ )
192
+
193
+ def _generate_password(self) -> str:
194
+ chars = string.ascii_letters + string.digits
195
+ return "".join(secrets.choice(chars) for _ in range(24))
196
+
197
+ def _reset_password(self) -> str:
198
+ if not self.cluster:
199
+ raise TiDbResourceIsNotReady("Cluster not found")
200
+ new_password = self._generate_password()
201
+ self.api.change_password(self.cluster.id, new_password)
202
+ self._root_password = new_password
203
+ return new_password
204
+
205
+ def deploy_init(self):
206
+ pass
207
+
208
+ def create(self):
209
+ self._ensure_cluster()
210
+ self._ensure_database()
211
+
212
+ def _ensure_cluster(self):
213
+ if self.cluster:
214
+ return
215
+ if not self.project:
216
+ raise TiDbResourceIsNotReady(
217
+ "No TiDB Cloud project found. Create one at https://tidbcloud.com/"
218
+ )
219
+ password = self._generate_password()
220
+ data = {
221
+ "displayName": self.context.cluster_name,
222
+ "region": {"name": "regions/aws-%s" % self.context.region},
223
+ "rootPassword": password,
224
+ "labels": {"tidb.cloud/project": self.project.id},
225
+ "endpoints": {
226
+ "public": {
227
+ "authorizedNetworks": [
228
+ {
229
+ "startIpAddress": "0.0.0.0",
230
+ "endIpAddress": "255.255.255.255",
231
+ "displayName": "Allow all",
232
+ }
233
+ ]
234
+ }
235
+ },
236
+ }
237
+ self.api.create_cluster(data)
238
+ self._root_password = password
239
+ self._wait_for_cluster()
240
+
241
+ def _wait_for_cluster(self):
242
+ logger.info("Waiting for cluster to become available...")
243
+ del self.cluster
244
+ for _ in range(60):
245
+ if self.cluster and self.cluster.status == "ACTIVE":
246
+ return
247
+ time.sleep(5)
248
+ del self.cluster
249
+ raise TiDbResourceIsNotReady("Cluster did not become available in time")
250
+
251
+ def _get_sql_connection(self):
252
+ try:
253
+ import pymysql # noqa: PLC0415
254
+ except ModuleNotFoundError:
255
+ raise ModuleNotFoundError(
256
+ "pymysql is required for TiDB database management.\n"
257
+ "Install it with: uv add pymysql"
258
+ ) from None
259
+ if not self.cluster:
260
+ raise TiDbResourceIsNotReady("Cluster not found")
261
+ password = self._root_password
262
+ if not password:
263
+ password = self._reset_password()
264
+ ssl_ctx = ssl.create_default_context()
265
+ return pymysql.connect(
266
+ host=self.cluster.host,
267
+ port=self.cluster.port,
268
+ user=self.cluster.user,
269
+ password=password,
270
+ ssl=ssl_ctx,
271
+ )
272
+
273
+ def _ensure_database(self):
274
+ conn = self._get_sql_connection()
275
+ try:
276
+ with conn.cursor() as cursor:
277
+ cursor.execute(
278
+ "CREATE DATABASE IF NOT EXISTS `%s`" % self.context.database_name
279
+ )
280
+ finally:
281
+ conn.close()
282
+
283
+ def delete_cluster(self):
284
+ if not self.cluster:
285
+ raise TiDbResourceIsNotReady("Cluster not found")
286
+ self.api.delete_cluster(self.cluster.id)
287
+ del self.cluster
288
+
289
+ def reset_database(self):
290
+ conn = self._get_sql_connection()
291
+ try:
292
+ with conn.cursor() as cursor:
293
+ cursor.execute(
294
+ "DROP DATABASE IF EXISTS `%s`" % self.context.database_name
295
+ )
296
+ cursor.execute("CREATE DATABASE `%s`" % self.context.database_name)
297
+ finally:
298
+ conn.close()