mlflow-tclake-plugin 2.0.1__py3-none-any.whl → 3.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mlflow-tclake-plugin might be problematic. Click here for more details.

@@ -3,6 +3,7 @@ import json
3
3
  import os
4
4
  import time
5
5
  import uuid
6
+ import urllib.parse
6
7
  from datetime import datetime
7
8
 
8
9
  from cachetools import TTLCache
@@ -10,16 +11,23 @@ from tencentcloud.common import credential
10
11
  from tencentcloud.common.common_client import CommonClient
11
12
  from tencentcloud.common.profile.client_profile import ClientProfile
12
13
 
13
- from mlflow.entities.model_registry import RegisteredModel, ModelVersion, ModelVersionTag
14
+ from mlflow.entities.model_registry import (
15
+ RegisteredModel,
16
+ ModelVersion,
17
+ ModelVersionTag,
18
+ )
14
19
  from mlflow.exceptions import MlflowException
15
- from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE
16
- from mlflow.store.model_registry import SEARCH_REGISTERED_MODEL_MAX_RESULTS_THRESHOLD, \
17
- SEARCH_MODEL_VERSION_MAX_RESULTS_THRESHOLD
20
+ from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE, RESOURCE_ALREADY_EXISTS
21
+ from mlflow.store.model_registry import (
22
+ SEARCH_REGISTERED_MODEL_MAX_RESULTS_THRESHOLD,
23
+ SEARCH_MODEL_VERSION_MAX_RESULTS_THRESHOLD,
24
+ )
18
25
  from mlflow.store.model_registry.abstract_store import AbstractStore
19
26
  from mlflow.utils.annotations import experimental
20
27
  from mlflow.utils.search_utils import SearchModelUtils, SearchModelVersionUtils
21
28
  from mlflow.store.entities.paged_list import PagedList
22
-
29
+ from mlflow.models.model import get_model_info
30
+ from tencentcloud.common.exception.tencent_cloud_sdk_exception import TencentCloudSDKException
23
31
 
24
32
  def to_string(obj):
25
33
  if obj is None:
@@ -65,9 +73,13 @@ def _get_description(resp):
65
73
  TCLAKE_MLFLOW_TAG_PREFIX = "tclake.tag."
66
74
  TCLAKE_MLFLOW_RUN_ID_KEY = "tclake.mlflow.run_id"
67
75
  TCLAKE_MLFLOW_RUN_LINK_KEY = "tclake.mlflow.run_link"
68
- TCLAKE_UUID_KEY = "tccatalog.identifier"
76
+ TCLAKE_MLFLOW_DEPLOYMENT_JOB_ID_KEY = "tclake.mlflow.deployment_job_id"
77
+ TCLAKE_WEDATA_WORKSPACE_ID_KEY = "tclake.wedata.workspace_id"
78
+ TCLAKE_MLFLOW_DEPLOYMENT_STATUS_KEY = "tclake.mlflow.deployment_status"
79
+ TCLAKE_MLFLOW_MODEL_ID_KEY = "tclake.mlflow.model_id"
69
80
  TCLAKE_MLFLOW_MODEL_SIGNATURE_KEY = "tclake.mlflow.model_signature"
70
81
 
82
+ UN_DEPLOYMENT = "UnDeployment"
71
83
 
72
84
  def _set_kv_to_properties(key, value, properties=None):
73
85
  if properties is None:
@@ -86,10 +98,6 @@ def _get_kv_from_properties(properties, key):
86
98
  return None
87
99
 
88
100
 
89
- def _get_uuid_from_properties(properties):
90
- return _get_kv_from_properties(properties, TCLAKE_UUID_KEY)
91
-
92
-
93
101
  def _set_run_id_to_properties(run_id, properties=None):
94
102
  return _set_kv_to_properties(TCLAKE_MLFLOW_RUN_ID_KEY, run_id, properties)
95
103
 
@@ -105,10 +113,15 @@ def _set_run_link_to_properties(run_link, properties):
105
113
  def _get_run_link_from_properties(properties):
106
114
  return _get_kv_from_properties(properties, TCLAKE_MLFLOW_RUN_LINK_KEY)
107
115
 
116
+ def _get_model_id_from_properties(properties):
117
+ return _get_kv_from_properties(properties, TCLAKE_MLFLOW_MODEL_ID_KEY)
118
+
108
119
 
109
120
  def _add_tag_to_properties(tag, properties=None):
110
121
  if tag and tag.value is not None:
