mlflow-tclake-plugin 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mlflow-tclake-plugin might be problematic. Click here for more details.
- mlflow_tclake_plugin/__init__.py +0 -0
- mlflow_tclake_plugin/tclake_store.py +743 -0
- mlflow_tclake_plugin-0.0.1.dist-info/METADATA +8 -0
- mlflow_tclake_plugin-0.0.1.dist-info/RECORD +7 -0
- mlflow_tclake_plugin-0.0.1.dist-info/WHEEL +5 -0
- mlflow_tclake_plugin-0.0.1.dist-info/entry_points.txt +2 -0
- mlflow_tclake_plugin-0.0.1.dist-info/top_level.txt +1 -0
|
File without changes
|
|
@@ -0,0 +1,743 @@
|
|
|
1
|
+
import base64
|
|
2
|
+
import json
|
|
3
|
+
import os
|
|
4
|
+
import time
|
|
5
|
+
import uuid
|
|
6
|
+
import urllib.parse
|
|
7
|
+
from datetime import datetime
|
|
8
|
+
|
|
9
|
+
from cachetools import TTLCache
|
|
10
|
+
from tencentcloud.common import credential
|
|
11
|
+
from tencentcloud.common.common_client import CommonClient
|
|
12
|
+
from tencentcloud.common.profile.client_profile import ClientProfile
|
|
13
|
+
|
|
14
|
+
from mlflow.entities.model_registry import (
|
|
15
|
+
RegisteredModel,
|
|
16
|
+
ModelVersion,
|
|
17
|
+
ModelVersionTag,
|
|
18
|
+
)
|
|
19
|
+
from mlflow.exceptions import MlflowException
|
|
20
|
+
from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE, RESOURCE_ALREADY_EXISTS
|
|
21
|
+
from mlflow.store.model_registry import (
|
|
22
|
+
SEARCH_REGISTERED_MODEL_MAX_RESULTS_THRESHOLD,
|
|
23
|
+
SEARCH_MODEL_VERSION_MAX_RESULTS_THRESHOLD,
|
|
24
|
+
)
|
|
25
|
+
from mlflow.store.model_registry.abstract_store import AbstractStore
|
|
26
|
+
from mlflow.utils.annotations import experimental
|
|
27
|
+
from mlflow.utils.search_utils import SearchModelUtils, SearchModelVersionUtils
|
|
28
|
+
from mlflow.store.entities.paged_list import PagedList
|
|
29
|
+
from mlflow.models.model import get_model_info
|
|
30
|
+
from tencentcloud.common.exception.tencent_cloud_sdk_exception import TencentCloudSDKException
|
|
31
|
+
|
|
32
|
+
def to_string(obj):
|
|
33
|
+
if obj is None:
|
|
34
|
+
return "None"
|
|
35
|
+
if isinstance(obj, str):
|
|
36
|
+
return obj
|
|
37
|
+
if isinstance(obj, dict):
|
|
38
|
+
return json.dumps(obj, indent=2)
|
|
39
|
+
try:
|
|
40
|
+
return json.dumps(obj, indent=2)
|
|
41
|
+
except:
|
|
42
|
+
return str(obj)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
tencent_cloud_debug = os.getenv("TENCENTCLOUD_DEBUG", None)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def log_msg(msg):
|
|
49
|
+
if tencent_cloud_debug:
|
|
50
|
+
print(msg)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def _get_create_time_from_audit(audit):
|
|
54
|
+
dt_str = audit["CreatedTime"]
|
|
55
|
+
dt = datetime.strptime(dt_str, "%Y-%m-%d %H:%M:%S")
|
|
56
|
+
return int(dt.timestamp() * 1000)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def _get_last_updated_time_from_audit(audit):
|
|
60
|
+
dt_str = audit["LastModifiedTime"]
|
|
61
|
+
if dt_str is None or len(dt_str) == 0:
|
|
62
|
+
return _get_create_time_from_audit(audit)
|
|
63
|
+
dt = datetime.strptime(dt_str, "%Y-%m-%d %H:%M:%S")
|
|
64
|
+
return int(dt.timestamp() * 1000)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def _get_description(resp):
|
|
68
|
+
if resp is None or "Comment" not in resp:
|
|
69
|
+
return None
|
|
70
|
+
return resp["Comment"]
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
TCLAKE_MLFLOW_TAG_PREFIX = "tclake.tag."
|
|
74
|
+
TCLAKE_MLFLOW_RUN_ID_KEY = "tclake.mlflow.run_id"
|
|
75
|
+
TCLAKE_MLFLOW_RUN_LINK_KEY = "tclake.mlflow.run_link"
|
|
76
|
+
TCLAKE_MLFLOW_DEPLOYMENT_JOB_ID_KEY = "tclake.mlflow.deployment_job_id"
|
|
77
|
+
TCLAKE_WEDATA_WORKSPACE_ID_KEY = "tclake.wedata.workspace_id"
|
|
78
|
+
TCLAKE_MLFLOW_DEPLOYMENT_STATUS_KEY = "tclake.mlflow.deployment_status"
|
|
79
|
+
TCLAKE_MLFLOW_MODEL_ID_KEY = "tclake.mlflow.model_id"
|
|
80
|
+
TCLAKE_MLFLOW_MODEL_SIGNATURE_KEY = "tclake.mlflow.model_signature"
|
|
81
|
+
|
|
82
|
+
UN_DEPLOYMENT = "UnDeployment"
|
|
83
|
+
|
|
84
|
+
def _set_kv_to_properties(key, value, properties=None):
|
|
85
|
+
if properties is None:
|
|
86
|
+
properties = []
|
|
87
|
+
if value is not None:
|
|
88
|
+
properties.append({"Key": key, "Value": value})
|
|
89
|
+
return properties
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def _get_kv_from_properties(properties, key):
|
|
93
|
+
if properties is None:
|
|
94
|
+
return None
|
|
95
|
+
for p in properties:
|
|
96
|
+
if p["Key"] == key:
|
|
97
|
+
return p["Value"]
|
|
98
|
+
return None
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def _set_run_id_to_properties(run_id, properties=None):
|
|
102
|
+
return _set_kv_to_properties(TCLAKE_MLFLOW_RUN_ID_KEY, run_id, properties)
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def _get_run_id_from_properties(properties):
|
|
106
|
+
return _get_kv_from_properties(properties, TCLAKE_MLFLOW_RUN_ID_KEY)
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def _set_run_link_to_properties(run_link, properties):
|
|
110
|
+
return _set_kv_to_properties(TCLAKE_MLFLOW_RUN_LINK_KEY, run_link, properties)
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def _get_run_link_from_properties(properties):
|
|
114
|
+
return _get_kv_from_properties(properties, TCLAKE_MLFLOW_RUN_LINK_KEY)
|
|
115
|
+
|
|
116
|
+
def _get_model_id_from_properties(properties):
|
|
117
|
+
return _get_kv_from_properties(properties, TCLAKE_MLFLOW_MODEL_ID_KEY)
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def _add_tag_to_properties(tag, properties=None):
|
|
121
|
+
if tag and tag.value is not None:
|
|
122
|
+
properties = _set_kv_to_properties(
|
|
123
|
+
TCLAKE_MLFLOW_TAG_PREFIX + tag.key, tag.value, properties
|
|
124
|
+
)
|
|
125
|
+
return properties
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def _add_tags_to_properties(tags, properties=None):
|
|
129
|
+
if tags:
|
|
130
|
+
for tag in tags:
|
|
131
|
+
properties = _add_tag_to_properties(tag, properties)
|
|
132
|
+
return properties
|
|
133
|
+
|
|
134
|
+
def _add_deployment_job_id_to_properties(deployment_job_id, properties=None):
|
|
135
|
+
properties = _set_kv_to_properties(TCLAKE_MLFLOW_DEPLOYMENT_JOB_ID_KEY, deployment_job_id, properties)
|
|
136
|
+
return properties
|
|
137
|
+
|
|
138
|
+
def _add_workspace_id_to_properties(workspace_id, properties=None):
|
|
139
|
+
if workspace_id is not None:
|
|
140
|
+
properties = _set_kv_to_properties(TCLAKE_WEDATA_WORKSPACE_ID_KEY, workspace_id, properties)
|
|
141
|
+
return properties
|
|
142
|
+
|
|
143
|
+
def _add_deployment_status_to_properties(status, properties=None):
|
|
144
|
+
if status is not None:
|
|
145
|
+
properties = _set_kv_to_properties(TCLAKE_MLFLOW_DEPLOYMENT_STATUS_KEY, status, properties)
|
|
146
|
+
return properties
|
|
147
|
+
|
|
148
|
+
def _add_model_id_to_properties(model_id, properties=None):
|
|
149
|
+
if model_id is not None:
|
|
150
|
+
properties = _set_kv_to_properties(TCLAKE_MLFLOW_MODEL_ID_KEY, model_id, properties)
|
|
151
|
+
return properties
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def _add_model_signature_to_properties(source, properties):
|
|
155
|
+
log_msg("model source is {}".format(source))
|
|
156
|
+
model_info = get_model_info(source)
|
|
157
|
+
signature = model_info.signature
|
|
158
|
+
log_msg("model {} signature is {}".format(source, signature))
|
|
159
|
+
if signature is None:
|
|
160
|
+
raise MlflowException(f"Registered model signature is not found in source artifact location '{source}'")
|
|
161
|
+
sig_json = json.dumps(signature.to_dict())
|
|
162
|
+
log_msg("model {} signature json is {}".format(source, sig_json))
|
|
163
|
+
properties = _set_kv_to_properties(TCLAKE_MLFLOW_MODEL_SIGNATURE_KEY, sig_json, properties)
|
|
164
|
+
return properties
|
|
165
|
+
|
|
166
|
+
def _get_tags_from_properties(properties):
|
|
167
|
+
if properties is None:
|
|
168
|
+
return None
|
|
169
|
+
tags = []
|
|
170
|
+
for p in properties:
|
|
171
|
+
if p["Key"].startswith(TCLAKE_MLFLOW_TAG_PREFIX):
|
|
172
|
+
tags.append(
|
|
173
|
+
ModelVersionTag(p["Key"][len(TCLAKE_MLFLOW_TAG_PREFIX) :], p["Value"])
|
|
174
|
+
)
|
|
175
|
+
return tags
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
def _append_tag_key_prefix(key):
|
|
179
|
+
return TCLAKE_MLFLOW_TAG_PREFIX + key
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
def _get_model_version(version):
|
|
183
|
+
return str(version)
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
def _set_model_version(version):
|
|
187
|
+
return int(version)
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
def _make_model(resp):
|
|
191
|
+
properties = resp["Properties"]
|
|
192
|
+
audit = resp["Audit"]
|
|
193
|
+
return RegisteredModel(
|
|
194
|
+
name=resp["Name"],
|
|
195
|
+
creation_timestamp=_get_create_time_from_audit(audit),
|
|
196
|
+
last_updated_timestamp=_get_last_updated_time_from_audit(audit),
|
|
197
|
+
description=_get_description(resp),
|
|
198
|
+
tags=_get_tags_from_properties(properties)
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
def _get_model_version_name(entity):
|
|
203
|
+
return "{}.{}.{}".format(
|
|
204
|
+
entity["CatalogName"], entity["SchemaName"], entity["ModelName"]
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
def _make_model_version(entity, name):
|
|
209
|
+
properties = entity["Properties"]
|
|
210
|
+
audit = entity["Audit"]
|
|
211
|
+
aliases = entity["Aliases"]
|
|
212
|
+
return ModelVersion(
|
|
213
|
+
name=name,
|
|
214
|
+
version=_get_model_version(entity["Version"]),
|
|
215
|
+
creation_timestamp=_get_create_time_from_audit(audit),
|
|
216
|
+
last_updated_timestamp=_get_last_updated_time_from_audit(audit),
|
|
217
|
+
description=_get_description(entity),
|
|
218
|
+
source=entity["Uri"],
|
|
219
|
+
run_id=_get_run_id_from_properties(properties),
|
|
220
|
+
tags=_get_tags_from_properties(properties),
|
|
221
|
+
run_link=_get_run_link_from_properties(properties),
|
|
222
|
+
aliases=aliases,
|
|
223
|
+
model_id=_get_model_id_from_properties(properties),
|
|
224
|
+
status="READY",
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
def get_tencent_cloud_headers():
|
|
229
|
+
header_json = os.getenv("TENCENTCLOUD_HEADER_JSON", None)
|
|
230
|
+
if header_json is None:
|
|
231
|
+
return None
|
|
232
|
+
return json.loads(header_json)
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
def get_tencent_cloud_client_profile():
|
|
236
|
+
endpoint = os.getenv("TENCENTCLOUD_ENDPOINT", None)
|
|
237
|
+
if endpoint is None:
|
|
238
|
+
return None
|
|
239
|
+
client_profile = ClientProfile()
|
|
240
|
+
client_profile.httpProfile.endpoint = endpoint
|
|
241
|
+
return client_profile
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
def _parse_page_token(page_token):
|
|
245
|
+
decoded_token = base64.b64decode(page_token)
|
|
246
|
+
parsed_token = json.loads(decoded_token)
|
|
247
|
+
return parsed_token
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
def _create_page_token(offset, search_id):
|
|
251
|
+
return base64.b64encode(
|
|
252
|
+
json.dumps({"offset": offset, "search_id": search_id}).encode("utf-8")
|
|
253
|
+
)
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
@experimental
|
|
257
|
+
class TCLakeStore(AbstractStore):
|
|
258
|
+
"""
|
|
259
|
+
Client for an Open Source Unity Catalog Server accessed via REST API calls.
|
|
260
|
+
"""
|
|
261
|
+
|
|
262
|
+
def __init__(self, store_uri=None, tracking_uri=None):
|
|
263
|
+
super().__init__(store_uri, tracking_uri)
|
|
264
|
+
log_msg(
|
|
265
|
+
"initializing tencent tclake client {} {}".format(store_uri, tracking_uri)
|
|
266
|
+
)
|
|
267
|
+
sid = os.getenv("TENCENTCLOUD_SECRET_ID", "")
|
|
268
|
+
if len(sid) == 0:
|
|
269
|
+
raise MlflowException("TENCENTCLOUD_SECRET_ID is not set")
|
|
270
|
+
sk = os.getenv("TENCENTCLOUD_SECRET_KEY", "")
|
|
271
|
+
if len(sk) == 0:
|
|
272
|
+
raise MlflowException("TENCENTCLOUD_SECRET_KEY is not set")
|
|
273
|
+
sk = os.getenv("TENCENTCLOUD_SECRET_KEY", "")
|
|
274
|
+
if len(sk) == 0:
|
|
275
|
+
raise MlflowException("TENCENTCLOUD_SECRET_KEY is not set")
|
|
276
|
+
|
|
277
|
+
token = os.getenv("TENCENTCLOUD_TOKEN", None)
|
|
278
|
+
|
|
279
|
+
self.workspace_id = os.getenv("WEDATA_WORKSPACE_ID", "")
|
|
280
|
+
if len(self.workspace_id) == 0:
|
|
281
|
+
raise MlflowException("WEDATA_WORKSPACE_ID is not set")
|
|
282
|
+
log_msg(str.format("wedata workspace id: {}", self.workspace_id))
|
|
283
|
+
|
|
284
|
+
client_profile = get_tencent_cloud_client_profile()
|
|
285
|
+
cred = credential.Credential(sid, sk, token)
|
|
286
|
+
parts = store_uri.split(":")
|
|
287
|
+
if len(parts) < 2:
|
|
288
|
+
raise MlflowException("set store_uri tclake:{region}")
|
|
289
|
+
region = parts[1]
|
|
290
|
+
self.client = CommonClient(
|
|
291
|
+
"tccatalog", "2024-10-24", cred, region, client_profile
|
|
292
|
+
)
|
|
293
|
+
self.headers = get_tencent_cloud_headers()
|
|
294
|
+
self.default_catalog_name = os.getenv(
|
|
295
|
+
"TENCENTCLOUD_DEFAULT_CATALOG_NAME", "default"
|
|
296
|
+
)
|
|
297
|
+
self.default_schema_name = os.getenv(
|
|
298
|
+
"TENCENTCLOUD_DEFAULT_SCHEMA_NAME", "default"
|
|
299
|
+
)
|
|
300
|
+
cache_size = int(os.getenv("TCLAKE_CACHE_SIZE", "100"))
|
|
301
|
+
cache_ttl = int(os.getenv("TCLAKE_CACHE_TTL_SECS", "300"))
|
|
302
|
+
self.cache = TTLCache(maxsize=cache_size, ttl=cache_ttl)
|
|
303
|
+
log_msg(
|
|
304
|
+
"initialized tencent tclake client successfully {} {} {} {} {}".format(
|
|
305
|
+
region, client_profile, self.headers, cache_size, cache_ttl
|
|
306
|
+
)
|
|
307
|
+
)
|
|
308
|
+
|
|
309
|
+
def _split_model_name(self, name):
|
|
310
|
+
parts = name.split(".")
|
|
311
|
+
if len(parts) == 1:
|
|
312
|
+
return [self.default_catalog_name, self.default_schema_name, parts[0]]
|
|
313
|
+
if len(parts) == 2:
|
|
314
|
+
return [self.default_catalog_name, parts[0], parts[1]]
|
|
315
|
+
if len(parts) == 3:
|
|
316
|
+
return parts
|
|
317
|
+
raise MlflowException(
|
|
318
|
+
"invalid model name: {}, must be catalog.schema.model".format(name)
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
def _call(self, action, req):
|
|
322
|
+
log_msg("req: {}\n{}".format(action, json.dumps(req, indent=2)))
|
|
323
|
+
body = self.client.call(action, req, headers=self.headers)
|
|
324
|
+
body_obj = json.loads(body)
|
|
325
|
+
log_msg("body: {}\n{}".format(action, json.dumps(body_obj, indent=2)))
|
|
326
|
+
resp = body_obj["Response"]
|
|
327
|
+
return resp
|
|
328
|
+
|
|
329
|
+
def _get_model_version_numbers(self, name):
|
|
330
|
+
[catalog_name, schema_name, model_name] = self._split_model_name(name)
|
|
331
|
+
req_body = {
|
|
332
|
+
"CatalogName": catalog_name,
|
|
333
|
+
"SchemaName": schema_name,
|
|
334
|
+
"ModelName": model_name,
|
|
335
|
+
}
|
|
336
|
+
resp = self._call("DescribeModelVersionNumbers", req_body)
|
|
337
|
+
return resp["Versions"]
|
|
338
|
+
|
|
339
|
+
def _get_model_version_by_alias(self, name, alias):
|
|
340
|
+
[catalog_name, schema_name, model_name] = self._split_model_name(name)
|
|
341
|
+
req_body = {
|
|
342
|
+
"CatalogName": catalog_name,
|
|
343
|
+
"SchemaName": schema_name,
|
|
344
|
+
"ModelName": model_name,
|
|
345
|
+
"ModelVersions": self._get_model_version_numbers(name),
|
|
346
|
+
}
|
|
347
|
+
resp = self._call("DescribeModelVersions", req_body)
|
|
348
|
+
model_version = None
|
|
349
|
+
for mv in resp["ModelVersions"]:
|
|
350
|
+
if alias in mv["Aliases"]:
|
|
351
|
+
model_version = mv
|
|
352
|
+
break
|
|
353
|
+
if model_version is None:
|
|
354
|
+
return None
|
|
355
|
+
return _make_model_version(model_version, name)
|
|
356
|
+
|
|
357
|
+
def create_registered_model(self, name, tags=None, description=None, deployment_job_id=None):
|
|
358
|
+
log_msg("create_registered_model {} {} {}".format(name, tags, description))
|
|
359
|
+
[catalog_name, schema_name, model_name] = self._split_model_name(name)
|
|
360
|
+
properties = []
|
|
361
|
+
_add_tags_to_properties(tags, properties)
|
|
362
|
+
_add_deployment_job_id_to_properties(deployment_job_id, properties)
|
|
363
|
+
_add_workspace_id_to_properties(self.workspace_id, properties)
|
|
364
|
+
req_body = {
|
|
365
|
+
"CatalogName": catalog_name,
|
|
366
|
+
"SchemaName": schema_name,
|
|
367
|
+
"ModelName": model_name,
|
|
368
|
+
"Comment": description if description else "",
|
|
369
|
+
"Properties": properties,
|
|
370
|
+
}
|
|
371
|
+
try:
|
|
372
|
+
resp = self._call("CreateModel", req_body)
|
|
373
|
+
except Exception as e:
|
|
374
|
+
if isinstance(e, TencentCloudSDKException):
|
|
375
|
+
if e.code == "FailedOperation.MetalakeAlreadyExistsError":
|
|
376
|
+
raise MlflowException(
|
|
377
|
+
f"Registered Model (name={name}) already exists.",
|
|
378
|
+
RESOURCE_ALREADY_EXISTS,
|
|
379
|
+
)
|
|
380
|
+
raise
|
|
381
|
+
return _make_model(resp["Model"])
|
|
382
|
+
|
|
383
|
+
def update_registered_model(self, name, description):
|
|
384
|
+
log_msg("update_register_model {} {}".format(name, description))
|
|
385
|
+
[catalog_name, schema_name, model_name] = self._split_model_name(name)
|
|
386
|
+
req_body = {
|
|
387
|
+
"CatalogName": catalog_name,
|
|
388
|
+
"SchemaName": schema_name,
|
|
389
|
+
"ModelName": model_name,
|
|
390
|
+
"NewComment": description if description else "",
|
|
391
|
+
}
|
|
392
|
+
resp = self._call("ModifyModelComment", req_body)
|
|
393
|
+
return _make_model(resp["Model"])
|
|
394
|
+
|
|
395
|
+
def rename_registered_model(self, name, new_name):
|
|
396
|
+
log_msg("rename_register_model {} {}".format(name, new_name))
|
|
397
|
+
[catalog_name, schema_name, model_name] = self._split_model_name(name)
|
|
398
|
+
req_body = {
|
|
399
|
+
"CatalogName": catalog_name,
|
|
400
|
+
"SchemaName": schema_name,
|
|
401
|
+
"ModelName": model_name,
|
|
402
|
+
"NewName": new_name,
|
|
403
|
+
}
|
|
404
|
+
resp = self._call("ModifyModelName", req_body)
|
|
405
|
+
return _make_model(resp["Model"])
|
|
406
|
+
|
|
407
|
+
def delete_registered_model(self, name):
|
|
408
|
+
log_msg("delete_register_model {}".format(name))
|
|
409
|
+
[catalog_name, schema_name, model_name] = self._split_model_name(name)
|
|
410
|
+
req_body = {
|
|
411
|
+
"CatalogName": catalog_name,
|
|
412
|
+
"SchemaName": schema_name,
|
|
413
|
+
"ModelName": model_name,
|
|
414
|
+
}
|
|
415
|
+
resp = self._call("DropModel", req_body)
|
|
416
|
+
if not resp["Dropped"]:
|
|
417
|
+
raise MlflowException("Failed to delete model {}".format(name))
|
|
418
|
+
|
|
419
|
+
def _fetch_all_models(self):
|
|
420
|
+
req_body = {"Offset": 0, "Limit": 200, "SnapshotBased": True, "SnapshotId": ""}
|
|
421
|
+
resp = self._call("SearchModels", req_body)
|
|
422
|
+
total = resp["TotalCount"]
|
|
423
|
+
model_list = resp["Models"]
|
|
424
|
+
while len(model_list) < total:
|
|
425
|
+
time.sleep(0.05)
|
|
426
|
+
req_body["Offset"] += req_body["Limit"]
|
|
427
|
+
req_body["SnapshotId"] = resp["SnapshotId"]
|
|
428
|
+
resp = self._call("SearchModels", req_body)
|
|
429
|
+
model_list.extend(resp["Models"])
|
|
430
|
+
if len(resp["Models"]) < req_body["Limit"]:
|
|
431
|
+
break
|
|
432
|
+
return [_make_model(model) for model in model_list]
|
|
433
|
+
|
|
434
|
+
def search_registered_models(
|
|
435
|
+
self, filter_string=None, max_results=None, order_by=None, page_token=None
|
|
436
|
+
):
|
|
437
|
+
log_msg(
|
|
438
|
+
"search_registered_models {} {} {} {}".format(
|
|
439
|
+
filter_string, max_results, order_by, page_token
|
|
440
|
+
)
|
|
441
|
+
)
|
|
442
|
+
if not isinstance(max_results, int) or max_results < 1:
|
|
443
|
+
raise MlflowException(
|
|
444
|
+
"Invalid value for max_results. It must be a positive integer,"
|
|
445
|
+
f" but got {max_results}",
|
|
446
|
+
INVALID_PARAMETER_VALUE,
|
|
447
|
+
)
|
|
448
|
+
|
|
449
|
+
if max_results > SEARCH_REGISTERED_MODEL_MAX_RESULTS_THRESHOLD:
|
|
450
|
+
raise MlflowException(
|
|
451
|
+
"Invalid value for request parameter max_results. It must be at most "
|
|
452
|
+
f"{SEARCH_REGISTERED_MODEL_MAX_RESULTS_THRESHOLD}, but got value {max_results}",
|
|
453
|
+
INVALID_PARAMETER_VALUE,
|
|
454
|
+
)
|
|
455
|
+
if page_token is None:
|
|
456
|
+
registered_models = self._fetch_all_models()
|
|
457
|
+
filtered_rms = SearchModelUtils.filter(registered_models, filter_string)
|
|
458
|
+
sorted_rms = SearchModelUtils.sort(filtered_rms, order_by)
|
|
459
|
+
if len(sorted_rms) == 0:
|
|
460
|
+
return PagedList([], None)
|
|
461
|
+
search_id = "model_" + str(uuid.uuid4())
|
|
462
|
+
page_token = _create_page_token(0, search_id)
|
|
463
|
+
self.cache[search_id] = sorted_rms
|
|
464
|
+
log_msg(
|
|
465
|
+
"search_registered_models add cache {} {} {} {} {} {}".format(
|
|
466
|
+
filter_string,
|
|
467
|
+
max_results,
|
|
468
|
+
order_by,
|
|
469
|
+
page_token,
|
|
470
|
+
search_id,
|
|
471
|
+
len(sorted_rms),
|
|
472
|
+
)
|
|
473
|
+
)
|
|
474
|
+
return self._get_page_list(page_token, max_results)
|
|
475
|
+
|
|
476
|
+
def get_registered_model(self, name):
|
|
477
|
+
log_msg("get_registered_model {}".format(name))
|
|
478
|
+
[catalog_name, schema_name, model_name] = self._split_model_name(name)
|
|
479
|
+
req = {
|
|
480
|
+
"CatalogName": catalog_name,
|
|
481
|
+
"SchemaName": schema_name,
|
|
482
|
+
"ModelName": model_name,
|
|
483
|
+
}
|
|
484
|
+
resp = self._call("DescribeModel", req)
|
|
485
|
+
return _make_model(resp["Model"])
|
|
486
|
+
|
|
487
|
+
def get_latest_versions(self, name, stages=None):
|
|
488
|
+
log_msg("get_latest_versions {} {}".format(name, stages))
|
|
489
|
+
if stages is not None:
|
|
490
|
+
raise NotImplementedError("Method not support stages")
|
|
491
|
+
[catalog_name, schema_name, model_name] = self._split_model_name(name)
|
|
492
|
+
req = {
|
|
493
|
+
"CatalogName": catalog_name,
|
|
494
|
+
"SchemaName": schema_name,
|
|
495
|
+
"ModelName": model_name,
|
|
496
|
+
}
|
|
497
|
+
resp = self._call("DescribeModelVersions", req)
|
|
498
|
+
return [_make_model_version(mv, name) for mv in resp["ModelVersions"]]
|
|
499
|
+
|
|
500
|
+
def set_registered_model_tag(self, name, tag):
|
|
501
|
+
log_msg("set_registered_model_tag {} {}".format(name, tag))
|
|
502
|
+
[catalog_name, schema_name, model_name] = self._split_model_name(name)
|
|
503
|
+
req = {
|
|
504
|
+
"CatalogName": catalog_name,
|
|
505
|
+
"SchemaName": schema_name,
|
|
506
|
+
"ModelName": model_name,
|
|
507
|
+
"Properties": _add_tag_to_properties(tag),
|
|
508
|
+
}
|
|
509
|
+
self._call("ModifyModelProperties", req)
|
|
510
|
+
|
|
511
|
+
def delete_registered_model_tag(self, name, key):
|
|
512
|
+
log_msg("delete_registered_model_tag {} {}".format(name, key))
|
|
513
|
+
[catalog_name, schema_name, model_name] = self._split_model_name(name)
|
|
514
|
+
req = {
|
|
515
|
+
"CatalogName": catalog_name,
|
|
516
|
+
"SchemaName": schema_name,
|
|
517
|
+
"ModelName": model_name,
|
|
518
|
+
"RemovedKeys": [_append_tag_key_prefix(key)],
|
|
519
|
+
}
|
|
520
|
+
self._call("ModifyModelProperties", req)
|
|
521
|
+
|
|
522
|
+
def create_model_version(
|
|
523
|
+
self,
|
|
524
|
+
name,
|
|
525
|
+
source,
|
|
526
|
+
run_id=None,
|
|
527
|
+
tags=None,
|
|
528
|
+
run_link=None,
|
|
529
|
+
description=None,
|
|
530
|
+
local_model_path=None,
|
|
531
|
+
model_id=None,
|
|
532
|
+
):
|
|
533
|
+
log_msg(
|
|
534
|
+
"create_model_version {} {} {} {} {} {} {}".format(
|
|
535
|
+
name, source, run_id, tags, run_link, description, local_model_path
|
|
536
|
+
)
|
|
537
|
+
)
|
|
538
|
+
[catalog_name, schema_name, model_name] = self._split_model_name(name)
|
|
539
|
+
version_alias = str(uuid.uuid4())
|
|
540
|
+
properties = []
|
|
541
|
+
_add_tags_to_properties(tags, properties)
|
|
542
|
+
_set_run_id_to_properties(run_id, properties)
|
|
543
|
+
_set_run_link_to_properties(run_link, properties)
|
|
544
|
+
_add_workspace_id_to_properties(self.workspace_id, properties)
|
|
545
|
+
_add_deployment_status_to_properties(UN_DEPLOYMENT, properties)
|
|
546
|
+
_add_model_id_to_properties(model_id, properties)
|
|
547
|
+
_add_model_signature_to_properties(source, properties)
|
|
548
|
+
|
|
549
|
+
req_body = {
|
|
550
|
+
"CatalogName": catalog_name,
|
|
551
|
+
"SchemaName": schema_name,
|
|
552
|
+
"ModelName": model_name,
|
|
553
|
+
"Uri": source,
|
|
554
|
+
"Comment": description if description else "",
|
|
555
|
+
"Properties": properties,
|
|
556
|
+
"Aliases": [version_alias],
|
|
557
|
+
}
|
|
558
|
+
self._call("CreateModelVersion", req_body)
|
|
559
|
+
return self._get_model_version_by_alias(name, version_alias)
|
|
560
|
+
|
|
561
|
+
def update_model_version(self, name, version, description):
|
|
562
|
+
log_msg("update_model_version {} {} {}".format(name, version, description))
|
|
563
|
+
[catalog_name, schema_name, model_name] = self._split_model_name(name)
|
|
564
|
+
req_body = {
|
|
565
|
+
"CatalogName": catalog_name,
|
|
566
|
+
"SchemaName": schema_name,
|
|
567
|
+
"ModelName": model_name,
|
|
568
|
+
"ModelVersion": _set_model_version(version),
|
|
569
|
+
"NewComment": description if description else "",
|
|
570
|
+
}
|
|
571
|
+
resp = self._call("ModifyModelVersionComment", req_body)
|
|
572
|
+
return _make_model_version(resp["ModelVersion"], name)
|
|
573
|
+
|
|
574
|
+
def transition_model_version_stage(
|
|
575
|
+
self, name, version, stage, archive_existing_versions
|
|
576
|
+
):
|
|
577
|
+
raise NotImplementedError("Method not implemented")
|
|
578
|
+
|
|
579
|
+
def delete_model_version(self, name, version):
|
|
580
|
+
log_msg("delete_model_version {} {}".format(name, version))
|
|
581
|
+
[catalog_name, schema_name, model_name] = self._split_model_name(name)
|
|
582
|
+
req_body = {
|
|
583
|
+
"CatalogName": catalog_name,
|
|
584
|
+
"SchemaName": schema_name,
|
|
585
|
+
"ModelName": model_name,
|
|
586
|
+
"ModelVersion": _set_model_version(version),
|
|
587
|
+
}
|
|
588
|
+
resp = self._call("DropModelVersion", req_body)
|
|
589
|
+
if not resp["Dropped"]:
|
|
590
|
+
raise Exception(
|
|
591
|
+
"Failed to delete model version {} {}".format(name, version)
|
|
592
|
+
)
|
|
593
|
+
|
|
594
|
+
def get_model_version(self, name, version):
|
|
595
|
+
log_msg("get_model_version {} {}".format(name, version))
|
|
596
|
+
[catalog_name, schema_name, model_name] = self._split_model_name(name)
|
|
597
|
+
req_body = {
|
|
598
|
+
"CatalogName": catalog_name,
|
|
599
|
+
"SchemaName": schema_name,
|
|
600
|
+
"ModelName": model_name,
|
|
601
|
+
"ModelVersion": _set_model_version(version),
|
|
602
|
+
}
|
|
603
|
+
resp = self._call("DescribeModelVersion", req_body)
|
|
604
|
+
return _make_model_version(resp["ModelVersion"], name)
|
|
605
|
+
|
|
606
|
+
def _fetch_all_model_versions(self):
|
|
607
|
+
req_body = {"Offset": 0, "Limit": 200, "SnapshotBased": True, "SnapshotId": ""}
|
|
608
|
+
resp = self._call("SearchModelVersions", req_body)
|
|
609
|
+
total = resp["TotalCount"]
|
|
610
|
+
model_version_list = resp["ModelVersions"]
|
|
611
|
+
while len(model_version_list) < total:
|
|
612
|
+
time.sleep(0.05)
|
|
613
|
+
req_body["Offset"] += req_body["Limit"]
|
|
614
|
+
req_body["SnapshotId"] = resp["SnapshotId"]
|
|
615
|
+
resp = self._call("SearchModelVersions", req_body)
|
|
616
|
+
model_version_list.extend(resp["ModelVersions"])
|
|
617
|
+
if len(resp["ModelVersions"]) < req_body["Limit"]:
|
|
618
|
+
break
|
|
619
|
+
return [
|
|
620
|
+
_make_model_version(model_version, _get_model_version_name(model_version))
|
|
621
|
+
for model_version in model_version_list
|
|
622
|
+
]
|
|
623
|
+
|
|
624
|
+
def search_model_versions(
|
|
625
|
+
self, filter_string=None, max_results=None, order_by=None, page_token=None
|
|
626
|
+
):
|
|
627
|
+
log_msg(
|
|
628
|
+
"search_model_versions {} {} {} {}".format(
|
|
629
|
+
filter_string, max_results, order_by, page_token
|
|
630
|
+
)
|
|
631
|
+
)
|
|
632
|
+
if not isinstance(max_results, int) or max_results < 1:
|
|
633
|
+
raise MlflowException(
|
|
634
|
+
"Invalid value for max_results. It must be a positive integer,"
|
|
635
|
+
f" but got {max_results}",
|
|
636
|
+
INVALID_PARAMETER_VALUE,
|
|
637
|
+
)
|
|
638
|
+
|
|
639
|
+
if max_results > SEARCH_MODEL_VERSION_MAX_RESULTS_THRESHOLD:
|
|
640
|
+
raise MlflowException(
|
|
641
|
+
"Invalid value for request parameter max_results. It must be at most "
|
|
642
|
+
f"{SEARCH_MODEL_VERSION_MAX_RESULTS_THRESHOLD}, but got value {max_results}",
|
|
643
|
+
INVALID_PARAMETER_VALUE,
|
|
644
|
+
)
|
|
645
|
+
if page_token is None:
|
|
646
|
+
model_versions = self._fetch_all_model_versions()
|
|
647
|
+
filtered_mvs = SearchModelVersionUtils.filter(model_versions, filter_string)
|
|
648
|
+
sorted_mvs = SearchModelVersionUtils.sort(
|
|
649
|
+
filtered_mvs,
|
|
650
|
+
order_by
|
|
651
|
+
or ["last_updated_timestamp DESC", "name ASC", "version_number DESC"],
|
|
652
|
+
)
|
|
653
|
+
if len(sorted_mvs) == 0:
|
|
654
|
+
return PagedList([], None)
|
|
655
|
+
search_id = "model_version_" + str(uuid.uuid4())
|
|
656
|
+
page_token = _create_page_token(0, search_id)
|
|
657
|
+
self.cache[search_id] = sorted_mvs
|
|
658
|
+
log_msg(
|
|
659
|
+
"search_model_versions add cache {} {} {} {} {} {}".format(
|
|
660
|
+
filter_string,
|
|
661
|
+
max_results,
|
|
662
|
+
order_by,
|
|
663
|
+
page_token,
|
|
664
|
+
search_id,
|
|
665
|
+
len(sorted_mvs),
|
|
666
|
+
)
|
|
667
|
+
)
|
|
668
|
+
|
|
669
|
+
return self._get_page_list(page_token, max_results)
|
|
670
|
+
|
|
671
|
+
def _get_page_list(self, page_token, max_results):
|
|
672
|
+
token_info = _parse_page_token(page_token)
|
|
673
|
+
log_msg(
|
|
674
|
+
"_get_page_list token_info {} {} {}".format(
|
|
675
|
+
page_token, max_results, token_info
|
|
676
|
+
)
|
|
677
|
+
)
|
|
678
|
+
sorted_mvs = self.cache.get(token_info["search_id"])
|
|
679
|
+
if sorted_mvs is None:
|
|
680
|
+
raise MlflowException(
|
|
681
|
+
"Invalid page token: search id not found or expired",
|
|
682
|
+
INVALID_PARAMETER_VALUE,
|
|
683
|
+
)
|
|
684
|
+
start_offset = token_info["offset"]
|
|
685
|
+
final_offset = start_offset + max_results
|
|
686
|
+
|
|
687
|
+
paginated_rms = sorted_mvs[start_offset : min(len(sorted_mvs), final_offset)]
|
|
688
|
+
next_page_token = None
|
|
689
|
+
if final_offset < len(sorted_mvs):
|
|
690
|
+
next_page_token = _create_page_token(final_offset, token_info["search_id"])
|
|
691
|
+
else:
|
|
692
|
+
self.cache.pop(token_info["search_id"], None)
|
|
693
|
+
log_msg(
|
|
694
|
+
"pop cache {} {} {} {} {} {}".format(
|
|
695
|
+
token_info["search_id"],
|
|
696
|
+
start_offset,
|
|
697
|
+
final_offset,
|
|
698
|
+
len(sorted_mvs),
|
|
699
|
+
page_token,
|
|
700
|
+
next_page_token,
|
|
701
|
+
)
|
|
702
|
+
)
|
|
703
|
+
return PagedList(paginated_rms, next_page_token)
|
|
704
|
+
|
|
705
|
+
def set_model_version_tag(self, name, version, tag):
|
|
706
|
+
log_msg("set_model_version_tag {} {} {}".format(name, version, tag))
|
|
707
|
+
[catalog_name, schema_name, model_name] = self._split_model_name(name)
|
|
708
|
+
req = {
|
|
709
|
+
"CatalogName": catalog_name,
|
|
710
|
+
"SchemaName": schema_name,
|
|
711
|
+
"ModelName": model_name,
|
|
712
|
+
"ModelVersion": _set_model_version(version),
|
|
713
|
+
"Properties": _add_tag_to_properties(tag),
|
|
714
|
+
}
|
|
715
|
+
self._call("ModifyModelVersionProperties", req)
|
|
716
|
+
|
|
717
|
+
def delete_model_version_tag(self, name, version, key):
|
|
718
|
+
log_msg("delete_model_version_tag {} {} {}".format(name, version, key))
|
|
719
|
+
[catalog_name, schema_name, model_name] = self._split_model_name(name)
|
|
720
|
+
req = {
|
|
721
|
+
"CatalogName": catalog_name,
|
|
722
|
+
"SchemaName": schema_name,
|
|
723
|
+
"ModelName": model_name,
|
|
724
|
+
"ModelVersion": _set_model_version(version),
|
|
725
|
+
"RemovedKeys": [_append_tag_key_prefix(key)],
|
|
726
|
+
}
|
|
727
|
+
self._call("ModifyModelVersionProperties", req)
|
|
728
|
+
|
|
729
|
+
def set_registered_model_alias(self, name, alias, version):
|
|
730
|
+
raise NotImplementedError("Method not implemented")
|
|
731
|
+
|
|
732
|
+
def delete_registered_model_alias(self, name, alias):
|
|
733
|
+
raise NotImplementedError("Method not implemented")
|
|
734
|
+
|
|
735
|
+
def get_model_version_by_alias(self, name, alias):
|
|
736
|
+
log_msg("get_model_version_by_alias {} {}".format(name, alias))
|
|
737
|
+
return self._get_model_version_by_alias(name, alias)
|
|
738
|
+
|
|
739
|
+
def get_model_version_download_uri(self, name, version):
|
|
740
|
+
log_msg("get_model_version_download_uri {} {}".format(name, version))
|
|
741
|
+
model_version = self.get_model_version(name, version)
|
|
742
|
+
return model_version.source
|
|
743
|
+
|
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
mlflow_tclake_plugin/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
|
+
mlflow_tclake_plugin/tclake_store.py,sha256=Ys3giqvpdth7gwZ7O3TeaQH1OJH08XUpqjRkOVIM9JM,28301
|
|
3
|
+
mlflow_tclake_plugin-0.0.1.dist-info/METADATA,sha256=My3rL7JEx5v_HsA42LimQvwHJU1ms0Qa2mvHlYGZ70s,223
|
|
4
|
+
mlflow_tclake_plugin-0.0.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
5
|
+
mlflow_tclake_plugin-0.0.1.dist-info/entry_points.txt,sha256=xKXky9-NyJWEG1SgBknrqskN2rsZG3t4UXiSzkTcsuE,85
|
|
6
|
+
mlflow_tclake_plugin-0.0.1.dist-info/top_level.txt,sha256=zrA5UNyfF3skRmxPNsvrJI3yf-um6CJX_xO5KvWc2o0,21
|
|
7
|
+
mlflow_tclake_plugin-0.0.1.dist-info/RECORD,,
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
mlflow_tclake_plugin
|