mlflow-tclake-plugin 0.0.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mlflow-tclake-plugin might be problematic. Click here for more details.

@@ -0,0 +1,8 @@
1
+ Metadata-Version: 2.4
2
+ Name: mlflow-tclake-plugin
3
+ Version: 0.0.1
4
+ Summary: Tclake plugin for MLflow
5
+ Requires-Dist: mlflow>=3.1.4
6
+ Requires-Dist: tencentcloud-sdk-python-common>=3.0.1478
7
+ Dynamic: requires-dist
8
+ Dynamic: summary
@@ -0,0 +1,36 @@
1
+ # MLflow Tclake 插件
2
+
3
+ 本插件将 [Tclake](https://cloud.tencent.com/product/tclake) 集成为 MLflow 的模型注册表存储后端,实现机器学习模型在腾讯云数据湖中的无缝存储与管理。
4
+
5
+ ## 功能特性
6
+
7
+ - ✅ **模型注册表存储**:使用 Tclake 作为 MLflow 模型注册表的后端存储
8
+ - 🌐 **云服务集成**:原生支持腾讯云 SDK
9
+ - 🔒 **安全访问**:基于腾讯云认证机制的安全访问控制
10
+
11
+ ## 安装方法
12
+
13
+ ```bash
14
+ cd myplugin/mlflow-tclake-plugin
15
+ pip install -e .
16
+ ```
17
+ 安装后使用mlflow client时会自动加载该插件
18
+
19
+
20
+ ## 系统要求
21
+
22
+ - Python 3.9+
23
+ - MLflow >= 3.1.4
24
+ - Tencent Cloud SDK >= 3.0.1478
25
+
26
+ ## 开发指南
27
+
28
+ 1. 安装开发依赖:
29
+ ```bash
30
+ pip install -r requirements.txt
31
+ ```
32
+
33
+ 2. 运行测试:
34
+ ```bash
35
+ pytest mlflow_tclake_plugin/test/
36
+ ```
@@ -0,0 +1,743 @@
1
+ import base64
2
+ import json
3
+ import os
4
+ import time
5
+ import uuid
6
+ import urllib.parse
7
+ from datetime import datetime
8
+
9
+ from cachetools import TTLCache
10
+ from tencentcloud.common import credential
11
+ from tencentcloud.common.common_client import CommonClient
12
+ from tencentcloud.common.profile.client_profile import ClientProfile
13
+
14
+ from mlflow.entities.model_registry import (
15
+ RegisteredModel,
16
+ ModelVersion,
17
+ ModelVersionTag,
18
+ )
19
+ from mlflow.exceptions import MlflowException
20
+ from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE, RESOURCE_ALREADY_EXISTS
21
+ from mlflow.store.model_registry import (
22
+ SEARCH_REGISTERED_MODEL_MAX_RESULTS_THRESHOLD,
23
+ SEARCH_MODEL_VERSION_MAX_RESULTS_THRESHOLD,
24
+ )
25
+ from mlflow.store.model_registry.abstract_store import AbstractStore
26
+ from mlflow.utils.annotations import experimental
27
+ from mlflow.utils.search_utils import SearchModelUtils, SearchModelVersionUtils
28
+ from mlflow.store.entities.paged_list import PagedList
29
+ from mlflow.models.model import get_model_info
30
+ from tencentcloud.common.exception.tencent_cloud_sdk_exception import TencentCloudSDKException
31
+
32
+ def to_string(obj):
33
+ if obj is None:
34
+ return "None"
35
+ if isinstance(obj, str):
36
+ return obj
37
+ if isinstance(obj, dict):
38
+ return json.dumps(obj, indent=2)
39
+ try:
40
+ return json.dumps(obj, indent=2)
41
+ except:
42
+ return str(obj)
43
+
44
+
45
+ tencent_cloud_debug = os.getenv("TENCENTCLOUD_DEBUG", None)
46
+
47
+
48
+ def log_msg(msg):
49
+ if tencent_cloud_debug:
50
+ print(msg)
51
+
52
+
53
+ def _get_create_time_from_audit(audit):
54
+ dt_str = audit["CreatedTime"]
55
+ dt = datetime.strptime(dt_str, "%Y-%m-%d %H:%M:%S")
56
+ return int(dt.timestamp() * 1000)
57
+
58
+
59
+ def _get_last_updated_time_from_audit(audit):
60
+ dt_str = audit["LastModifiedTime"]
61
+ if dt_str is None or len(dt_str) == 0:
62
+ return _get_create_time_from_audit(audit)
63
+ dt = datetime.strptime(dt_str, "%Y-%m-%d %H:%M:%S")
64
+ return int(dt.timestamp() * 1000)
65
+
66
+
67
+ def _get_description(resp):
68
+ if resp is None or "Comment" not in resp:
69
+ return None
70
+ return resp["Comment"]
71
+
72
+
73
+ TCLAKE_MLFLOW_TAG_PREFIX = "tclake.tag."
74
+ TCLAKE_MLFLOW_RUN_ID_KEY = "tclake.mlflow.run_id"
75
+ TCLAKE_MLFLOW_RUN_LINK_KEY = "tclake.mlflow.run_link"
76
+ TCLAKE_MLFLOW_DEPLOYMENT_JOB_ID_KEY = "tclake.mlflow.deployment_job_id"
77
+ TCLAKE_WEDATA_WORKSPACE_ID_KEY = "tclake.wedata.workspace_id"
78
+ TCLAKE_MLFLOW_DEPLOYMENT_STATUS_KEY = "tclake.mlflow.deployment_status"
79
+ TCLAKE_MLFLOW_MODEL_ID_KEY = "tclake.mlflow.model_id"
80
+ TCLAKE_MLFLOW_MODEL_SIGNATURE_KEY = "tclake.mlflow.model_signature"
81
+
82
+ UN_DEPLOYMENT = "UnDeployment"
83
+
84
+ def _set_kv_to_properties(key, value, properties=None):
85
+ if properties is None:
86
+ properties = []
87
+ if value is not None:
88
+ properties.append({"Key": key, "Value": value})
89
+ return properties
90
+
91
+
92
+ def _get_kv_from_properties(properties, key):
93
+ if properties is None:
94
+ return None
95
+ for p in properties:
96
+ if p["Key"] == key:
97
+ return p["Value"]
98
+ return None
99
+
100
+
101
+ def _set_run_id_to_properties(run_id, properties=None):
102
+ return _set_kv_to_properties(TCLAKE_MLFLOW_RUN_ID_KEY, run_id, properties)
103
+
104
+
105
+ def _get_run_id_from_properties(properties):
106
+ return _get_kv_from_properties(properties, TCLAKE_MLFLOW_RUN_ID_KEY)
107
+
108
+
109
+ def _set_run_link_to_properties(run_link, properties):
110
+ return _set_kv_to_properties(TCLAKE_MLFLOW_RUN_LINK_KEY, run_link, properties)
111
+
112
+
113
+ def _get_run_link_from_properties(properties):
114
+ return _get_kv_from_properties(properties, TCLAKE_MLFLOW_RUN_LINK_KEY)
115
+
116
+ def _get_model_id_from_properties(properties):
117
+ return _get_kv_from_properties(properties, TCLAKE_MLFLOW_MODEL_ID_KEY)
118
+
119
+
120
+ def _add_tag_to_properties(tag, properties=None):
121
+ if tag and tag.value is not None:
122
+ properties = _set_kv_to_properties(
123
+ TCLAKE_MLFLOW_TAG_PREFIX + tag.key, tag.value, properties
124
+ )
125
+ return properties
126
+
127
+
128
+ def _add_tags_to_properties(tags, properties=None):
129
+ if tags:
130
+ for tag in tags:
131
+ properties = _add_tag_to_properties(tag, properties)
132
+ return properties
133
+
134
+ def _add_deployment_job_id_to_properties(deployment_job_id, properties=None):
135
+ properties = _set_kv_to_properties(TCLAKE_MLFLOW_DEPLOYMENT_JOB_ID_KEY, deployment_job_id, properties)
136
+ return properties
137
+
138
+ def _add_workspace_id_to_properties(workspace_id, properties=None):
139
+ if workspace_id is not None:
140
+ properties = _set_kv_to_properties(TCLAKE_WEDATA_WORKSPACE_ID_KEY, workspace_id, properties)
141
+ return properties
142
+
143
+ def _add_deployment_status_to_properties(status, properties=None):
144
+ if status is not None:
145
+ properties = _set_kv_to_properties(TCLAKE_MLFLOW_DEPLOYMENT_STATUS_KEY, status, properties)
146
+ return properties
147
+
148
+ def _add_model_id_to_properties(model_id, properties=None):
149
+ if model_id is not None:
150
+ properties = _set_kv_to_properties(TCLAKE_MLFLOW_MODEL_ID_KEY, model_id, properties)
151
+ return properties
152
+
153
+
154
+ def _add_model_signature_to_properties(source, properties):
155
+ log_msg("model source is {}".format(source))
156
+ model_info = get_model_info(source)
157
+ signature = model_info.signature
158
+ log_msg("model {} signature is {}".format(source, signature))
159
+ if signature is None:
160
+ raise MlflowException(f"Registered model signature is not found in source artifact location '{source}'")
161
+ sig_json = json.dumps(signature.to_dict())
162
+ log_msg("model {} signature json is {}".format(source, sig_json))
163
+ properties = _set_kv_to_properties(TCLAKE_MLFLOW_MODEL_SIGNATURE_KEY, sig_json, properties)
164
+ return properties
165
+
166
+ def _get_tags_from_properties(properties):
167
+ if properties is None:
168
+ return None
169
+ tags = []
170
+ for p in properties:
171
+ if p["Key"].startswith(TCLAKE_MLFLOW_TAG_PREFIX):
172
+ tags.append(
173
+ ModelVersionTag(p["Key"][len(TCLAKE_MLFLOW_TAG_PREFIX) :], p["Value"])
174
+ )
175
+ return tags
176
+
177
+
178
+ def _append_tag_key_prefix(key):
179
+ return TCLAKE_MLFLOW_TAG_PREFIX + key
180
+
181
+
182
+ def _get_model_version(version):
183
+ return str(version)
184
+
185
+
186
+ def _set_model_version(version):
187
+ return int(version)
188
+
189
+
190
+ def _make_model(resp):
191
+ properties = resp["Properties"]
192
+ audit = resp["Audit"]
193
+ return RegisteredModel(
194
+ name=resp["Name"],
195
+ creation_timestamp=_get_create_time_from_audit(audit),
196
+ last_updated_timestamp=_get_last_updated_time_from_audit(audit),
197
+ description=_get_description(resp),
198
+ tags=_get_tags_from_properties(properties)
199
+ )
200
+
201
+
202
+ def _get_model_version_name(entity):
203
+ return "{}.{}.{}".format(
204
+ entity["CatalogName"], entity["SchemaName"], entity["ModelName"]
205
+ )
206
+
207
+
208
+ def _make_model_version(entity, name):
209
+ properties = entity["Properties"]
210
+ audit = entity["Audit"]
211
+ aliases = entity["Aliases"]
212
+ return ModelVersion(
213
+ name=name,
214
+ version=_get_model_version(entity["Version"]),
215
+ creation_timestamp=_get_create_time_from_audit(audit),
216
+ last_updated_timestamp=_get_last_updated_time_from_audit(audit),
217
+ description=_get_description(entity),
218
+ source=entity["Uri"],
219
+ run_id=_get_run_id_from_properties(properties),
220
+ tags=_get_tags_from_properties(properties),
221
+ run_link=_get_run_link_from_properties(properties),
222
+ aliases=aliases,
223
+ model_id=_get_model_id_from_properties(properties),
224
+ status="READY",
225
+ )
226
+
227
+
228
+ def get_tencent_cloud_headers():
229
+ header_json = os.getenv("TENCENTCLOUD_HEADER_JSON", None)
230
+ if header_json is None:
231
+ return None
232
+ return json.loads(header_json)
233
+
234
+
235
+ def get_tencent_cloud_client_profile():
236
+ endpoint = os.getenv("TENCENTCLOUD_ENDPOINT", None)
237
+ if endpoint is None:
238
+ return None
239
+ client_profile = ClientProfile()
240
+ client_profile.httpProfile.endpoint = endpoint
241
+ return client_profile
242
+
243
+
244
+ def _parse_page_token(page_token):
245
+ decoded_token = base64.b64decode(page_token)
246
+ parsed_token = json.loads(decoded_token)
247
+ return parsed_token
248
+
249
+
250
+ def _create_page_token(offset, search_id):
251
+ return base64.b64encode(
252
+ json.dumps({"offset": offset, "search_id": search_id}).encode("utf-8")
253
+ )
254
+
255
+
256
+ @experimental
257
+ class TCLakeStore(AbstractStore):
258
+ """
259
+ Client for an Open Source Unity Catalog Server accessed via REST API calls.
260
+ """
261
+
262
+ def __init__(self, store_uri=None, tracking_uri=None):
263
+ super().__init__(store_uri, tracking_uri)
264
+ log_msg(
265
+ "initializing tencent tclake client {} {}".format(store_uri, tracking_uri)
266
+ )
267
+ sid = os.getenv("TENCENTCLOUD_SECRET_ID", "")
268
+ if len(sid) == 0:
269
+ raise MlflowException("TENCENTCLOUD_SECRET_ID is not set")
270
+ sk = os.getenv("TENCENTCLOUD_SECRET_KEY", "")
271
+ if len(sk) == 0:
272
+ raise MlflowException("TENCENTCLOUD_SECRET_KEY is not set")
273
+ sk = os.getenv("TENCENTCLOUD_SECRET_KEY", "")
274
+ if len(sk) == 0:
275
+ raise MlflowException("TENCENTCLOUD_SECRET_KEY is not set")
276
+
277
+ token = os.getenv("TENCENTCLOUD_TOKEN", None)
278
+
279
+ self.workspace_id = os.getenv("WEDATA_WORKSPACE_ID", "")
280
+ if len(self.workspace_id) == 0:
281
+ raise MlflowException("WEDATA_WORKSPACE_ID is not set")
282
+ log_msg(str.format("wedata workspace id: {}", self.workspace_id))
283
+
284
+ client_profile = get_tencent_cloud_client_profile()
285
+ cred = credential.Credential(sid, sk, token)
286
+ parts = store_uri.split(":")
287
+ if len(parts) < 2:
288
+ raise MlflowException("set store_uri tclake:{region}")
289
+ region = parts[1]
290
+ self.client = CommonClient(
291
+ "tccatalog", "2024-10-24", cred, region, client_profile
292
+ )
293
+ self.headers = get_tencent_cloud_headers()
294
+ self.default_catalog_name = os.getenv(
295
+ "TENCENTCLOUD_DEFAULT_CATALOG_NAME", "default"
296
+ )
297
+ self.default_schema_name = os.getenv(
298
+ "TENCENTCLOUD_DEFAULT_SCHEMA_NAME", "default"
299
+ )
300
+ cache_size = int(os.getenv("TCLAKE_CACHE_SIZE", "100"))
301
+ cache_ttl = int(os.getenv("TCLAKE_CACHE_TTL_SECS", "300"))
302
+ self.cache = TTLCache(maxsize=cache_size, ttl=cache_ttl)
303
+ log_msg(
304
+ "initialized tencent tclake client successfully {} {} {} {} {}".format(
305
+ region, client_profile, self.headers, cache_size, cache_ttl
306
+ )
307
+ )
308
+
309
+ def _split_model_name(self, name):
310
+ parts = name.split(".")
311
+ if len(parts) == 1:
312
+ return [self.default_catalog_name, self.default_schema_name, parts[0]]
313
+ if len(parts) == 2:
314
+ return [self.default_catalog_name, parts[0], parts[1]]
315
+ if len(parts) == 3:
316
+ return parts
317
+ raise MlflowException(
318
+ "invalid model name: {}, must be catalog.schema.model".format(name)
319
+ )
320
+
321
+ def _call(self, action, req):
322
+ log_msg("req: {}\n{}".format(action, json.dumps(req, indent=2)))
323
+ body = self.client.call(action, req, headers=self.headers)
324
+ body_obj = json.loads(body)
325
+ log_msg("body: {}\n{}".format(action, json.dumps(body_obj, indent=2)))
326
+ resp = body_obj["Response"]
327
+ return resp
328
+
329
+ def _get_model_version_numbers(self, name):
330
+ [catalog_name, schema_name, model_name] = self._split_model_name(name)
331
+ req_body = {
332
+ "CatalogName": catalog_name,
333
+ "SchemaName": schema_name,
334
+ "ModelName": model_name,
335
+ }
336
+ resp = self._call("DescribeModelVersionNumbers", req_body)
337
+ return resp["Versions"]
338
+
339
+ def _get_model_version_by_alias(self, name, alias):
340
+ [catalog_name, schema_name, model_name] = self._split_model_name(name)
341
+ req_body = {
342
+ "CatalogName": catalog_name,
343
+ "SchemaName": schema_name,
344
+ "ModelName": model_name,
345
+ "ModelVersions": self._get_model_version_numbers(name),
346
+ }
347
+ resp = self._call("DescribeModelVersions", req_body)
348
+ model_version = None
349
+ for mv in resp["ModelVersions"]:
350
+ if alias in mv["Aliases"]:
351
+ model_version = mv
352
+ break
353
+ if model_version is None:
354
+ return None
355
+ return _make_model_version(model_version, name)
356
+
357
+ def create_registered_model(self, name, tags=None, description=None, deployment_job_id=None):
358
+ log_msg("create_registered_model {} {} {}".format(name, tags, description))
359
+ [catalog_name, schema_name, model_name] = self._split_model_name(name)
360
+ properties = []
361
+ _add_tags_to_properties(tags, properties)
362
+ _add_deployment_job_id_to_properties(deployment_job_id, properties)
363
+ _add_workspace_id_to_properties(self.workspace_id, properties)
364
+ req_body = {
365
+ "CatalogName": catalog_name,
366
+ "SchemaName": schema_name,
367
+ "ModelName": model_name,
368
+ "Comment": description if description else "",
369
+ "Properties": properties,
370
+ }
371
+ try:
372
+ resp = self._call("CreateModel", req_body)
373
+ except Exception as e:
374
+ if isinstance(e, TencentCloudSDKException):
375
+ if e.code == "FailedOperation.MetalakeAlreadyExistsError":
376
+ raise MlflowException(
377
+ f"Registered Model (name={name}) already exists.",
378
+ RESOURCE_ALREADY_EXISTS,
379
+ )
380
+ raise
381
+ return _make_model(resp["Model"])
382
+
383
+ def update_registered_model(self, name, description):
384
+ log_msg("update_register_model {} {}".format(name, description))
385
+ [catalog_name, schema_name, model_name] = self._split_model_name(name)
386
+ req_body = {
387
+ "CatalogName": catalog_name,
388
+ "SchemaName": schema_name,
389
+ "ModelName": model_name,
390
+ "NewComment": description if description else "",
391
+ }
392
+ resp = self._call("ModifyModelComment", req_body)
393
+ return _make_model(resp["Model"])
394
+
395
+ def rename_registered_model(self, name, new_name):
396
+ log_msg("rename_register_model {} {}".format(name, new_name))
397
+ [catalog_name, schema_name, model_name] = self._split_model_name(name)
398
+ req_body = {
399
+ "CatalogName": catalog_name,
400
+ "SchemaName": schema_name,
401
+ "ModelName": model_name,
402
+ "NewName": new_name,
403
+ }
404
+ resp = self._call("ModifyModelName", req_body)
405
+ return _make_model(resp["Model"])
406
+
407
+ def delete_registered_model(self, name):
408
+ log_msg("delete_register_model {}".format(name))
409
+ [catalog_name, schema_name, model_name] = self._split_model_name(name)
410
+ req_body = {
411
+ "CatalogName": catalog_name,
412
+ "SchemaName": schema_name,
413
+ "ModelName": model_name,
414
+ }
415
+ resp = self._call("DropModel", req_body)
416
+ if not resp["Dropped"]:
417
+ raise MlflowException("Failed to delete model {}".format(name))
418
+
419
+ def _fetch_all_models(self):
420
+ req_body = {"Offset": 0, "Limit": 200, "SnapshotBased": True, "SnapshotId": ""}
421
+ resp = self._call("SearchModels", req_body)
422
+ total = resp["TotalCount"]
423
+ model_list = resp["Models"]
424
+ while len(model_list) < total:
425
+ time.sleep(0.05)
426
+ req_body["Offset"] += req_body["Limit"]
427
+ req_body["SnapshotId"] = resp["SnapshotId"]
428
+ resp = self._call("SearchModels", req_body)
429
+ model_list.extend(resp["Models"])
430
+ if len(resp["Models"]) < req_body["Limit"]:
431
+ break
432
+ return [_make_model(model) for model in model_list]
433
+
434
+ def search_registered_models(
435
+ self, filter_string=None, max_results=None, order_by=None, page_token=None
436
+ ):
437
+ log_msg(
438
+ "search_registered_models {} {} {} {}".format(
439
+ filter_string, max_results, order_by, page_token
440
+ )
441
+ )
442
+ if not isinstance(max_results, int) or max_results < 1:
443
+ raise MlflowException(
444
+ "Invalid value for max_results. It must be a positive integer,"
445
+ f" but got {max_results}",
446
+ INVALID_PARAMETER_VALUE,
447
+ )
448
+
449
+ if max_results > SEARCH_REGISTERED_MODEL_MAX_RESULTS_THRESHOLD:
450
+ raise MlflowException(
451
+ "Invalid value for request parameter max_results. It must be at most "
452
+ f"{SEARCH_REGISTERED_MODEL_MAX_RESULTS_THRESHOLD}, but got value {max_results}",
453
+ INVALID_PARAMETER_VALUE,
454
+ )
455
+ if page_token is None:
456
+ registered_models = self._fetch_all_models()
457
+ filtered_rms = SearchModelUtils.filter(registered_models, filter_string)
458
+ sorted_rms = SearchModelUtils.sort(filtered_rms, order_by)
459
+ if len(sorted_rms) == 0:
460
+ return PagedList([], None)
461
+ search_id = "model_" + str(uuid.uuid4())
462
+ page_token = _create_page_token(0, search_id)
463
+ self.cache[search_id] = sorted_rms
464
+ log_msg(
465
+ "search_registered_models add cache {} {} {} {} {} {}".format(
466
+ filter_string,
467
+ max_results,
468
+ order_by,
469
+ page_token,
470
+ search_id,
471
+ len(sorted_rms),
472
+ )
473
+ )
474
+ return self._get_page_list(page_token, max_results)
475
+
476
+ def get_registered_model(self, name):
477
+ log_msg("get_registered_model {}".format(name))
478
+ [catalog_name, schema_name, model_name] = self._split_model_name(name)
479
+ req = {
480
+ "CatalogName": catalog_name,
481
+ "SchemaName": schema_name,
482
+ "ModelName": model_name,
483
+ }
484
+ resp = self._call("DescribeModel", req)
485
+ return _make_model(resp["Model"])
486
+
487
+ def get_latest_versions(self, name, stages=None):
488
+ log_msg("get_latest_versions {} {}".format(name, stages))
489
+ if stages is not None:
490
+ raise NotImplementedError("Method not support stages")
491
+ [catalog_name, schema_name, model_name] = self._split_model_name(name)
492
+ req = {
493
+ "CatalogName": catalog_name,
494
+ "SchemaName": schema_name,
495
+ "ModelName": model_name,
496
+ }
497
+ resp = self._call("DescribeModelVersions", req)
498
+ return [_make_model_version(mv, name) for mv in resp["ModelVersions"]]
499
+
500
+ def set_registered_model_tag(self, name, tag):
501
+ log_msg("set_registered_model_tag {} {}".format(name, tag))
502
+ [catalog_name, schema_name, model_name] = self._split_model_name(name)
503
+ req = {
504
+ "CatalogName": catalog_name,
505
+ "SchemaName": schema_name,
506
+ "ModelName": model_name,
507
+ "Properties": _add_tag_to_properties(tag),
508
+ }
509
+ self._call("ModifyModelProperties", req)
510
+
511
+ def delete_registered_model_tag(self, name, key):
512
+ log_msg("delete_registered_model_tag {} {}".format(name, key))
513
+ [catalog_name, schema_name, model_name] = self._split_model_name(name)
514
+ req = {
515
+ "CatalogName": catalog_name,
516
+ "SchemaName": schema_name,
517
+ "ModelName": model_name,
518
+ "RemovedKeys": [_append_tag_key_prefix(key)],
519
+ }
520
+ self._call("ModifyModelProperties", req)
521
+
522
+ def create_model_version(
523
+ self,
524
+ name,
525
+ source,
526
+ run_id=None,
527
+ tags=None,
528
+ run_link=None,
529
+ description=None,
530
+ local_model_path=None,
531
+ model_id=None,
532
+ ):
533
+ log_msg(
534
+ "create_model_version {} {} {} {} {} {} {}".format(
535
+ name, source, run_id, tags, run_link, description, local_model_path
536
+ )
537
+ )
538
+ [catalog_name, schema_name, model_name] = self._split_model_name(name)
539
+ version_alias = str(uuid.uuid4())
540
+ properties = []
541
+ _add_tags_to_properties(tags, properties)
542
+ _set_run_id_to_properties(run_id, properties)
543
+ _set_run_link_to_properties(run_link, properties)
544
+ _add_workspace_id_to_properties(self.workspace_id, properties)
545
+ _add_deployment_status_to_properties(UN_DEPLOYMENT, properties)
546
+ _add_model_id_to_properties(model_id, properties)
547
+ _add_model_signature_to_properties(source, properties)
548
+
549
+ req_body = {
550
+ "CatalogName": catalog_name,
551
+ "SchemaName": schema_name,
552
+ "ModelName": model_name,
553
+ "Uri": source,
554
+ "Comment": description if description else "",
555
+ "Properties": properties,
556
+ "Aliases": [version_alias],
557
+ }
558
+ self._call("CreateModelVersion", req_body)
559
+ return self._get_model_version_by_alias(name, version_alias)
560
+
561
+ def update_model_version(self, name, version, description):
562
+ log_msg("update_model_version {} {} {}".format(name, version, description))
563
+ [catalog_name, schema_name, model_name] = self._split_model_name(name)
564
+ req_body = {
565
+ "CatalogName": catalog_name,
566
+ "SchemaName": schema_name,
567
+ "ModelName": model_name,
568
+ "ModelVersion": _set_model_version(version),
569
+ "NewComment": description if description else "",
570
+ }
571
+ resp = self._call("ModifyModelVersionComment", req_body)
572
+ return _make_model_version(resp["ModelVersion"], name)
573
+
574
+ def transition_model_version_stage(
575
+ self, name, version, stage, archive_existing_versions
576
+ ):
577
+ raise NotImplementedError("Method not implemented")
578
+
579
+ def delete_model_version(self, name, version):
580
+ log_msg("delete_model_version {} {}".format(name, version))
581
+ [catalog_name, schema_name, model_name] = self._split_model_name(name)
582
+ req_body = {
583
+ "CatalogName": catalog_name,
584
+ "SchemaName": schema_name,
585
+ "ModelName": model_name,
586
+ "ModelVersion": _set_model_version(version),
587
+ }
588
+ resp = self._call("DropModelVersion", req_body)
589
+ if not resp["Dropped"]:
590
+ raise Exception(
591
+ "Failed to delete model version {} {}".format(name, version)
592
+ )
593
+
594
+ def get_model_version(self, name, version):
595
+ log_msg("get_model_version {} {}".format(name, version))
596
+ [catalog_name, schema_name, model_name] = self._split_model_name(name)
597
+ req_body = {
598
+ "CatalogName": catalog_name,
599
+ "SchemaName": schema_name,
600
+ "ModelName": model_name,
601
+ "ModelVersion": _set_model_version(version),
602
+ }
603
+ resp = self._call("DescribeModelVersion", req_body)
604
+ return _make_model_version(resp["ModelVersion"], name)
605
+
606
+ def _fetch_all_model_versions(self):
607
+ req_body = {"Offset": 0, "Limit": 200, "SnapshotBased": True, "SnapshotId": ""}
608
+ resp = self._call("SearchModelVersions", req_body)
609
+ total = resp["TotalCount"]
610
+ model_version_list = resp["ModelVersions"]
611
+ while len(model_version_list) < total:
612
+ time.sleep(0.05)
613
+ req_body["Offset"] += req_body["Limit"]
614
+ req_body["SnapshotId"] = resp["SnapshotId"]
615
+ resp = self._call("SearchModelVersions", req_body)
616
+ model_version_list.extend(resp["ModelVersions"])
617
+ if len(resp["ModelVersions"]) < req_body["Limit"]:
618
+ break
619
+ return [
620
+ _make_model_version(model_version, _get_model_version_name(model_version))
621
+ for model_version in model_version_list
622
+ ]
623
+
624
+ def search_model_versions(
625
+ self, filter_string=None, max_results=None, order_by=None, page_token=None
626
+ ):
627
+ log_msg(
628
+ "search_model_versions {} {} {} {}".format(
629
+ filter_string, max_results, order_by, page_token
630
+ )
631
+ )
632
+ if not isinstance(max_results, int) or max_results < 1:
633
+ raise MlflowException(
634
+ "Invalid value for max_results. It must be a positive integer,"
635
+ f" but got {max_results}",
636
+ INVALID_PARAMETER_VALUE,
637
+ )
638
+
639
+ if max_results > SEARCH_MODEL_VERSION_MAX_RESULTS_THRESHOLD:
640
+ raise MlflowException(
641
+ "Invalid value for request parameter max_results. It must be at most "
642
+ f"{SEARCH_MODEL_VERSION_MAX_RESULTS_THRESHOLD}, but got value {max_results}",
643
+ INVALID_PARAMETER_VALUE,
644
+ )
645
+ if page_token is None:
646
+ model_versions = self._fetch_all_model_versions()
647
+ filtered_mvs = SearchModelVersionUtils.filter(model_versions, filter_string)
648
+ sorted_mvs = SearchModelVersionUtils.sort(
649
+ filtered_mvs,
650
+ order_by
651
+ or ["last_updated_timestamp DESC", "name ASC", "version_number DESC"],
652
+ )
653
+ if len(sorted_mvs) == 0:
654
+ return PagedList([], None)
655
+ search_id = "model_version_" + str(uuid.uuid4())
656
+ page_token = _create_page_token(0, search_id)
657
+ self.cache[search_id] = sorted_mvs
658
+ log_msg(
659
+ "search_model_versions add cache {} {} {} {} {} {}".format(
660
+ filter_string,
661
+ max_results,
662
+ order_by,
663
+ page_token,
664
+ search_id,
665
+ len(sorted_mvs),
666
+ )
667
+ )
668
+
669
+ return self._get_page_list(page_token, max_results)
670
+
671
+ def _get_page_list(self, page_token, max_results):
672
+ token_info = _parse_page_token(page_token)
673
+ log_msg(
674
+ "_get_page_list token_info {} {} {}".format(
675
+ page_token, max_results, token_info
676
+ )
677
+ )
678
+ sorted_mvs = self.cache.get(token_info["search_id"])
679
+ if sorted_mvs is None:
680
+ raise MlflowException(
681
+ "Invalid page token: search id not found or expired",
682
+ INVALID_PARAMETER_VALUE,
683
+ )
684
+ start_offset = token_info["offset"]
685
+ final_offset = start_offset + max_results
686
+
687
+ paginated_rms = sorted_mvs[start_offset : min(len(sorted_mvs), final_offset)]
688
+ next_page_token = None
689
+ if final_offset < len(sorted_mvs):
690
+ next_page_token = _create_page_token(final_offset, token_info["search_id"])
691
+ else:
692
+ self.cache.pop(token_info["search_id"], None)
693
+ log_msg(
694
+ "pop cache {} {} {} {} {} {}".format(
695
+ token_info["search_id"],
696
+ start_offset,
697
+ final_offset,
698
+ len(sorted_mvs),
699
+ page_token,
700
+ next_page_token,
701
+ )
702
+ )
703
+ return PagedList(paginated_rms, next_page_token)
704
+
705
+ def set_model_version_tag(self, name, version, tag):
706
+ log_msg("set_model_version_tag {} {} {}".format(name, version, tag))
707
+ [catalog_name, schema_name, model_name] = self._split_model_name(name)
708
+ req = {
709
+ "CatalogName": catalog_name,
710
+ "SchemaName": schema_name,
711
+ "ModelName": model_name,
712
+ "ModelVersion": _set_model_version(version),
713
+ "Properties": _add_tag_to_properties(tag),
714
+ }
715
+ self._call("ModifyModelVersionProperties", req)
716
+
717
+ def delete_model_version_tag(self, name, version, key):
718
+ log_msg("delete_model_version_tag {} {} {}".format(name, version, key))
719
+ [catalog_name, schema_name, model_name] = self._split_model_name(name)
720
+ req = {
721
+ "CatalogName": catalog_name,
722
+ "SchemaName": schema_name,
723
+ "ModelName": model_name,
724
+ "ModelVersion": _set_model_version(version),
725
+ "RemovedKeys": [_append_tag_key_prefix(key)],
726
+ }
727
+ self._call("ModifyModelVersionProperties", req)
728
+
729
+ def set_registered_model_alias(self, name, alias, version):
730
+ raise NotImplementedError("Method not implemented")
731
+
732
+ def delete_registered_model_alias(self, name, alias):
733
+ raise NotImplementedError("Method not implemented")
734
+
735
+ def get_model_version_by_alias(self, name, alias):
736
+ log_msg("get_model_version_by_alias {} {}".format(name, alias))
737
+ return self._get_model_version_by_alias(name, alias)
738
+
739
+ def get_model_version_download_uri(self, name, version):
740
+ log_msg("get_model_version_download_uri {} {}".format(name, version))
741
+ model_version = self.get_model_version(name, version)
742
+ return model_version.source
743
+
@@ -0,0 +1,8 @@
1
+ Metadata-Version: 2.4
2
+ Name: mlflow-tclake-plugin
3
+ Version: 0.0.1
4
+ Summary: Tclake plugin for MLflow
5
+ Requires-Dist: mlflow>=3.1.4
6
+ Requires-Dist: tencentcloud-sdk-python-common>=3.0.1478
7
+ Dynamic: requires-dist
8
+ Dynamic: summary
@@ -0,0 +1,10 @@
1
+ README.md
2
+ setup.py
3
+ mlflow_tclake_plugin/__init__.py
4
+ mlflow_tclake_plugin/tclake_store.py
5
+ mlflow_tclake_plugin.egg-info/PKG-INFO
6
+ mlflow_tclake_plugin.egg-info/SOURCES.txt
7
+ mlflow_tclake_plugin.egg-info/dependency_links.txt
8
+ mlflow_tclake_plugin.egg-info/entry_points.txt
9
+ mlflow_tclake_plugin.egg-info/requires.txt
10
+ mlflow_tclake_plugin.egg-info/top_level.txt
@@ -0,0 +1,2 @@
1
+ [mlflow.model_registry_store]
2
+ tclake = mlflow_tclake_plugin.tclake_store:TCLakeStore
@@ -0,0 +1,2 @@
1
+ mlflow>=3.1.4
2
+ tencentcloud-sdk-python-common>=3.0.1478
@@ -0,0 +1 @@
1
+ mlflow_tclake_plugin
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
@@ -0,0 +1,16 @@
1
+ from setuptools import find_packages, setup
2
+
3
+ setup(
4
+ name="mlflow-tclake-plugin",
5
+ version="0.0.1",
6
+ description="Tclake plugin for MLflow",
7
+ packages=find_packages(),
8
+ # Require MLflow as a dependency of the plugin, so that plugin users can simply install
9
+ # the plugin & then immediately use it with MLflow
10
+ install_requires=["mlflow>=3.1.4", "tencentcloud-sdk-python-common>=3.0.1478"],
11
+ entry_points={
12
+ "mlflow.model_registry_store": (
13
+ "tclake=mlflow_tclake_plugin.tclake_store:TCLakeStore"
14
+ ),
15
+ },
16
+ )