111
- properties = _set_kv_to_properties(TCLAKE_MLFLOW_TAG_PREFIX + tag.key, tag.value, properties)
122
+ properties = _set_kv_to_properties(
123
+ TCLAKE_MLFLOW_TAG_PREFIX + tag.key, tag.value, properties
124
+ )
112
125
  return properties
113
126
 
114
127
 
@@ -118,84 +131,47 @@ def _add_tags_to_properties(tags, properties=None):
118
131
  properties = _add_tag_to_properties(tag, properties)
119
132
  return properties
120
133
 
134
+ def _add_deployment_job_id_to_properties(deployment_job_id, properties=None):
135
+ properties = _set_kv_to_properties(TCLAKE_MLFLOW_DEPLOYMENT_JOB_ID_KEY, deployment_job_id, properties)
136
+ return properties
137
+
138
+ def _add_workspace_id_to_properties(workspace_id, properties=None):
139
+ if workspace_id is not None:
140
+ properties = _set_kv_to_properties(TCLAKE_WEDATA_WORKSPACE_ID_KEY, workspace_id, properties)
141
+ return properties
121
142
 
143
+ def _add_deployment_status_to_properties(status, properties=None):
144
+ if status is not None:
145
+ properties = _set_kv_to_properties(TCLAKE_MLFLOW_DEPLOYMENT_STATUS_KEY, status, properties)
146
+ return properties
122
147
 
123
- def parse_model_signatures_from_dict(signature_dict):
124
- signatures = []
125
- inputs_str = signature_dict.get("inputs")
126
- if inputs_str:
127
- inputs = json.loads(inputs_str)
128
- for input_item in inputs:
129
- model_signature = {
130
- "name": input_item.get("name", ""),
131
- "type": "INPUT",
132
- "inputFlag": "true"
133
- }
134
- type_val = input_item.get("type")
135
- if type_val is not None:
136
- if type_val == "tensor":
137
- tensor_spec = input_item.get("tensor-spec", {}).get("dtype", "")
138
- if tensor_spec:
139
- model_signature["type"] = str(tensor_spec)
140
- else:
141
- model_signature["type"] = str(type_val)
142
-
143
- signatures.append(model_signature)
144
-
145
- outputs_str = signature_dict.get("outputs")
146
- if outputs_str:
147
- outputs = json.loads(outputs_str)
148
- for output_item in outputs:
149
- name_val = output_item.get("name", "")
150
- if not name_val:
151
- name_val = output_item.get("prediction_column_name", "")
152
- model_signature = {
153
- "name": name_val,
154
- "type": "OUTPUT",
155
- "inputFlag": "false"
156
- }
157
- type_val = output_item.get("type")
158
- if type_val is not None:
159
- if type_val == "tensor":
160
- tensor_spec = output_item.get("tensor-spec", {}).get("dtype", "")
161
- if tensor_spec:
162
- model_signature["type"] = str(tensor_spec)
163
- else:
164
- model_signature["type"] = str(type_val)
165
-
166
- signatures.append(model_signature)
167
-
168
- log_msg(f"parse model version signatures: {signatures}")
169
- return signatures
148
+ def _add_model_id_to_properties(model_id, properties=None):
149
+ if model_id is not None:
150
+ properties = _set_kv_to_properties(TCLAKE_MLFLOW_MODEL_ID_KEY, model_id, properties)
151
+ return properties
170
152
 
171
153
 
172
154
  def _add_model_signature_to_properties(source, properties):
173
155
  log_msg("model source is {}".format(source))
174
- try:
175
- from mlflow.models.model import get_model_info
176
- model_info = get_model_info(source)
177
- signature = model_info.signature
178
- log_msg("model {} signature is {}".format(source, signature))
179
- if signature:
180
- sig_json = json.dumps(parse_model_signatures_from_dict(signature.to_dict()))
181
- log_msg("model {} signature json is {}".format(source, sig_json))
182
- else:
183
- log_msg(f"Registered model signature is not found in source artifact location '{source}'")
184
- sig_json = json.dumps([])
185
- except Exception as e:
186
- log_msg(f"Failed to get model signature from {source}: {e}")
187
- sig_json = json.dumps([])
156
+ model_info = get_model_info(source)
157
+ signature = model_info.signature
158
+ log_msg("model {} signature is {}".format(source, signature))
159
+ if signature is None:
160
+ raise MlflowException(f"Registered model signature is not found in source artifact location '{source}'")
161
+ sig_json = json.dumps(signature.to_dict())
162
+ log_msg("model {} signature json is {}".format(source, sig_json))
188
163
  properties = _set_kv_to_properties(TCLAKE_MLFLOW_MODEL_SIGNATURE_KEY, sig_json, properties)
