mlflow-tcdeploy-plugin 1.0.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) [2025] [tencent]
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,96 @@
1
+ Metadata-Version: 2.4
2
+ Name: mlflow-tcdeploy-plugin
3
+ Version: 1.0.0
4
+ Summary: Tencent Cloud deployment plugin for MLflow
5
+ Home-page: https://git.woa.com/WeDataOS/wedata3-monorepo
6
+ Author: Tencent WeData Team
7
+ Classifier: Programming Language :: Python :: 3
8
+ Classifier: Programming Language :: Python :: 3.10
9
+ Classifier: Programming Language :: Python :: 3.11
10
+ Classifier: Programming Language :: Python :: 3.12
11
+ Classifier: License :: OSI Approved :: MIT License
12
+ Classifier: Operating System :: OS Independent
13
+ Requires-Python: >=3.10
14
+ Description-Content-Type: text/markdown
15
+ License-File: LICENSE.txt
16
+ Requires-Dist: mlflow<3.11.0,>=3.10.0
17
+ Requires-Dist: tencentcloud-sdk-python>=3.0.1478
18
+ Provides-Extra: dev
19
+ Requires-Dist: pytest; extra == "dev"
20
+ Requires-Dist: python-dotenv; extra == "dev"
21
+ Dynamic: author
22
+ Dynamic: classifier
23
+ Dynamic: description
24
+ Dynamic: description-content-type
25
+ Dynamic: home-page
26
+ Dynamic: license-file
27
+ Dynamic: provides-extra
28
+ Dynamic: requires-dist
29
+ Dynamic: requires-python
30
+ Dynamic: summary
31
+
32
+ # mlflow-tcdeploy-plugin
33
+
34
+ MLflow deployment plugin,将腾讯云模型服务对接为 MLflow Deployment 后端。
35
+
36
+ 与 `mlflow-tclake-plugin`(Model Registry)配套使用:
37
+ - `tclake`:负责模型存储 → TCLake
38
+ - `tcdeploy`:负责模型部署 → 腾讯云在线推理服务
39
+
40
+ ## 安装
41
+
42
+ ```bash
43
+ pip install mlflow-tcdeploy-plugin
44
+ ```
45
+
46
+ ## 环境变量(来自 wedata-pre-execute)
47
+
48
+ | 变量 | 说明 | 必填 |
49
+ |------|------|------|
50
+ | `KERNEL_WEDATA_CLOUD_SDK_SECRET_ID` | 腾讯云 SecretId | 是 |
51
+ | `KERNEL_WEDATA_CLOUD_SDK_SECRET_KEY` | 腾讯云 SecretKey | 是 |
52
+ | `KERNEL_WEDATA_REGION` | 地域,如 `ap-guangzhou` | 否(默认 ap-guangzhou) |
53
+ | `WEDATA_WORKSPACE_ID` | WeData 工作空间 ID | 是 |
54
+ | `TENCENTCLOUD_ENDPOINT` | WeData API endpoint | 否(默认 wedata.internal.tencentcloudapi.com) |
55
+ | `KERNEL_WEDATA_CLOUD_SDK_SECRET_TOKEN` | 临时凭证 Token | 否 |
56
+
57
+ ## 使用
58
+
59
+ ```python
60
+ from mlflow.deployments import get_deploy_client
61
+
62
+ client = get_deploy_client("tcdeploy")
63
+
64
+ # 查询可用规格
65
+ specs = client.list_instance_types(spec_type="CPU")
66
+ available = [s for s in specs if s["available"]]
67
+ print(available[0]["spec_name"]) # e.g. "TI.SA5.2XLARGE32.POST"
68
+
69
+ # 创建部署
70
+ client.create_deployment(
71
+ name="my-service",
72
+ model_uri="models:/MyModel/1",
73
+ config={"instance_type": "TI.SA5.2XLARGE32.POST"},
74
+ )
75
+
76
+ # 获取详情
77
+ info = client.get_deployment(name="svc-id-xxxx")
78
+
79
+ # 更新
80
+ client.update_deployment(name="svc-id-xxxx", config={"replicas": 2})
81
+
82
+ # 删除(幂等)
83
+ client.delete_deployment(name="svc-id-xxxx")
84
+
85
+ # 列举服务组
86
+ groups = client.list_deployments()
87
+ # 按服务组过滤
88
+ groups = client.list_deployments(endpoint="grp-001")
89
+ ```
90
+
91
+ ## 测试
92
+
93
+ ```bash
94
+ pytest tests/ # 单元测试
95
+ pytest tests/integration/ # 集成测试(需要真实环境变量)
96
+ ```
@@ -0,0 +1,65 @@
1
+ # mlflow-tcdeploy-plugin
2
+
3
+ MLflow deployment plugin,将腾讯云模型服务对接为 MLflow Deployment 后端。
4
+
5
+ 与 `mlflow-tclake-plugin`(Model Registry)配套使用:
6
+ - `tclake`:负责模型存储 → TCLake
7
+ - `tcdeploy`:负责模型部署 → 腾讯云在线推理服务
8
+
9
+ ## 安装
10
+
11
+ ```bash
12
+ pip install mlflow-tcdeploy-plugin
13
+ ```
14
+
15
+ ## 环境变量(来自 wedata-pre-execute)
16
+
17
+ | 变量 | 说明 | 必填 |
18
+ |------|------|------|
19
+ | `KERNEL_WEDATA_CLOUD_SDK_SECRET_ID` | 腾讯云 SecretId | 是 |
20
+ | `KERNEL_WEDATA_CLOUD_SDK_SECRET_KEY` | 腾讯云 SecretKey | 是 |
21
+ | `KERNEL_WEDATA_REGION` | 地域,如 `ap-guangzhou` | 否(默认 ap-guangzhou) |
22
+ | `WEDATA_WORKSPACE_ID` | WeData 工作空间 ID | 是 |
23
+ | `TENCENTCLOUD_ENDPOINT` | WeData API endpoint | 否(默认 wedata.internal.tencentcloudapi.com) |
24
+ | `KERNEL_WEDATA_CLOUD_SDK_SECRET_TOKEN` | 临时凭证 Token | 否 |
25
+
26
+ ## 使用
27
+
28
+ ```python
29
+ from mlflow.deployments import get_deploy_client
30
+
31
+ client = get_deploy_client("tcdeploy")
32
+
33
+ # 查询可用规格
34
+ specs = client.list_instance_types(spec_type="CPU")
35
+ available = [s for s in specs if s["available"]]
36
+ print(available[0]["spec_name"]) # e.g. "TI.SA5.2XLARGE32.POST"
37
+
38
+ # 创建部署
39
+ client.create_deployment(
40
+ name="my-service",
41
+ model_uri="models:/MyModel/1",
42
+ config={"instance_type": "TI.SA5.2XLARGE32.POST"},
43
+ )
44
+
45
+ # 获取详情
46
+ info = client.get_deployment(name="svc-id-xxxx")
47
+
48
+ # 更新
49
+ client.update_deployment(name="svc-id-xxxx", config={"replicas": 2})
50
+
51
+ # 删除(幂等)
52
+ client.delete_deployment(name="svc-id-xxxx")
53
+
54
+ # 列举服务组
55
+ groups = client.list_deployments()
56
+ # 按服务组过滤
57
+ groups = client.list_deployments(endpoint="grp-001")
58
+ ```
59
+
60
+ ## 测试
61
+
62
+ ```bash
63
+ pytest tests/ # 单元测试
64
+ pytest tests/integration/ # 集成测试(需要真实环境变量)
65
+ ```
@@ -0,0 +1,3 @@
1
+ # -*-coding:utf-8-*-
2
+
3
+ __version__ = "1.0.0"
@@ -0,0 +1,129 @@
1
+ from typing import Optional
2
+
3
+ from tencentcloud.common import credential
4
+ from tencentcloud.common.abstract_client import AbstractClient
5
+ from tencentcloud.common.profile.client_profile import ClientProfile
6
+ from tencentcloud.common.exception.tencent_cloud_sdk_exception import TencentCloudSDKException
7
+ from mlflow.exceptions import MlflowException
8
+
9
+ from mlflow_tcdeploy_plugin.config import TcDeployConfig
10
+
11
+
12
+ class WedataDeployClient(AbstractClient):
13
+ """WeData deployment API client, inherits AbstractClient for TC3 signing."""
14
+ _apiVersion = '2025-10-10'
15
+ _endpoint = 'wedata.internal.tencentcloudapi.com'
16
+ _service = 'wedata'
17
+
18
+
19
+ class ModelServiceApiClient:
20
+ """Wraps WeData model service cloud API calls."""
21
+
22
+ def __init__(self, config: TcDeployConfig):
23
+ self._config = config
24
+ cred = credential.Credential(
25
+ config.secret_id,
26
+ config.secret_key,
27
+ config.secret_token,
28
+ )
29
+ client_profile = ClientProfile()
30
+ client_profile.httpProfile.endpoint = config.endpoint
31
+
32
+ self._wedata_client = WedataDeployClient(cred, config.region, client_profile)
33
+
34
+ def _call(self, action: str, body: dict) -> dict:
35
+ try:
36
+ return self._wedata_client.call_json(action, body)
37
+ except TencentCloudSDKException as e:
38
+ raise MlflowException(f"{e.code}: {e.message}") from e
39
+
40
+ def _base_body(self) -> dict:
41
+ """Common fields required in all requests."""
42
+ return {"WorkspaceId": self._config.workspace_id}
43
+
44
+ def create_deployment(self, body: dict) -> dict:
45
+ req = {**self._base_body(), **body}
46
+ return self._call("CreateMLModelServices", req)
47
+
48
+ def update_deployment(self, service_id: str, body: dict) -> None:
49
+ req = {**self._base_body(), "ServiceId": service_id, **body}
50
+ self._call("UpdateMLModelService", req)
51
+
52
+ def delete_deployment(self, service_id: str) -> None:
53
+ req = {**self._base_body(), "ServiceId": service_id}
54
+ try:
55
+ self._call("DeleteMLModelService", req)
56
+ except MlflowException as e:
57
+ if "ResourceNotFound" in str(e):
58
+ return # idempotent
59
+ raise
60
+
61
+ def get_deployment(self, service_id: str) -> dict:
62
+ req = {**self._base_body(), "ServiceId": service_id}
63
+ return self._call("GetMLModelService", req)
64
+
65
+ def list_deployments(self, endpoint: Optional[str] = None) -> list:
66
+ req = {**self._base_body()}
67
+ if endpoint:
68
+ req["ServiceGroupId"] = endpoint
69
+ resp = self._call("ListMLModelServiceGroups", req)
70
+ return resp.get("Groups", [])
71
+
72
+ def predict(
73
+ self,
74
+ service_group_id: str,
75
+ curl_data: str,
76
+ relative_url=None,
77
+ auth_token_value=None,
78
+ ) -> dict:
79
+ req = {
80
+ **self._base_body(),
81
+ "ServiceGroupId": service_group_id,
82
+ "CurlData": curl_data,
83
+ }
84
+ if relative_url is not None:
85
+ req["RelativeUrl"] = relative_url
86
+ if auth_token_value is not None:
87
+ req["AuthTokenValue"] = auth_token_value
88
+ return self._call("ModelServiceInterfaceCallTest", req)
89
+
90
+ def debug_pod_shell(self, service_id: str, pod_name: str) -> dict:
91
+ req = {
92
+ **self._base_body(),
93
+ "ServiceId": service_id,
94
+ "PodName": pod_name,
95
+ }
96
+ return self._call("CreateModelServicePodUrl", req)
97
+
98
+ def restart_pod(self, service_id: str, pod_name: str) -> dict:
99
+ req = {
100
+ **self._base_body(),
101
+ "ServiceId": service_id,
102
+ "PodName": pod_name,
103
+ }
104
+ return self._call("RebuildModelServicePod", req)
105
+
106
+ def get_pod_logs(
107
+ self,
108
+ service_id: str,
109
+ pod_name=None,
110
+ limit=None,
111
+ start_time=None,
112
+ end_time=None,
113
+ context=None,
114
+ ) -> dict:
115
+ req = {
116
+ **self._base_body(),
117
+ "ServiceId": service_id,
118
+ }
119
+ if pod_name is not None:
120
+ req["PodName"] = pod_name
121
+ if limit is not None:
122
+ req["Limit"] = limit
123
+ if start_time is not None:
124
+ req["StartTime"] = start_time
125
+ if end_time is not None:
126
+ req["EndTime"] = end_time
127
+ if context is not None:
128
+ req["Context"] = context
129
+ return self._call("ListMLServiceLogs", req)
@@ -0,0 +1,181 @@
1
+ import json
2
+ from typing import List, Optional
3
+
4
+ from mlflow.deployments import BaseDeploymentClient
5
+ from tencentcloud.common import credential
6
+ from tencentcloud.tione.v20211111 import tione_client
7
+
8
+ from mlflow_tcdeploy_plugin.api_client import ModelServiceApiClient
9
+ from mlflow_tcdeploy_plugin.config import TcDeployConfig
10
+ from mlflow_tcdeploy_plugin.models import build_create_request, snake_to_camel_dict, to_snake_case_dict
11
+
12
+
13
+ class TcDeploymentClient(BaseDeploymentClient):
14
+ """
15
+ MLflow deployment client, connecting Tencent Cloud model service as MLflow deployment backend.
16
+
17
+ Usage:
18
+ from mlflow.deployments import get_deploy_client
19
+ client = get_deploy_client("tcdeploy")
20
+
21
+ Required environment variables (from wedata-pre-execute):
22
+ KERNEL_WEDATA_CLOUD_SDK_SECRET_ID, KERNEL_WEDATA_CLOUD_SDK_SECRET_KEY,
23
+ KERNEL_WEDATA_REGION, WEDATA_WORKSPACE_ID
24
+ """
25
+
26
+ def __init__(self, target_uri: str):
27
+ super().__init__(target_uri)
28
+ self._config = TcDeployConfig()
29
+ self._api = ModelServiceApiClient(self._config)
30
+
31
+ def create_deployment(
32
+ self,
33
+ name: str,
34
+ model_uri: str,
35
+ flavor=None,
36
+ config: Optional[dict] = None,
37
+ endpoint: Optional[str] = None,
38
+ ) -> dict:
39
+ """
40
+ Create a model service.
41
+
42
+ Args:
43
+ name: service name
44
+ model_uri: MLflow model URI, e.g. "models:/MyModel/1"
45
+ config: deployment config, must include instance_type.
46
+ Use list_instance_types() to see available options.
47
+ endpoint: service group ID (serviceGroupId), optional
48
+ """
49
+ body = build_create_request(
50
+ name=name,
51
+ model_uri=model_uri,
52
+ config=config or {},
53
+ endpoint=endpoint,
54
+ )
55
+ raw = self._api.create_deployment(body)
56
+ return to_snake_case_dict(raw)
57
+
58
+ def update_deployment(
59
+ self,
60
+ name: str,
61
+ model_uri: Optional[str] = None,
62
+ flavor=None,
63
+ config: Optional[dict] = None,
64
+ endpoint: Optional[str] = None,
65
+ ) -> None:
66
+ """
67
+ Update a model service. name is serviceId. Only pass fields to change.
68
+ """
69
+ self._api.update_deployment(name, snake_to_camel_dict(config or {}))
70
+
71
+ def delete_deployment(
72
+ self,
73
+ name: str,
74
+ config: Optional[dict] = None,
75
+ endpoint: Optional[str] = None,
76
+ ) -> None:
77
+ """
78
+ Delete a model service. name is serviceId. Idempotent.
79
+ """
80
+ self._api.delete_deployment(name)
81
+
82
+ def get_deployment(self, name: str, endpoint: Optional[str] = None) -> dict:
83
+ """
84
+ Get model service details. name is serviceId. Returns snake_case dict.
85
+ """
86
+ raw = self._api.get_deployment(name)
87
+ return to_snake_case_dict(raw)
88
+
89
+ def list_deployments(self, endpoint: Optional[str] = None) -> List[dict]:
90
+ """
91
+ List service groups. endpoint is serviceGroupId filter, optional.
92
+ Each item includes "name" field (serviceGroupId) as required by MLflow spec.
93
+ """
94
+ raw_list = self._api.list_deployments(endpoint=endpoint)
95
+ result = []
96
+ for item in raw_list:
97
+ converted = to_snake_case_dict(item)
98
+ # MLflow spec requires each item to have "name" field
99
+ converted["name"] = item.get("ServiceGroupId", "")
100
+ result.append(converted)
101
+ return result
102
+
103
+ def list_instance_types(self, spec_type: Optional[str] = None) -> List[dict]:
104
+ """
105
+ Query available TIone spec list. Calls TIone SDK directly, bypassing Java layer.
106
+
107
+ Args:
108
+ spec_type: "CPU" or "GPU", returns all if not specified
109
+
110
+ Returns:
111
+ list of dict with spec_name, spec_alias, spec_type, available, available_region, gpu_type
112
+ """
113
+ cred = credential.Credential(
114
+ self._config.secret_id,
115
+ self._config.secret_key,
116
+ self._config.secret_token,
117
+ )
118
+ client = tione_client.TioneClient(cred, self._config.region)
119
+
120
+ # Access via module attribute so that unit-test patches on tione_client are effective
121
+ req = tione_client.models.DescribeInferenceSpecsRequest()
122
+ resp = client.DescribeInferenceSpecs(req)
123
+
124
+ specs = []
125
+ for s in resp.SpecInfos or []:
126
+ if spec_type and s.SpecType != spec_type:
127
+ continue
128
+ specs.append({
129
+ "spec_name": s.SpecName,
130
+ "spec_alias": s.SpecAlias,
131
+ "spec_type": s.SpecType,
132
+ "available": s.Available,
133
+ "available_region": list(s.AvailableRegion or []),
134
+ "gpu_type": s.GpuType or "",
135
+ })
136
+ return specs
137
+
138
+ def predict(
139
+ self,
140
+ deployment_name=None,
141
+ inputs=None,
142
+ endpoint=None,
143
+ config=None,
144
+ ):
145
+ try:
146
+ curl_data = json.dumps(inputs)
147
+ except (TypeError, ValueError) as e:
148
+ raise ValueError(f"inputs must be JSON-serializable, got {type(inputs).__name__}: {e}") from e
149
+
150
+ raw = self._api.predict(
151
+ service_group_id=deployment_name,
152
+ curl_data=curl_data,
153
+ relative_url=endpoint or "/predict",
154
+ auth_token_value=(config or {}).get("auth_token"),
155
+ )
156
+
157
+ raw_str = raw.get("CurlResponseRaw", "")
158
+ try:
159
+ return json.loads(raw_str)
160
+ except (json.JSONDecodeError, TypeError):
161
+ return {"raw": raw_str}
162
+
163
+ def debug_pod_shell(self, service_id, pod_name):
164
+ raw = self._api.debug_pod_shell(service_id=service_id, pod_name=pod_name)
165
+ return to_snake_case_dict(raw)
166
+
167
+ def restart_pod(self, service_id, pod_name):
168
+ raw = self._api.restart_pod(service_id=service_id, pod_name=pod_name)
169
+ return {"request_id": raw.get("TiOneRebuildServiceRequestId")}
170
+
171
+ def get_pod_logs(self, service_id, pod_name=None, limit=None, start_time=None, end_time=None, context=None):
172
+ raw = self._api.get_pod_logs(
173
+ service_id=service_id,
174
+ pod_name=pod_name,
175
+ limit=limit,
176
+ start_time=start_time,
177
+ end_time=end_time,
178
+ context=context,
179
+ )
180
+ logs = [to_snake_case_dict(entry) for entry in (raw.get("Content") or [])]
181
+ return {"logs": logs, "context": raw.get("Context")}
@@ -0,0 +1,45 @@
1
+ import os
2
+ from dataclasses import dataclass, field
3
+ from typing import Optional
4
+
5
+
6
+ def _require_env(key: str, override: Optional[str]) -> str:
7
+ value = override or os.getenv(key)
8
+ if not value:
9
+ raise ValueError(
10
+ f"Missing required configuration: '{key}'. "
11
+ f"Set the environment variable or pass it as a parameter."
12
+ )
13
+ return value
14
+
15
+
16
+ @dataclass
17
+ class TcDeployConfig:
18
+ secret_id: str = field(init=False)
19
+ secret_key: str = field(init=False)
20
+ secret_token: Optional[str] = field(init=False)
21
+ region: str = field(init=False)
22
+ workspace_id: str = field(init=False)
23
+ endpoint: str = field(init=False)
24
+ sub_account_uin: str = field(init=False)
25
+ owner_uin: str = field(init=False)
26
+
27
+ def __init__(
28
+ self,
29
+ secret_id: Optional[str] = None,
30
+ secret_key: Optional[str] = None,
31
+ secret_token: Optional[str] = None,
32
+ region: Optional[str] = None,
33
+ workspace_id: Optional[str] = None,
34
+ endpoint: Optional[str] = None,
35
+ ):
36
+ self.secret_id = _require_env("KERNEL_WEDATA_CLOUD_SDK_SECRET_ID", secret_id)
37
+ self.secret_key = _require_env("KERNEL_WEDATA_CLOUD_SDK_SECRET_KEY", secret_key)
38
+ self.secret_token = secret_token or os.getenv("KERNEL_WEDATA_CLOUD_SDK_SECRET_TOKEN")
39
+ self.region = region or os.getenv("KERNEL_WEDATA_REGION", "ap-guangzhou")
40
+ self.workspace_id = _require_env("WEDATA_WORKSPACE_ID", workspace_id)
41
+ self.endpoint = endpoint or os.getenv(
42
+ "TENCENTCLOUD_ENDPOINT", "wedata.internal.tencentcloudapi.com"
43
+ )
44
+ self.sub_account_uin = os.getenv("KERNEL_LOGIN_UIN", "")
45
+ self.owner_uin = os.getenv("QCLOUD_UIN", "")
@@ -0,0 +1,89 @@
1
+ import re
2
+ from typing import Optional
3
+
4
+ DEFAULT_CONFIG = {
5
+ "scaleMode": "MANUAL",
6
+ "replicas": 1,
7
+ "chargeType": "POSTPAID_BY_HOUR",
8
+ "authorizationEnable": False,
9
+ "logEnable": False,
10
+ }
11
+
12
+ # snake_case key → camelCase key mapping (config uses snake_case, request uses camelCase)
13
+ _SNAKE_TO_CAMEL = {
14
+ "instance_type": "instanceType",
15
+ "scale_mode": "scaleMode",
16
+ "replicas": "replicas",
17
+ "charge_type": "chargeType",
18
+ "authorization_enable": "authorizationEnable",
19
+ "log_enable": "logEnable",
20
+ "service_group_id": "serviceGroupId",
21
+ }
22
+
23
+
24
+ def _camel_to_snake(name: str) -> str:
25
+ s1 = re.sub(r"(.)([A-Z][a-z]+)", r"\1_\2", name)
26
+ return re.sub(r"([a-z0-9])([A-Z])", r"\1_\2", s1).lower()
27
+
28
+
29
+ def to_snake_case_dict(data: dict) -> dict:
30
+ """Recursively convert dict keys from camelCase to snake_case."""
31
+ result = {}
32
+ for k, v in data.items():
33
+ snake_key = _camel_to_snake(k)
34
+ if isinstance(v, dict):
35
+ result[snake_key] = to_snake_case_dict(v)
36
+ else:
37
+ result[snake_key] = v
38
+ return result
39
+
40
+
41
+ def parse_model_uri(model_uri: str) -> dict:
42
+ """
43
+ Parse MLflow model URI into mlModelInfo fields.
44
+ Supports:
45
+ - models:/ModelName/Version → name + version
46
+ - runs:/run_id/path → modelPath
47
+ """
48
+ if model_uri.startswith("models:/"):
49
+ parts = model_uri[len("models:/"):].split("/")
50
+ if len(parts) < 2:
51
+ raise ValueError(f"Invalid models URI: {model_uri}")
52
+ return {"name": parts[0], "version": parts[1], "modelPath": None}
53
+ elif model_uri.startswith("runs:/"):
54
+ return {"name": None, "version": None, "modelPath": model_uri}
55
+ else:
56
+ raise ValueError(f"Unsupported model_uri scheme: {model_uri}")
57
+
58
+
59
+ def snake_to_camel_dict(config: dict) -> dict:
60
+ """Convert config dict keys from snake_case to camelCase using _SNAKE_TO_CAMEL mapping."""
61
+ return {_SNAKE_TO_CAMEL.get(k, k): v for k, v in config.items()}
62
+
63
+
64
+ def build_create_request(
65
+ name: str,
66
+ model_uri: str,
67
+ config: dict,
68
+ endpoint: Optional[str] = None,
69
+ ) -> dict:
70
+ """Assemble create_deployment parameters into a cloud API request body.
71
+
72
+ Note: WorkspaceId is injected by ModelServiceApiClient._base_body(), not here.
73
+ """
74
+ if not config.get("instance_type"):
75
+ raise ValueError(
76
+ "config must include 'instance_type'. "
77
+ "Use list_instance_types() to see available options."
78
+ )
79
+
80
+ camel_config = snake_to_camel_dict(config)
81
+
82
+ req = {**DEFAULT_CONFIG, **camel_config}
83
+ req["serviceName"] = name
84
+ req["mlModelInfo"] = parse_model_uri(model_uri)
85
+
86
+ if endpoint:
87
+ req["serviceGroupId"] = endpoint
88
+
89
+ return req