viettelcloud-aiplatform 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- viettelcloud/__init__.py +1 -0
- viettelcloud/aiplatform/__init__.py +15 -0
- viettelcloud/aiplatform/common/__init__.py +0 -0
- viettelcloud/aiplatform/common/constants.py +22 -0
- viettelcloud/aiplatform/common/types.py +28 -0
- viettelcloud/aiplatform/common/utils.py +40 -0
- viettelcloud/aiplatform/hub/OWNERS +14 -0
- viettelcloud/aiplatform/hub/__init__.py +25 -0
- viettelcloud/aiplatform/hub/api/__init__.py +13 -0
- viettelcloud/aiplatform/hub/api/_proxy_client.py +355 -0
- viettelcloud/aiplatform/hub/api/model_registry_client.py +561 -0
- viettelcloud/aiplatform/hub/api/model_registry_client_test.py +462 -0
- viettelcloud/aiplatform/optimizer/__init__.py +45 -0
- viettelcloud/aiplatform/optimizer/api/__init__.py +0 -0
- viettelcloud/aiplatform/optimizer/api/optimizer_client.py +248 -0
- viettelcloud/aiplatform/optimizer/backends/__init__.py +13 -0
- viettelcloud/aiplatform/optimizer/backends/base.py +77 -0
- viettelcloud/aiplatform/optimizer/backends/kubernetes/__init__.py +13 -0
- viettelcloud/aiplatform/optimizer/backends/kubernetes/backend.py +563 -0
- viettelcloud/aiplatform/optimizer/backends/kubernetes/utils.py +112 -0
- viettelcloud/aiplatform/optimizer/constants/__init__.py +13 -0
- viettelcloud/aiplatform/optimizer/constants/constants.py +59 -0
- viettelcloud/aiplatform/optimizer/types/__init__.py +13 -0
- viettelcloud/aiplatform/optimizer/types/algorithm_types.py +87 -0
- viettelcloud/aiplatform/optimizer/types/optimization_types.py +135 -0
- viettelcloud/aiplatform/optimizer/types/search_types.py +95 -0
- viettelcloud/aiplatform/py.typed +0 -0
- viettelcloud/aiplatform/trainer/__init__.py +82 -0
- viettelcloud/aiplatform/trainer/api/__init__.py +3 -0
- viettelcloud/aiplatform/trainer/api/trainer_client.py +277 -0
- viettelcloud/aiplatform/trainer/api/trainer_client_test.py +72 -0
- viettelcloud/aiplatform/trainer/backends/__init__.py +0 -0
- viettelcloud/aiplatform/trainer/backends/base.py +94 -0
- viettelcloud/aiplatform/trainer/backends/container/adapters/base.py +195 -0
- viettelcloud/aiplatform/trainer/backends/container/adapters/docker.py +231 -0
- viettelcloud/aiplatform/trainer/backends/container/adapters/podman.py +258 -0
- viettelcloud/aiplatform/trainer/backends/container/backend.py +668 -0
- viettelcloud/aiplatform/trainer/backends/container/backend_test.py +867 -0
- viettelcloud/aiplatform/trainer/backends/container/runtime_loader.py +631 -0
- viettelcloud/aiplatform/trainer/backends/container/runtime_loader_test.py +637 -0
- viettelcloud/aiplatform/trainer/backends/container/types.py +67 -0
- viettelcloud/aiplatform/trainer/backends/container/utils.py +213 -0
- viettelcloud/aiplatform/trainer/backends/kubernetes/__init__.py +0 -0
- viettelcloud/aiplatform/trainer/backends/kubernetes/backend.py +710 -0
- viettelcloud/aiplatform/trainer/backends/kubernetes/backend_test.py +1344 -0
- viettelcloud/aiplatform/trainer/backends/kubernetes/constants.py +15 -0
- viettelcloud/aiplatform/trainer/backends/kubernetes/utils.py +636 -0
- viettelcloud/aiplatform/trainer/backends/kubernetes/utils_test.py +582 -0
- viettelcloud/aiplatform/trainer/backends/localprocess/__init__.py +0 -0
- viettelcloud/aiplatform/trainer/backends/localprocess/backend.py +306 -0
- viettelcloud/aiplatform/trainer/backends/localprocess/backend_test.py +501 -0
- viettelcloud/aiplatform/trainer/backends/localprocess/constants.py +90 -0
- viettelcloud/aiplatform/trainer/backends/localprocess/job.py +184 -0
- viettelcloud/aiplatform/trainer/backends/localprocess/types.py +52 -0
- viettelcloud/aiplatform/trainer/backends/localprocess/utils.py +302 -0
- viettelcloud/aiplatform/trainer/constants/__init__.py +0 -0
- viettelcloud/aiplatform/trainer/constants/constants.py +179 -0
- viettelcloud/aiplatform/trainer/options/__init__.py +52 -0
- viettelcloud/aiplatform/trainer/options/common.py +55 -0
- viettelcloud/aiplatform/trainer/options/kubernetes.py +502 -0
- viettelcloud/aiplatform/trainer/options/kubernetes_test.py +259 -0
- viettelcloud/aiplatform/trainer/options/localprocess.py +20 -0
- viettelcloud/aiplatform/trainer/test/common.py +22 -0
- viettelcloud/aiplatform/trainer/types/__init__.py +0 -0
- viettelcloud/aiplatform/trainer/types/types.py +517 -0
- viettelcloud/aiplatform/trainer/types/types_test.py +115 -0
- viettelcloud_aiplatform-0.3.0.dist-info/METADATA +226 -0
- viettelcloud_aiplatform-0.3.0.dist-info/RECORD +71 -0
- viettelcloud_aiplatform-0.3.0.dist-info/WHEEL +4 -0
- viettelcloud_aiplatform-0.3.0.dist-info/licenses/LICENSE +201 -0
- viettelcloud_aiplatform-0.3.0.dist-info/licenses/NOTICE +36 -0
|
@@ -0,0 +1,462 @@
|
|
|
1
|
+
# Copyright 2025 The Kubeflow Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Tests for ModelRegistryClient."""
|
|
16
|
+
|
|
17
|
+
from __future__ import annotations
|
|
18
|
+
|
|
19
|
+
from unittest.mock import MagicMock, Mock
|
|
20
|
+
|
|
21
|
+
import pytest
|
|
22
|
+
|
|
23
|
+
from viettelcloud.aiplatform.trainer.test.common import FAILED, SUCCESS, TestCase
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@pytest.fixture(autouse=True)
|
|
27
|
+
def skip_if_no_model_registry():
|
|
28
|
+
"""Skip tests if model-registry not installed."""
|
|
29
|
+
pytest.importorskip("model_registry")
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@pytest.fixture
|
|
33
|
+
def mock_registry():
|
|
34
|
+
"""Create a mock ModelRegistry with all methods we wrap."""
|
|
35
|
+
registry = MagicMock()
|
|
36
|
+
# Set up return values for list methods to be iterable
|
|
37
|
+
registry.get_registered_models.return_value = iter([])
|
|
38
|
+
registry.get_model_versions.return_value = iter([])
|
|
39
|
+
return registry
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@pytest.fixture
|
|
43
|
+
def client(mock_registry, monkeypatch):
|
|
44
|
+
"""Create ModelRegistryClient with mock registry."""
|
|
45
|
+
from viettelcloud.aiplatform.hub.api.model_registry_client import ModelRegistryClient
|
|
46
|
+
|
|
47
|
+
# Patch ModelRegistry so __init__ uses the mock
|
|
48
|
+
monkeypatch.setattr("model_registry.ModelRegistry", lambda **kwargs: mock_registry)
|
|
49
|
+
|
|
50
|
+
return ModelRegistryClient(base_url="http://localhost", port=8080, author="test")
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
@pytest.mark.parametrize(
|
|
54
|
+
"test_case",
|
|
55
|
+
[
|
|
56
|
+
TestCase(
|
|
57
|
+
name="raises helpful ImportError when model-registry not installed",
|
|
58
|
+
expected_status=FAILED,
|
|
59
|
+
config={
|
|
60
|
+
"base_url": "http://localhost",
|
|
61
|
+
"port": 9080,
|
|
62
|
+
"author": "test-author",
|
|
63
|
+
},
|
|
64
|
+
expected_error=ImportError,
|
|
65
|
+
),
|
|
66
|
+
],
|
|
67
|
+
)
|
|
68
|
+
def test_init_import_error(test_case, monkeypatch):
|
|
69
|
+
"""Test that __init__ raises helpful ImportError when model-registry missing."""
|
|
70
|
+
|
|
71
|
+
from viettelcloud.aiplatform.hub.api.model_registry_client import ModelRegistryClient
|
|
72
|
+
|
|
73
|
+
# Simulate missing model_registry by making import fail
|
|
74
|
+
def mock_import(name, *args, **kwargs):
|
|
75
|
+
if name == "model_registry":
|
|
76
|
+
raise ImportError("No module named 'model_registry'")
|
|
77
|
+
return __import__(name, *args, **kwargs)
|
|
78
|
+
|
|
79
|
+
monkeypatch.setattr("builtins.__import__", mock_import)
|
|
80
|
+
|
|
81
|
+
try:
|
|
82
|
+
ModelRegistryClient(**test_case.config)
|
|
83
|
+
assert test_case.expected_status == SUCCESS
|
|
84
|
+
except ImportError as e:
|
|
85
|
+
assert test_case.expected_status == FAILED
|
|
86
|
+
assert "pip install 'viettelcloud-aiplatform[hub]'" in str(e)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
@pytest.mark.parametrize(
|
|
90
|
+
"test_case",
|
|
91
|
+
[
|
|
92
|
+
TestCase(
|
|
93
|
+
name="http URL infers port=8080 and is_secure=False",
|
|
94
|
+
expected_status=SUCCESS,
|
|
95
|
+
config={
|
|
96
|
+
"base_url": "http://localhost",
|
|
97
|
+
"author": "test",
|
|
98
|
+
},
|
|
99
|
+
expected_output={
|
|
100
|
+
"port": 8080,
|
|
101
|
+
"is_secure": False,
|
|
102
|
+
},
|
|
103
|
+
),
|
|
104
|
+
TestCase(
|
|
105
|
+
name="https URL infers port=443 and is_secure=True",
|
|
106
|
+
expected_status=SUCCESS,
|
|
107
|
+
config={
|
|
108
|
+
"base_url": "https://registry.example.com",
|
|
109
|
+
"author": "test",
|
|
110
|
+
},
|
|
111
|
+
expected_output={
|
|
112
|
+
"port": 443,
|
|
113
|
+
"is_secure": True,
|
|
114
|
+
},
|
|
115
|
+
),
|
|
116
|
+
TestCase(
|
|
117
|
+
name="explicit port overrides inference",
|
|
118
|
+
expected_status=SUCCESS,
|
|
119
|
+
config={
|
|
120
|
+
"base_url": "http://localhost",
|
|
121
|
+
"port": 9080,
|
|
122
|
+
"author": "test-author",
|
|
123
|
+
},
|
|
124
|
+
expected_output={
|
|
125
|
+
"port": 9080,
|
|
126
|
+
"is_secure": False,
|
|
127
|
+
},
|
|
128
|
+
),
|
|
129
|
+
],
|
|
130
|
+
)
|
|
131
|
+
def test_init(test_case, monkeypatch):
|
|
132
|
+
"""Test ModelRegistryClient initialization with different URL schemes."""
|
|
133
|
+
|
|
134
|
+
from viettelcloud.aiplatform.hub.api.model_registry_client import ModelRegistryClient
|
|
135
|
+
|
|
136
|
+
mock_registry_class = MagicMock()
|
|
137
|
+
mock_registry_instance = MagicMock()
|
|
138
|
+
mock_registry_class.return_value = mock_registry_instance
|
|
139
|
+
|
|
140
|
+
monkeypatch.setattr("model_registry.ModelRegistry", mock_registry_class)
|
|
141
|
+
|
|
142
|
+
try:
|
|
143
|
+
client = ModelRegistryClient(**test_case.config)
|
|
144
|
+
|
|
145
|
+
assert test_case.expected_status == SUCCESS
|
|
146
|
+
mock_registry_class.assert_called_once()
|
|
147
|
+
call_kwargs = mock_registry_class.call_args[1]
|
|
148
|
+
assert call_kwargs["port"] == test_case.expected_output["port"]
|
|
149
|
+
assert call_kwargs["is_secure"] is test_case.expected_output["is_secure"]
|
|
150
|
+
assert client._registry == mock_registry_instance
|
|
151
|
+
|
|
152
|
+
except Exception as e:
|
|
153
|
+
assert test_case.expected_status == FAILED
|
|
154
|
+
assert type(e) is test_case.expected_error
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
@pytest.mark.parametrize(
|
|
158
|
+
"test_case",
|
|
159
|
+
[
|
|
160
|
+
TestCase(
|
|
161
|
+
name="register_model delegates to ModelRegistry.register_model",
|
|
162
|
+
expected_status=SUCCESS,
|
|
163
|
+
config={
|
|
164
|
+
"name": "test",
|
|
165
|
+
"uri": "s3://test",
|
|
166
|
+
"model_format_name": "pytorch",
|
|
167
|
+
"model_format_version": "1.0",
|
|
168
|
+
"version": "v1",
|
|
169
|
+
},
|
|
170
|
+
),
|
|
171
|
+
],
|
|
172
|
+
)
|
|
173
|
+
def test_register_model(test_case, client, mock_registry):
|
|
174
|
+
"""Test register_model delegates to ModelRegistry.register_model."""
|
|
175
|
+
|
|
176
|
+
try:
|
|
177
|
+
client.register_model(**test_case.config)
|
|
178
|
+
|
|
179
|
+
assert test_case.expected_status == SUCCESS
|
|
180
|
+
assert mock_registry.register_model.called
|
|
181
|
+
|
|
182
|
+
except Exception as e:
|
|
183
|
+
assert test_case.expected_status == FAILED
|
|
184
|
+
assert type(e) is test_case.expected_error
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
@pytest.mark.parametrize(
|
|
188
|
+
"test_case",
|
|
189
|
+
[
|
|
190
|
+
TestCase(
|
|
191
|
+
name="get_model delegates to get_registered_model",
|
|
192
|
+
expected_status=SUCCESS,
|
|
193
|
+
config={
|
|
194
|
+
"name": "test-model",
|
|
195
|
+
},
|
|
196
|
+
),
|
|
197
|
+
],
|
|
198
|
+
)
|
|
199
|
+
def test_get_model(test_case, client, mock_registry):
|
|
200
|
+
"""Test get_model delegates to get_registered_model."""
|
|
201
|
+
|
|
202
|
+
try:
|
|
203
|
+
client.get_model(test_case.config["name"])
|
|
204
|
+
|
|
205
|
+
assert test_case.expected_status == SUCCESS
|
|
206
|
+
mock_registry.get_registered_model.assert_called_once_with(test_case.config["name"])
|
|
207
|
+
|
|
208
|
+
except Exception as e:
|
|
209
|
+
assert test_case.expected_status == FAILED
|
|
210
|
+
assert type(e) is test_case.expected_error
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
@pytest.mark.parametrize(
|
|
214
|
+
"test_case",
|
|
215
|
+
[
|
|
216
|
+
TestCase(
|
|
217
|
+
name="get_model_version delegates to get_model_version",
|
|
218
|
+
expected_status=SUCCESS,
|
|
219
|
+
config={
|
|
220
|
+
"name": "test-model",
|
|
221
|
+
"version": "v1",
|
|
222
|
+
},
|
|
223
|
+
),
|
|
224
|
+
],
|
|
225
|
+
)
|
|
226
|
+
def test_get_model_version(test_case, client, mock_registry):
|
|
227
|
+
"""Test get_model_version delegates to ModelRegistry.get_model_version."""
|
|
228
|
+
|
|
229
|
+
try:
|
|
230
|
+
client.get_model_version(test_case.config["name"], test_case.config["version"])
|
|
231
|
+
|
|
232
|
+
assert test_case.expected_status == SUCCESS
|
|
233
|
+
mock_registry.get_model_version.assert_called_once_with(
|
|
234
|
+
test_case.config["name"], test_case.config["version"]
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
except Exception as e:
|
|
238
|
+
assert test_case.expected_status == FAILED
|
|
239
|
+
assert type(e) is test_case.expected_error
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
@pytest.mark.parametrize(
|
|
243
|
+
"test_case",
|
|
244
|
+
[
|
|
245
|
+
TestCase(
|
|
246
|
+
name="get_model_artifact delegates to get_model_artifact",
|
|
247
|
+
expected_status=SUCCESS,
|
|
248
|
+
config={
|
|
249
|
+
"name": "test-model",
|
|
250
|
+
"version": "v1",
|
|
251
|
+
},
|
|
252
|
+
),
|
|
253
|
+
],
|
|
254
|
+
)
|
|
255
|
+
def test_get_model_artifact(test_case, client, mock_registry):
|
|
256
|
+
"""Test get_model_artifact delegates to ModelRegistry.get_model_artifact."""
|
|
257
|
+
|
|
258
|
+
try:
|
|
259
|
+
client.get_model_artifact(test_case.config["name"], test_case.config["version"])
|
|
260
|
+
|
|
261
|
+
assert test_case.expected_status == SUCCESS
|
|
262
|
+
mock_registry.get_model_artifact.assert_called_once_with(
|
|
263
|
+
test_case.config["name"], test_case.config["version"]
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
except Exception as e:
|
|
267
|
+
assert test_case.expected_status == FAILED
|
|
268
|
+
assert type(e) is test_case.expected_error
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
@pytest.mark.parametrize(
|
|
272
|
+
"test_case",
|
|
273
|
+
[
|
|
274
|
+
TestCase(
|
|
275
|
+
name="list_models returns iterator that yields from pager",
|
|
276
|
+
expected_status=SUCCESS,
|
|
277
|
+
config={
|
|
278
|
+
"mock_models_count": 2,
|
|
279
|
+
},
|
|
280
|
+
expected_output=2,
|
|
281
|
+
),
|
|
282
|
+
TestCase(
|
|
283
|
+
name="list_models returns empty iterator when no models",
|
|
284
|
+
expected_status=SUCCESS,
|
|
285
|
+
config={
|
|
286
|
+
"mock_models_count": 0,
|
|
287
|
+
},
|
|
288
|
+
expected_output=0,
|
|
289
|
+
),
|
|
290
|
+
],
|
|
291
|
+
)
|
|
292
|
+
def test_list_models(test_case, client, mock_registry):
|
|
293
|
+
"""Test list_models returns an iterator that yields from pager."""
|
|
294
|
+
|
|
295
|
+
mock_models = [Mock() for _ in range(test_case.config["mock_models_count"])]
|
|
296
|
+
mock_registry.get_registered_models.return_value = iter(mock_models)
|
|
297
|
+
|
|
298
|
+
try:
|
|
299
|
+
result = client.list_models()
|
|
300
|
+
items = list(result)
|
|
301
|
+
|
|
302
|
+
assert test_case.expected_status == SUCCESS
|
|
303
|
+
assert len(items) == test_case.expected_output
|
|
304
|
+
assert mock_registry.get_registered_models.called
|
|
305
|
+
|
|
306
|
+
except Exception as e:
|
|
307
|
+
assert test_case.expected_status == FAILED
|
|
308
|
+
assert type(e) is test_case.expected_error
|
|
309
|
+
|
|
310
|
+
|
|
311
|
+
@pytest.mark.parametrize(
|
|
312
|
+
"test_case",
|
|
313
|
+
[
|
|
314
|
+
TestCase(
|
|
315
|
+
name="update_model delegates to ModelRegistry.update",
|
|
316
|
+
expected_status=SUCCESS,
|
|
317
|
+
config={
|
|
318
|
+
"model_name": "test-model",
|
|
319
|
+
},
|
|
320
|
+
),
|
|
321
|
+
TestCase(
|
|
322
|
+
name="update_model raises TypeError for ModelVersion",
|
|
323
|
+
expected_status=FAILED,
|
|
324
|
+
config={
|
|
325
|
+
"wrong_type": "ModelVersion",
|
|
326
|
+
"name": "v1",
|
|
327
|
+
},
|
|
328
|
+
expected_error=TypeError,
|
|
329
|
+
),
|
|
330
|
+
TestCase(
|
|
331
|
+
name="update_model raises TypeError for ModelArtifact",
|
|
332
|
+
expected_status=FAILED,
|
|
333
|
+
config={
|
|
334
|
+
"wrong_type": "ModelArtifact",
|
|
335
|
+
"name": "artifact",
|
|
336
|
+
},
|
|
337
|
+
expected_error=TypeError,
|
|
338
|
+
),
|
|
339
|
+
],
|
|
340
|
+
)
|
|
341
|
+
def test_update_model(test_case, client, mock_registry):
|
|
342
|
+
"""Test update_model delegates to ModelRegistry.update and validates types."""
|
|
343
|
+
|
|
344
|
+
from model_registry.types import ModelArtifact, ModelVersion, RegisteredModel
|
|
345
|
+
|
|
346
|
+
try:
|
|
347
|
+
if test_case.expected_status == SUCCESS:
|
|
348
|
+
model = RegisteredModel(name=test_case.config["model_name"])
|
|
349
|
+
client.update_model(model)
|
|
350
|
+
mock_registry.update.assert_called_once_with(model)
|
|
351
|
+
else:
|
|
352
|
+
# Test type checking
|
|
353
|
+
if test_case.config["wrong_type"] == "ModelVersion":
|
|
354
|
+
wrong_type = ModelVersion(name=test_case.config["name"])
|
|
355
|
+
else:
|
|
356
|
+
wrong_type = ModelArtifact(name=test_case.config["name"], uri="s3://bucket/model")
|
|
357
|
+
client.update_model(wrong_type)
|
|
358
|
+
|
|
359
|
+
assert test_case.expected_status == SUCCESS
|
|
360
|
+
|
|
361
|
+
except TypeError as e:
|
|
362
|
+
assert test_case.expected_status == FAILED
|
|
363
|
+
assert "Expected RegisteredModel" in str(e)
|
|
364
|
+
except Exception as e:
|
|
365
|
+
assert test_case.expected_status == FAILED
|
|
366
|
+
assert type(e) is test_case.expected_error
|
|
367
|
+
|
|
368
|
+
|
|
369
|
+
@pytest.mark.parametrize(
|
|
370
|
+
"test_case",
|
|
371
|
+
[
|
|
372
|
+
TestCase(
|
|
373
|
+
name="update_model_version delegates to ModelRegistry.update",
|
|
374
|
+
expected_status=SUCCESS,
|
|
375
|
+
config={
|
|
376
|
+
"version_name": "v1.0",
|
|
377
|
+
},
|
|
378
|
+
),
|
|
379
|
+
TestCase(
|
|
380
|
+
name="update_model_version raises TypeError for RegisteredModel",
|
|
381
|
+
expected_status=FAILED,
|
|
382
|
+
config={
|
|
383
|
+
"wrong_type": "RegisteredModel",
|
|
384
|
+
"name": "model",
|
|
385
|
+
},
|
|
386
|
+
expected_error=TypeError,
|
|
387
|
+
),
|
|
388
|
+
],
|
|
389
|
+
)
|
|
390
|
+
def test_update_model_version(test_case, client, mock_registry):
|
|
391
|
+
"""Test update_model_version delegates to ModelRegistry.update and validates types."""
|
|
392
|
+
|
|
393
|
+
from model_registry.types import ModelVersion, RegisteredModel
|
|
394
|
+
|
|
395
|
+
try:
|
|
396
|
+
if test_case.expected_status == SUCCESS:
|
|
397
|
+
version = ModelVersion(name=test_case.config["version_name"])
|
|
398
|
+
client.update_model_version(version)
|
|
399
|
+
mock_registry.update.assert_called_once_with(version)
|
|
400
|
+
else:
|
|
401
|
+
# Test type checking
|
|
402
|
+
wrong_type = RegisteredModel(name=test_case.config["name"])
|
|
403
|
+
client.update_model_version(wrong_type)
|
|
404
|
+
|
|
405
|
+
assert test_case.expected_status == SUCCESS
|
|
406
|
+
|
|
407
|
+
except TypeError as e:
|
|
408
|
+
assert test_case.expected_status == FAILED
|
|
409
|
+
assert "Expected ModelVersion" in str(e)
|
|
410
|
+
except Exception as e:
|
|
411
|
+
assert test_case.expected_status == FAILED
|
|
412
|
+
assert type(e) is test_case.expected_error
|
|
413
|
+
|
|
414
|
+
|
|
415
|
+
@pytest.mark.parametrize(
|
|
416
|
+
"test_case",
|
|
417
|
+
[
|
|
418
|
+
TestCase(
|
|
419
|
+
name="update_model_artifact delegates to ModelRegistry.update",
|
|
420
|
+
expected_status=SUCCESS,
|
|
421
|
+
config={
|
|
422
|
+
"artifact_name": "model-artifact",
|
|
423
|
+
"uri": "s3://bucket/model",
|
|
424
|
+
},
|
|
425
|
+
),
|
|
426
|
+
TestCase(
|
|
427
|
+
name="update_model_artifact raises TypeError for RegisteredModel",
|
|
428
|
+
expected_status=FAILED,
|
|
429
|
+
config={
|
|
430
|
+
"wrong_type": "RegisteredModel",
|
|
431
|
+
"name": "model",
|
|
432
|
+
},
|
|
433
|
+
expected_error=TypeError,
|
|
434
|
+
),
|
|
435
|
+
],
|
|
436
|
+
)
|
|
437
|
+
def test_update_model_artifact(test_case, client, mock_registry):
|
|
438
|
+
"""Test update_model_artifact delegates to ModelRegistry.update and validates types."""
|
|
439
|
+
|
|
440
|
+
from model_registry.types import ModelArtifact, RegisteredModel
|
|
441
|
+
|
|
442
|
+
try:
|
|
443
|
+
if test_case.expected_status == SUCCESS:
|
|
444
|
+
artifact = ModelArtifact(
|
|
445
|
+
name=test_case.config["artifact_name"],
|
|
446
|
+
uri=test_case.config["uri"],
|
|
447
|
+
)
|
|
448
|
+
client.update_model_artifact(artifact)
|
|
449
|
+
mock_registry.update.assert_called_once_with(artifact)
|
|
450
|
+
else:
|
|
451
|
+
# Test type checking
|
|
452
|
+
wrong_type = RegisteredModel(name=test_case.config["name"])
|
|
453
|
+
client.update_model_artifact(wrong_type)
|
|
454
|
+
|
|
455
|
+
assert test_case.expected_status == SUCCESS
|
|
456
|
+
|
|
457
|
+
except TypeError as e:
|
|
458
|
+
assert test_case.expected_status == FAILED
|
|
459
|
+
assert "Expected ModelArtifact" in str(e)
|
|
460
|
+
except Exception as e:
|
|
461
|
+
assert test_case.expected_status == FAILED
|
|
462
|
+
assert type(e) is test_case.expected_error
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
# Copyright 2025 The Kubeflow Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
# Import common types.
|
|
16
|
+
from viettelcloud.aiplatform.common.types import KubernetesBackendConfig
|
|
17
|
+
|
|
18
|
+
# Import the Kubeflow Optimizer client.
|
|
19
|
+
from viettelcloud.aiplatform.optimizer.api.optimizer_client import OptimizerClient
|
|
20
|
+
|
|
21
|
+
# Import the Kubeflow Optimizer types.
|
|
22
|
+
from viettelcloud.aiplatform.optimizer.types.algorithm_types import GridSearch, RandomSearch
|
|
23
|
+
from viettelcloud.aiplatform.optimizer.types.optimization_types import (
|
|
24
|
+
Objective,
|
|
25
|
+
OptimizationJob,
|
|
26
|
+
Result,
|
|
27
|
+
TrialConfig,
|
|
28
|
+
)
|
|
29
|
+
from viettelcloud.aiplatform.optimizer.types.search_types import Search
|
|
30
|
+
|
|
31
|
+
# Import the Kubeflow Trainer types.
|
|
32
|
+
from viettelcloud.aiplatform.trainer.types.types import TrainJobTemplate
|
|
33
|
+
|
|
34
|
+
__all__ = [
|
|
35
|
+
"GridSearch",
|
|
36
|
+
"KubernetesBackendConfig",
|
|
37
|
+
"Objective",
|
|
38
|
+
"OptimizationJob",
|
|
39
|
+
"OptimizerClient",
|
|
40
|
+
"RandomSearch",
|
|
41
|
+
"Result",
|
|
42
|
+
"Search",
|
|
43
|
+
"TrainJobTemplate",
|
|
44
|
+
"TrialConfig",
|
|
45
|
+
]
|
|
File without changes
|