189
164
  return properties
190
165
 
191
-
192
166
  def _get_tags_from_properties(properties):
193
167
  if properties is None:
194
168
  return None
195
169
  tags = []
196
170
  for p in properties:
197
171
  if p["Key"].startswith(TCLAKE_MLFLOW_TAG_PREFIX):
198
- tags.append(ModelVersionTag(p["Key"][len(TCLAKE_MLFLOW_TAG_PREFIX):], p["Value"]))
172
+ tags.append(
173
+ ModelVersionTag(p["Key"][len(TCLAKE_MLFLOW_TAG_PREFIX) :], p["Value"])
174
+ )
199
175
  return tags
200
176
 
201
177
 
@@ -219,17 +195,20 @@ def _make_model(resp):
219
195
  creation_timestamp=_get_create_time_from_audit(audit),
220
196
  last_updated_timestamp=_get_last_updated_time_from_audit(audit),
221
197
  description=_get_description(resp),
222
- tags=_get_tags_from_properties(properties),
198
+ tags=_get_tags_from_properties(properties)
223
199
  )
224
200
 
225
201
 
226
202
  def _get_model_version_name(entity):
227
- return "{}.{}.{}".format(entity["CatalogName"], entity["SchemaName"], entity["ModelName"])
203
+ return "{}.{}.{}".format(
204
+ entity["CatalogName"], entity["SchemaName"], entity["ModelName"]
205
+ )
228
206
 
229
207
 
230
208
  def _make_model_version(entity, name):
231
209
  properties = entity["Properties"]
232
210
  audit = entity["Audit"]
211
+ aliases = entity["Aliases"]
233
212
  return ModelVersion(
234
213
  name=name,
235
214
  version=_get_model_version(entity["Version"]),
@@ -240,6 +219,8 @@ def _make_model_version(entity, name):
240
219
  run_id=_get_run_id_from_properties(properties),
241
220
  tags=_get_tags_from_properties(properties),
242
221
  run_link=_get_run_link_from_properties(properties),
222
+ aliases=aliases,
223
+ model_id=_get_model_id_from_properties(properties),
243
224
  status="READY",
244
225
  )
245
226
 
@@ -267,7 +248,9 @@ def _parse_page_token(page_token):
267
248
 
268
249
 
269
250
  def _create_page_token(offset, search_id):
270
- return base64.b64encode(json.dumps({"offset": offset, "search_id": search_id}).encode("utf-8"))
251
+ return base64.b64encode(
252
+ json.dumps({"offset": offset, "search_id": search_id}).encode("utf-8")
253
+ )
271
254
 
272
255
 
273
256
  @experimental
@@ -278,29 +261,50 @@ class TCLakeStore(AbstractStore):
278
261
 
279
262
  def __init__(self, store_uri=None, tracking_uri=None):
280
263
  super().__init__(store_uri, tracking_uri)
281
- log_msg("initializing tencent tclake client {} {}".format(store_uri, tracking_uri))
264
+ log_msg(
265
+ "initializing tencent tclake client {} {}".format(store_uri, tracking_uri)
266
+ )
282
267
  sid = os.getenv("TENCENTCLOUD_SECRET_ID", "")
283
268
  if len(sid) == 0:
284
269
  raise MlflowException("TENCENTCLOUD_SECRET_ID is not set")
285
270
  sk = os.getenv("TENCENTCLOUD_SECRET_KEY", "")
286
271
  if len(sk) == 0:
287
272
  raise MlflowException("TENCENTCLOUD_SECRET_KEY is not set")
273
+ sk = os.getenv("TENCENTCLOUD_SECRET_KEY", "")
274
+ if len(sk) == 0:
275
+ raise MlflowException("TENCENTCLOUD_SECRET_KEY is not set")
276
+
288
277
  token = os.getenv("TENCENTCLOUD_TOKEN", None)
278
+
279
+ self.workspace_id = os.getenv("WEDATA_WORKSPACE_ID", "")
280
+ if len(self.workspace_id) == 0:
281
+ raise MlflowException("WEDATA_WORKSPACE_ID is not set")
282
+ log_msg(str.format("wedata workspace id: {}", self.workspace_id))
283
+
289
284
  client_profile = get_tencent_cloud_client_profile()
290
285
  cred = credential.Credential(sid, sk, token)
291
286
  parts = store_uri.split(":")
292
287
  if len(parts) < 2:
293
288
  raise MlflowException("set store_uri tclake:{region}")
294
289
  region = parts[1]
