viettelcloud-aiplatform 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- viettelcloud/__init__.py +1 -0
- viettelcloud/aiplatform/__init__.py +15 -0
- viettelcloud/aiplatform/common/__init__.py +0 -0
- viettelcloud/aiplatform/common/constants.py +22 -0
- viettelcloud/aiplatform/common/types.py +28 -0
- viettelcloud/aiplatform/common/utils.py +40 -0
- viettelcloud/aiplatform/hub/OWNERS +14 -0
- viettelcloud/aiplatform/hub/__init__.py +25 -0
- viettelcloud/aiplatform/hub/api/__init__.py +13 -0
- viettelcloud/aiplatform/hub/api/_proxy_client.py +355 -0
- viettelcloud/aiplatform/hub/api/model_registry_client.py +561 -0
- viettelcloud/aiplatform/hub/api/model_registry_client_test.py +462 -0
- viettelcloud/aiplatform/optimizer/__init__.py +45 -0
- viettelcloud/aiplatform/optimizer/api/__init__.py +0 -0
- viettelcloud/aiplatform/optimizer/api/optimizer_client.py +248 -0
- viettelcloud/aiplatform/optimizer/backends/__init__.py +13 -0
- viettelcloud/aiplatform/optimizer/backends/base.py +77 -0
- viettelcloud/aiplatform/optimizer/backends/kubernetes/__init__.py +13 -0
- viettelcloud/aiplatform/optimizer/backends/kubernetes/backend.py +563 -0
- viettelcloud/aiplatform/optimizer/backends/kubernetes/utils.py +112 -0
- viettelcloud/aiplatform/optimizer/constants/__init__.py +13 -0
- viettelcloud/aiplatform/optimizer/constants/constants.py +59 -0
- viettelcloud/aiplatform/optimizer/types/__init__.py +13 -0
- viettelcloud/aiplatform/optimizer/types/algorithm_types.py +87 -0
- viettelcloud/aiplatform/optimizer/types/optimization_types.py +135 -0
- viettelcloud/aiplatform/optimizer/types/search_types.py +95 -0
- viettelcloud/aiplatform/py.typed +0 -0
- viettelcloud/aiplatform/trainer/__init__.py +82 -0
- viettelcloud/aiplatform/trainer/api/__init__.py +3 -0
- viettelcloud/aiplatform/trainer/api/trainer_client.py +277 -0
- viettelcloud/aiplatform/trainer/api/trainer_client_test.py +72 -0
- viettelcloud/aiplatform/trainer/backends/__init__.py +0 -0
- viettelcloud/aiplatform/trainer/backends/base.py +94 -0
- viettelcloud/aiplatform/trainer/backends/container/adapters/base.py +195 -0
- viettelcloud/aiplatform/trainer/backends/container/adapters/docker.py +231 -0
- viettelcloud/aiplatform/trainer/backends/container/adapters/podman.py +258 -0
- viettelcloud/aiplatform/trainer/backends/container/backend.py +668 -0
- viettelcloud/aiplatform/trainer/backends/container/backend_test.py +867 -0
- viettelcloud/aiplatform/trainer/backends/container/runtime_loader.py +631 -0
- viettelcloud/aiplatform/trainer/backends/container/runtime_loader_test.py +637 -0
- viettelcloud/aiplatform/trainer/backends/container/types.py +67 -0
- viettelcloud/aiplatform/trainer/backends/container/utils.py +213 -0
- viettelcloud/aiplatform/trainer/backends/kubernetes/__init__.py +0 -0
- viettelcloud/aiplatform/trainer/backends/kubernetes/backend.py +710 -0
- viettelcloud/aiplatform/trainer/backends/kubernetes/backend_test.py +1344 -0
- viettelcloud/aiplatform/trainer/backends/kubernetes/constants.py +15 -0
- viettelcloud/aiplatform/trainer/backends/kubernetes/utils.py +636 -0
- viettelcloud/aiplatform/trainer/backends/kubernetes/utils_test.py +582 -0
- viettelcloud/aiplatform/trainer/backends/localprocess/__init__.py +0 -0
- viettelcloud/aiplatform/trainer/backends/localprocess/backend.py +306 -0
- viettelcloud/aiplatform/trainer/backends/localprocess/backend_test.py +501 -0
- viettelcloud/aiplatform/trainer/backends/localprocess/constants.py +90 -0
- viettelcloud/aiplatform/trainer/backends/localprocess/job.py +184 -0
- viettelcloud/aiplatform/trainer/backends/localprocess/types.py +52 -0
- viettelcloud/aiplatform/trainer/backends/localprocess/utils.py +302 -0
- viettelcloud/aiplatform/trainer/constants/__init__.py +0 -0
- viettelcloud/aiplatform/trainer/constants/constants.py +179 -0
- viettelcloud/aiplatform/trainer/options/__init__.py +52 -0
- viettelcloud/aiplatform/trainer/options/common.py +55 -0
- viettelcloud/aiplatform/trainer/options/kubernetes.py +502 -0
- viettelcloud/aiplatform/trainer/options/kubernetes_test.py +259 -0
- viettelcloud/aiplatform/trainer/options/localprocess.py +20 -0
- viettelcloud/aiplatform/trainer/test/common.py +22 -0
- viettelcloud/aiplatform/trainer/types/__init__.py +0 -0
- viettelcloud/aiplatform/trainer/types/types.py +517 -0
- viettelcloud/aiplatform/trainer/types/types_test.py +115 -0
- viettelcloud_aiplatform-0.3.0.dist-info/METADATA +226 -0
- viettelcloud_aiplatform-0.3.0.dist-info/RECORD +71 -0
- viettelcloud_aiplatform-0.3.0.dist-info/WHEEL +4 -0
- viettelcloud_aiplatform-0.3.0.dist-info/licenses/LICENSE +201 -0
- viettelcloud_aiplatform-0.3.0.dist-info/licenses/NOTICE +36 -0
|
@@ -0,0 +1,637 @@
|
|
|
1
|
+
# Copyright 2025 The Kubeflow Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""
|
|
16
|
+
Unit tests for runtime_loader module.
|
|
17
|
+
|
|
18
|
+
Tests runtime loading from various sources including GitHub, HTTP, and filesystem.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
from unittest.mock import MagicMock, patch
|
|
22
|
+
|
|
23
|
+
import pytest
|
|
24
|
+
|
|
25
|
+
from viettelcloud.aiplatform.trainer.backends.container import runtime_loader
|
|
26
|
+
from viettelcloud.aiplatform.trainer.constants import constants
|
|
27
|
+
from viettelcloud.aiplatform.trainer.test.common import FAILED, SUCCESS, TestCase
|
|
28
|
+
from viettelcloud.aiplatform.trainer.types import types as base_types
|
|
29
|
+
|
|
30
|
+
# Sample runtime YAML data for testing
|
|
31
|
+
SAMPLE_RUNTIME_YAML = {
|
|
32
|
+
"apiVersion": "trainer.kubeflow.org/v1alpha1",
|
|
33
|
+
"kind": "ClusterTrainingRuntime",
|
|
34
|
+
"metadata": {
|
|
35
|
+
"name": constants.DEFAULT_TRAINING_RUNTIME,
|
|
36
|
+
"labels": {"trainer.kubeflow.org/framework": "torch"},
|
|
37
|
+
},
|
|
38
|
+
"spec": {
|
|
39
|
+
"mlPolicy": {"numNodes": 1},
|
|
40
|
+
"template": {
|
|
41
|
+
"spec": {
|
|
42
|
+
"replicatedJobs": [
|
|
43
|
+
{
|
|
44
|
+
"name": "node",
|
|
45
|
+
"template": {
|
|
46
|
+
"spec": {
|
|
47
|
+
"template": {
|
|
48
|
+
"spec": {
|
|
49
|
+
"containers": [
|
|
50
|
+
{
|
|
51
|
+
"name": "trainer",
|
|
52
|
+
"image": "pytorch/pytorch:2.0.0",
|
|
53
|
+
}
|
|
54
|
+
]
|
|
55
|
+
}
|
|
56
|
+
}
|
|
57
|
+
}
|
|
58
|
+
},
|
|
59
|
+
}
|
|
60
|
+
]
|
|
61
|
+
}
|
|
62
|
+
},
|
|
63
|
+
},
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
@pytest.mark.parametrize(
|
|
68
|
+
"test_case",
|
|
69
|
+
[
|
|
70
|
+
TestCase(
|
|
71
|
+
name="parse github url",
|
|
72
|
+
expected_status=SUCCESS,
|
|
73
|
+
config={
|
|
74
|
+
"url": "github://kubeflow/trainer",
|
|
75
|
+
"expected_type": "github",
|
|
76
|
+
"expected_path": "kubeflow/trainer",
|
|
77
|
+
},
|
|
78
|
+
),
|
|
79
|
+
TestCase(
|
|
80
|
+
name="parse github url with path",
|
|
81
|
+
expected_status=SUCCESS,
|
|
82
|
+
config={
|
|
83
|
+
"url": "github://myorg/myrepo/custom/path",
|
|
84
|
+
"expected_type": "github",
|
|
85
|
+
"expected_path": "myorg/myrepo/custom/path",
|
|
86
|
+
},
|
|
87
|
+
),
|
|
88
|
+
TestCase(
|
|
89
|
+
name="parse https url",
|
|
90
|
+
expected_status=SUCCESS,
|
|
91
|
+
config={
|
|
92
|
+
"url": "https://example.com/runtime.yaml",
|
|
93
|
+
"expected_type": "https",
|
|
94
|
+
"expected_path": "https://example.com/runtime.yaml",
|
|
95
|
+
},
|
|
96
|
+
),
|
|
97
|
+
TestCase(
|
|
98
|
+
name="parse http url",
|
|
99
|
+
expected_status=SUCCESS,
|
|
100
|
+
config={
|
|
101
|
+
"url": "http://example.com/runtime.yaml",
|
|
102
|
+
"expected_type": "http",
|
|
103
|
+
"expected_path": "http://example.com/runtime.yaml",
|
|
104
|
+
},
|
|
105
|
+
),
|
|
106
|
+
TestCase(
|
|
107
|
+
name="parse file url",
|
|
108
|
+
expected_status=SUCCESS,
|
|
109
|
+
config={
|
|
110
|
+
"url": "file:///path/to/runtime.yaml",
|
|
111
|
+
"expected_type": "file",
|
|
112
|
+
"expected_path": "/path/to/runtime.yaml",
|
|
113
|
+
},
|
|
114
|
+
),
|
|
115
|
+
TestCase(
|
|
116
|
+
name="parse absolute path",
|
|
117
|
+
expected_status=SUCCESS,
|
|
118
|
+
config={
|
|
119
|
+
"url": "/absolute/path/to/runtime.yaml",
|
|
120
|
+
"expected_type": "file",
|
|
121
|
+
"expected_path": "/absolute/path/to/runtime.yaml",
|
|
122
|
+
},
|
|
123
|
+
),
|
|
124
|
+
TestCase(
|
|
125
|
+
name="parse unsupported scheme",
|
|
126
|
+
expected_status=FAILED,
|
|
127
|
+
config={"url": "ftp://example.com/runtime.yaml"},
|
|
128
|
+
expected_error=ValueError,
|
|
129
|
+
),
|
|
130
|
+
],
|
|
131
|
+
)
|
|
132
|
+
def test_parse_source_url(test_case):
|
|
133
|
+
"""Test parsing various source URL formats."""
|
|
134
|
+
print("Executing test:", test_case.name)
|
|
135
|
+
try:
|
|
136
|
+
source_type, path = runtime_loader._parse_source_url(test_case.config["url"])
|
|
137
|
+
|
|
138
|
+
assert test_case.expected_status == SUCCESS
|
|
139
|
+
assert source_type == test_case.config["expected_type"]
|
|
140
|
+
assert path == test_case.config["expected_path"]
|
|
141
|
+
|
|
142
|
+
except Exception as e:
|
|
143
|
+
assert type(e) is test_case.expected_error
|
|
144
|
+
print("test execution complete")
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
@pytest.mark.parametrize(
|
|
148
|
+
"test_case",
|
|
149
|
+
[
|
|
150
|
+
TestCase(
|
|
151
|
+
name="load from default github",
|
|
152
|
+
expected_status=SUCCESS,
|
|
153
|
+
config={
|
|
154
|
+
"github_path": "kubeflow/trainer",
|
|
155
|
+
"discovered_files": ["torch_distributed.yaml"],
|
|
156
|
+
"expected_runtime_name": constants.DEFAULT_TRAINING_RUNTIME,
|
|
157
|
+
"expected_framework": "torch",
|
|
158
|
+
},
|
|
159
|
+
),
|
|
160
|
+
TestCase(
|
|
161
|
+
name="load from custom github",
|
|
162
|
+
expected_status=SUCCESS,
|
|
163
|
+
config={
|
|
164
|
+
"github_path": "myorg/myrepo",
|
|
165
|
+
"discovered_files": ["custom_runtime.yaml"],
|
|
166
|
+
"expected_runtime_name": "custom-runtime",
|
|
167
|
+
"expected_framework": "custom",
|
|
168
|
+
},
|
|
169
|
+
),
|
|
170
|
+
TestCase(
|
|
171
|
+
name="load from github no files",
|
|
172
|
+
expected_status=SUCCESS,
|
|
173
|
+
config={
|
|
174
|
+
"github_path": "kubeflow/trainer",
|
|
175
|
+
"discovered_files": [],
|
|
176
|
+
"expected_count": 0,
|
|
177
|
+
},
|
|
178
|
+
),
|
|
179
|
+
TestCase(
|
|
180
|
+
name="load from github invalid path",
|
|
181
|
+
expected_status=SUCCESS,
|
|
182
|
+
config={
|
|
183
|
+
"github_path": "invalid",
|
|
184
|
+
"expected_count": 0,
|
|
185
|
+
},
|
|
186
|
+
),
|
|
187
|
+
],
|
|
188
|
+
)
|
|
189
|
+
def test_load_from_github_url(test_case):
|
|
190
|
+
"""Test loading runtimes from GitHub URLs."""
|
|
191
|
+
print("Executing test:", test_case.name)
|
|
192
|
+
try:
|
|
193
|
+
with (
|
|
194
|
+
patch(
|
|
195
|
+
"viettelcloud.aiplatform.trainer.backends.container.runtime_loader._discover_github_runtime_files"
|
|
196
|
+
) as mock_discover,
|
|
197
|
+
patch(
|
|
198
|
+
"viettelcloud.aiplatform.trainer.backends.container.runtime_loader._fetch_runtime_from_github"
|
|
199
|
+
) as mock_fetch,
|
|
200
|
+
):
|
|
201
|
+
if test_case.name == "load from github invalid path":
|
|
202
|
+
# Don't set up mocks for invalid path test
|
|
203
|
+
runtimes = runtime_loader._load_from_github_url(test_case.config["github_path"])
|
|
204
|
+
assert len(runtimes) == test_case.config["expected_count"]
|
|
205
|
+
else:
|
|
206
|
+
mock_discover.return_value = test_case.config.get("discovered_files", [])
|
|
207
|
+
|
|
208
|
+
# Create runtime YAML with custom name/framework if specified
|
|
209
|
+
runtime_yaml = SAMPLE_RUNTIME_YAML.copy()
|
|
210
|
+
if "expected_runtime_name" in test_case.config:
|
|
211
|
+
runtime_yaml["metadata"]["name"] = test_case.config["expected_runtime_name"]
|
|
212
|
+
runtime_yaml["metadata"]["labels"]["trainer.kubeflow.org/framework"] = (
|
|
213
|
+
test_case.config["expected_framework"]
|
|
214
|
+
)
|
|
215
|
+
mock_fetch.return_value = runtime_yaml
|
|
216
|
+
|
|
217
|
+
runtimes = runtime_loader._load_from_github_url(test_case.config["github_path"])
|
|
218
|
+
|
|
219
|
+
if "expected_count" in test_case.config:
|
|
220
|
+
assert len(runtimes) == test_case.config["expected_count"]
|
|
221
|
+
else:
|
|
222
|
+
assert len(runtimes) == 1
|
|
223
|
+
assert runtimes[0].name == test_case.config["expected_runtime_name"]
|
|
224
|
+
assert runtimes[0].trainer.framework == test_case.config["expected_framework"]
|
|
225
|
+
|
|
226
|
+
assert test_case.expected_status == SUCCESS
|
|
227
|
+
|
|
228
|
+
except Exception as e:
|
|
229
|
+
assert type(e) is test_case.expected_error
|
|
230
|
+
print("test execution complete")
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
@pytest.mark.parametrize(
|
|
234
|
+
"test_case",
|
|
235
|
+
[
|
|
236
|
+
TestCase(
|
|
237
|
+
name="priority order github sources",
|
|
238
|
+
expected_status=SUCCESS,
|
|
239
|
+
config={
|
|
240
|
+
"sources": ["github://myorg/myrepo", "github://kubeflow/trainer"],
|
|
241
|
+
"expected_count": 2,
|
|
242
|
+
"expected_names": [constants.DEFAULT_TRAINING_RUNTIME, "deepspeed-distributed"],
|
|
243
|
+
},
|
|
244
|
+
),
|
|
245
|
+
TestCase(
|
|
246
|
+
name="duplicate runtime names skipped",
|
|
247
|
+
expected_status=SUCCESS,
|
|
248
|
+
config={
|
|
249
|
+
"sources": ["github://myorg/myrepo", "github://kubeflow/trainer"],
|
|
250
|
+
"duplicate_names": True,
|
|
251
|
+
"expected_count": 1,
|
|
252
|
+
"expected_names": [constants.DEFAULT_TRAINING_RUNTIME],
|
|
253
|
+
},
|
|
254
|
+
),
|
|
255
|
+
TestCase(
|
|
256
|
+
name="fallback to defaults",
|
|
257
|
+
expected_status=SUCCESS,
|
|
258
|
+
config={
|
|
259
|
+
"sources": ["github://myorg/myrepo"],
|
|
260
|
+
"no_github_runtimes": True,
|
|
261
|
+
"expected_count": 1,
|
|
262
|
+
"expected_names": [constants.DEFAULT_TRAINING_RUNTIME],
|
|
263
|
+
},
|
|
264
|
+
),
|
|
265
|
+
],
|
|
266
|
+
)
|
|
267
|
+
def test_list_training_runtimes_from_sources(test_case):
|
|
268
|
+
"""Test listing runtimes from multiple sources."""
|
|
269
|
+
print("Executing test:", test_case.name)
|
|
270
|
+
try:
|
|
271
|
+
with (
|
|
272
|
+
patch(
|
|
273
|
+
"viettelcloud.aiplatform.trainer.backends.container.runtime_loader._load_from_github_url"
|
|
274
|
+
) as mock_github,
|
|
275
|
+
patch(
|
|
276
|
+
"viettelcloud.aiplatform.trainer.backends.container.runtime_loader._create_default_runtimes"
|
|
277
|
+
) as mock_defaults,
|
|
278
|
+
):
|
|
279
|
+
if test_case.name == "priority order github sources":
|
|
280
|
+
torch_runtime = base_types.Runtime(
|
|
281
|
+
name=constants.DEFAULT_TRAINING_RUNTIME,
|
|
282
|
+
trainer=base_types.RuntimeTrainer(
|
|
283
|
+
trainer_type=base_types.TrainerType.CUSTOM_TRAINER,
|
|
284
|
+
framework="torch",
|
|
285
|
+
num_nodes=1,
|
|
286
|
+
image="example.com/container",
|
|
287
|
+
),
|
|
288
|
+
)
|
|
289
|
+
deepspeed_runtime = base_types.Runtime(
|
|
290
|
+
name="deepspeed-distributed",
|
|
291
|
+
trainer=base_types.RuntimeTrainer(
|
|
292
|
+
trainer_type=base_types.TrainerType.CUSTOM_TRAINER,
|
|
293
|
+
framework="deepspeed",
|
|
294
|
+
num_nodes=1,
|
|
295
|
+
image="example.com/container",
|
|
296
|
+
),
|
|
297
|
+
)
|
|
298
|
+
mock_github.side_effect = [[torch_runtime], [deepspeed_runtime]]
|
|
299
|
+
mock_defaults.return_value = []
|
|
300
|
+
|
|
301
|
+
elif test_case.name == "duplicate runtime names skipped":
|
|
302
|
+
torch_runtime_1 = base_types.Runtime(
|
|
303
|
+
name=constants.DEFAULT_TRAINING_RUNTIME,
|
|
304
|
+
trainer=base_types.RuntimeTrainer(
|
|
305
|
+
trainer_type=base_types.TrainerType.CUSTOM_TRAINER,
|
|
306
|
+
framework="torch",
|
|
307
|
+
num_nodes=1,
|
|
308
|
+
image="example.com/container",
|
|
309
|
+
),
|
|
310
|
+
)
|
|
311
|
+
torch_runtime_2 = base_types.Runtime(
|
|
312
|
+
name=constants.DEFAULT_TRAINING_RUNTIME,
|
|
313
|
+
trainer=base_types.RuntimeTrainer(
|
|
314
|
+
trainer_type=base_types.TrainerType.CUSTOM_TRAINER,
|
|
315
|
+
framework="torch",
|
|
316
|
+
num_nodes=2,
|
|
317
|
+
image="example.com/container",
|
|
318
|
+
),
|
|
319
|
+
)
|
|
320
|
+
mock_github.side_effect = [[torch_runtime_1], [torch_runtime_2]]
|
|
321
|
+
mock_defaults.return_value = []
|
|
322
|
+
|
|
323
|
+
elif test_case.name == "fallback to defaults":
|
|
324
|
+
mock_github.return_value = []
|
|
325
|
+
default_runtime = base_types.Runtime(
|
|
326
|
+
name=constants.DEFAULT_TRAINING_RUNTIME,
|
|
327
|
+
trainer=base_types.RuntimeTrainer(
|
|
328
|
+
trainer_type=base_types.TrainerType.CUSTOM_TRAINER,
|
|
329
|
+
framework="torch",
|
|
330
|
+
num_nodes=1,
|
|
331
|
+
image="example.com/container",
|
|
332
|
+
),
|
|
333
|
+
)
|
|
334
|
+
mock_defaults.return_value = [default_runtime]
|
|
335
|
+
|
|
336
|
+
runtimes = runtime_loader.list_training_runtimes_from_sources(
|
|
337
|
+
test_case.config["sources"]
|
|
338
|
+
)
|
|
339
|
+
|
|
340
|
+
assert len(runtimes) == test_case.config["expected_count"]
|
|
341
|
+
runtime_names = [r.name for r in runtimes]
|
|
342
|
+
for expected_name in test_case.config["expected_names"]:
|
|
343
|
+
assert expected_name in runtime_names
|
|
344
|
+
|
|
345
|
+
assert test_case.expected_status == SUCCESS
|
|
346
|
+
|
|
347
|
+
except Exception as e:
|
|
348
|
+
assert type(e) is test_case.expected_error
|
|
349
|
+
print("test execution complete")
|
|
350
|
+
|
|
351
|
+
|
|
352
|
+
def test_create_default_runtimes():
|
|
353
|
+
"""Test creating default runtimes from constants."""
|
|
354
|
+
print("Executing test: create default runtimes")
|
|
355
|
+
runtimes = runtime_loader._create_default_runtimes()
|
|
356
|
+
|
|
357
|
+
assert len(runtimes) == len(constants.DEFAULT_FRAMEWORK_IMAGES)
|
|
358
|
+
|
|
359
|
+
# Check torch runtime
|
|
360
|
+
torch_runtimes = [r for r in runtimes if r.trainer.framework == "torch"]
|
|
361
|
+
assert len(torch_runtimes) == 1
|
|
362
|
+
assert torch_runtimes[0].name == constants.DEFAULT_TRAINING_RUNTIME
|
|
363
|
+
assert torch_runtimes[0].trainer.trainer_type == base_types.TrainerType.CUSTOM_TRAINER
|
|
364
|
+
assert torch_runtimes[0].trainer.num_nodes == 1
|
|
365
|
+
# Verify default image is set
|
|
366
|
+
assert torch_runtimes[0].trainer.image == constants.DEFAULT_FRAMEWORK_IMAGES["torch"]
|
|
367
|
+
print("test execution complete")
|
|
368
|
+
|
|
369
|
+
|
|
370
|
+
@pytest.mark.parametrize(
|
|
371
|
+
"test_case",
|
|
372
|
+
[
|
|
373
|
+
TestCase(
|
|
374
|
+
name="discover runtime files",
|
|
375
|
+
expected_status=SUCCESS,
|
|
376
|
+
config={
|
|
377
|
+
"html_content": """
|
|
378
|
+
<html>
|
|
379
|
+
<a>torch_distributed.yaml</a>
|
|
380
|
+
<a>deepspeed_distributed.yaml</a>
|
|
381
|
+
<a>kustomization.yaml</a>
|
|
382
|
+
</html>
|
|
383
|
+
""",
|
|
384
|
+
"expected_files": ["torch_distributed.yaml", "deepspeed_distributed.yaml"],
|
|
385
|
+
"excluded_files": ["kustomization.yaml"],
|
|
386
|
+
},
|
|
387
|
+
),
|
|
388
|
+
TestCase(
|
|
389
|
+
name="discover runtime files custom repo",
|
|
390
|
+
expected_status=SUCCESS,
|
|
391
|
+
config={
|
|
392
|
+
"html_content": """
|
|
393
|
+
<html>
|
|
394
|
+
<a>custom_runtime.yaml</a>
|
|
395
|
+
</html>
|
|
396
|
+
""",
|
|
397
|
+
"expected_files": ["custom_runtime.yaml"],
|
|
398
|
+
"owner": "myorg",
|
|
399
|
+
"repo": "myrepo",
|
|
400
|
+
"path": "custom/path",
|
|
401
|
+
},
|
|
402
|
+
),
|
|
403
|
+
TestCase(
|
|
404
|
+
name="discover runtime files network error",
|
|
405
|
+
expected_status=SUCCESS,
|
|
406
|
+
config={
|
|
407
|
+
"network_error": True,
|
|
408
|
+
"expected_files": [],
|
|
409
|
+
},
|
|
410
|
+
),
|
|
411
|
+
],
|
|
412
|
+
)
|
|
413
|
+
def test_discover_github_runtime_files(test_case):
|
|
414
|
+
"""Test discovering runtime files from GitHub."""
|
|
415
|
+
print("Executing test:", test_case.name)
|
|
416
|
+
try:
|
|
417
|
+
with patch("urllib.request.urlopen") as mock_urlopen:
|
|
418
|
+
if test_case.config.get("network_error"):
|
|
419
|
+
mock_urlopen.side_effect = Exception("Network error")
|
|
420
|
+
else:
|
|
421
|
+
mock_response = MagicMock()
|
|
422
|
+
mock_response.read.return_value = test_case.config["html_content"].encode("utf-8")
|
|
423
|
+
mock_response.__enter__.return_value = mock_response
|
|
424
|
+
mock_urlopen.return_value = mock_response
|
|
425
|
+
|
|
426
|
+
kwargs = {}
|
|
427
|
+
if "owner" in test_case.config:
|
|
428
|
+
kwargs["owner"] = test_case.config["owner"]
|
|
429
|
+
kwargs["repo"] = test_case.config["repo"]
|
|
430
|
+
kwargs["path"] = test_case.config["path"]
|
|
431
|
+
|
|
432
|
+
files = runtime_loader._discover_github_runtime_files(**kwargs)
|
|
433
|
+
|
|
434
|
+
for expected_file in test_case.config["expected_files"]:
|
|
435
|
+
assert expected_file in files
|
|
436
|
+
|
|
437
|
+
for excluded_file in test_case.config.get("excluded_files", []):
|
|
438
|
+
assert excluded_file not in files
|
|
439
|
+
|
|
440
|
+
if "owner" in test_case.config and not test_case.config.get("network_error"):
|
|
441
|
+
called_url = mock_urlopen.call_args[0][0]
|
|
442
|
+
assert f"{kwargs['owner']}/{kwargs['repo']}" in called_url
|
|
443
|
+
assert kwargs["path"] in called_url
|
|
444
|
+
|
|
445
|
+
assert test_case.expected_status == SUCCESS
|
|
446
|
+
|
|
447
|
+
except Exception as e:
|
|
448
|
+
assert type(e) is test_case.expected_error
|
|
449
|
+
print("test execution complete")
|
|
450
|
+
|
|
451
|
+
|
|
452
|
+
@pytest.mark.parametrize(
|
|
453
|
+
"test_case",
|
|
454
|
+
[
|
|
455
|
+
TestCase(
|
|
456
|
+
name="fetch runtime success",
|
|
457
|
+
expected_status=SUCCESS,
|
|
458
|
+
config={
|
|
459
|
+
"yaml_content": """
|
|
460
|
+
apiVersion: trainer.kubeflow.org/v1alpha1
|
|
461
|
+
kind: ClusterTrainingRuntime
|
|
462
|
+
metadata:
|
|
463
|
+
name: torch-distributed
|
|
464
|
+
""",
|
|
465
|
+
"expected_name": constants.DEFAULT_TRAINING_RUNTIME,
|
|
466
|
+
},
|
|
467
|
+
),
|
|
468
|
+
TestCase(
|
|
469
|
+
name="fetch runtime custom repo",
|
|
470
|
+
expected_status=SUCCESS,
|
|
471
|
+
config={
|
|
472
|
+
"yaml_content": """
|
|
473
|
+
apiVersion: trainer.kubeflow.org/v1alpha1
|
|
474
|
+
kind: ClusterTrainingRuntime
|
|
475
|
+
metadata:
|
|
476
|
+
name: custom-runtime
|
|
477
|
+
""",
|
|
478
|
+
"expected_name": "custom-runtime",
|
|
479
|
+
"runtime_file": "custom.yaml",
|
|
480
|
+
"owner": "myorg",
|
|
481
|
+
"repo": "myrepo",
|
|
482
|
+
"path": "custom/path",
|
|
483
|
+
},
|
|
484
|
+
),
|
|
485
|
+
TestCase(
|
|
486
|
+
name="fetch runtime network error",
|
|
487
|
+
expected_status=SUCCESS,
|
|
488
|
+
config={
|
|
489
|
+
"network_error": True,
|
|
490
|
+
"expected_none": True,
|
|
491
|
+
},
|
|
492
|
+
),
|
|
493
|
+
],
|
|
494
|
+
)
|
|
495
|
+
def test_fetch_runtime_from_github(test_case):
|
|
496
|
+
"""Test fetching runtime YAML from GitHub."""
|
|
497
|
+
print("Executing test:", test_case.name)
|
|
498
|
+
try:
|
|
499
|
+
with patch("urllib.request.urlopen") as mock_urlopen:
|
|
500
|
+
if test_case.config.get("network_error"):
|
|
501
|
+
mock_urlopen.side_effect = Exception("Network error")
|
|
502
|
+
else:
|
|
503
|
+
mock_response = MagicMock()
|
|
504
|
+
mock_response.read.return_value = test_case.config["yaml_content"].encode("utf-8")
|
|
505
|
+
mock_response.__enter__.return_value = mock_response
|
|
506
|
+
mock_urlopen.return_value = mock_response
|
|
507
|
+
|
|
508
|
+
default_runtime_file = "torch_distributed.yaml"
|
|
509
|
+
kwargs = {"runtime_file": test_case.config.get("runtime_file", default_runtime_file)}
|
|
510
|
+
if "owner" in test_case.config:
|
|
511
|
+
kwargs["owner"] = test_case.config["owner"]
|
|
512
|
+
kwargs["repo"] = test_case.config["repo"]
|
|
513
|
+
kwargs["path"] = test_case.config["path"]
|
|
514
|
+
|
|
515
|
+
data = runtime_loader._fetch_runtime_from_github(**kwargs)
|
|
516
|
+
|
|
517
|
+
if test_case.config.get("expected_none"):
|
|
518
|
+
assert data is None
|
|
519
|
+
else:
|
|
520
|
+
assert data is not None
|
|
521
|
+
assert data["metadata"]["name"] == test_case.config["expected_name"]
|
|
522
|
+
|
|
523
|
+
if "owner" in test_case.config:
|
|
524
|
+
called_url = mock_urlopen.call_args[0][0]
|
|
525
|
+
assert "raw.githubusercontent.com" in called_url
|
|
526
|
+
assert f"{kwargs['owner']}/{kwargs['repo']}" in called_url
|
|
527
|
+
assert f"{kwargs['path']}/{kwargs['runtime_file']}" in called_url
|
|
528
|
+
|
|
529
|
+
assert test_case.expected_status == SUCCESS
|
|
530
|
+
|
|
531
|
+
except Exception as e:
|
|
532
|
+
assert type(e) is test_case.expected_error
|
|
533
|
+
print("test execution complete")
|
|
534
|
+
|
|
535
|
+
|
|
536
|
+
@pytest.mark.parametrize(
|
|
537
|
+
"test_case",
|
|
538
|
+
[
|
|
539
|
+
TestCase(
|
|
540
|
+
name="parse runtime yaml with custom image",
|
|
541
|
+
expected_status=SUCCESS,
|
|
542
|
+
config={
|
|
543
|
+
"custom_image": "quay.io/custom/pytorch-arm:v1.0",
|
|
544
|
+
"runtime_name": "torch-arm",
|
|
545
|
+
"framework": "torch",
|
|
546
|
+
"num_nodes": 2,
|
|
547
|
+
},
|
|
548
|
+
),
|
|
549
|
+
TestCase(
|
|
550
|
+
name="parse runtime yaml with different custom image",
|
|
551
|
+
expected_status=SUCCESS,
|
|
552
|
+
config={
|
|
553
|
+
"custom_image": "my-registry.io/pytorch:gpu-arm64",
|
|
554
|
+
"runtime_name": "torch-gpu-arm",
|
|
555
|
+
"framework": "torch",
|
|
556
|
+
"num_nodes": 4,
|
|
557
|
+
},
|
|
558
|
+
),
|
|
559
|
+
TestCase(
|
|
560
|
+
name="parse runtime yaml prefers container named node",
|
|
561
|
+
expected_status=SUCCESS,
|
|
562
|
+
config={
|
|
563
|
+
"custom_image": "correct-node-image:v1.0",
|
|
564
|
+
"runtime_name": "multi-container-runtime",
|
|
565
|
+
"framework": "torch",
|
|
566
|
+
"num_nodes": 1,
|
|
567
|
+
"multiple_containers": True,
|
|
568
|
+
},
|
|
569
|
+
),
|
|
570
|
+
],
|
|
571
|
+
)
|
|
572
|
+
def test_parse_runtime_yaml_extracts_image(test_case):
|
|
573
|
+
"""
|
|
574
|
+
Test that _parse_runtime_yaml correctly extracts and stores the container image.
|
|
575
|
+
This prevents regression of bugs where custom images are ignored.
|
|
576
|
+
"""
|
|
577
|
+
print("Executing test:", test_case.name)
|
|
578
|
+
try:
|
|
579
|
+
# Create container list based on test case
|
|
580
|
+
if test_case.config.get("multiple_containers"):
|
|
581
|
+
# Test case with multiple containers - should prefer 'node' container
|
|
582
|
+
containers = [
|
|
583
|
+
{
|
|
584
|
+
"name": "sidecar",
|
|
585
|
+
"image": "wrong-sidecar-image:v1.0",
|
|
586
|
+
},
|
|
587
|
+
{
|
|
588
|
+
"name": "node",
|
|
589
|
+
"image": test_case.config["custom_image"],
|
|
590
|
+
},
|
|
591
|
+
]
|
|
592
|
+
else:
|
|
593
|
+
# Single container test case
|
|
594
|
+
containers = [
|
|
595
|
+
{
|
|
596
|
+
"name": "trainer",
|
|
597
|
+
"image": test_case.config["custom_image"],
|
|
598
|
+
}
|
|
599
|
+
]
|
|
600
|
+
|
|
601
|
+
# Create runtime YAML with custom image
|
|
602
|
+
runtime_yaml = {
|
|
603
|
+
"kind": "ClusterTrainingRuntime",
|
|
604
|
+
"metadata": {
|
|
605
|
+
"name": test_case.config["runtime_name"],
|
|
606
|
+
"labels": {"trainer.kubeflow.org/framework": test_case.config["framework"]},
|
|
607
|
+
},
|
|
608
|
+
"spec": {
|
|
609
|
+
"mlPolicy": {"numNodes": test_case.config["num_nodes"]},
|
|
610
|
+
"template": {
|
|
611
|
+
"spec": {
|
|
612
|
+
"replicatedJobs": [
|
|
613
|
+
{
|
|
614
|
+
"name": "node",
|
|
615
|
+
"template": {
|
|
616
|
+
"spec": {"template": {"spec": {"containers": containers}}}
|
|
617
|
+
},
|
|
618
|
+
}
|
|
619
|
+
]
|
|
620
|
+
}
|
|
621
|
+
},
|
|
622
|
+
},
|
|
623
|
+
}
|
|
624
|
+
|
|
625
|
+
runtime = runtime_loader._parse_runtime_yaml(runtime_yaml, "test")
|
|
626
|
+
|
|
627
|
+
# Verify image is extracted and stored
|
|
628
|
+
assert runtime.name == test_case.config["runtime_name"]
|
|
629
|
+
assert runtime.trainer.framework == test_case.config["framework"]
|
|
630
|
+
assert runtime.trainer.num_nodes == test_case.config["num_nodes"]
|
|
631
|
+
assert runtime.trainer.image == test_case.config["custom_image"]
|
|
632
|
+
|
|
633
|
+
assert test_case.expected_status == SUCCESS
|
|
634
|
+
|
|
635
|
+
except Exception as e:
|
|
636
|
+
assert type(e) is test_case.expected_error
|
|
637
|
+
print("test execution complete")
|