mlflow-tcdeploy-plugin 1.0.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlflow_tcdeploy_plugin-1.0.0/LICENSE.txt +21 -0
- mlflow_tcdeploy_plugin-1.0.0/PKG-INFO +96 -0
- mlflow_tcdeploy_plugin-1.0.0/README.md +65 -0
- mlflow_tcdeploy_plugin-1.0.0/mlflow_tcdeploy_plugin/__init__.py +3 -0
- mlflow_tcdeploy_plugin-1.0.0/mlflow_tcdeploy_plugin/api_client.py +129 -0
- mlflow_tcdeploy_plugin-1.0.0/mlflow_tcdeploy_plugin/client.py +181 -0
- mlflow_tcdeploy_plugin-1.0.0/mlflow_tcdeploy_plugin/config.py +45 -0
- mlflow_tcdeploy_plugin-1.0.0/mlflow_tcdeploy_plugin/models.py +89 -0
- mlflow_tcdeploy_plugin-1.0.0/mlflow_tcdeploy_plugin.egg-info/PKG-INFO +96 -0
- mlflow_tcdeploy_plugin-1.0.0/mlflow_tcdeploy_plugin.egg-info/SOURCES.txt +19 -0
- mlflow_tcdeploy_plugin-1.0.0/mlflow_tcdeploy_plugin.egg-info/dependency_links.txt +1 -0
- mlflow_tcdeploy_plugin-1.0.0/mlflow_tcdeploy_plugin.egg-info/entry_points.txt +2 -0
- mlflow_tcdeploy_plugin-1.0.0/mlflow_tcdeploy_plugin.egg-info/requires.txt +6 -0
- mlflow_tcdeploy_plugin-1.0.0/mlflow_tcdeploy_plugin.egg-info/top_level.txt +2 -0
- mlflow_tcdeploy_plugin-1.0.0/setup.cfg +4 -0
- mlflow_tcdeploy_plugin-1.0.0/setup.py +42 -0
- mlflow_tcdeploy_plugin-1.0.0/tests/__init__.py +0 -0
- mlflow_tcdeploy_plugin-1.0.0/tests/test_api_client.py +183 -0
- mlflow_tcdeploy_plugin-1.0.0/tests/test_client.py +306 -0
- mlflow_tcdeploy_plugin-1.0.0/tests/test_config.py +66 -0
- mlflow_tcdeploy_plugin-1.0.0/tests/test_models.py +83 -0
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) [2025] [tencent]
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
|
@@ -0,0 +1,96 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: mlflow-tcdeploy-plugin
|
|
3
|
+
Version: 1.0.0
|
|
4
|
+
Summary: Tencent Cloud deployment plugin for MLflow
|
|
5
|
+
Home-page: https://git.woa.com/WeDataOS/wedata3-monorepo
|
|
6
|
+
Author: Tencent WeData Team
|
|
7
|
+
Classifier: Programming Language :: Python :: 3
|
|
8
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
9
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
10
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
11
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
12
|
+
Classifier: Operating System :: OS Independent
|
|
13
|
+
Requires-Python: >=3.10
|
|
14
|
+
Description-Content-Type: text/markdown
|
|
15
|
+
License-File: LICENSE.txt
|
|
16
|
+
Requires-Dist: mlflow<3.11.0,>=3.10.0
|
|
17
|
+
Requires-Dist: tencentcloud-sdk-python>=3.0.1478
|
|
18
|
+
Provides-Extra: dev
|
|
19
|
+
Requires-Dist: pytest; extra == "dev"
|
|
20
|
+
Requires-Dist: python-dotenv; extra == "dev"
|
|
21
|
+
Dynamic: author
|
|
22
|
+
Dynamic: classifier
|
|
23
|
+
Dynamic: description
|
|
24
|
+
Dynamic: description-content-type
|
|
25
|
+
Dynamic: home-page
|
|
26
|
+
Dynamic: license-file
|
|
27
|
+
Dynamic: provides-extra
|
|
28
|
+
Dynamic: requires-dist
|
|
29
|
+
Dynamic: requires-python
|
|
30
|
+
Dynamic: summary
|
|
31
|
+
|
|
32
|
+
# mlflow-tcdeploy-plugin
|
|
33
|
+
|
|
34
|
+
MLflow deployment plugin,将腾讯云模型服务对接为 MLflow Deployment 后端。
|
|
35
|
+
|
|
36
|
+
与 `mlflow-tclake-plugin`(Model Registry)配套使用:
|
|
37
|
+
- `tclake`:负责模型存储 → TCLake
|
|
38
|
+
- `tcdeploy`:负责模型部署 → 腾讯云在线推理服务
|
|
39
|
+
|
|
40
|
+
## 安装
|
|
41
|
+
|
|
42
|
+
```bash
|
|
43
|
+
pip install mlflow-tcdeploy-plugin
|
|
44
|
+
```
|
|
45
|
+
|
|
46
|
+
## 环境变量(来自 wedata-pre-execute)
|
|
47
|
+
|
|
48
|
+
| 变量 | 说明 | 必填 |
|
|
49
|
+
|------|------|------|
|
|
50
|
+
| `KERNEL_WEDATA_CLOUD_SDK_SECRET_ID` | 腾讯云 SecretId | 是 |
|
|
51
|
+
| `KERNEL_WEDATA_CLOUD_SDK_SECRET_KEY` | 腾讯云 SecretKey | 是 |
|
|
52
|
+
| `KERNEL_WEDATA_REGION` | 地域,如 `ap-guangzhou` | 否(默认 ap-guangzhou) |
|
|
53
|
+
| `WEDATA_WORKSPACE_ID` | WeData 工作空间 ID | 是 |
|
|
54
|
+
| `TENCENTCLOUD_ENDPOINT` | WeData API endpoint | 否(默认 wedata.internal.tencentcloudapi.com) |
|
|
55
|
+
| `KERNEL_WEDATA_CLOUD_SDK_SECRET_TOKEN` | 临时凭证 Token | 否 |
|
|
56
|
+
|
|
57
|
+
## 使用
|
|
58
|
+
|
|
59
|
+
```python
|
|
60
|
+
from mlflow.deployments import get_deploy_client
|
|
61
|
+
|
|
62
|
+
client = get_deploy_client("tcdeploy")
|
|
63
|
+
|
|
64
|
+
# 查询可用规格
|
|
65
|
+
specs = client.list_instance_types(spec_type="CPU")
|
|
66
|
+
available = [s for s in specs if s["available"]]
|
|
67
|
+
print(available[0]["spec_name"]) # e.g. "TI.SA5.2XLARGE32.POST"
|
|
68
|
+
|
|
69
|
+
# 创建部署
|
|
70
|
+
client.create_deployment(
|
|
71
|
+
name="my-service",
|
|
72
|
+
model_uri="models:/MyModel/1",
|
|
73
|
+
config={"instance_type": "TI.SA5.2XLARGE32.POST"},
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
# 获取详情
|
|
77
|
+
info = client.get_deployment(name="svc-id-xxxx")
|
|
78
|
+
|
|
79
|
+
# 更新
|
|
80
|
+
client.update_deployment(name="svc-id-xxxx", config={"replicas": 2})
|
|
81
|
+
|
|
82
|
+
# 删除(幂等)
|
|
83
|
+
client.delete_deployment(name="svc-id-xxxx")
|
|
84
|
+
|
|
85
|
+
# 列举服务组
|
|
86
|
+
groups = client.list_deployments()
|
|
87
|
+
# 按服务组过滤
|
|
88
|
+
groups = client.list_deployments(endpoint="grp-001")
|
|
89
|
+
```
|
|
90
|
+
|
|
91
|
+
## 测试
|
|
92
|
+
|
|
93
|
+
```bash
|
|
94
|
+
pytest tests/ # 单元测试
|
|
95
|
+
pytest tests/integration/ # 集成测试(需要真实环境变量)
|
|
96
|
+
```
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
# mlflow-tcdeploy-plugin
|
|
2
|
+
|
|
3
|
+
MLflow deployment plugin,将腾讯云模型服务对接为 MLflow Deployment 后端。
|
|
4
|
+
|
|
5
|
+
与 `mlflow-tclake-plugin`(Model Registry)配套使用:
|
|
6
|
+
- `tclake`:负责模型存储 → TCLake
|
|
7
|
+
- `tcdeploy`:负责模型部署 → 腾讯云在线推理服务
|
|
8
|
+
|
|
9
|
+
## 安装
|
|
10
|
+
|
|
11
|
+
```bash
|
|
12
|
+
pip install mlflow-tcdeploy-plugin
|
|
13
|
+
```
|
|
14
|
+
|
|
15
|
+
## 环境变量(来自 wedata-pre-execute)
|
|
16
|
+
|
|
17
|
+
| 变量 | 说明 | 必填 |
|
|
18
|
+
|------|------|------|
|
|
19
|
+
| `KERNEL_WEDATA_CLOUD_SDK_SECRET_ID` | 腾讯云 SecretId | 是 |
|
|
20
|
+
| `KERNEL_WEDATA_CLOUD_SDK_SECRET_KEY` | 腾讯云 SecretKey | 是 |
|
|
21
|
+
| `KERNEL_WEDATA_REGION` | 地域,如 `ap-guangzhou` | 否(默认 ap-guangzhou) |
|
|
22
|
+
| `WEDATA_WORKSPACE_ID` | WeData 工作空间 ID | 是 |
|
|
23
|
+
| `TENCENTCLOUD_ENDPOINT` | WeData API endpoint | 否(默认 wedata.internal.tencentcloudapi.com) |
|
|
24
|
+
| `KERNEL_WEDATA_CLOUD_SDK_SECRET_TOKEN` | 临时凭证 Token | 否 |
|
|
25
|
+
|
|
26
|
+
## 使用
|
|
27
|
+
|
|
28
|
+
```python
|
|
29
|
+
from mlflow.deployments import get_deploy_client
|
|
30
|
+
|
|
31
|
+
client = get_deploy_client("tcdeploy")
|
|
32
|
+
|
|
33
|
+
# 查询可用规格
|
|
34
|
+
specs = client.list_instance_types(spec_type="CPU")
|
|
35
|
+
available = [s for s in specs if s["available"]]
|
|
36
|
+
print(available[0]["spec_name"]) # e.g. "TI.SA5.2XLARGE32.POST"
|
|
37
|
+
|
|
38
|
+
# 创建部署
|
|
39
|
+
client.create_deployment(
|
|
40
|
+
name="my-service",
|
|
41
|
+
model_uri="models:/MyModel/1",
|
|
42
|
+
config={"instance_type": "TI.SA5.2XLARGE32.POST"},
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
# 获取详情
|
|
46
|
+
info = client.get_deployment(name="svc-id-xxxx")
|
|
47
|
+
|
|
48
|
+
# 更新
|
|
49
|
+
client.update_deployment(name="svc-id-xxxx", config={"replicas": 2})
|
|
50
|
+
|
|
51
|
+
# 删除(幂等)
|
|
52
|
+
client.delete_deployment(name="svc-id-xxxx")
|
|
53
|
+
|
|
54
|
+
# 列举服务组
|
|
55
|
+
groups = client.list_deployments()
|
|
56
|
+
# 按服务组过滤
|
|
57
|
+
groups = client.list_deployments(endpoint="grp-001")
|
|
58
|
+
```
|
|
59
|
+
|
|
60
|
+
## 测试
|
|
61
|
+
|
|
62
|
+
```bash
|
|
63
|
+
pytest tests/ # 单元测试
|
|
64
|
+
pytest tests/integration/ # 集成测试(需要真实环境变量)
|
|
65
|
+
```
|
|
@@ -0,0 +1,129 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
from tencentcloud.common import credential
|
|
4
|
+
from tencentcloud.common.abstract_client import AbstractClient
|
|
5
|
+
from tencentcloud.common.profile.client_profile import ClientProfile
|
|
6
|
+
from tencentcloud.common.exception.tencent_cloud_sdk_exception import TencentCloudSDKException
|
|
7
|
+
from mlflow.exceptions import MlflowException
|
|
8
|
+
|
|
9
|
+
from mlflow_tcdeploy_plugin.config import TcDeployConfig
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class WedataDeployClient(AbstractClient):
|
|
13
|
+
"""WeData deployment API client, inherits AbstractClient for TC3 signing."""
|
|
14
|
+
_apiVersion = '2025-10-10'
|
|
15
|
+
_endpoint = 'wedata.internal.tencentcloudapi.com'
|
|
16
|
+
_service = 'wedata'
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class ModelServiceApiClient:
|
|
20
|
+
"""Wraps WeData model service cloud API calls."""
|
|
21
|
+
|
|
22
|
+
def __init__(self, config: TcDeployConfig):
|
|
23
|
+
self._config = config
|
|
24
|
+
cred = credential.Credential(
|
|
25
|
+
config.secret_id,
|
|
26
|
+
config.secret_key,
|
|
27
|
+
config.secret_token,
|
|
28
|
+
)
|
|
29
|
+
client_profile = ClientProfile()
|
|
30
|
+
client_profile.httpProfile.endpoint = config.endpoint
|
|
31
|
+
|
|
32
|
+
self._wedata_client = WedataDeployClient(cred, config.region, client_profile)
|
|
33
|
+
|
|
34
|
+
def _call(self, action: str, body: dict) -> dict:
|
|
35
|
+
try:
|
|
36
|
+
return self._wedata_client.call_json(action, body)
|
|
37
|
+
except TencentCloudSDKException as e:
|
|
38
|
+
raise MlflowException(f"{e.code}: {e.message}") from e
|
|
39
|
+
|
|
40
|
+
def _base_body(self) -> dict:
|
|
41
|
+
"""Common fields required in all requests."""
|
|
42
|
+
return {"WorkspaceId": self._config.workspace_id}
|
|
43
|
+
|
|
44
|
+
def create_deployment(self, body: dict) -> dict:
|
|
45
|
+
req = {**self._base_body(), **body}
|
|
46
|
+
return self._call("CreateMLModelServices", req)
|
|
47
|
+
|
|
48
|
+
def update_deployment(self, service_id: str, body: dict) -> None:
|
|
49
|
+
req = {**self._base_body(), "ServiceId": service_id, **body}
|
|
50
|
+
self._call("UpdateMLModelService", req)
|
|
51
|
+
|
|
52
|
+
def delete_deployment(self, service_id: str) -> None:
|
|
53
|
+
req = {**self._base_body(), "ServiceId": service_id}
|
|
54
|
+
try:
|
|
55
|
+
self._call("DeleteMLModelService", req)
|
|
56
|
+
except MlflowException as e:
|
|
57
|
+
if "ResourceNotFound" in str(e):
|
|
58
|
+
return # idempotent
|
|
59
|
+
raise
|
|
60
|
+
|
|
61
|
+
def get_deployment(self, service_id: str) -> dict:
|
|
62
|
+
req = {**self._base_body(), "ServiceId": service_id}
|
|
63
|
+
return self._call("GetMLModelService", req)
|
|
64
|
+
|
|
65
|
+
def list_deployments(self, endpoint: Optional[str] = None) -> list:
|
|
66
|
+
req = {**self._base_body()}
|
|
67
|
+
if endpoint:
|
|
68
|
+
req["ServiceGroupId"] = endpoint
|
|
69
|
+
resp = self._call("ListMLModelServiceGroups", req)
|
|
70
|
+
return resp.get("Groups", [])
|
|
71
|
+
|
|
72
|
+
def predict(
|
|
73
|
+
self,
|
|
74
|
+
service_group_id: str,
|
|
75
|
+
curl_data: str,
|
|
76
|
+
relative_url=None,
|
|
77
|
+
auth_token_value=None,
|
|
78
|
+
) -> dict:
|
|
79
|
+
req = {
|
|
80
|
+
**self._base_body(),
|
|
81
|
+
"ServiceGroupId": service_group_id,
|
|
82
|
+
"CurlData": curl_data,
|
|
83
|
+
}
|
|
84
|
+
if relative_url is not None:
|
|
85
|
+
req["RelativeUrl"] = relative_url
|
|
86
|
+
if auth_token_value is not None:
|
|
87
|
+
req["AuthTokenValue"] = auth_token_value
|
|
88
|
+
return self._call("ModelServiceInterfaceCallTest", req)
|
|
89
|
+
|
|
90
|
+
def debug_pod_shell(self, service_id: str, pod_name: str) -> dict:
|
|
91
|
+
req = {
|
|
92
|
+
**self._base_body(),
|
|
93
|
+
"ServiceId": service_id,
|
|
94
|
+
"PodName": pod_name,
|
|
95
|
+
}
|
|
96
|
+
return self._call("CreateModelServicePodUrl", req)
|
|
97
|
+
|
|
98
|
+
def restart_pod(self, service_id: str, pod_name: str) -> dict:
|
|
99
|
+
req = {
|
|
100
|
+
**self._base_body(),
|
|
101
|
+
"ServiceId": service_id,
|
|
102
|
+
"PodName": pod_name,
|
|
103
|
+
}
|
|
104
|
+
return self._call("RebuildModelServicePod", req)
|
|
105
|
+
|
|
106
|
+
def get_pod_logs(
|
|
107
|
+
self,
|
|
108
|
+
service_id: str,
|
|
109
|
+
pod_name=None,
|
|
110
|
+
limit=None,
|
|
111
|
+
start_time=None,
|
|
112
|
+
end_time=None,
|
|
113
|
+
context=None,
|
|
114
|
+
) -> dict:
|
|
115
|
+
req = {
|
|
116
|
+
**self._base_body(),
|
|
117
|
+
"ServiceId": service_id,
|
|
118
|
+
}
|
|
119
|
+
if pod_name is not None:
|
|
120
|
+
req["PodName"] = pod_name
|
|
121
|
+
if limit is not None:
|
|
122
|
+
req["Limit"] = limit
|
|
123
|
+
if start_time is not None:
|
|
124
|
+
req["StartTime"] = start_time
|
|
125
|
+
if end_time is not None:
|
|
126
|
+
req["EndTime"] = end_time
|
|
127
|
+
if context is not None:
|
|
128
|
+
req["Context"] = context
|
|
129
|
+
return self._call("ListMLServiceLogs", req)
|
|
@@ -0,0 +1,181 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from typing import List, Optional
|
|
3
|
+
|
|
4
|
+
from mlflow.deployments import BaseDeploymentClient
|
|
5
|
+
from tencentcloud.common import credential
|
|
6
|
+
from tencentcloud.tione.v20211111 import tione_client
|
|
7
|
+
|
|
8
|
+
from mlflow_tcdeploy_plugin.api_client import ModelServiceApiClient
|
|
9
|
+
from mlflow_tcdeploy_plugin.config import TcDeployConfig
|
|
10
|
+
from mlflow_tcdeploy_plugin.models import build_create_request, snake_to_camel_dict, to_snake_case_dict
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class TcDeploymentClient(BaseDeploymentClient):
|
|
14
|
+
"""
|
|
15
|
+
MLflow deployment client, connecting Tencent Cloud model service as MLflow deployment backend.
|
|
16
|
+
|
|
17
|
+
Usage:
|
|
18
|
+
from mlflow.deployments import get_deploy_client
|
|
19
|
+
client = get_deploy_client("tcdeploy")
|
|
20
|
+
|
|
21
|
+
Required environment variables (from wedata-pre-execute):
|
|
22
|
+
KERNEL_WEDATA_CLOUD_SDK_SECRET_ID, KERNEL_WEDATA_CLOUD_SDK_SECRET_KEY,
|
|
23
|
+
KERNEL_WEDATA_REGION, WEDATA_WORKSPACE_ID
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
def __init__(self, target_uri: str):
|
|
27
|
+
super().__init__(target_uri)
|
|
28
|
+
self._config = TcDeployConfig()
|
|
29
|
+
self._api = ModelServiceApiClient(self._config)
|
|
30
|
+
|
|
31
|
+
def create_deployment(
|
|
32
|
+
self,
|
|
33
|
+
name: str,
|
|
34
|
+
model_uri: str,
|
|
35
|
+
flavor=None,
|
|
36
|
+
config: Optional[dict] = None,
|
|
37
|
+
endpoint: Optional[str] = None,
|
|
38
|
+
) -> dict:
|
|
39
|
+
"""
|
|
40
|
+
Create a model service.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
name: service name
|
|
44
|
+
model_uri: MLflow model URI, e.g. "models:/MyModel/1"
|
|
45
|
+
config: deployment config, must include instance_type.
|
|
46
|
+
Use list_instance_types() to see available options.
|
|
47
|
+
endpoint: service group ID (serviceGroupId), optional
|
|
48
|
+
"""
|
|
49
|
+
body = build_create_request(
|
|
50
|
+
name=name,
|
|
51
|
+
model_uri=model_uri,
|
|
52
|
+
config=config or {},
|
|
53
|
+
endpoint=endpoint,
|
|
54
|
+
)
|
|
55
|
+
raw = self._api.create_deployment(body)
|
|
56
|
+
return to_snake_case_dict(raw)
|
|
57
|
+
|
|
58
|
+
def update_deployment(
|
|
59
|
+
self,
|
|
60
|
+
name: str,
|
|
61
|
+
model_uri: Optional[str] = None,
|
|
62
|
+
flavor=None,
|
|
63
|
+
config: Optional[dict] = None,
|
|
64
|
+
endpoint: Optional[str] = None,
|
|
65
|
+
) -> None:
|
|
66
|
+
"""
|
|
67
|
+
Update a model service. name is serviceId. Only pass fields to change.
|
|
68
|
+
"""
|
|
69
|
+
self._api.update_deployment(name, snake_to_camel_dict(config or {}))
|
|
70
|
+
|
|
71
|
+
def delete_deployment(
|
|
72
|
+
self,
|
|
73
|
+
name: str,
|
|
74
|
+
config: Optional[dict] = None,
|
|
75
|
+
endpoint: Optional[str] = None,
|
|
76
|
+
) -> None:
|
|
77
|
+
"""
|
|
78
|
+
Delete a model service. name is serviceId. Idempotent.
|
|
79
|
+
"""
|
|
80
|
+
self._api.delete_deployment(name)
|
|
81
|
+
|
|
82
|
+
def get_deployment(self, name: str, endpoint: Optional[str] = None) -> dict:
|
|
83
|
+
"""
|
|
84
|
+
Get model service details. name is serviceId. Returns snake_case dict.
|
|
85
|
+
"""
|
|
86
|
+
raw = self._api.get_deployment(name)
|
|
87
|
+
return to_snake_case_dict(raw)
|
|
88
|
+
|
|
89
|
+
def list_deployments(self, endpoint: Optional[str] = None) -> List[dict]:
|
|
90
|
+
"""
|
|
91
|
+
List service groups. endpoint is serviceGroupId filter, optional.
|
|
92
|
+
Each item includes "name" field (serviceGroupId) as required by MLflow spec.
|
|
93
|
+
"""
|
|
94
|
+
raw_list = self._api.list_deployments(endpoint=endpoint)
|
|
95
|
+
result = []
|
|
96
|
+
for item in raw_list:
|
|
97
|
+
converted = to_snake_case_dict(item)
|
|
98
|
+
# MLflow spec requires each item to have "name" field
|
|
99
|
+
converted["name"] = item.get("ServiceGroupId", "")
|
|
100
|
+
result.append(converted)
|
|
101
|
+
return result
|
|
102
|
+
|
|
103
|
+
def list_instance_types(self, spec_type: Optional[str] = None) -> List[dict]:
|
|
104
|
+
"""
|
|
105
|
+
Query available TIone spec list. Calls TIone SDK directly, bypassing Java layer.
|
|
106
|
+
|
|
107
|
+
Args:
|
|
108
|
+
spec_type: "CPU" or "GPU", returns all if not specified
|
|
109
|
+
|
|
110
|
+
Returns:
|
|
111
|
+
list of dict with spec_name, spec_alias, spec_type, available, available_region, gpu_type
|
|
112
|
+
"""
|
|
113
|
+
cred = credential.Credential(
|
|
114
|
+
self._config.secret_id,
|
|
115
|
+
self._config.secret_key,
|
|
116
|
+
self._config.secret_token,
|
|
117
|
+
)
|
|
118
|
+
client = tione_client.TioneClient(cred, self._config.region)
|
|
119
|
+
|
|
120
|
+
# Access via module attribute so that unit-test patches on tione_client are effective
|
|
121
|
+
req = tione_client.models.DescribeInferenceSpecsRequest()
|
|
122
|
+
resp = client.DescribeInferenceSpecs(req)
|
|
123
|
+
|
|
124
|
+
specs = []
|
|
125
|
+
for s in resp.SpecInfos or []:
|
|
126
|
+
if spec_type and s.SpecType != spec_type:
|
|
127
|
+
continue
|
|
128
|
+
specs.append({
|
|
129
|
+
"spec_name": s.SpecName,
|
|
130
|
+
"spec_alias": s.SpecAlias,
|
|
131
|
+
"spec_type": s.SpecType,
|
|
132
|
+
"available": s.Available,
|
|
133
|
+
"available_region": list(s.AvailableRegion or []),
|
|
134
|
+
"gpu_type": s.GpuType or "",
|
|
135
|
+
})
|
|
136
|
+
return specs
|
|
137
|
+
|
|
138
|
+
def predict(
|
|
139
|
+
self,
|
|
140
|
+
deployment_name=None,
|
|
141
|
+
inputs=None,
|
|
142
|
+
endpoint=None,
|
|
143
|
+
config=None,
|
|
144
|
+
):
|
|
145
|
+
try:
|
|
146
|
+
curl_data = json.dumps(inputs)
|
|
147
|
+
except (TypeError, ValueError) as e:
|
|
148
|
+
raise ValueError(f"inputs must be JSON-serializable, got {type(inputs).__name__}: {e}") from e
|
|
149
|
+
|
|
150
|
+
raw = self._api.predict(
|
|
151
|
+
service_group_id=deployment_name,
|
|
152
|
+
curl_data=curl_data,
|
|
153
|
+
relative_url=endpoint or "/predict",
|
|
154
|
+
auth_token_value=(config or {}).get("auth_token"),
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
raw_str = raw.get("CurlResponseRaw", "")
|
|
158
|
+
try:
|
|
159
|
+
return json.loads(raw_str)
|
|
160
|
+
except (json.JSONDecodeError, TypeError):
|
|
161
|
+
return {"raw": raw_str}
|
|
162
|
+
|
|
163
|
+
def debug_pod_shell(self, service_id, pod_name):
|
|
164
|
+
raw = self._api.debug_pod_shell(service_id=service_id, pod_name=pod_name)
|
|
165
|
+
return to_snake_case_dict(raw)
|
|
166
|
+
|
|
167
|
+
def restart_pod(self, service_id, pod_name):
|
|
168
|
+
raw = self._api.restart_pod(service_id=service_id, pod_name=pod_name)
|
|
169
|
+
return {"request_id": raw.get("TiOneRebuildServiceRequestId")}
|
|
170
|
+
|
|
171
|
+
def get_pod_logs(self, service_id, pod_name=None, limit=None, start_time=None, end_time=None, context=None):
|
|
172
|
+
raw = self._api.get_pod_logs(
|
|
173
|
+
service_id=service_id,
|
|
174
|
+
pod_name=pod_name,
|
|
175
|
+
limit=limit,
|
|
176
|
+
start_time=start_time,
|
|
177
|
+
end_time=end_time,
|
|
178
|
+
context=context,
|
|
179
|
+
)
|
|
180
|
+
logs = [to_snake_case_dict(entry) for entry in (raw.get("Content") or [])]
|
|
181
|
+
return {"logs": logs, "context": raw.get("Context")}
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from dataclasses import dataclass, field
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def _require_env(key: str, override: Optional[str]) -> str:
|
|
7
|
+
value = override or os.getenv(key)
|
|
8
|
+
if not value:
|
|
9
|
+
raise ValueError(
|
|
10
|
+
f"Missing required configuration: '{key}'. "
|
|
11
|
+
f"Set the environment variable or pass it as a parameter."
|
|
12
|
+
)
|
|
13
|
+
return value
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclass
|
|
17
|
+
class TcDeployConfig:
|
|
18
|
+
secret_id: str = field(init=False)
|
|
19
|
+
secret_key: str = field(init=False)
|
|
20
|
+
secret_token: Optional[str] = field(init=False)
|
|
21
|
+
region: str = field(init=False)
|
|
22
|
+
workspace_id: str = field(init=False)
|
|
23
|
+
endpoint: str = field(init=False)
|
|
24
|
+
sub_account_uin: str = field(init=False)
|
|
25
|
+
owner_uin: str = field(init=False)
|
|
26
|
+
|
|
27
|
+
def __init__(
|
|
28
|
+
self,
|
|
29
|
+
secret_id: Optional[str] = None,
|
|
30
|
+
secret_key: Optional[str] = None,
|
|
31
|
+
secret_token: Optional[str] = None,
|
|
32
|
+
region: Optional[str] = None,
|
|
33
|
+
workspace_id: Optional[str] = None,
|
|
34
|
+
endpoint: Optional[str] = None,
|
|
35
|
+
):
|
|
36
|
+
self.secret_id = _require_env("KERNEL_WEDATA_CLOUD_SDK_SECRET_ID", secret_id)
|
|
37
|
+
self.secret_key = _require_env("KERNEL_WEDATA_CLOUD_SDK_SECRET_KEY", secret_key)
|
|
38
|
+
self.secret_token = secret_token or os.getenv("KERNEL_WEDATA_CLOUD_SDK_SECRET_TOKEN")
|
|
39
|
+
self.region = region or os.getenv("KERNEL_WEDATA_REGION", "ap-guangzhou")
|
|
40
|
+
self.workspace_id = _require_env("WEDATA_WORKSPACE_ID", workspace_id)
|
|
41
|
+
self.endpoint = endpoint or os.getenv(
|
|
42
|
+
"TENCENTCLOUD_ENDPOINT", "wedata.internal.tencentcloudapi.com"
|
|
43
|
+
)
|
|
44
|
+
self.sub_account_uin = os.getenv("KERNEL_LOGIN_UIN", "")
|
|
45
|
+
self.owner_uin = os.getenv("QCLOUD_UIN", "")
|
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
import re
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
4
|
+
DEFAULT_CONFIG = {
|
|
5
|
+
"scaleMode": "MANUAL",
|
|
6
|
+
"replicas": 1,
|
|
7
|
+
"chargeType": "POSTPAID_BY_HOUR",
|
|
8
|
+
"authorizationEnable": False,
|
|
9
|
+
"logEnable": False,
|
|
10
|
+
}
|
|
11
|
+
|
|
12
|
+
# snake_case key → camelCase key mapping (config uses snake_case, request uses camelCase)
|
|
13
|
+
_SNAKE_TO_CAMEL = {
|
|
14
|
+
"instance_type": "instanceType",
|
|
15
|
+
"scale_mode": "scaleMode",
|
|
16
|
+
"replicas": "replicas",
|
|
17
|
+
"charge_type": "chargeType",
|
|
18
|
+
"authorization_enable": "authorizationEnable",
|
|
19
|
+
"log_enable": "logEnable",
|
|
20
|
+
"service_group_id": "serviceGroupId",
|
|
21
|
+
}
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def _camel_to_snake(name: str) -> str:
|
|
25
|
+
s1 = re.sub(r"(.)([A-Z][a-z]+)", r"\1_\2", name)
|
|
26
|
+
return re.sub(r"([a-z0-9])([A-Z])", r"\1_\2", s1).lower()
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def to_snake_case_dict(data: dict) -> dict:
|
|
30
|
+
"""Recursively convert dict keys from camelCase to snake_case."""
|
|
31
|
+
result = {}
|
|
32
|
+
for k, v in data.items():
|
|
33
|
+
snake_key = _camel_to_snake(k)
|
|
34
|
+
if isinstance(v, dict):
|
|
35
|
+
result[snake_key] = to_snake_case_dict(v)
|
|
36
|
+
else:
|
|
37
|
+
result[snake_key] = v
|
|
38
|
+
return result
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def parse_model_uri(model_uri: str) -> dict:
|
|
42
|
+
"""
|
|
43
|
+
Parse MLflow model URI into mlModelInfo fields.
|
|
44
|
+
Supports:
|
|
45
|
+
- models:/ModelName/Version → name + version
|
|
46
|
+
- runs:/run_id/path → modelPath
|
|
47
|
+
"""
|
|
48
|
+
if model_uri.startswith("models:/"):
|
|
49
|
+
parts = model_uri[len("models:/"):].split("/")
|
|
50
|
+
if len(parts) < 2:
|
|
51
|
+
raise ValueError(f"Invalid models URI: {model_uri}")
|
|
52
|
+
return {"name": parts[0], "version": parts[1], "modelPath": None}
|
|
53
|
+
elif model_uri.startswith("runs:/"):
|
|
54
|
+
return {"name": None, "version": None, "modelPath": model_uri}
|
|
55
|
+
else:
|
|
56
|
+
raise ValueError(f"Unsupported model_uri scheme: {model_uri}")
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def snake_to_camel_dict(config: dict) -> dict:
|
|
60
|
+
"""Convert config dict keys from snake_case to camelCase using _SNAKE_TO_CAMEL mapping."""
|
|
61
|
+
return {_SNAKE_TO_CAMEL.get(k, k): v for k, v in config.items()}
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def build_create_request(
|
|
65
|
+
name: str,
|
|
66
|
+
model_uri: str,
|
|
67
|
+
config: dict,
|
|
68
|
+
endpoint: Optional[str] = None,
|
|
69
|
+
) -> dict:
|
|
70
|
+
"""Assemble create_deployment parameters into a cloud API request body.
|
|
71
|
+
|
|
72
|
+
Note: WorkspaceId is injected by ModelServiceApiClient._base_body(), not here.
|
|
73
|
+
"""
|
|
74
|
+
if not config.get("instance_type"):
|
|
75
|
+
raise ValueError(
|
|
76
|
+
"config must include 'instance_type'. "
|
|
77
|
+
"Use list_instance_types() to see available options."
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
camel_config = snake_to_camel_dict(config)
|
|
81
|
+
|
|
82
|
+
req = {**DEFAULT_CONFIG, **camel_config}
|
|
83
|
+
req["serviceName"] = name
|
|
84
|
+
req["mlModelInfo"] = parse_model_uri(model_uri)
|
|
85
|
+
|
|
86
|
+
if endpoint:
|
|
87
|
+
req["serviceGroupId"] = endpoint
|
|
88
|
+
|
|
89
|
+
return req
|