295
- self.client = CommonClient("tccatalog", "2024-10-24", cred, region, client_profile)
290
+ self.client = CommonClient(
291
+ "tccatalog", "2024-10-24", cred, region, client_profile
292
+ )
296
293
  self.headers = get_tencent_cloud_headers()
297
- self.default_catalog_name = os.getenv("TENCENTCLOUD_DEFAULT_CATALOG_NAME", "default")
298
- self.default_schema_name = os.getenv("TENCENTCLOUD_DEFAULT_SCHEMA_NAME", "default")
294
+ self.default_catalog_name = os.getenv(
295
+ "TENCENTCLOUD_DEFAULT_CATALOG_NAME", "default"
296
+ )
297
+ self.default_schema_name = os.getenv(
298
+ "TENCENTCLOUD_DEFAULT_SCHEMA_NAME", "default"
299
+ )
299
300
  cache_size = int(os.getenv("TCLAKE_CACHE_SIZE", "100"))
300
301
  cache_ttl = int(os.getenv("TCLAKE_CACHE_TTL_SECS", "300"))
301
302
  self.cache = TTLCache(maxsize=cache_size, ttl=cache_ttl)
302
- log_msg("initialized tencent tclake client successfully {} {} {} {} {}".format(
303
- region, client_profile, self.headers, cache_size, cache_ttl))
303
+ log_msg(
304
+ "initialized tencent tclake client successfully {} {} {} {} {}".format(
305
+ region, client_profile, self.headers, cache_size, cache_ttl
306
+ )
307
+ )
304
308
 
305
309
  def _split_model_name(self, name):
306
310
  parts = name.split(".")
@@ -310,7 +314,9 @@ class TCLakeStore(AbstractStore):
310
314
  return [self.default_catalog_name, parts[0], parts[1]]
311
315
  if len(parts) == 3:
312
316
  return parts
313
- raise MlflowException("invalid model name: {}, must be catalog.schema.model".format(name))
317
+ raise MlflowException(
318
+ "invalid model name: {}, must be catalog.schema.model".format(name)
319
+ )
314
320
 
315
321
  def _call(self, action, req):
316
322
  log_msg("req: {}\n{}".format(action, json.dumps(req, indent=2)))
@@ -336,7 +342,7 @@ class TCLakeStore(AbstractStore):
336
342
  "CatalogName": catalog_name,
337
343
  "SchemaName": schema_name,
338
344
  "ModelName": model_name,
339
- "ModelVersions": self._get_model_version_numbers(name)
345
+ "ModelVersions": self._get_model_version_numbers(name),
340
346
  }
341
347
  resp = self._call("DescribeModelVersions", req_body)
342
348
  model_version = None
@@ -348,21 +354,32 @@ class TCLakeStore(AbstractStore):
348
354
  return None
349
355
  return _make_model_version(model_version, name)
350
356
 
351
- def create_registered_model(self, name, tags=None, description=None):
357
+ def create_registered_model(self, name, tags=None, description=None, deployment_job_id=None):
352
358
  log_msg("create_registered_model {} {} {}".format(name, tags, description))
353
359
  [catalog_name, schema_name, model_name] = self._split_model_name(name)
354
360
  properties = []
355
361
  _add_tags_to_properties(tags, properties)
362
+ _add_deployment_job_id_to_properties(deployment_job_id, properties)
363
+ _add_workspace_id_to_properties(self.workspace_id, properties)
356
364
  req_body = {
357
365
  "CatalogName": catalog_name,
358
366
  "SchemaName": schema_name,
359
367
  "ModelName": model_name,
360
368
  "Comment": description if description else "",
361
- "Properties": properties
369
+ "Properties": properties,
362
370
  }
363
- resp = self._call("CreateModel", req_body)
371
+ try:
372
+ resp = self._call("CreateModel", req_body)
373
+ except Exception as e:
374
+ if isinstance(e, TencentCloudSDKException):
375
+ if e.code == "FailedOperation.MetalakeAlreadyExistsError":
376
+ raise MlflowException(
377
+ f"Registered Model (name={name}) already exists.",
378
+ RESOURCE_ALREADY_EXISTS,
379
+ )
380
+ raise
364
381
  return _make_model(resp["Model"])
365
-
382
+
366
383
  def update_registered_model(self, name, description):
367
384
  log_msg("update_register_model {} {}".format(name, description))
368
385
  [catalog_name, schema_name, model_name] = self._split_model_name(name)
@@ -370,7 +387,7 @@ class TCLakeStore(AbstractStore):
370
387
  "CatalogName": catalog_name,
