viettelcloud-aiplatform 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- viettelcloud/__init__.py +1 -0
- viettelcloud/aiplatform/__init__.py +15 -0
- viettelcloud/aiplatform/common/__init__.py +0 -0
- viettelcloud/aiplatform/common/constants.py +22 -0
- viettelcloud/aiplatform/common/types.py +28 -0
- viettelcloud/aiplatform/common/utils.py +40 -0
- viettelcloud/aiplatform/hub/OWNERS +14 -0
- viettelcloud/aiplatform/hub/__init__.py +25 -0
- viettelcloud/aiplatform/hub/api/__init__.py +13 -0
- viettelcloud/aiplatform/hub/api/_proxy_client.py +355 -0
- viettelcloud/aiplatform/hub/api/model_registry_client.py +561 -0
- viettelcloud/aiplatform/hub/api/model_registry_client_test.py +462 -0
- viettelcloud/aiplatform/optimizer/__init__.py +45 -0
- viettelcloud/aiplatform/optimizer/api/__init__.py +0 -0
- viettelcloud/aiplatform/optimizer/api/optimizer_client.py +248 -0
- viettelcloud/aiplatform/optimizer/backends/__init__.py +13 -0
- viettelcloud/aiplatform/optimizer/backends/base.py +77 -0
- viettelcloud/aiplatform/optimizer/backends/kubernetes/__init__.py +13 -0
- viettelcloud/aiplatform/optimizer/backends/kubernetes/backend.py +563 -0
- viettelcloud/aiplatform/optimizer/backends/kubernetes/utils.py +112 -0
- viettelcloud/aiplatform/optimizer/constants/__init__.py +13 -0
- viettelcloud/aiplatform/optimizer/constants/constants.py +59 -0
- viettelcloud/aiplatform/optimizer/types/__init__.py +13 -0
- viettelcloud/aiplatform/optimizer/types/algorithm_types.py +87 -0
- viettelcloud/aiplatform/optimizer/types/optimization_types.py +135 -0
- viettelcloud/aiplatform/optimizer/types/search_types.py +95 -0
- viettelcloud/aiplatform/py.typed +0 -0
- viettelcloud/aiplatform/trainer/__init__.py +82 -0
- viettelcloud/aiplatform/trainer/api/__init__.py +3 -0
- viettelcloud/aiplatform/trainer/api/trainer_client.py +277 -0
- viettelcloud/aiplatform/trainer/api/trainer_client_test.py +72 -0
- viettelcloud/aiplatform/trainer/backends/__init__.py +0 -0
- viettelcloud/aiplatform/trainer/backends/base.py +94 -0
- viettelcloud/aiplatform/trainer/backends/container/adapters/base.py +195 -0
- viettelcloud/aiplatform/trainer/backends/container/adapters/docker.py +231 -0
- viettelcloud/aiplatform/trainer/backends/container/adapters/podman.py +258 -0
- viettelcloud/aiplatform/trainer/backends/container/backend.py +668 -0
- viettelcloud/aiplatform/trainer/backends/container/backend_test.py +867 -0
- viettelcloud/aiplatform/trainer/backends/container/runtime_loader.py +631 -0
- viettelcloud/aiplatform/trainer/backends/container/runtime_loader_test.py +637 -0
- viettelcloud/aiplatform/trainer/backends/container/types.py +67 -0
- viettelcloud/aiplatform/trainer/backends/container/utils.py +213 -0
- viettelcloud/aiplatform/trainer/backends/kubernetes/__init__.py +0 -0
- viettelcloud/aiplatform/trainer/backends/kubernetes/backend.py +710 -0
- viettelcloud/aiplatform/trainer/backends/kubernetes/backend_test.py +1344 -0
- viettelcloud/aiplatform/trainer/backends/kubernetes/constants.py +15 -0
- viettelcloud/aiplatform/trainer/backends/kubernetes/utils.py +636 -0
- viettelcloud/aiplatform/trainer/backends/kubernetes/utils_test.py +582 -0
- viettelcloud/aiplatform/trainer/backends/localprocess/__init__.py +0 -0
- viettelcloud/aiplatform/trainer/backends/localprocess/backend.py +306 -0
- viettelcloud/aiplatform/trainer/backends/localprocess/backend_test.py +501 -0
- viettelcloud/aiplatform/trainer/backends/localprocess/constants.py +90 -0
- viettelcloud/aiplatform/trainer/backends/localprocess/job.py +184 -0
- viettelcloud/aiplatform/trainer/backends/localprocess/types.py +52 -0
- viettelcloud/aiplatform/trainer/backends/localprocess/utils.py +302 -0
- viettelcloud/aiplatform/trainer/constants/__init__.py +0 -0
- viettelcloud/aiplatform/trainer/constants/constants.py +179 -0
- viettelcloud/aiplatform/trainer/options/__init__.py +52 -0
- viettelcloud/aiplatform/trainer/options/common.py +55 -0
- viettelcloud/aiplatform/trainer/options/kubernetes.py +502 -0
- viettelcloud/aiplatform/trainer/options/kubernetes_test.py +259 -0
- viettelcloud/aiplatform/trainer/options/localprocess.py +20 -0
- viettelcloud/aiplatform/trainer/test/common.py +22 -0
- viettelcloud/aiplatform/trainer/types/__init__.py +0 -0
- viettelcloud/aiplatform/trainer/types/types.py +517 -0
- viettelcloud/aiplatform/trainer/types/types_test.py +115 -0
- viettelcloud_aiplatform-0.3.0.dist-info/METADATA +226 -0
- viettelcloud_aiplatform-0.3.0.dist-info/RECORD +71 -0
- viettelcloud_aiplatform-0.3.0.dist-info/WHEEL +4 -0
- viettelcloud_aiplatform-0.3.0.dist-info/licenses/LICENSE +201 -0
- viettelcloud_aiplatform-0.3.0.dist-info/licenses/NOTICE +36 -0
|
@@ -0,0 +1,561 @@
|
|
|
1
|
+
# Copyright 2025 The Kubeflow Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
import os
|
|
18
|
+
from collections.abc import Iterator, Mapping
|
|
19
|
+
from typing import TYPE_CHECKING, Literal
|
|
20
|
+
|
|
21
|
+
if TYPE_CHECKING:
|
|
22
|
+
from model_registry.types import (
|
|
23
|
+
ModelArtifact,
|
|
24
|
+
ModelVersion,
|
|
25
|
+
RegisteredModel,
|
|
26
|
+
SupportedTypes,
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class ModelRegistryClient:
|
|
31
|
+
"""Client for Model Registry operations.
|
|
32
|
+
|
|
33
|
+
Supports two modes of operation:
|
|
34
|
+
|
|
35
|
+
1. **Direct mode**: Connect directly to a Model Registry server.
|
|
36
|
+
Requires the model-registry package to be installed:
|
|
37
|
+
``pip install 'viettelcloud-aiplatform[hub]'``
|
|
38
|
+
|
|
39
|
+
2. **Proxy mode**: Connect through cmp-backend which handles authentication
|
|
40
|
+
and routes to the correct Model Registry instance based on project.
|
|
41
|
+
Does not require the model-registry package.
|
|
42
|
+
|
|
43
|
+
Examples:
|
|
44
|
+
Direct mode::
|
|
45
|
+
|
|
46
|
+
client = ModelRegistryClient(
|
|
47
|
+
base_url="https://registry.example.com",
|
|
48
|
+
user_token="...",
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
Proxy mode::
|
|
52
|
+
|
|
53
|
+
client = ModelRegistryClient(
|
|
54
|
+
cmp_backend_url="https://api.viettelcloud.vn",
|
|
55
|
+
pat_token="vtp_xxx...",
|
|
56
|
+
project_id="my-ml-project",
|
|
57
|
+
region="HN", # or "HCM", etc.
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
Proxy mode with environment variables::
|
|
61
|
+
|
|
62
|
+
# Set these environment variables:
|
|
63
|
+
# VIETTELCLOUD_CMP_URL=https://api.viettelcloud.vn
|
|
64
|
+
# VIETTELCLOUD_PAT_TOKEN=vtp_xxx...
|
|
65
|
+
# VIETTELCLOUD_PROJECT_ID=my-ml-project
|
|
66
|
+
# VIETTELCLOUD_REGION=HN
|
|
67
|
+
|
|
68
|
+
client = ModelRegistryClient() # Uses env vars
|
|
69
|
+
"""
|
|
70
|
+
|
|
71
|
+
def __init__(
|
|
72
|
+
self,
|
|
73
|
+
base_url: str | None = None,
|
|
74
|
+
port: int | None = None,
|
|
75
|
+
*,
|
|
76
|
+
author: str | None = None,
|
|
77
|
+
is_secure: bool | None = None,
|
|
78
|
+
user_token: str | None = None,
|
|
79
|
+
custom_ca: str | None = None,
|
|
80
|
+
# Proxy mode parameters
|
|
81
|
+
cmp_backend_url: str | None = None,
|
|
82
|
+
pat_token: str | None = None,
|
|
83
|
+
project_id: str | None = None,
|
|
84
|
+
region: str | None = None,
|
|
85
|
+
):
|
|
86
|
+
"""Initialize the ModelRegistryClient.
|
|
87
|
+
|
|
88
|
+
Two modes of operation:
|
|
89
|
+
|
|
90
|
+
1. **Direct mode** (existing): Connect directly to MR server
|
|
91
|
+
|
|
92
|
+
- Requires: ``base_url``
|
|
93
|
+
- Optional: ``port``, ``author``, ``is_secure``, ``user_token``, ``custom_ca``
|
|
94
|
+
|
|
95
|
+
2. **Proxy mode** (new): Connect through cmp-backend
|
|
96
|
+
|
|
97
|
+
- Requires: ``cmp_backend_url``, ``pat_token``, ``project_id``
|
|
98
|
+
- cmp-backend handles authentication and routing to correct MR instance
|
|
99
|
+
- Supports environment variables for configuration
|
|
100
|
+
|
|
101
|
+
Args:
|
|
102
|
+
base_url: Base URL of the model registry server including scheme.
|
|
103
|
+
Examples: "https://registry.example.com", "http://localhost"
|
|
104
|
+
Required for direct mode.
|
|
105
|
+
|
|
106
|
+
Keyword Args:
|
|
107
|
+
port: Server port (direct mode). If not provided, inferred from scheme:
|
|
108
|
+
- https:// defaults to 443
|
|
109
|
+
- http:// defaults to 8080
|
|
110
|
+
author: Name of the author.
|
|
111
|
+
is_secure: Whether to use TLS (direct mode). Inferred from base_url if not provided.
|
|
112
|
+
user_token: The PEM-encoded user token (direct mode).
|
|
113
|
+
custom_ca: Path to PEM-encoded root certificates (direct mode).
|
|
114
|
+
cmp_backend_url: URL of cmp-backend (proxy mode).
|
|
115
|
+
Can also be set via ``VIETTELCLOUD_CMP_URL`` env var.
|
|
116
|
+
pat_token: Personal Access Token (proxy mode).
|
|
117
|
+
Can also be set via ``VIETTELCLOUD_PAT_TOKEN`` env var.
|
|
118
|
+
project_id: Project UUID or slug (proxy mode).
|
|
119
|
+
Can also be set via ``VIETTELCLOUD_PROJECT_ID`` env var.
|
|
120
|
+
region: Region UUID or name (proxy mode). E.g., "HN", "HCM".
|
|
121
|
+
Can also be set via ``VIETTELCLOUD_REGION`` env var.
|
|
122
|
+
|
|
123
|
+
Raises:
|
|
124
|
+
ValueError: If neither direct nor proxy mode parameters are provided.
|
|
125
|
+
ImportError: If model-registry is not installed (direct mode only).
|
|
126
|
+
"""
|
|
127
|
+
# Support environment variables for proxy mode
|
|
128
|
+
cmp_backend_url = cmp_backend_url or os.environ.get("VIETTELCLOUD_CMP_URL")
|
|
129
|
+
pat_token = pat_token or os.environ.get("VIETTELCLOUD_PAT_TOKEN")
|
|
130
|
+
project_id = project_id or os.environ.get("VIETTELCLOUD_PROJECT_ID")
|
|
131
|
+
region = region or os.environ.get("VIETTELCLOUD_REGION")
|
|
132
|
+
|
|
133
|
+
self._author = author
|
|
134
|
+
|
|
135
|
+
# Determine mode
|
|
136
|
+
if cmp_backend_url and pat_token and project_id and region:
|
|
137
|
+
# Proxy mode
|
|
138
|
+
self._mode: Literal["direct", "proxy"] = "proxy"
|
|
139
|
+
self._registry = None
|
|
140
|
+
|
|
141
|
+
from ._proxy_client import ProxyHTTPClient
|
|
142
|
+
|
|
143
|
+
self._proxy_client = ProxyHTTPClient(
|
|
144
|
+
cmp_backend_url=cmp_backend_url,
|
|
145
|
+
pat_token=pat_token,
|
|
146
|
+
project_id=project_id,
|
|
147
|
+
region=region,
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
elif base_url:
|
|
151
|
+
# Direct mode
|
|
152
|
+
self._mode = "direct"
|
|
153
|
+
self._proxy_client = None
|
|
154
|
+
|
|
155
|
+
try:
|
|
156
|
+
from model_registry import ModelRegistry
|
|
157
|
+
except ImportError as e:
|
|
158
|
+
raise ImportError(
|
|
159
|
+
"model-registry is not installed. Install it with:\n\n"
|
|
160
|
+
" pip install 'viettelcloud-aiplatform[hub]'\n"
|
|
161
|
+
) from e
|
|
162
|
+
|
|
163
|
+
is_http = base_url.startswith("http://")
|
|
164
|
+
if is_secure is None:
|
|
165
|
+
is_secure = not is_http
|
|
166
|
+
if port is None:
|
|
167
|
+
port = 8080 if is_http else 443
|
|
168
|
+
|
|
169
|
+
self._registry = ModelRegistry(
|
|
170
|
+
server_address=base_url,
|
|
171
|
+
port=port,
|
|
172
|
+
author=author, # type: ignore[arg-type]
|
|
173
|
+
is_secure=is_secure,
|
|
174
|
+
user_token=user_token,
|
|
175
|
+
custom_ca=custom_ca,
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
else:
|
|
179
|
+
raise ValueError(
|
|
180
|
+
"Must provide either:\n"
|
|
181
|
+
" - base_url (direct mode), or\n"
|
|
182
|
+
" - cmp_backend_url + pat_token + project_id + region (proxy mode)\n\n"
|
|
183
|
+
"Proxy mode also supports environment variables:\n"
|
|
184
|
+
" VIETTELCLOUD_CMP_URL, VIETTELCLOUD_PAT_TOKEN, VIETTELCLOUD_PROJECT_ID, VIETTELCLOUD_REGION"
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
@property
|
|
188
|
+
def mode(self) -> Literal["direct", "proxy"]:
|
|
189
|
+
"""Return the client mode ('direct' or 'proxy')."""
|
|
190
|
+
return self._mode
|
|
191
|
+
|
|
192
|
+
# =========================================================================
|
|
193
|
+
# Helper methods for proxy mode
|
|
194
|
+
# =========================================================================
|
|
195
|
+
|
|
196
|
+
def _dict_to_registered_model(self, data: dict) -> RegisteredModel:
|
|
197
|
+
"""Convert proxy API response dict to RegisteredModel."""
|
|
198
|
+
from model_registry.types import RegisteredModel
|
|
199
|
+
|
|
200
|
+
return RegisteredModel(
|
|
201
|
+
name=data.get("name", ""),
|
|
202
|
+
owner=data.get("owner"),
|
|
203
|
+
description=data.get("description"),
|
|
204
|
+
external_id=data.get("externalId"),
|
|
205
|
+
id=data.get("id"),
|
|
206
|
+
create_time_since_epoch=data.get("createTimeSinceEpoch"),
|
|
207
|
+
last_update_time_since_epoch=data.get("lastUpdateTimeSinceEpoch"),
|
|
208
|
+
custom_properties=data.get("customProperties"),
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
def _dict_to_model_version(self, data: dict) -> ModelVersion:
|
|
212
|
+
"""Convert proxy API response dict to ModelVersion."""
|
|
213
|
+
from model_registry.types import ModelVersion
|
|
214
|
+
|
|
215
|
+
return ModelVersion(
|
|
216
|
+
name=data.get("name", ""),
|
|
217
|
+
author=data.get("author"),
|
|
218
|
+
description=data.get("description"),
|
|
219
|
+
external_id=data.get("externalId"),
|
|
220
|
+
id=data.get("id"),
|
|
221
|
+
create_time_since_epoch=data.get("createTimeSinceEpoch"),
|
|
222
|
+
last_update_time_since_epoch=data.get("lastUpdateTimeSinceEpoch"),
|
|
223
|
+
custom_properties=data.get("customProperties"),
|
|
224
|
+
registered_model_id=data.get("registeredModelId"),
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
def _dict_to_model_artifact(self, data: dict) -> ModelArtifact:
|
|
228
|
+
"""Convert proxy API response dict to ModelArtifact."""
|
|
229
|
+
from model_registry.types import ModelArtifact
|
|
230
|
+
|
|
231
|
+
return ModelArtifact(
|
|
232
|
+
name=data.get("name", ""),
|
|
233
|
+
uri=data.get("uri", ""),
|
|
234
|
+
description=data.get("description"),
|
|
235
|
+
external_id=data.get("externalId"),
|
|
236
|
+
id=data.get("id"),
|
|
237
|
+
create_time_since_epoch=data.get("createTimeSinceEpoch"),
|
|
238
|
+
last_update_time_since_epoch=data.get("lastUpdateTimeSinceEpoch"),
|
|
239
|
+
custom_properties=data.get("customProperties"),
|
|
240
|
+
model_format_name=data.get("modelFormatName"),
|
|
241
|
+
model_format_version=data.get("modelFormatVersion"),
|
|
242
|
+
storage_key=data.get("storageKey"),
|
|
243
|
+
storage_path=data.get("storagePath"),
|
|
244
|
+
service_account_name=data.get("serviceAccountName"),
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
# =========================================================================
|
|
248
|
+
# Public API
|
|
249
|
+
# =========================================================================
|
|
250
|
+
|
|
251
|
+
def register_model(
|
|
252
|
+
self,
|
|
253
|
+
name: str,
|
|
254
|
+
uri: str,
|
|
255
|
+
*,
|
|
256
|
+
version: str,
|
|
257
|
+
model_format_name: str | None = None,
|
|
258
|
+
model_format_version: str | None = None,
|
|
259
|
+
author: str | None = None,
|
|
260
|
+
owner: str | None = None,
|
|
261
|
+
version_description: str | None = None,
|
|
262
|
+
metadata: Mapping[str, SupportedTypes] | None = None,
|
|
263
|
+
) -> RegisteredModel:
|
|
264
|
+
"""Register a model.
|
|
265
|
+
|
|
266
|
+
This registers a model in the model registry. The model is not downloaded,
|
|
267
|
+
and has to be stored prior to registration.
|
|
268
|
+
|
|
269
|
+
Args:
|
|
270
|
+
name: Name of the model.
|
|
271
|
+
uri: URI of the model.
|
|
272
|
+
|
|
273
|
+
Keyword Args:
|
|
274
|
+
version: Version of the model. Has to be unique.
|
|
275
|
+
model_format_name: Name of the model format (e.g., "pytorch", "tensorflow", "onnx").
|
|
276
|
+
model_format_version: Version of the model format (e.g., "2.0", "1.15").
|
|
277
|
+
author: Author of the model. Defaults to the client author.
|
|
278
|
+
owner: Owner of the model. Defaults to the client author.
|
|
279
|
+
version_description: Description of the model version.
|
|
280
|
+
metadata: Additional version metadata.
|
|
281
|
+
|
|
282
|
+
Returns:
|
|
283
|
+
Registered model.
|
|
284
|
+
"""
|
|
285
|
+
if self._mode == "proxy":
|
|
286
|
+
# Create model
|
|
287
|
+
model_data = self._proxy_client.create_registered_model(
|
|
288
|
+
name=name,
|
|
289
|
+
owner=owner or self._author,
|
|
290
|
+
custom_properties=dict(metadata) if metadata else None,
|
|
291
|
+
)
|
|
292
|
+
# Create version
|
|
293
|
+
self._proxy_client.create_model_version(
|
|
294
|
+
model_name=name,
|
|
295
|
+
version_name=version,
|
|
296
|
+
description=version_description,
|
|
297
|
+
author=author or self._author,
|
|
298
|
+
)
|
|
299
|
+
# Create artifact
|
|
300
|
+
self._proxy_client.create_model_artifact(
|
|
301
|
+
name=f"{name}-{version}",
|
|
302
|
+
uri=uri,
|
|
303
|
+
model_format_name=model_format_name,
|
|
304
|
+
model_format_version=model_format_version,
|
|
305
|
+
)
|
|
306
|
+
return self._dict_to_registered_model(model_data)
|
|
307
|
+
else:
|
|
308
|
+
return self._registry.register_model(
|
|
309
|
+
name=name,
|
|
310
|
+
uri=uri,
|
|
311
|
+
model_format_name=model_format_name, # type: ignore[arg-type]
|
|
312
|
+
model_format_version=model_format_version, # type: ignore[arg-type]
|
|
313
|
+
version=version,
|
|
314
|
+
author=author,
|
|
315
|
+
owner=owner,
|
|
316
|
+
description=version_description,
|
|
317
|
+
metadata=metadata,
|
|
318
|
+
)
|
|
319
|
+
|
|
320
|
+
def update_model(self, model: RegisteredModel) -> RegisteredModel:
|
|
321
|
+
"""Update a registered model.
|
|
322
|
+
|
|
323
|
+
Args:
|
|
324
|
+
model: The registered model to update. Must have an ID.
|
|
325
|
+
|
|
326
|
+
Returns:
|
|
327
|
+
Updated registered model.
|
|
328
|
+
|
|
329
|
+
Raises:
|
|
330
|
+
TypeError: If model is not a RegisteredModel instance.
|
|
331
|
+
model_registry.exceptions.StoreError: If model does not have an ID.
|
|
332
|
+
"""
|
|
333
|
+
from model_registry.types import RegisteredModel
|
|
334
|
+
|
|
335
|
+
if not isinstance(model, RegisteredModel):
|
|
336
|
+
raise TypeError(f"Expected RegisteredModel, got {type(model).__name__}. ")
|
|
337
|
+
|
|
338
|
+
if self._mode == "proxy":
|
|
339
|
+
if not model.id:
|
|
340
|
+
raise ValueError("Model must have an ID to update")
|
|
341
|
+
data = self._proxy_client.update_registered_model(
|
|
342
|
+
model_id=model.id,
|
|
343
|
+
description=model.description,
|
|
344
|
+
custom_properties=model.custom_properties,
|
|
345
|
+
)
|
|
346
|
+
return self._dict_to_registered_model(data)
|
|
347
|
+
else:
|
|
348
|
+
return self._registry.update(model)
|
|
349
|
+
|
|
350
|
+
def update_model_version(self, model_version: ModelVersion) -> ModelVersion:
|
|
351
|
+
"""Update a model version.
|
|
352
|
+
|
|
353
|
+
Args:
|
|
354
|
+
model_version: The model version to update. Must have an ID.
|
|
355
|
+
|
|
356
|
+
Returns:
|
|
357
|
+
Updated model version.
|
|
358
|
+
|
|
359
|
+
Raises:
|
|
360
|
+
TypeError: If model_version is not a ModelVersion instance.
|
|
361
|
+
model_registry.exceptions.StoreError: If model version does not have an ID.
|
|
362
|
+
"""
|
|
363
|
+
from model_registry.types import ModelVersion
|
|
364
|
+
|
|
365
|
+
if not isinstance(model_version, ModelVersion):
|
|
366
|
+
raise TypeError(f"Expected ModelVersion, got {type(model_version).__name__}. ")
|
|
367
|
+
|
|
368
|
+
if self._mode == "proxy":
|
|
369
|
+
# Proxy mode requires model name and version name
|
|
370
|
+
# This is a limitation - need to fetch model first or change API
|
|
371
|
+
raise NotImplementedError(
|
|
372
|
+
"update_model_version in proxy mode requires model name. "
|
|
373
|
+
"Use the proxy client directly for this operation."
|
|
374
|
+
)
|
|
375
|
+
else:
|
|
376
|
+
return self._registry.update(model_version)
|
|
377
|
+
|
|
378
|
+
def update_model_artifact(self, model_artifact: ModelArtifact) -> ModelArtifact:
|
|
379
|
+
"""Update a model artifact.
|
|
380
|
+
|
|
381
|
+
Args:
|
|
382
|
+
model_artifact: The model artifact to update. Must have an ID.
|
|
383
|
+
|
|
384
|
+
Returns:
|
|
385
|
+
Updated model artifact.
|
|
386
|
+
|
|
387
|
+
Raises:
|
|
388
|
+
TypeError: If model_artifact is not a ModelArtifact instance.
|
|
389
|
+
model_registry.exceptions.StoreError: If model artifact does not have an ID.
|
|
390
|
+
"""
|
|
391
|
+
from model_registry.types import ModelArtifact
|
|
392
|
+
|
|
393
|
+
if not isinstance(model_artifact, ModelArtifact):
|
|
394
|
+
raise TypeError(f"Expected ModelArtifact, got {type(model_artifact).__name__}. ")
|
|
395
|
+
|
|
396
|
+
if self._mode == "proxy":
|
|
397
|
+
if not model_artifact.id:
|
|
398
|
+
raise ValueError("ModelArtifact must have an ID to update")
|
|
399
|
+
data = self._proxy_client.update_model_artifact(
|
|
400
|
+
artifact_id=model_artifact.id,
|
|
401
|
+
description=model_artifact.description,
|
|
402
|
+
custom_properties=model_artifact.custom_properties,
|
|
403
|
+
)
|
|
404
|
+
return self._dict_to_model_artifact(data)
|
|
405
|
+
else:
|
|
406
|
+
return self._registry.update(model_artifact)
|
|
407
|
+
|
|
408
|
+
def get_model(self, name: str) -> RegisteredModel:
|
|
409
|
+
"""Get a registered model.
|
|
410
|
+
|
|
411
|
+
Args:
|
|
412
|
+
name: Name of the model.
|
|
413
|
+
|
|
414
|
+
Returns:
|
|
415
|
+
Registered model.
|
|
416
|
+
|
|
417
|
+
Raises:
|
|
418
|
+
ValueError: If the model does not exist.
|
|
419
|
+
"""
|
|
420
|
+
if self._mode == "proxy":
|
|
421
|
+
try:
|
|
422
|
+
data = self._proxy_client.get_registered_model(name)
|
|
423
|
+
return self._dict_to_registered_model(data)
|
|
424
|
+
except Exception as e:
|
|
425
|
+
if "404" in str(e) or "not found" in str(e).lower():
|
|
426
|
+
raise ValueError(f"Model {name!r} not found") from e
|
|
427
|
+
raise
|
|
428
|
+
else:
|
|
429
|
+
model = self._registry.get_registered_model(name)
|
|
430
|
+
if model is None:
|
|
431
|
+
raise ValueError(f"Model {name!r} not found")
|
|
432
|
+
return model
|
|
433
|
+
|
|
434
|
+
def get_model_version(self, name: str, version: str) -> ModelVersion:
|
|
435
|
+
"""Get a model version.
|
|
436
|
+
|
|
437
|
+
Args:
|
|
438
|
+
name: Name of the model.
|
|
439
|
+
version: Version of the model.
|
|
440
|
+
|
|
441
|
+
Returns:
|
|
442
|
+
Model version.
|
|
443
|
+
|
|
444
|
+
Raises:
|
|
445
|
+
model_registry.exceptions.StoreError: If the model does not exist.
|
|
446
|
+
ValueError: If the version does not exist.
|
|
447
|
+
"""
|
|
448
|
+
if self._mode == "proxy":
|
|
449
|
+
try:
|
|
450
|
+
data = self._proxy_client.get_model_version(name, version)
|
|
451
|
+
return self._dict_to_model_version(data)
|
|
452
|
+
except Exception as e:
|
|
453
|
+
if "404" in str(e) or "not found" in str(e).lower():
|
|
454
|
+
raise ValueError(f"Model version {version!r} not found for model {name!r}") from e
|
|
455
|
+
raise
|
|
456
|
+
else:
|
|
457
|
+
model_version = self._registry.get_model_version(name, version)
|
|
458
|
+
if model_version is None:
|
|
459
|
+
raise ValueError(f"Model version {version!r} not found for model {name!r}")
|
|
460
|
+
return model_version
|
|
461
|
+
|
|
462
|
+
def get_model_artifact(self, name: str, version: str) -> ModelArtifact:
|
|
463
|
+
"""Get a model artifact.
|
|
464
|
+
|
|
465
|
+
Args:
|
|
466
|
+
name: Name of the model.
|
|
467
|
+
version: Version of the model.
|
|
468
|
+
|
|
469
|
+
Returns:
|
|
470
|
+
Model artifact.
|
|
471
|
+
|
|
472
|
+
Raises:
|
|
473
|
+
model_registry.exceptions.StoreError: If either the model or the version don't exist.
|
|
474
|
+
ValueError: If the artifact does not exist.
|
|
475
|
+
"""
|
|
476
|
+
if self._mode == "proxy":
|
|
477
|
+
# In proxy mode, we need artifact ID, not model name/version
|
|
478
|
+
# This is a limitation of the current API design
|
|
479
|
+
raise NotImplementedError(
|
|
480
|
+
"get_model_artifact by name/version not supported in proxy mode. "
|
|
481
|
+
"Use get_model_artifact_by_id() instead."
|
|
482
|
+
)
|
|
483
|
+
else:
|
|
484
|
+
artifact = self._registry.get_model_artifact(name, version)
|
|
485
|
+
if artifact is None:
|
|
486
|
+
raise ValueError(f"Model artifact not found for model {name!r} version {version!r}")
|
|
487
|
+
return artifact
|
|
488
|
+
|
|
489
|
+
def get_model_artifact_by_id(self, artifact_id: str) -> ModelArtifact:
|
|
490
|
+
"""Get a model artifact by ID.
|
|
491
|
+
|
|
492
|
+
This method is available in proxy mode.
|
|
493
|
+
|
|
494
|
+
Args:
|
|
495
|
+
artifact_id: ID of the artifact.
|
|
496
|
+
|
|
497
|
+
Returns:
|
|
498
|
+
Model artifact.
|
|
499
|
+
"""
|
|
500
|
+
if self._mode == "proxy":
|
|
501
|
+
data = self._proxy_client.get_model_artifact(artifact_id)
|
|
502
|
+
return self._dict_to_model_artifact(data)
|
|
503
|
+
else:
|
|
504
|
+
# Direct mode doesn't have this method in the library
|
|
505
|
+
raise NotImplementedError(
|
|
506
|
+
"get_model_artifact_by_id not available in direct mode. "
|
|
507
|
+
"Use get_model_artifact(name, version) instead."
|
|
508
|
+
)
|
|
509
|
+
|
|
510
|
+
def list_models(self) -> Iterator[RegisteredModel]:
|
|
511
|
+
"""Get an iterator for registered models.
|
|
512
|
+
|
|
513
|
+
Yields:
|
|
514
|
+
Registered models.
|
|
515
|
+
"""
|
|
516
|
+
if self._mode == "proxy":
|
|
517
|
+
next_page_token = None
|
|
518
|
+
while True:
|
|
519
|
+
result = self._proxy_client.list_registered_models(
|
|
520
|
+
page_size=100,
|
|
521
|
+
next_page_token=next_page_token,
|
|
522
|
+
)
|
|
523
|
+
items = result.get("items", [])
|
|
524
|
+
for item in items:
|
|
525
|
+
yield self._dict_to_registered_model(item)
|
|
526
|
+
|
|
527
|
+
next_page_token = result.get("nextPageToken")
|
|
528
|
+
if not next_page_token:
|
|
529
|
+
break
|
|
530
|
+
else:
|
|
531
|
+
yield from self._registry.get_registered_models()
|
|
532
|
+
|
|
533
|
+
def list_model_versions(self, name: str) -> Iterator[ModelVersion]:
|
|
534
|
+
"""Get an iterator for model versions.
|
|
535
|
+
|
|
536
|
+
Args:
|
|
537
|
+
name: Name of the model.
|
|
538
|
+
|
|
539
|
+
Yields:
|
|
540
|
+
Model versions.
|
|
541
|
+
|
|
542
|
+
Raises:
|
|
543
|
+
model_registry.exceptions.StoreError: If the model does not exist.
|
|
544
|
+
"""
|
|
545
|
+
if self._mode == "proxy":
|
|
546
|
+
next_page_token = None
|
|
547
|
+
while True:
|
|
548
|
+
result = self._proxy_client.list_model_versions(
|
|
549
|
+
model_name=name,
|
|
550
|
+
page_size=100,
|
|
551
|
+
next_page_token=next_page_token,
|
|
552
|
+
)
|
|
553
|
+
items = result.get("items", [])
|
|
554
|
+
for item in items:
|
|
555
|
+
yield self._dict_to_model_version(item)
|
|
556
|
+
|
|
557
|
+
next_page_token = result.get("nextPageToken")
|
|
558
|
+
if not next_page_token:
|
|
559
|
+
break
|
|
560
|
+
else:
|
|
561
|
+
yield from self._registry.get_model_versions(name)
|