magic-pocket-cli 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- magic_pocket_cli-0.2.0.dist-info/METADATA +14 -0
- magic_pocket_cli-0.2.0.dist-info/RECORD +65 -0
- magic_pocket_cli-0.2.0.dist-info/WHEEL +4 -0
- magic_pocket_cli-0.2.0.dist-info/entry_points.txt +2 -0
- pocket_cli/__init__.py +0 -0
- pocket_cli/cli/__init__.py +0 -0
- pocket_cli/cli/aws_auth.py +48 -0
- pocket_cli/cli/awscontainer_cli.py +328 -0
- pocket_cli/cli/cloudfront_cli.py +116 -0
- pocket_cli/cli/cloudfront_keys_cli.py +68 -0
- pocket_cli/cli/cloudfront_waf_cli.py +68 -0
- pocket_cli/cli/deploy_cli.py +274 -0
- pocket_cli/cli/destroy_cli.py +358 -0
- pocket_cli/cli/dsql_cli.py +60 -0
- pocket_cli/cli/main_cli.py +91 -0
- pocket_cli/cli/migrate_cli.py +148 -0
- pocket_cli/cli/neon_cli.py +97 -0
- pocket_cli/cli/permissions_cli.py +46 -0
- pocket_cli/cli/rds_cli.py +63 -0
- pocket_cli/cli/runtime_config_cli.py +185 -0
- pocket_cli/cli/s3_cli.py +69 -0
- pocket_cli/cli/status_cli.py +56 -0
- pocket_cli/cli/tidb_cli.py +73 -0
- pocket_cli/cli/vpc_cli.py +92 -0
- pocket_cli/cli/waf_cli.py +182 -0
- pocket_cli/django_cli.py +412 -0
- pocket_cli/mediator.py +220 -0
- pocket_cli/resources/__init__.py +0 -0
- pocket_cli/resources/aws/__init__.py +0 -0
- pocket_cli/resources/aws/builders/__init__.py +57 -0
- pocket_cli/resources/aws/builders/codebuild.py +363 -0
- pocket_cli/resources/aws/builders/depot.py +84 -0
- pocket_cli/resources/aws/builders/docker.py +34 -0
- pocket_cli/resources/aws/builders/dockerignore.py +44 -0
- pocket_cli/resources/aws/cloudformation.py +790 -0
- pocket_cli/resources/aws/ecr.py +145 -0
- pocket_cli/resources/aws/efs.py +138 -0
- pocket_cli/resources/aws/lambdahandler.py +182 -0
- pocket_cli/resources/aws/s3_utils.py +58 -0
- pocket_cli/resources/aws/state.py +74 -0
- pocket_cli/resources/awscontainer.py +265 -0
- pocket_cli/resources/cloudfront.py +491 -0
- pocket_cli/resources/cloudfront_acm.py +55 -0
- pocket_cli/resources/cloudfront_keys.py +81 -0
- pocket_cli/resources/cloudfront_waf.py +67 -0
- pocket_cli/resources/dsql.py +142 -0
- pocket_cli/resources/neon.py +353 -0
- pocket_cli/resources/rds.py +680 -0
- pocket_cli/resources/s3.py +307 -0
- pocket_cli/resources/tidb.py +298 -0
- pocket_cli/resources/upstash.py +152 -0
- pocket_cli/resources/vpc.py +67 -0
- pocket_cli/templates/cloudformation/awscontainer.yaml +516 -0
- pocket_cli/templates/cloudformation/cf_function_api_host.js +5 -0
- pocket_cli/templates/cloudformation/cf_function_spa_auth.js +28 -0
- pocket_cli/templates/cloudformation/cf_function_spa_fallback.js +8 -0
- pocket_cli/templates/cloudformation/cloudfront.yaml +309 -0
- pocket_cli/templates/cloudformation/cloudfront_acm.yaml +43 -0
- pocket_cli/templates/cloudformation/cloudfront_keys.yaml +32 -0
- pocket_cli/templates/cloudformation/cloudfront_waf.yaml +97 -0
- pocket_cli/templates/cloudformation/vpc.yaml +213 -0
- pocket_cli/templates/init/django-dotenv.env +3 -0
- pocket_cli/templates/init/django-settings.py +140 -0
- pocket_cli/templates/init/pocket.Dockerfile +26 -0
- pocket_cli/templates/init/pocket_simple.toml +31 -0
|
@@ -0,0 +1,307 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from functools import cached_property
|
|
4
|
+
from typing import TYPE_CHECKING
|
|
5
|
+
|
|
6
|
+
import boto3
|
|
7
|
+
from botocore.exceptions import ClientError
|
|
8
|
+
|
|
9
|
+
from pocket.resources.base import ResourceStatus
|
|
10
|
+
from pocket.utils import echo
|
|
11
|
+
from pocket_cli.resources.aws.s3_utils import (
|
|
12
|
+
bucket_exists,
|
|
13
|
+
create_bucket,
|
|
14
|
+
delete_bucket_with_contents,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
if TYPE_CHECKING:
|
|
18
|
+
from pocket.context import CloudFrontContext, S3Context
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class S3:
|
|
22
|
+
context: S3Context
|
|
23
|
+
_cloudfront_contexts: dict[str, CloudFrontContext]
|
|
24
|
+
|
|
25
|
+
def __init__(
|
|
26
|
+
self,
|
|
27
|
+
context: S3Context,
|
|
28
|
+
cloudfront_contexts: dict[str, CloudFrontContext] | None = None,
|
|
29
|
+
) -> None:
|
|
30
|
+
self.context = context
|
|
31
|
+
self._cloudfront_contexts = cloudfront_contexts or {}
|
|
32
|
+
self.client = boto3.client("s3", region_name=context.region)
|
|
33
|
+
|
|
34
|
+
@property
|
|
35
|
+
def description(self):
|
|
36
|
+
return "Create bucket: %s" % self.context.bucket_name
|
|
37
|
+
|
|
38
|
+
def state_info(self):
|
|
39
|
+
return {"s3": {"bucket_name": self.context.bucket_name}}
|
|
40
|
+
|
|
41
|
+
def deploy_init(self):
|
|
42
|
+
pass
|
|
43
|
+
|
|
44
|
+
def create(self):
|
|
45
|
+
create_bucket(self.client, self.context.bucket_name, self.context.region)
|
|
46
|
+
self.ensure_public_access_block()
|
|
47
|
+
self._ensure_cors()
|
|
48
|
+
self._ensure_versioning()
|
|
49
|
+
self._ensure_lifecycle()
|
|
50
|
+
|
|
51
|
+
def ensure_exists(self):
|
|
52
|
+
if self.exists():
|
|
53
|
+
self.ensure_public_access_block()
|
|
54
|
+
self._ensure_cors()
|
|
55
|
+
self._ensure_versioning()
|
|
56
|
+
self._ensure_lifecycle()
|
|
57
|
+
return
|
|
58
|
+
self.create()
|
|
59
|
+
|
|
60
|
+
def delete(self):
|
|
61
|
+
delete_bucket_with_contents(self.client, self.context.bucket_name)
|
|
62
|
+
|
|
63
|
+
def update(self):
|
|
64
|
+
self.ensure_public_access_block()
|
|
65
|
+
self._ensure_cors()
|
|
66
|
+
self._ensure_versioning()
|
|
67
|
+
self._ensure_lifecycle()
|
|
68
|
+
|
|
69
|
+
def exists(self):
|
|
70
|
+
try:
|
|
71
|
+
return bucket_exists(self.client, self.context.bucket_name)
|
|
72
|
+
except ClientError as e:
|
|
73
|
+
raise Exception(
|
|
74
|
+
"Bucket might be already used by other account."
|
|
75
|
+
" Try another bucket_prefix."
|
|
76
|
+
) from e
|
|
77
|
+
|
|
78
|
+
@property
|
|
79
|
+
def status(self) -> ResourceStatus:
|
|
80
|
+
if not self.exists():
|
|
81
|
+
return "NOEXIST"
|
|
82
|
+
if self.public_access_block_require_update:
|
|
83
|
+
return "REQUIRE_UPDATE"
|
|
84
|
+
if self.cors_require_update:
|
|
85
|
+
return "REQUIRE_UPDATE"
|
|
86
|
+
if self.versioning_require_update:
|
|
87
|
+
return "REQUIRE_UPDATE"
|
|
88
|
+
if self.lifecycle_require_update:
|
|
89
|
+
return "REQUIRE_UPDATE"
|
|
90
|
+
return "COMPLETED"
|
|
91
|
+
|
|
92
|
+
@cached_property
|
|
93
|
+
def public_access_block(self) -> dict | None:
|
|
94
|
+
try:
|
|
95
|
+
res = self.client.get_public_access_block(
|
|
96
|
+
Bucket=self.context.bucket_name,
|
|
97
|
+
)
|
|
98
|
+
except ClientError as e:
|
|
99
|
+
if e.response["Error"]["Code"] == "NoSuchPublicAccessBlockConfiguration":
|
|
100
|
+
return None
|
|
101
|
+
raise
|
|
102
|
+
return res["PublicAccessBlockConfiguration"]
|
|
103
|
+
|
|
104
|
+
@property
|
|
105
|
+
def public_access_block_should_be(self):
|
|
106
|
+
return {
|
|
107
|
+
"BlockPublicAcls": True,
|
|
108
|
+
"IgnorePublicAcls": True,
|
|
109
|
+
"BlockPublicPolicy": True,
|
|
110
|
+
"RestrictPublicBuckets": True,
|
|
111
|
+
}
|
|
112
|
+
|
|
113
|
+
@property
|
|
114
|
+
def public_access_block_require_update(self):
|
|
115
|
+
return self.public_access_block_should_be != self.public_access_block
|
|
116
|
+
|
|
117
|
+
def ensure_public_access_block(self):
|
|
118
|
+
if self.public_access_block_require_update:
|
|
119
|
+
echo.info("Update public access block configuration")
|
|
120
|
+
echo.info("Current configuration: %s" % self.public_access_block)
|
|
121
|
+
self.client.put_public_access_block(
|
|
122
|
+
Bucket=self.context.bucket_name,
|
|
123
|
+
PublicAccessBlockConfiguration=self.public_access_block_should_be,
|
|
124
|
+
)
|
|
125
|
+
del self.public_access_block
|
|
126
|
+
echo.info("Updated to: %s" % self.public_access_block_should_be)
|
|
127
|
+
else:
|
|
128
|
+
echo.info("Public access block is already configured properly.")
|
|
129
|
+
|
|
130
|
+
def _resolve_cors_origins(self) -> list[str]:
|
|
131
|
+
"""CORS の AllowedOrigins を CloudFront ドメインから解決する"""
|
|
132
|
+
if not self.context.cors:
|
|
133
|
+
return []
|
|
134
|
+
origins: list[str] = []
|
|
135
|
+
for cf_name in self.context.cors.cloudfront_names:
|
|
136
|
+
cf_ctx = self._cloudfront_contexts.get(cf_name)
|
|
137
|
+
if not cf_ctx:
|
|
138
|
+
echo.warning("CORS: cloudfront '%s' が見つかりません" % cf_name)
|
|
139
|
+
continue
|
|
140
|
+
if cf_ctx.domain:
|
|
141
|
+
origins.append("https://%s" % cf_ctx.domain)
|
|
142
|
+
else:
|
|
143
|
+
origins.append("https://*.cloudfront.net")
|
|
144
|
+
return origins
|
|
145
|
+
|
|
146
|
+
def _desired_cors_rules(self) -> list[dict] | None:
|
|
147
|
+
"""期待する CORS ルールを返す。CORS 未設定なら None。"""
|
|
148
|
+
if not self.context.cors:
|
|
149
|
+
return None
|
|
150
|
+
origins = self._resolve_cors_origins()
|
|
151
|
+
if not origins:
|
|
152
|
+
return None
|
|
153
|
+
return [
|
|
154
|
+
{
|
|
155
|
+
"AllowedOrigins": origins,
|
|
156
|
+
"AllowedMethods": self.context.cors.methods,
|
|
157
|
+
"AllowedHeaders": ["*"],
|
|
158
|
+
"MaxAgeSeconds": 3600,
|
|
159
|
+
}
|
|
160
|
+
]
|
|
161
|
+
|
|
162
|
+
@cached_property
|
|
163
|
+
def current_cors_rules(self) -> list[dict] | None:
|
|
164
|
+
"""現在の S3 バケット CORS ルールを返す。未設定なら None。"""
|
|
165
|
+
try:
|
|
166
|
+
res = self.client.get_bucket_cors(Bucket=self.context.bucket_name)
|
|
167
|
+
except ClientError as e:
|
|
168
|
+
if e.response["Error"]["Code"] == "NoSuchCORSConfiguration":
|
|
169
|
+
return None
|
|
170
|
+
raise
|
|
171
|
+
return res.get("CORSRules")
|
|
172
|
+
|
|
173
|
+
@property
|
|
174
|
+
def cors_require_update(self) -> bool:
|
|
175
|
+
# 宣言的: desired と current が一致しなければ drift
|
|
176
|
+
return self._desired_cors_rules() != self.current_cors_rules
|
|
177
|
+
|
|
178
|
+
def _ensure_cors(self):
|
|
179
|
+
"""S3 バケットの CORS 設定を宣言的に適用する。
|
|
180
|
+
|
|
181
|
+
- cors 宣言あり: PutBucketCors で置き換え
|
|
182
|
+
- cors 宣言なし & 現状あり: DeleteBucketCors
|
|
183
|
+
- cors 宣言なし & 現状なし: 何もしない
|
|
184
|
+
"""
|
|
185
|
+
desired = self._desired_cors_rules()
|
|
186
|
+
current = self.current_cors_rules
|
|
187
|
+
if desired == current:
|
|
188
|
+
return
|
|
189
|
+
if desired is None:
|
|
190
|
+
self.client.delete_bucket_cors(Bucket=self.context.bucket_name)
|
|
191
|
+
self.__dict__.pop("current_cors_rules", None)
|
|
192
|
+
echo.info("CORS 設定を削除しました")
|
|
193
|
+
return
|
|
194
|
+
self.client.put_bucket_cors(
|
|
195
|
+
Bucket=self.context.bucket_name,
|
|
196
|
+
CORSConfiguration={"CORSRules": desired},
|
|
197
|
+
)
|
|
198
|
+
self.__dict__.pop("current_cors_rules", None)
|
|
199
|
+
echo.info("CORS 設定を適用しました: %s" % desired[0]["AllowedOrigins"])
|
|
200
|
+
|
|
201
|
+
@cached_property
|
|
202
|
+
def current_versioning_status(self) -> str | None:
|
|
203
|
+
"""現在の S3 バケット versioning 状態 (Enabled / Suspended / None)。"""
|
|
204
|
+
res = self.client.get_bucket_versioning(Bucket=self.context.bucket_name)
|
|
205
|
+
return res.get("Status")
|
|
206
|
+
|
|
207
|
+
@property
|
|
208
|
+
def versioning_require_update(self) -> bool:
|
|
209
|
+
"""宣言的判定。
|
|
210
|
+
|
|
211
|
+
- versioning=True: 現状が Enabled でなければ drift
|
|
212
|
+
- versioning=False: 現状が Enabled なら drift
|
|
213
|
+
(None / Suspended は "未有効化" として等価扱い、no-op)
|
|
214
|
+
"""
|
|
215
|
+
current = self.current_versioning_status
|
|
216
|
+
if self.context.versioning:
|
|
217
|
+
return current != "Enabled"
|
|
218
|
+
return current == "Enabled"
|
|
219
|
+
|
|
220
|
+
def _ensure_versioning(self):
|
|
221
|
+
"""S3 バケットの versioning を宣言的に適用する。
|
|
222
|
+
|
|
223
|
+
- versioning=True かつ現状 != Enabled: PutBucketVersioning(Enabled)
|
|
224
|
+
- versioning=False かつ現状 == Enabled: PutBucketVersioning(Suspended)
|
|
225
|
+
- その他: no-op (None と Suspended は "未有効化" として等価)
|
|
226
|
+
|
|
227
|
+
S3 仕様上、一度 Enabled にしたバケットは Suspended にしか戻せない
|
|
228
|
+
(バージョン情報は保持される)。docs に明記する。
|
|
229
|
+
"""
|
|
230
|
+
current = self.current_versioning_status
|
|
231
|
+
if self.context.versioning:
|
|
232
|
+
if current == "Enabled":
|
|
233
|
+
return
|
|
234
|
+
self.client.put_bucket_versioning(
|
|
235
|
+
Bucket=self.context.bucket_name,
|
|
236
|
+
VersioningConfiguration={"Status": "Enabled"},
|
|
237
|
+
)
|
|
238
|
+
self.__dict__.pop("current_versioning_status", None)
|
|
239
|
+
echo.info("Versioning を有効化しました")
|
|
240
|
+
return
|
|
241
|
+
if current != "Enabled":
|
|
242
|
+
return
|
|
243
|
+
self.client.put_bucket_versioning(
|
|
244
|
+
Bucket=self.context.bucket_name,
|
|
245
|
+
VersioningConfiguration={"Status": "Suspended"},
|
|
246
|
+
)
|
|
247
|
+
self.__dict__.pop("current_versioning_status", None)
|
|
248
|
+
echo.info("Versioning を Suspended にしました")
|
|
249
|
+
|
|
250
|
+
def _desired_lifecycle_rules(self) -> list[dict] | None:
|
|
251
|
+
"""期待する Lifecycle ルールを返す。空のとき None (= ルールなし)。"""
|
|
252
|
+
if not self.context.lifecycle_rules:
|
|
253
|
+
return None
|
|
254
|
+
rules: list[dict] = []
|
|
255
|
+
for rule in self.context.lifecycle_rules:
|
|
256
|
+
rules.append(
|
|
257
|
+
{
|
|
258
|
+
"ID": rule.id,
|
|
259
|
+
"Status": "Enabled",
|
|
260
|
+
"Filter": {"Prefix": rule.prefix},
|
|
261
|
+
"NoncurrentVersionExpiration": {
|
|
262
|
+
"NoncurrentDays": rule.noncurrent_version_expiration_days,
|
|
263
|
+
},
|
|
264
|
+
}
|
|
265
|
+
)
|
|
266
|
+
return rules
|
|
267
|
+
|
|
268
|
+
@cached_property
|
|
269
|
+
def current_lifecycle_rules(self) -> list[dict] | None:
|
|
270
|
+
"""現在の S3 バケット Lifecycle ルール。未設定なら None。"""
|
|
271
|
+
try:
|
|
272
|
+
res = self.client.get_bucket_lifecycle_configuration(
|
|
273
|
+
Bucket=self.context.bucket_name,
|
|
274
|
+
)
|
|
275
|
+
except ClientError as e:
|
|
276
|
+
if e.response["Error"]["Code"] == "NoSuchLifecycleConfiguration":
|
|
277
|
+
return None
|
|
278
|
+
raise
|
|
279
|
+
return res.get("Rules")
|
|
280
|
+
|
|
281
|
+
@property
|
|
282
|
+
def lifecycle_require_update(self) -> bool:
|
|
283
|
+
# 宣言的: desired と current が一致しなければ drift
|
|
284
|
+
return self._desired_lifecycle_rules() != self.current_lifecycle_rules
|
|
285
|
+
|
|
286
|
+
def _ensure_lifecycle(self):
|
|
287
|
+
"""S3 バケットの Lifecycle 設定を宣言的に適用する。
|
|
288
|
+
|
|
289
|
+
- lifecycle_rules 宣言あり: PutBucketLifecycleConfiguration で置き換え
|
|
290
|
+
- lifecycle_rules 空 & 現状あり: DeleteBucketLifecycle
|
|
291
|
+
- lifecycle_rules 空 & 現状なし: 何もしない
|
|
292
|
+
"""
|
|
293
|
+
desired = self._desired_lifecycle_rules()
|
|
294
|
+
current = self.current_lifecycle_rules
|
|
295
|
+
if desired == current:
|
|
296
|
+
return
|
|
297
|
+
if desired is None:
|
|
298
|
+
self.client.delete_bucket_lifecycle(Bucket=self.context.bucket_name)
|
|
299
|
+
self.__dict__.pop("current_lifecycle_rules", None)
|
|
300
|
+
echo.info("Lifecycle 設定を削除しました")
|
|
301
|
+
return
|
|
302
|
+
self.client.put_bucket_lifecycle_configuration(
|
|
303
|
+
Bucket=self.context.bucket_name,
|
|
304
|
+
LifecycleConfiguration={"Rules": desired},
|
|
305
|
+
)
|
|
306
|
+
self.__dict__.pop("current_lifecycle_rules", None)
|
|
307
|
+
echo.info("Lifecycle ルールを適用しました: %d 件" % len(desired))
|
|
@@ -0,0 +1,298 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import logging
|
|
5
|
+
import os
|
|
6
|
+
import secrets
|
|
7
|
+
import ssl
|
|
8
|
+
import string
|
|
9
|
+
import time
|
|
10
|
+
from functools import cached_property
|
|
11
|
+
from typing import TYPE_CHECKING
|
|
12
|
+
|
|
13
|
+
import requests
|
|
14
|
+
from pydantic import BaseModel
|
|
15
|
+
from requests.auth import HTTPDigestAuth
|
|
16
|
+
|
|
17
|
+
from pocket.resources.base import ResourceStatus
|
|
18
|
+
|
|
19
|
+
if TYPE_CHECKING:
|
|
20
|
+
from pocket.context import TiDbContext
|
|
21
|
+
|
|
22
|
+
logging.basicConfig()
|
|
23
|
+
logger = logging.getLogger(__name__)
|
|
24
|
+
logger.setLevel(level=os.getenv("POCKET_LOGGER_LEVEL", "WARNING").upper())
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class TiDbResourceIsNotReady(Exception):
|
|
28
|
+
pass
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class Project(BaseModel):
|
|
32
|
+
id: str
|
|
33
|
+
name: str
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class Cluster(BaseModel):
|
|
37
|
+
id: str
|
|
38
|
+
name: str
|
|
39
|
+
status: str
|
|
40
|
+
host: str
|
|
41
|
+
port: int
|
|
42
|
+
user: str
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class TiDbApi:
|
|
46
|
+
serverless_endpoint = "https://serverless.tidbapi.com/v1beta1/"
|
|
47
|
+
iam_endpoint = "https://iam.tidbapi.com/v1beta1/"
|
|
48
|
+
|
|
49
|
+
def __init__(self, public_key: str, private_key: str) -> None:
|
|
50
|
+
self.auth = HTTPDigestAuth(public_key, private_key)
|
|
51
|
+
|
|
52
|
+
def _request(self, method: str, url: str, data=None):
|
|
53
|
+
logger.info("%s %s" % (method, url))
|
|
54
|
+
if data:
|
|
55
|
+
logger.debug(json.dumps(data, indent=2))
|
|
56
|
+
res = requests.request(method, url, auth=self.auth, json=data)
|
|
57
|
+
logger.debug(res.status_code)
|
|
58
|
+
logger.debug(json.dumps(res.json(), indent=2))
|
|
59
|
+
if 200 <= res.status_code < 300:
|
|
60
|
+
if method != "GET":
|
|
61
|
+
time.sleep(2)
|
|
62
|
+
return res
|
|
63
|
+
raise RuntimeError("%s: %s" % (res.status_code, res.text))
|
|
64
|
+
|
|
65
|
+
def iam_get(self, path: str):
|
|
66
|
+
return self._request("GET", self.iam_endpoint + path)
|
|
67
|
+
|
|
68
|
+
def serverless_get(self, path: str):
|
|
69
|
+
return self._request("GET", self.serverless_endpoint + path)
|
|
70
|
+
|
|
71
|
+
def serverless_post(self, path: str, data=None):
|
|
72
|
+
return self._request("POST", self.serverless_endpoint + path, data)
|
|
73
|
+
|
|
74
|
+
def serverless_put(self, path: str, data=None):
|
|
75
|
+
return self._request("PUT", self.serverless_endpoint + path, data)
|
|
76
|
+
|
|
77
|
+
def serverless_delete(self, path: str):
|
|
78
|
+
return self._request("DELETE", self.serverless_endpoint + path)
|
|
79
|
+
|
|
80
|
+
def list_projects(self) -> list[dict]:
|
|
81
|
+
return self.iam_get("projects").json().get("projects", [])
|
|
82
|
+
|
|
83
|
+
def list_clusters(self, project_id: str) -> list[dict]:
|
|
84
|
+
return (
|
|
85
|
+
self.serverless_get("clusters?filter=projectId=%s" % project_id)
|
|
86
|
+
.json()
|
|
87
|
+
.get("clusters", [])
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
def create_cluster(self, data: dict) -> dict:
|
|
91
|
+
return self.serverless_post("clusters", data).json()
|
|
92
|
+
|
|
93
|
+
def get_cluster(self, cluster_id: str) -> dict:
|
|
94
|
+
return self.serverless_get("clusters/%s" % cluster_id).json()
|
|
95
|
+
|
|
96
|
+
def delete_cluster(self, cluster_id: str):
|
|
97
|
+
return self.serverless_delete("clusters/%s" % cluster_id)
|
|
98
|
+
|
|
99
|
+
def change_password(self, cluster_id: str, password: str):
|
|
100
|
+
return self.serverless_put(
|
|
101
|
+
"clusters/%s/password" % cluster_id,
|
|
102
|
+
{"password": password},
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
class TiDb:
|
|
107
|
+
context: TiDbContext
|
|
108
|
+
_root_password: str | None
|
|
109
|
+
|
|
110
|
+
def __init__(self, context: TiDbContext) -> None:
|
|
111
|
+
self.context = context
|
|
112
|
+
self._root_password = None
|
|
113
|
+
|
|
114
|
+
@cached_property
|
|
115
|
+
def api(self) -> TiDbApi:
|
|
116
|
+
if not self.context.public_key or not self.context.private_key:
|
|
117
|
+
raise TiDbResourceIsNotReady(
|
|
118
|
+
"TiDB API keys not configured. "
|
|
119
|
+
"Set tidb_public_key and tidb_private_key in .env"
|
|
120
|
+
)
|
|
121
|
+
return TiDbApi(self.context.public_key, self.context.private_key)
|
|
122
|
+
|
|
123
|
+
@cached_property
|
|
124
|
+
def project(self) -> Project | None:
|
|
125
|
+
projects = self.api.list_projects()
|
|
126
|
+
if self.context.tidb_project:
|
|
127
|
+
for p in projects:
|
|
128
|
+
if p["name"] == self.context.tidb_project:
|
|
129
|
+
return Project(id=str(p["id"]), name=p["name"])
|
|
130
|
+
return None
|
|
131
|
+
if len(projects) == 1:
|
|
132
|
+
p = projects[0]
|
|
133
|
+
return Project(id=str(p["id"]), name=p["name"])
|
|
134
|
+
if len(projects) > 1:
|
|
135
|
+
names = [p["name"] for p in projects]
|
|
136
|
+
raise TiDbResourceIsNotReady(
|
|
137
|
+
"複数の TiDB Cloud プロジェクトが見つかりました: %s\n"
|
|
138
|
+
"pocket.toml の [tidb] で project を指定してください。\n"
|
|
139
|
+
'例: project = "%s"' % (names, names[0])
|
|
140
|
+
)
|
|
141
|
+
return None
|
|
142
|
+
|
|
143
|
+
@cached_property
|
|
144
|
+
def cluster(self) -> Cluster | None:
|
|
145
|
+
if not self.project:
|
|
146
|
+
return None
|
|
147
|
+
for c in self.api.list_clusters(self.project.id):
|
|
148
|
+
if c.get("displayName") == self.context.cluster_name:
|
|
149
|
+
endpoints = c.get("endpoints", {}).get("public", {})
|
|
150
|
+
user_prefix = c.get("userPrefix", "")
|
|
151
|
+
return Cluster(
|
|
152
|
+
id=str(c["clusterId"]),
|
|
153
|
+
name=c["displayName"],
|
|
154
|
+
status=c.get("state", "UNKNOWN"),
|
|
155
|
+
host=endpoints.get("host", ""),
|
|
156
|
+
port=int(endpoints.get("port", 4000)),
|
|
157
|
+
user="%s.root" % user_prefix,
|
|
158
|
+
)
|
|
159
|
+
return None
|
|
160
|
+
|
|
161
|
+
@property
|
|
162
|
+
def status(self) -> ResourceStatus:
|
|
163
|
+
if self.context.skip_check_existing:
|
|
164
|
+
# 存在確認の TiDB API call を skip し COMPLETED 固定 (settings 参照)。
|
|
165
|
+
return "COMPLETED"
|
|
166
|
+
if not self.context.public_key or not self.context.private_key:
|
|
167
|
+
return "NOEXIST"
|
|
168
|
+
if self.cluster and self.cluster.status == "ACTIVE":
|
|
169
|
+
return "COMPLETED"
|
|
170
|
+
return "NOEXIST"
|
|
171
|
+
|
|
172
|
+
@property
|
|
173
|
+
def description(self) -> str:
|
|
174
|
+
return "Create TiDB cluster and database"
|
|
175
|
+
|
|
176
|
+
@property
|
|
177
|
+
def database_url(self) -> str:
|
|
178
|
+
if not self.cluster:
|
|
179
|
+
raise TiDbResourceIsNotReady("Cluster not found")
|
|
180
|
+
if self.cluster.status != "ACTIVE":
|
|
181
|
+
raise TiDbResourceIsNotReady("Cluster is not available")
|
|
182
|
+
password = self._root_password
|
|
183
|
+
if not password:
|
|
184
|
+
password = self._reset_password()
|
|
185
|
+
return "mysql://%s:%s@%s:%d/%s" % (
|
|
186
|
+
self.cluster.user,
|
|
187
|
+
password,
|
|
188
|
+
self.cluster.host,
|
|
189
|
+
self.cluster.port,
|
|
190
|
+
self.context.database_name,
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
def _generate_password(self) -> str:
|
|
194
|
+
chars = string.ascii_letters + string.digits
|
|
195
|
+
return "".join(secrets.choice(chars) for _ in range(24))
|
|
196
|
+
|
|
197
|
+
def _reset_password(self) -> str:
|
|
198
|
+
if not self.cluster:
|
|
199
|
+
raise TiDbResourceIsNotReady("Cluster not found")
|
|
200
|
+
new_password = self._generate_password()
|
|
201
|
+
self.api.change_password(self.cluster.id, new_password)
|
|
202
|
+
self._root_password = new_password
|
|
203
|
+
return new_password
|
|
204
|
+
|
|
205
|
+
def deploy_init(self):
|
|
206
|
+
pass
|
|
207
|
+
|
|
208
|
+
def create(self):
|
|
209
|
+
self._ensure_cluster()
|
|
210
|
+
self._ensure_database()
|
|
211
|
+
|
|
212
|
+
def _ensure_cluster(self):
|
|
213
|
+
if self.cluster:
|
|
214
|
+
return
|
|
215
|
+
if not self.project:
|
|
216
|
+
raise TiDbResourceIsNotReady(
|
|
217
|
+
"No TiDB Cloud project found. Create one at https://tidbcloud.com/"
|
|
218
|
+
)
|
|
219
|
+
password = self._generate_password()
|
|
220
|
+
data = {
|
|
221
|
+
"displayName": self.context.cluster_name,
|
|
222
|
+
"region": {"name": "regions/aws-%s" % self.context.region},
|
|
223
|
+
"rootPassword": password,
|
|
224
|
+
"labels": {"tidb.cloud/project": self.project.id},
|
|
225
|
+
"endpoints": {
|
|
226
|
+
"public": {
|
|
227
|
+
"authorizedNetworks": [
|
|
228
|
+
{
|
|
229
|
+
"startIpAddress": "0.0.0.0",
|
|
230
|
+
"endIpAddress": "255.255.255.255",
|
|
231
|
+
"displayName": "Allow all",
|
|
232
|
+
}
|
|
233
|
+
]
|
|
234
|
+
}
|
|
235
|
+
},
|
|
236
|
+
}
|
|
237
|
+
self.api.create_cluster(data)
|
|
238
|
+
self._root_password = password
|
|
239
|
+
self._wait_for_cluster()
|
|
240
|
+
|
|
241
|
+
def _wait_for_cluster(self):
|
|
242
|
+
logger.info("Waiting for cluster to become available...")
|
|
243
|
+
del self.cluster
|
|
244
|
+
for _ in range(60):
|
|
245
|
+
if self.cluster and self.cluster.status == "ACTIVE":
|
|
246
|
+
return
|
|
247
|
+
time.sleep(5)
|
|
248
|
+
del self.cluster
|
|
249
|
+
raise TiDbResourceIsNotReady("Cluster did not become available in time")
|
|
250
|
+
|
|
251
|
+
def _get_sql_connection(self):
|
|
252
|
+
try:
|
|
253
|
+
import pymysql # noqa: PLC0415
|
|
254
|
+
except ModuleNotFoundError:
|
|
255
|
+
raise ModuleNotFoundError(
|
|
256
|
+
"pymysql is required for TiDB database management.\n"
|
|
257
|
+
"Install it with: uv add pymysql"
|
|
258
|
+
) from None
|
|
259
|
+
if not self.cluster:
|
|
260
|
+
raise TiDbResourceIsNotReady("Cluster not found")
|
|
261
|
+
password = self._root_password
|
|
262
|
+
if not password:
|
|
263
|
+
password = self._reset_password()
|
|
264
|
+
ssl_ctx = ssl.create_default_context()
|
|
265
|
+
return pymysql.connect(
|
|
266
|
+
host=self.cluster.host,
|
|
267
|
+
port=self.cluster.port,
|
|
268
|
+
user=self.cluster.user,
|
|
269
|
+
password=password,
|
|
270
|
+
ssl=ssl_ctx,
|
|
271
|
+
)
|
|
272
|
+
|
|
273
|
+
def _ensure_database(self):
|
|
274
|
+
conn = self._get_sql_connection()
|
|
275
|
+
try:
|
|
276
|
+
with conn.cursor() as cursor:
|
|
277
|
+
cursor.execute(
|
|
278
|
+
"CREATE DATABASE IF NOT EXISTS `%s`" % self.context.database_name
|
|
279
|
+
)
|
|
280
|
+
finally:
|
|
281
|
+
conn.close()
|
|
282
|
+
|
|
283
|
+
def delete_cluster(self):
|
|
284
|
+
if not self.cluster:
|
|
285
|
+
raise TiDbResourceIsNotReady("Cluster not found")
|
|
286
|
+
self.api.delete_cluster(self.cluster.id)
|
|
287
|
+
del self.cluster
|
|
288
|
+
|
|
289
|
+
def reset_database(self):
|
|
290
|
+
conn = self._get_sql_connection()
|
|
291
|
+
try:
|
|
292
|
+
with conn.cursor() as cursor:
|
|
293
|
+
cursor.execute(
|
|
294
|
+
"DROP DATABASE IF EXISTS `%s`" % self.context.database_name
|
|
295
|
+
)
|
|
296
|
+
cursor.execute("CREATE DATABASE `%s`" % self.context.database_name)
|
|
297
|
+
finally:
|
|
298
|
+
conn.close()
|