371
388
  "SchemaName": schema_name,
372
389
  "ModelName": model_name,
373
- "NewComment": description if description else ""
390
+ "NewComment": description if description else "",
374
391
  }
375
392
  resp = self._call("ModifyModelComment", req_body)
376
393
  return _make_model(resp["Model"])
@@ -382,7 +399,7 @@ class TCLakeStore(AbstractStore):
382
399
  "CatalogName": catalog_name,
383
400
  "SchemaName": schema_name,
384
401
  "ModelName": model_name,
385
- "NewName": new_name
402
+ "NewName": new_name,
386
403
  }
387
404
  resp = self._call("ModifyModelName", req_body)
388
405
  return _make_model(resp["Model"])
@@ -400,30 +417,28 @@ class TCLakeStore(AbstractStore):
400
417
  raise MlflowException("Failed to delete model {}".format(name))
401
418
 
402
419
  def _fetch_all_models(self):
403
- req_body = {
404
- "Offset": 0,
405
- "Limit": 200,
406
- "SnapshotBased": True,
407
- "SnapshotId": ""
408
- }
420
+ req_body = {"Offset": 0, "Limit": 200, "SnapshotBased": True, "SnapshotId": ""}
409
421
  resp = self._call("SearchModels", req_body)
410
- total = resp['TotalCount']
411
- model_list = resp['Models']
422
+ total = resp["TotalCount"]
423
+ model_list = resp["Models"]
412
424
  while len(model_list) < total:
413
425
  time.sleep(0.05)
414
426
  req_body["Offset"] += req_body["Limit"]
415
427
  req_body["SnapshotId"] = resp["SnapshotId"]
416
428
  resp = self._call("SearchModels", req_body)
417
- model_list.extend(resp['Models'])
418
- if len(resp['Models']) < req_body["Limit"]:
429
+ model_list.extend(resp["Models"])
430
+ if len(resp["Models"]) < req_body["Limit"]:
419
431
  break
420
432
  return [_make_model(model) for model in model_list]
421
433
 
422
434
  def search_registered_models(
423
- self, filter_string=None, max_results=None, order_by=None, page_token=None
435
+ self, filter_string=None, max_results=None, order_by=None, page_token=None
424
436
  ):
425
- log_msg("search_registered_models {} {} {} {}".format(
426
- filter_string, max_results, order_by, page_token))
437
+ log_msg(
438
+ "search_registered_models {} {} {} {}".format(
439
+ filter_string, max_results, order_by, page_token
440
+ )
441
+ )
427
442
  if not isinstance(max_results, int) or max_results < 1:
428
443
  raise MlflowException(
429
444
  "Invalid value for max_results. It must be a positive integer,"
@@ -446,8 +461,16 @@ class TCLakeStore(AbstractStore):
446
461
  search_id = "model_" + str(uuid.uuid4())
447
462
  page_token = _create_page_token(0, search_id)
448
463
  self.cache[search_id] = sorted_rms
449
- log_msg("search_registered_models add cache {} {} {} {} {} {}".format(
450
- filter_string, max_results, order_by, page_token, search_id, len(sorted_rms)))
464
+ log_msg(
465
+ "search_registered_models add cache {} {} {} {} {} {}".format(
466
+ filter_string,
467
+ max_results,
468
+ order_by,
469
+ page_token,
470
+ search_id,
471
+ len(sorted_rms),
472
+ )
473
+ )
451
474
  return self._get_page_list(page_token, max_results)
452
475
 
453
476
  def get_registered_model(self, name):
@@ -472,10 +495,7 @@ class TCLakeStore(AbstractStore):
472
495
  "ModelName": model_name,
473
496
  }
474
497
  resp = self._call("DescribeModelVersions", req)
475
- return [
476
- _make_model_version(mv, name)
477
- for mv in resp["ModelVersions"]
478
- ]
498
+ return [_make_model_version(mv, name) for mv in resp["ModelVersions"]]
479
499
 
480
500
  def set_registered_model_tag(self, name, tag):
481
501
  log_msg("set_registered_model_tag {} {}".format(name, tag))
@@ -484,7 +504,7 @@ class TCLakeStore(AbstractStore):
484
504
  "CatalogName": catalog_name,
485
505
  "SchemaName": schema_name,
486
506
  "ModelName": model_name,
487
- "Properties": _add_tag_to_properties(tag)
507
+ "Properties": _add_tag_to_properties(tag),
488
508
  }
489
509
  self._call("ModifyModelProperties", req)
490
510
 
@@ -495,29 +515,37 @@ class TCLakeStore(AbstractStore):
495
515
  "CatalogName": catalog_name,
496
516
  "SchemaName": schema_name,
497
517
  "ModelName": model_name,
498
- "RemovedKeys": [_append_tag_key_prefix(key)]
518
+ "RemovedKeys": [_append_tag_key_prefix(key)],
499
519
  }
500
520
  self._call("ModifyModelProperties", req)
501
521
 
502
522
  def create_model_version(
503
- self,
504
- name,
505
- source,
506
- run_id=None,
507
- tags=None,
508
- run_link=None,
509
- description=None,
510
- local_model_path=None,
523
+ self,
524
+ name,
525
+ source,
526
+ run_id=None,
527
+ tags=None,
528
+ run_link=None,
529
+ description=None,
530
+ local_model_path=None,
531
+ model_id=None,
511
532
  ):
512
- log_msg("create_model_version {} {} {} {} {} {} {}".format(
513
- name, source, run_id, tags, run_link, description, local_model_path))
533
+ log_msg(
534
+ "create_model_version {} {} {} {} {} {} {}".format(
535
+ name, source, run_id, tags, run_link, description, local_model_path
536
+ )
537
+ )
514
538
  [catalog_name, schema_name, model_name] = self._split_model_name(name)
515
539
  version_alias = str(uuid.uuid4())
516
540
  properties = []
517
541
  _add_tags_to_properties(tags, properties)
518
542
  _set_run_id_to_properties(run_id, properties)
519
543
  _set_run_link_to_properties(run_link, properties)
544
+ _add_workspace_id_to_properties(self.workspace_id, properties)
545
+ _add_deployment_status_to_properties(UN_DEPLOYMENT, properties)
546
+ _add_model_id_to_properties(model_id, properties)
520
547
  _add_model_signature_to_properties(source, properties)
548
+
521
549
  req_body = {
522
550
  "CatalogName": catalog_name,
523
551
  "SchemaName": schema_name,
@@ -525,7 +553,7 @@ class TCLakeStore(AbstractStore):
525
553
  "Uri": source,
526
554
  "Comment": description if description else "",
527
555
  "Properties": properties,
528
- "Aliases": [version_alias]
556
+ "Aliases": [version_alias],
529
557
  }
530
558
  self._call("CreateModelVersion", req_body)
531
559
  return self._get_model_version_by_alias(name, version_alias)
@@ -538,12 +566,14 @@ class TCLakeStore(AbstractStore):
538
566
  "SchemaName": schema_name,
539
567
  "ModelName": model_name,
540
568
  "ModelVersion": _set_model_version(version),
541
- "NewComment": description if description else ""
569
+ "NewComment": description if description else "",
542
570
  }
543
571
  resp = self._call("ModifyModelVersionComment", req_body)
544
572
  return _make_model_version(resp["ModelVersion"], name)
545
573
 
546
- def transition_model_version_stage(self, name, version, stage, archive_existing_versions):
574
+ def transition_model_version_stage(
575
+ self, name, version, stage, archive_existing_versions
576
+ ):
547
577
  raise NotImplementedError("Method not implemented")
548
578
 
549
579
  def delete_model_version(self, name, version):
@@ -557,7 +587,9 @@ class TCLakeStore(AbstractStore):
557
587
  }
558
588
  resp = self._call("DropModelVersion", req_body)
559
589
  if not resp["Dropped"]:
560
- raise Exception("Failed to delete model version {} {}".format(name, version))
590
+ raise Exception(
591
+ "Failed to delete model version {} {}".format(name, version)
592
+ )
561
593
 
562
594
  def get_model_version(self, name, version):
563
595
  log_msg("get_model_version {} {}".format(name, version))
@@ -572,31 +604,31 @@ class TCLakeStore(AbstractStore):
572
604
  return _make_model_version(resp["ModelVersion"], name)
573
605
 
574
606
  def _fetch_all_model_versions(self):
575
- req_body = {
576
- "Offset": 0,
577
- "Limit": 200,
578
- "SnapshotBased": True,
579
- "SnapshotId": ""
580
- }
607
+ req_body = {"Offset": 0, "Limit": 200, "SnapshotBased": True, "SnapshotId": ""}
581
608
  resp = self._call("SearchModelVersions", req_body)
582
- total = resp['TotalCount']
583
- model_version_list = resp['ModelVersions']
609
+ total = resp["TotalCount"]
610
+ model_version_list = resp["ModelVersions"]
584
611
  while len(model_version_list) < total:
585
612
  time.sleep(0.05)
586
613
  req_body["Offset"] += req_body["Limit"]
587
614
  req_body["SnapshotId"] = resp["SnapshotId"]
588
615
  resp = self._call("SearchModelVersions", req_body)
589
- model_version_list.extend(resp['ModelVersions'])
590
- if len(resp['ModelVersions']) < req_body["Limit"]:
616
+ model_version_list.extend(resp["ModelVersions"])
617
+ if len(resp["ModelVersions"]) < req_body["Limit"]:
591
618
  break
592
- return [_make_model_version(model_version, _get_model_version_name(model_version)) for model_version in
593
- model_version_list]
619
+ return [
620
+ _make_model_version(model_version, _get_model_version_name(model_version))
621
+ for model_version in model_version_list
622
+ ]
594
623
 
595
624
  def search_model_versions(
596
- self, filter_string=None, max_results=None, order_by=None, page_token=None
625
+ self, filter_string=None, max_results=None, order_by=None, page_token=None
597
626
  ):
598
- log_msg("search_model_versions {} {} {} {}".format(
599
- filter_string, max_results, order_by, page_token))
627
+ log_msg(
628
+ "search_model_versions {} {} {} {}".format(
629
+ filter_string, max_results, order_by, page_token
630
+ )
631
+ )
600
632
  if not isinstance(max_results, int) or max_results < 1:
601
633
  raise MlflowException(
602
634
  "Invalid value for max_results. It must be a positive integer,"
@@ -615,39 +647,59 @@ class TCLakeStore(AbstractStore):
615
647
  filtered_mvs = SearchModelVersionUtils.filter(model_versions, filter_string)
616
648
  sorted_mvs = SearchModelVersionUtils.sort(
617
649
  filtered_mvs,
618
- order_by or ["last_updated_timestamp DESC", "name ASC", "version_number DESC"],
650
+ order_by
651
+ or ["last_updated_timestamp DESC", "name ASC", "version_number DESC"],
619
652
  )
620
653
  if len(sorted_mvs) == 0:
621
654
  return PagedList([], None)
622
655
  search_id = "model_version_" + str(uuid.uuid4())
623
656
  page_token = _create_page_token(0, search_id)
624
657
  self.cache[search_id] = sorted_mvs
625
- log_msg("search_model_versions add cache {} {} {} {} {} {}".format(
626
- filter_string, max_results, order_by, page_token, search_id, len(sorted_mvs)))
658
+ log_msg(
659
+ "search_model_versions add cache {} {} {} {} {} {}".format(
660
+ filter_string,
661
+ max_results,
662
+ order_by,
663
+ page_token,
664
+ search_id,
665
+ len(sorted_mvs),
666
+ )
667
+ )
627
668
 
628
669
  return self._get_page_list(page_token, max_results)
629
670
 
630
671
  def _get_page_list(self, page_token, max_results):
631
672
  token_info = _parse_page_token(page_token)
632
- log_msg("_get_page_list token_info {} {} {}".format(
633
- page_token, max_results, token_info))
634
- sorted_mvs = self.cache.get(token_info['search_id'])
673
+ log_msg(
674
+ "_get_page_list token_info {} {} {}".format(
675
+ page_token, max_results, token_info
676
+ )
677
+ )
678
+ sorted_mvs = self.cache.get(token_info["search_id"])
635
679
  if sorted_mvs is None:
636
680
  raise MlflowException(
637
681
  "Invalid page token: search id not found or expired",
638
682
  INVALID_PARAMETER_VALUE,
639
683
  )
640
- start_offset = token_info['offset']
684
+ start_offset = token_info["offset"]
641
685
  final_offset = start_offset + max_results
642
686
 
643
- paginated_rms = sorted_mvs[start_offset: min(len(sorted_mvs), final_offset)]
687
+ paginated_rms = sorted_mvs[start_offset : min(len(sorted_mvs), final_offset)]
644
688
  next_page_token = None
645
689
  if final_offset < len(sorted_mvs):
646
- next_page_token = _create_page_token(final_offset, token_info['search_id'])
690
+ next_page_token = _create_page_token(final_offset, token_info["search_id"])
647
691
  else:
648
- self.cache.pop(token_info['search_id'], None)
649
- log_msg("pop cache {} {} {} {} {} {}".format(
650
- token_info['search_id'], start_offset, final_offset, len(sorted_mvs), page_token, next_page_token))
692
+ self.cache.pop(token_info["search_id"], None)
693
+ log_msg(
694
+ "pop cache {} {} {} {} {} {}".format(
695
+ token_info["search_id"],
696
+ start_offset,
697
+ final_offset,
698
+ len(sorted_mvs),
699
+ page_token,
700
+ next_page_token,
701
+ )
702
+ )
651
703
  return PagedList(paginated_rms, next_page_token)
652
704
 
653
705
  def set_model_version_tag(self, name, version, tag):
@@ -658,7 +710,7 @@ class TCLakeStore(AbstractStore):
658
710
  "SchemaName": schema_name,
659
711
  "ModelName": model_name,
660
712
  "ModelVersion": _set_model_version(version),
661
- "Properties": _add_tag_to_properties(tag)
713
+ "Properties": _add_tag_to_properties(tag),
662
714
  }
663
715
  self._call("ModifyModelVersionProperties", req)
664
716
 
@@ -670,7 +722,7 @@ class TCLakeStore(AbstractStore):
670
722
  "SchemaName": schema_name,
671
723
  "ModelName": model_name,
672
724
  "ModelVersion": _set_model_version(version),
673
- "RemovedKeys": [_append_tag_key_prefix(key)]
725
+ "RemovedKeys": [_append_tag_key_prefix(key)],
674
726
  }
675
727
  self._call("ModifyModelVersionProperties", req)
676
728
 
@@ -687,4 +739,5 @@ class TCLakeStore(AbstractStore):
687
739
  def get_model_version_download_uri(self, name, version):
688
740
  log_msg("get_model_version_download_uri {} {}".format(name, version))
689
741
  model_version = self.get_model_version(name, version)
690
- return model_version.source
742
+ return model_version.source
743
+
@@ -1,9 +1,9 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mlflow-tclake-plugin
3
- Version: 2.0.1
3
+ Version: 3.0.1
4
4
  Summary: Tclake plugin for MLflow
5
5
  License-File: LICENSE.txt
6
- Requires-Dist: mlflow>=2.7.2
6
+ Requires-Dist: mlflow>=3.1.4
7
7
  Requires-Dist: tencentcloud-sdk-python-common>=3.0.1478
8
8
  Dynamic: license-file
9
9
  Dynamic: requires-dist
@@ -0,0 +1,8 @@
1
+ mlflow_tclake_plugin/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
+ mlflow_tclake_plugin/tclake_store.py,sha256=Ys3giqvpdth7gwZ7O3TeaQH1OJH08XUpqjRkOVIM9JM,28301
3
+ mlflow_tclake_plugin-3.0.1.dist-info/licenses/LICENSE.txt,sha256=X60Z7_gpe--AyXGUWeFWpOpFl5m-yeemfodWGJn1rUA,1067
4
+ mlflow_tclake_plugin-3.0.1.dist-info/METADATA,sha256=tz2ZxO4mTJTZR6WlANcobb91uTAvyWCS6Ff0RDuR1Cs,271
5
+ mlflow_tclake_plugin-3.0.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
6
+ mlflow_tclake_plugin-3.0.1.dist-info/entry_points.txt,sha256=xKXky9-NyJWEG1SgBknrqskN2rsZG3t4UXiSzkTcsuE,85
7
+ mlflow_tclake_plugin-3.0.1.dist-info/top_level.txt,sha256=zrA5UNyfF3skRmxPNsvrJI3yf-um6CJX_xO5KvWc2o0,21
8
+ mlflow_tclake_plugin-3.0.1.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- mlflow_tclake_plugin/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
- mlflow_tclake_plugin/tclake_store.py,sha256=djAE2-C_2BKkguy8WN-uETIOcC9qLiiCSDu7teXFRQc,27054
3
- mlflow_tclake_plugin-2.0.1.dist-info/licenses/LICENSE.txt,sha256=X60Z7_gpe--AyXGUWeFWpOpFl5m-yeemfodWGJn1rUA,1067
4
- mlflow_tclake_plugin-2.0.1.dist-info/METADATA,sha256=rVelcV3Jsjdi0RVsbdGPkYGONGiONW63dL6iJb1rzFc,271
5
- mlflow_tclake_plugin-2.0.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
6
- mlflow_tclake_plugin-2.0.1.dist-info/entry_points.txt,sha256=xKXky9-NyJWEG1SgBknrqskN2rsZG3t4UXiSzkTcsuE,85
7
- mlflow_tclake_plugin-2.0.1.dist-info/top_level.txt,sha256=zrA5UNyfF3skRmxPNsvrJI3yf-um6CJX_xO5KvWc2o0,21
8
- mlflow_tclake_plugin-2.0.1.dist-info/RECORD,,